from .fors_utils import ForsSetupInfo

from adari_core.plots.combined import CombinedPlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.panel import Panel
from adari_core.plots.points import LinePlot, ScatterPlot
from adari_core.plots.text import TextPlot
from adari_core.report import AdariReportBase
from adari_core.utils.utils import fetch_kw_or_default, fetch_kw_or_error

import os
import numpy as np

class ForsSpecStdStarReport(AdariReportBase):
    files_needed = {
        "reduced_std": None,
        "reduced_sky": None,
        "spec_table": None,
        "reduced_flux": None,
        "raw": None,
        "reduced_v": None,
        "reduced_q": None,
        "reduced_u": None,
        "reduced_l": None,
        "reduced_v_err": None,
        "reduced_q_err": None,
        "reduced_u_err": None,
        "reduced_l_err": None,
    }  # Further elaborated in parse_sof
    ext = "PRIMARY"
    category_label = ""
    pmos_v = False

    def __init__(self):
        super().__init__("fors_spec_std_star")

    def get_wavelength(self, h):
        crpix1 = fetch_kw_or_error(h, "CRPIX1")
        crval1 = fetch_kw_or_error(h, "CRVAL1")
        cd1_1 = fetch_kw_or_error(h, "CD1_1")
        n_pix = fetch_kw_or_error(h, "NAXIS1")
        pixels = np.arange(1, n_pix + 1)
        w = crval1 + (pixels - crpix1) * cd1_1
        unit_w = "Angstrom"
        return [w, unit_w]
 
    def parse_sof(self):

        files_path, files_category = (
            [elem[0] for elem in self.inputs],
            [elem[1] for elem in self.inputs],
        )
        if "REDUCED_STD_LSS" in files_category:
            self.category_label = "LSS"
        elif "REDUCED_STD_MOS" in files_category:
            self.category_label = "MOS"
        elif "REDUCED_STD_PMOS" in files_category:
            self.category_label = "PMOS"

        categories_needed = {}
        if self.category_label != "PMOS":
            categories_needed["reduced_std"] = "REDUCED_STD_{}".format(self.category_label)
            categories_needed["reduced_sky"] = "REDUCED_SKY_STD_{}".format(
                self.category_label
            )
            categories_needed["reduced_flux"] = "REDUCED_FLUX_STD_{}".format(
                self.category_label
            )
            categories_needed["spec_table"] = "SPECPHOT_TABLE"
            categories_needed["raw"] = "STANDARD_{}".format(self.category_label)
        else:
            categories_needed["reduced_std"] = "REDUCED_STD_{}".format(self.category_label)
            categories_needed["reduced_sky"] = "REDUCED_SKY_STD_{}".format(
                self.category_label
            )
            categories_needed["reduced_v"] = "REDUCED_V_STD_PMOS"
            categories_needed["reduced_q"] = "REDUCED_Q_STD_PMOS"
            categories_needed["reduced_u"] = "REDUCED_U_STD_PMOS"
            categories_needed["reduced_l"] = "REDUCED_L_STD_PMOS"
            categories_needed["reduced_v_err"] = "REDUCED_ERROR_V_STD_PMOS"
            categories_needed["reduced_q_err"] = "REDUCED_ERROR_Q_STD_PMOS"
            categories_needed["reduced_u_err"] = "REDUCED_ERROR_U_STD_PMOS"
            categories_needed["reduced_l_err"] = "REDUCED_ERROR_L_STD_PMOS"

 
        for category_label, sof_procatg in categories_needed.items():
            # Check that category matches the requirement
            if sof_procatg in files_category and self.files_needed[category_label] == None:
                self.files_needed[category_label] = files_path[
                    files_category.index(sof_procatg)
                ]
        if self.category_label != "PMOS":
            keys = ["reduced_std", "reduced_sky", "reduced_flux", "spec_table", "raw"]
        else:
            if self.files_needed["reduced_v"] != None:
                keys = ["reduced_std", "reduced_sky", "reduced_v", "reduced_v_err"]
                self.pmos_v = True
            else:
                keys = ["reduced_std", "reduced_sky", "reduced_q", "reduced_q_err", "reduced_u", "reduced_u_err", "reduced_l", "reduced_l_err"]
        self.files_needed = dict((k, self.files_needed[k]) for k in keys if k in self.files_needed) 

        # Check if all files found
        for catg, file in self.files_needed.items():
            if file == None:
                raise IOError("{} file not found".format(categories_needed[catg]))


        return [self.files_needed]

    def generate_panels(self, **kwargs):

        # Retrieve appropriate metadata
        if self.category_label == "PMOS":
            self.metadata = ForsSetupInfo.spec_std_star_pmos(list(self.hdus[0].values())[0])
        else:
            self.metadata = ForsSetupInfo.spec_std_star(list(self.hdus[0].values())[0])

        reduced_std = self.hdus[0]["reduced_std"]
        reduced_sky = self.hdus[0]["reduced_sky"]

        if self.category_label == "PMOS":
            if self.pmos_v:
                spec = self.hdus[0]["reduced_v"]
                spec_err = self.hdus[0]["reduced_v_err"]
            else:
                spec_q = self.hdus[0]["reduced_q"]
                spec_q_err = self.hdus[0]["reduced_q_err"]
                spec_u = self.hdus[0]["reduced_u"]
                spec_u_err = self.hdus[0]["reduced_u_err"]
                spec_l = self.hdus[0]["reduced_l"]
                spec_l_err = self.hdus[0]["reduced_l_err"]

        else:
            reduced_flux = self.hdus[0]["reduced_flux"]
            spec_table = self.hdus[0]["spec_table"]
            raw_data = self.hdus[0]["raw"]
    
        instru = fetch_kw_or_default(
            reduced_std[self.ext], "INSTRUME", "Missing INSTRUME"
        )
        master_procatg = fetch_kw_or_default(
            reduced_std[self.ext], "HIERARCH ESO PRO CATG", "Missing PRO CATG"
        )
        fname = os.path.basename(str(reduced_std.filename()))
        panels = {}
        if self.category_label == "PMOS":
            panel = Panel(5, 5, height_ratios=[1, 2, 2, 2, 2])
        else:
            panel = Panel(5, 7, height_ratios=[1, 2, 2, 2, 2, 2, 2])

        panels[panel] = {
            "report_name": "{}_{}".format(instru, master_procatg.lower()),
            "report_description": "FORS spec std star {} panel - {}".format(
                self.category_label, fname
            ),
            "report_tags": [],
        }

        # Text Plot
        vspace = 0.2
        t1 = TextPlot(columns=1, v_space=vspace)
        col1 = (
            instru,
            "EXTNAME: " + self.ext,
            "PRO CATG: " + str(master_procatg),
            "FILE NAME: " + fname,
            "RAW1 NAME: "
            + str(
                fetch_kw_or_default(
                    reduced_std["PRIMARY"],
                    "HIERARCH ESO PRO REC1 RAW1 NAME",
                    "Missing RAW1 NAME",
                )
            ),
        )
        t1.add_data(col1)
        panel.assign_plot(t1, 0, 0, xext=2)

        t2 = TextPlot(columns=2, v_space=vspace)
        col2 = self.metadata
        t2.add_data(col2)
        panel.assign_plot(t2, 2, 0, xext=3)


        if self.category_label == "PMOS":

            # Polarization spectra
            flux_plot = LinePlot(
                title="Polarization spectra",
                x_label="Wavelength, Angstrom",
                y_label="ADU/s",
                legend=True,
            )
            if self.pmos_v:
                w,  unit_w = self.get_wavelength(spec[self.ext])
                flux = spec[self.ext].data[0]
                flux_plot.add_data((w, flux), label="Flux")
            else:
                w,  unit_w = self.get_wavelength(spec_q[self.ext])
                flux = spec_q[self.ext].data[0]
                flux_plot.add_data((w, flux), label="Flux_q")
                w,  unit_w = self.get_wavelength(spec_u[self.ext])
                flux = spec_u[self.ext].data[0]
                flux_plot.add_data((w, flux), label="Flux_u")
                w,  unit_w = self.get_wavelength(spec_l[self.ext])
                flux = spec_l[self.ext].data[0]
                flux_plot.add_data((w, flux), label="Flux_l")
            panel.assign_plot(flux_plot, 0, 1, xext=5)

            # Polarization spectra errors
            flux_plot = LinePlot(
                title="Polarization spectra (errors)",
                x_label="Wavelength, Angstrom",
                y_label="ADU/s",
                legend=True,
            )
            if self.pmos_v:
                w,  unit_w = self.get_wavelength(spec_err[self.ext])
                flux = spec_err[self.ext].data[0]
                flux_plot.add_data((w, flux), label="Flux")
            else:
                w,  unit_w = self.get_wavelength(spec_q_err[self.ext])
                flux = spec_q_err[self.ext].data[0]
                flux_plot.add_data((w, flux), label="Flux_q")
                w,  unit_w = self.get_wavelength(spec_u_err[self.ext])
                flux = spec_u_err[self.ext].data[0]
                flux_plot.add_data((w, flux), label="Flux_u")
                w,  unit_w = self.get_wavelength(spec_l_err[self.ext])
                flux = spec_l_err[self.ext].data[0]
                flux_plot.add_data((w, flux), label="Flux_l")
            panel.assign_plot(flux_plot, 0, 2, xext=5)


            # Extracted, sky-subtracted spectra
            unit_w = "Angstrom"
            flux_plot = LinePlot(
                title="Extracted, sky-subtracted spectra",
                x_label="Wavelength, "+unit_w,
                y_label="ADU/s",
                legend=False,
            )
            for ext in range(1,len(reduced_std)):
                w,  unit_w = self.get_wavelength(reduced_std[ext])
                flux = reduced_std[ext].data[0]
                flux_plot.add_data((w, flux), label="Flux "+str(ext))
            panel.assign_plot(flux_plot, 0, 3, xext=5)           

            # Extracted sky spectra
            unit_w = "Angstrom"
            flux_plot = LinePlot(
                title="Extracted sky spectra",
                x_label="Wavelength, "+unit_w,
                y_label="ADU/s",
                legend=False,
            )
            for ext in range(1,len(reduced_std)):
                w,  unit_w = self.get_wavelength(reduced_sky[ext])
                flux = reduced_sky[ext].data[0]     
                flux_plot.add_data((w, flux), label="Flux "+str(ext))
            panel.assign_plot(flux_plot, 0, 4, xext=5)


        # LSS, MOS 
        else:

            # Extracted flux
            w,  unit_w = self.get_wavelength(reduced_std[self.ext])
            flux = reduced_std[self.ext].data[0]
            flux_plot = LinePlot(
                title="Extracted flux",
                x_label="Wavelength, "+unit_w,
                y_label="ADU/s",
                legend=False,
            )
            flux_plot.add_data((w, flux), label="Flux")
            panel.assign_plot(flux_plot, 0, 1, xext=5)
    
            xmin = w[0] - (w[-1] - w[0]) * 0.05
            xmax = w[-1] + (w[-1] - w[0]) * 0.05

            # Extracted sky
            w,  unit_w = self.get_wavelength(reduced_sky[self.ext])
            flux = reduced_sky[self.ext].data[0]
            flux_plot = LinePlot(
                title="Extracted sky",
                x_label="Wavelength, "+unit_w,
                y_label="ADU/s",
                legend=False,
            )
            flux_plot.add_data((w, flux), label="Flux")
            flux_plot.set_xlim(xmin, xmax)
            panel.assign_plot(flux_plot, 0, 2, xext=5)
    
            # Efficiency
            tab = spec_table[1].data
            wave = tab["WAVE"]
            w_mask = (wave >= w[0]) & (wave <= w[-1])
            eff = tab["EFFICIENCY"]
            unit_eff = tab.columns[tab.names.index("EFFICIENCY")].unit
            eff_plot = LinePlot(
                title="Efficiency",
                x_label="Wavelength, "+unit_w,
                y_label=unit_eff,
                legend=False,
            )
            eff_plot.add_data((wave[w_mask], eff[w_mask]), label="Eff")
            eff_plot.set_xlim(xmin, xmax)
            panel.assign_plot(eff_plot, 0, 3, xext=5)
     
            # Raw response
            if any("RAW_RESPONSE_FFSED" in n for n in tab.columns.names):
                raw_response = tab["RAW_RESPONSE_FFSED"]
                response = tab["RESPONSE_FFSED"]
                unit_r = tab.columns[tab.names.index("RESPONSE_FFSED")].unit
            else:
                raw_response = tab["RAW_RESPONSE"]
                response = tab["RESPONSE"]
                unit_r = tab.columns[tab.names.index("RESPONSE")].unit
           
            used = (tab["USED_FIT"] == 1)
     
            combined = CombinedPlot(title="Response")
            res_plot = LinePlot(
                title="Response",
                x_label="Wavelength, "+unit_w,
                y_label=unit_r,
                legend=True,
            )
            res_plot.add_data((wave[w_mask], response[w_mask]), label="Response", color="black")
            raw_plot = ScatterPlot(
                title="Raw",
                legend=True,
            )
            raw_plot.add_data((wave[used&w_mask], raw_response[used&w_mask]), label="Raw response (used)")
            raw_plot.add_data((wave[(~used)&w_mask], raw_response[(~used)&w_mask]), label="Raw response (not used)", color="blue")
            combined.add_data(res_plot)
            combined.add_data(raw_plot)
            combined.set_xlim(xmin, xmax)
            panel.assign_plot(combined, 0, 4, xext=5)
    
            # Flux-calibrated extracted spectrum
            w,  unit_w = self.get_wavelength(reduced_flux[self.ext])
            flux = reduced_flux[self.ext].data[0]
            std = tab["STD_FLUX"]
            unit_std = tab.columns[tab.names.index("STD_FLUX")].unit
            flux_plot = LinePlot(
                title="Flux-calibrated extracted spectrum",
                x_label="Wavelength, "+unit_w,
                y_label=unit_std,
                legend=True,
            )
            flux_plot.add_data((w, flux), label="Extracted spectrum")
            flux_plot.add_data((wave[w_mask], std[w_mask]), label="Tabulated flux", color="black")
            flux_plot.set_xlim(xmin, xmax) 
            panel.assign_plot(flux_plot, 0, 5, xext=5)
    
            # Histogram of (first) input raw file
            hist_first = HistogramPlot(
                title=f"Raw data histogram", bins=50, v_min=-5000, v_max=70000
            )
            hist_first.add_data(raw_data[self.ext].data, label="Raw data counts")
            panel.assign_plot(hist_first, 0, 6, xext=2)
    


        return panels


rep = ForsSpecStdStarReport()

