import numpy as np
import os
from adari_core.plots.panel import Panel
from adari_core.plots.points import ScatterPlot
from adari_core.plots.text import TextPlot
from adari_core.report import AdariReportBase
from adari_core.utils.utils import fetch_kw_or_default
from .pionier_utils import PionierSetupInfo, PionierReportMixin
from .pionier_utils import find_ext_index


class PionierWavelengthReport(PionierReportMixin, AdariReportBase):
    def __init__(self):
        super().__init__("pionier_wavelength")

    def parse_sof(self):

        calib_file = None

        for filename, catg in self.inputs:
            if catg == "SPECTRAL_CALIBRATION":
                calib_file = filename

        file_lists = []
        if calib_file is not None:
            file_lists.append(
                {
                    "calib": calib_file,
                }
            )
        return file_lists

    def generate_panels(self, **kwargs):
        panels = {}
        ext = "OI_WAVELENGTH"
        calib = self.hdus[0]["calib"]
        ext_list = find_ext_index(calib, ext)

        single_point_size = 5.0

        p = Panel(2, 2, height_ratios=[1, 4], x_stretch=1.0, y_stretch=0.75)

        # Text Plot

        vspace = 0.3

        fname = os.path.basename(str(calib.filename()))
        mdata_hdul = calib
        mdata_ext = "PRIMARY"
        rawname = str(
            fetch_kw_or_default(
                mdata_hdul[mdata_ext], "HIERARCH ESO PRO REC1 RAW1 NAME", default="N/A"
            )
        )
        t0 = TextPlot(columns=1, v_space=vspace)
        col0 = (
            "INSTRUME: " + str(mdata_hdul[mdata_ext].header.get("INSTRUME")),
            "EXTNAME: " + ext,
            "PRO.CATG: "
            + str(
                fetch_kw_or_default(
                    mdata_hdul[mdata_ext], "HIERARCH ESO PRO CATG", default="N/A"
                )
            ),
            "FILE NAME: " + fname,
            "RAW1 NAME: " + rawname,
        )
        t0.add_data(col0)
        p.assign_plot(t0, 0, 0, xext=1)

        metadata1 = PionierSetupInfo.spectral_calibration(calib)
        t1 = TextPlot(columns=1, v_space=vspace, xext=1)
        col1 = metadata1

        t1.add_data(col1)
        p.assign_plot(t1, 1, 0, xext=1)

        # Eff wave plot

        colors = ["black", "red", "blue"]
        eff_plot = ScatterPlot(
            title="",
            x_label="Spectral channel",
            y_label="Wavelength, um",
            markersize=single_point_size,
            legend=True,
        )
        for i, i_ext in enumerate(ext_list):

            data = calib[i_ext].data
            insname = str(
                fetch_kw_or_default(calib[i_ext], "INSNAME", default="N/A")
            ).split("(")[0]
            wave = data["EFF_WAVE"]
            wave *= 1.0e6
            ch = np.arange(wave.size)
            eff_plot.add_data((ch, wave), label=insname, color=colors[i])

        p.assign_plot(eff_plot, 0, 1, xext=2)
        addme = {
            "ext": ext,
            "report_name": f"PIONIER_{str(rawname).removeprefix('PIONI.').removesuffix('.fits')}_{str(ext).lower()}",
            "report_description": f"Wavelength panel" f"{ext})",
            "report_tags": [],
        }

        panels[p] = addme

        return panels


rep = PionierWavelengthReport()
