from adari_core.plots.panel import Panel
from adari_core.plots.points import LinePlot
from adari_core.plots.text import TextPlot
from adari_core.data_libs.master_std_star_echelle import MasterSpecphotStdReport
from adari_core.utils.utils import fetch_kw_or_default, fetch_kw_or_error

from .visir_utils import VisirSetupInfo
import os
import numpy as np

from . import VisirReportMixin


class VisirSpecphotStdReport(VisirReportMixin, MasterSpecphotStdReport):
    def __init__(self):
        super().__init__("visir_specphot_std")

    def parse_sof(self):
        # building one report set
        spc_phot = None
        
        for filename, catg in self.inputs:
            if catg == "SPC_PHOT_TAB":
                spc_phot = filename

        file_lists = []
        if spc_phot is not None:
            file_lists.append(
                {
                    "spc_phot": spc_phot,
                }
            )

        return file_lists

    def generate_panels(self, **kwargs):
        panels = {}

        p = Panel(3, 4, height_ratios=[1, 4, 4, 4])  # for espresso 4,4
        spc_phot = self.hdus[0]["spc_phot"]
        data = spc_phot["TAB_SPECTRUM_0"].data
        spc_phot_procatg = fetch_kw_or_error(spc_phot[0], "HIERARCH ESO PRO CATG")

        wavelength = data["WLEN"] / 1E-6
        extracted_flux = data["SPC_EXTRACTED"]
        sky_flux = data["SPC_SKY"]
        sensitivity = data["SENSITIVITY"]
        wavelength_unit = "um"

                # Text Plot
        px = 0
        py = 0
        vspace = 0.5
        t1 = TextPlot(columns=1, v_space=vspace)
        fname = os.path.basename(str(spc_phot.filename()))
        instru = spc_phot["PRIMARY"].header.get("INSTRUME")
        col1 = (
            str(spc_phot["PRIMARY"].header.get("INSTRUME")),
            "EXTNAME: " + str(spc_phot["TAB_SPECTRUM_0"].header.get("EXTNAME", "N/A")),
            "PRO CATG: "
            + str(spc_phot["PRIMARY"].header.get("HIERARCH ESO PRO CATG")),
            "FILE NAME: " + fname,
            "RAW1 NAME: "
            + str(spc_phot["PRIMARY"].header.get("HIERARCH ESO PRO REC1 RAW1 NAME")),
        )
        t1.add_data(col1)
        p.assign_plot(t1, px, py, xext=2)

        px = px + 2
        t2 = TextPlot(columns=1, v_space=vspace, xext=1)
        self.metadata = VisirSetupInfo.specphot_std(spc_phot)
        col2 = self.metadata
        t2.add_data(col2)
        p.assign_plot(t2, px, py, xext=1)


        lineplot1 = LinePlot()
        lineplot1.add_data(d=[wavelength, extracted_flux], label="_n", color="black")
        lineplot1.x_label = f"Wavelength ({wavelength_unit})"
        lineplot1.y_label = "Extracted flux ({})".format(fetch_kw_or_default(
                spc_phot['TAB_SPECTRUM_0'], "TUNIT5", "Unknown TUNIT"))
        lineplot1.legend = False
        lineplot1.title = "Extracted flux vs. wavelength"
        p.assign_plot(lineplot1, 0, 1, xext=3)

                #Associated error vs. wavelength,
        lineplot2 = LinePlot(title = "Sky flux vs. wavelength",
                            legend = True)
        lineplot2.add_data(d=[wavelength, sky_flux], 
                           color="red",
                          label="Sky flux")
        lineplot2.x_label = f"Wavelength ({wavelength_unit})"
        lineplot2.y_label = "Sky flux ({})".format(fetch_kw_or_default(
                spc_phot['TAB_SPECTRUM_0'], "TUNIT4", "Unknown TUNIT"))
        p.assign_plot(lineplot2, 0, 2, xext=3)

                # Sensitivity vs wavelength
        eflux_v_wavl = LinePlot(title = "Sensitivity vs. wavelength",
                                legend = True, 
                                y_max = 0.5*np.max(sensitivity)
                               )
        eflux_v_wavl.add_data(d=[wavelength, sensitivity], 
                              color="red", 
                              label="Sensitivity",
                             )
        eflux_v_wavl.x_label = f"Wavelength ({wavelength_unit})"
        eflux_v_wavl.y_label = "Sensitivity ({})".format(fetch_kw_or_default(
                spc_phot['TAB_SPECTRUM_0'], "TUNIT8", "Unknown BUNIT"))

        p.assign_plot(eflux_v_wavl, 0, 3, xext=3)

        panels[p] = {
            "spc_phot": spc_phot.filename(),
            "ext": "TAB_SPECTRUM_0",
            "report_name": f"{instru}_{spc_phot_procatg.lower()}_{'TAB_SPECTRUM_0'}",
            "report_description": f"Specphotometric Std panel - ({spc_phot.filename()}, "f"{'TAB_SPECTRUM_0'})",
            "report_tags": []
        }
        return panels


rep = VisirSpecphotStdReport()
