import os
import numpy as np
from adari_core.plots.images import ImagePlot
from adari_core.plots.panel import Panel
from adari_core.plots.points import LinePlot
from adari_core.plots.text import TextPlot
from adari_core.data_libs.master_std_star_echelle import MasterSpecphotStdReport
from adari_core.utils.utils import (
    fetch_kw_or_default,
    fetch_kw_or_error,
    get_wavelength_from_header,
)

from .xshooter_utils import XshooterSetupInfo, XshooterReportMixin, get_arm

class XshooterSpecphotStdReport(XshooterReportMixin, MasterSpecphotStdReport):
    def __init__(self):
        super().__init__("xshooter_specphot_std")

    def parse_sof(self):
        # building one report set
        flux_slit_merge_uvb = None
        flux_slit_merge_vis = None
        flux_slit_merge_nir = None

        flux_slit_merge2d_uvb = None
        flux_slit_merge2d_vis = None
        flux_slit_merge2d_nir = None

        flux_slit_flux_merge_uvb = None
        flux_slit_flux_merge_vis = None
        flux_slit_flux_merge_nir = None

        raw_uvb = None
        raw_vis = None
        raw_nir = None

        for filename, catg in self.inputs:
            if catg == "FLUX_SLIT_MERGE1D_UVB" or catg == "SCI_SLIT_FLUX_IDP_UVB":
                flux_slit_merge_uvb = filename
            if catg == "FLUX_SLIT_MERGE1D_VIS" or catg == "SCI_SLIT_FLUX_IDP_VIS":
                flux_slit_merge_vis = filename
            if catg == "FLUX_SLIT_MERGE1D_NIR" or catg == "SCI_SLIT_FLUX_IDP_NIR":
                flux_slit_merge_nir = filename

            if catg == "FLUX_SLIT_MERGE2D_UVB":
                flux_slit_merge2d_uvb = filename
            if catg == "FLUX_SLIT_MERGE2D_VIS":
                flux_slit_merge2d_vis = filename
            if catg == "FLUX_SLIT_MERGE2D_NIR":
                flux_slit_merge2d_nir = filename

            if catg == "FLUX_SLIT_FLUX_MERGE1D_UVB":
                flux_slit_flux_merge_uvb = filename
            if catg == "FLUX_SLIT_FLUX_MERGE1D_VIS":
                flux_slit_flux_merge_vis = filename
            if catg == "FLUX_SLIT_FLUX_MERGE1D_NIR":
                flux_slit_flux_merge_nir = filename

            if (
                catg == "STD_FLUX_SLIT_UVB"
                or catg == "STD_FLUX_SLIT_NOD_UVB"
                or catg == "STD_FLUX_SLIT_OFFSET_UVB"
                or catg == "STD_FLUX_SLIT_STARE_UVB"
            ):
                raw_uvb = filename
            if (
                catg == "STD_FLUX_SLIT_VIS"
                or catg == "STD_FLUX_SLIT_NOD_VIS"
                or catg == "STD_FLUX_SLIT_OFFSET_VIS"
                or catg == "STD_FLUX_SLIT_STARE_VIS"
            ):
                raw_vis = filename
            if (
                catg == "STD_FLUX_SLIT_NIR"
                or catg == "STD_FLUX_SLIT_NOD_NIR"
                or catg == "STD_FLUX_SLIT_OFFSET_NIR"
                or catg == "STD_FLUX_SLIT_STARE_NIR"
            ):
                raw_nir = filename

        file_lists = []
        if flux_slit_merge_uvb is not None and raw_uvb is not None:
            file_lists.append(
                {
                    "raw": raw_uvb,
                    "std": flux_slit_merge_uvb,
                    "flux_slit_flux_merge": flux_slit_flux_merge_uvb,
                    "merge2d": flux_slit_merge2d_uvb,
                }
            )
        if flux_slit_merge_vis is not None and raw_vis is not None:
            file_lists.append(
                {
                    "raw": raw_vis,
                    "std": flux_slit_merge_vis,
                    "flux_slit_flux_merge": flux_slit_flux_merge_vis,
                    "merge2d": flux_slit_merge2d_vis,
                }
            )
        if flux_slit_merge_nir is not None and raw_nir is not None:
            file_dict = {
                "raw": raw_nir,
                "std": flux_slit_merge_nir,
                "merge2d": flux_slit_merge2d_nir,
            }
            if flux_slit_flux_merge_nir is not None:
                file_dict["flux_slit_flux_merge"] = flux_slit_flux_merge_nir
            file_lists.append(file_dict)

        return file_lists

    def generate_panels_std(self, **kwargs):
        vspace = 0.5

        raw_cut = []
        raw_cut_i = []
        for hdu in self.hdus:
            arm = get_arm(hdu["raw"][0])
            if arm == "NIR":
                raw_cut.append("x")
                raw_cut_i.append(1)
            else:
                raw_cut.append("y")
                raw_cut_i.append(0)

        panels = super().generate_panels(
            plot_eff=False,
            raw_cut=raw_cut,
            raw_cut_i=raw_cut_i,
            flux_ylim="percentile",
            flux_ext="PRIMARY",
        )

        for i, (panel, panel_descr) in enumerate(panels.items()):

            raw = self.hdus[i]["raw"]
            arm = get_arm(raw[0])
            flux_slit = self.hdus[i]["std"]

            if "flux_slit_flux_merge" in self.hdus[i]:
                flux_slit_flux = self.hdus[i]["flux_slit_flux_merge"]
                input_files = panel_descr["input_files"]
                input_files.append(flux_slit_flux.filename())
                panel_descr["input_files"] = input_files
            else:
                flux_slit_flux = None

            # Text Plot
            hdr = flux_slit["PRIMARY"].header
            fname = os.path.basename(str(flux_slit.filename()))
            t1 = TextPlot(columns=1, v_space=vspace)
            col1 = (
                str(hdr.get("INSTRUME")),
                "EXTNAME: " + str(hdr.get("EXTNAME", "N/A")),
                "PRO CATG: " + str(hdr.get("HIERARCH ESO PRO CATG")),
                "FILE NAME: " + fname,
                "RAW1 NAME: " + str(hdr.get("HIERARCH ESO PRO REC1 RAW1 NAME")),
            )
            t1.add_data(col1)
            panel.assign_plot(t1, 0, 0, xext=2)

            t2 = TextPlot(columns=1, v_space=vspace, xext=1)
            self.metadata = XshooterSetupInfo.specphot_star(flux_slit)
            col2 = self.metadata
            t2.add_data(col2)
            panel.assign_plot(t2, 2, 0, xext=1)

            # XSHOOTER specific panel
            wavelength, unit_w = get_wavelength_from_header(flux_slit[0])
            # Calibrated vs. wavelength
            if flux_slit_flux is not None:
                ymax1 = 1.1 * np.percentile(flux_slit_flux[0].data, 99)
                lineplot = LinePlot(
                    title="Calibrated flux vs. wavelength",
                    legend=True,
                    y_min=0,
                    y_max=ymax1,
                )
                lineplot.add_data(
                    d=[wavelength, flux_slit_flux[0].data],
                    color="red",
                    label="Calibrated flux",
                )
                lineplot.x_label = f"Wavelength ({unit_w})"
                lineplot.y_label = "Calibrated flux (erg/s/cm2/A)"

                panel.assign_plot(lineplot, 0, 1, xext=2)

            # Associated error vs. wavelength
            if arm == "NIR":
                ymax3 = 50 * np.median(flux_slit["ERRS"].data)
            else:
                ymax3 = 5 * np.median(flux_slit["ERRS"].data)
            lineplot2 = LinePlot(
                title="Associated error vs. wavelength",
                y_min=0,
                y_max=ymax3,
                legend=True,
            )
            lineplot2.add_data(
                d=[wavelength, flux_slit["ERRS"].data],
                color="red",
                label="Flux error",
            )
            lineplot2.x_label = f"Wavelength ({unit_w})"
            lineplot2.y_label = "Flux error ({})".format(
                fetch_kw_or_default(flux_slit["ERRS"], "BUNIT", "Unknown BUNIT")
            )
            panel.assign_plot(lineplot2, 0, 3, xext=2)

        return panels

    def generate_panels_merge2d(self):
        panels = {}
        vspace = 0.3
        panel = Panel(4, 2, height_ratios=[1, 4])
        m2d = self.hdus[0]["merge2d"]

        # Text Plot
        procatg = fetch_kw_or_error(m2d["PRIMARY"], "HIERARCH ESO PRO CATG")
        fname = os.path.basename(str(m2d.filename()))
        rawname = str(
            m2d["PRIMARY"].header.get(
                "HIERARCH ESO PRO REC1 RAW1 NAME", "Missing RAW1 NAME"
            )
        )
        t1 = TextPlot(columns=1, v_space=vspace)
        col1 = (
            str(m2d["PRIMARY"].header.get("INSTRUME", "Missing INSTRUME")),
            "EXTNAME: PRIMARY",
            "PRO CATG: " + procatg,
            "FILE NAME: " + fname,
            "RAW1 NAME: " + rawname,
        )
        t1.add_data(col1)
        panel.assign_plot(t1, 0, 0, xext=2)

        t2 = TextPlot(columns=1, v_space=vspace, xext=1)
        col2 = XshooterSetupInfo.specphot_star(m2d)
        t2.add_data(col2)
        panel.assign_plot(t2, 2, 0, xext=1)

        m2d_plot = ImagePlot(m2d[0].data, title="Merged 2D spectrum", aspect="auto")
        m2d_plot.tick_visibility = {
            "top": False,
            "labeltop": False,
            "bottom": True,
            "labelbottom": True,
            "right": True,
            "labelright": True,
            "left": True,
            "labelleft": True,
        }
        m2d_plot.cbar_kwargs = {"pad": 0.05}
        panel.assign_plot(m2d_plot, 0, 1, xext=4, yext=1)

        input_files = [m2d.filename()]
        rawname = rawname.removesuffix(".fits").removeprefix("XSHOO.")
        panels[panel] = {
            "report_name": f"XSHOOTER_{procatg.lower()}_{rawname}",
            "report_description": "StandardStar_merge2d",
            "report_tags": [],
            "input_files": input_files,
        }
        return panels

    def generate_panels(self, **kwargs):
        panels = {
            **self.generate_panels_std(),
            **self.generate_panels_merge2d(),
        }
        return panels


rep = XshooterSpecphotStdReport()
