import os
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.images import ImagePlot
from adari_core.plots.panel import Panel
from adari_core.plots.text import TextPlot
from adari_core.utils.utils import fetch_kw_or_default
from adari_core.utils.utils import format_kw_or_default
from adari_core.data_libs.master_spec_science import MasterSpecScienceReport

from .eris_utils import ErisReportMixin

class ErisIfuScienceReport(ErisReportMixin, MasterSpecScienceReport):
    def __init__(self):
        super().__init__("eris_ifu_science")
        self.spec_lw = 1.2

    def parse_sof(self):
        coadd = None
        coadd_nofluxcal = None
        spectrum = None
        spectrum_nofluxcal = None
        mean = None
        mean_nofluxcal = None
        exposure = None
        self.fluxcal = True

        # There are four possible inputs sets:
        # - OBJECT_CUBE_COADD_FLUXCAL,        OBJECT_CUBE_COADD_FLUXCAL_MEAN,        SPECTRUM_FLUXCAL -> Flux calibration, no DAR
        # - OBJECT_CUBE_COADD,                OBJECT_CUBE_MEAN,                     SPECTRUM         -> No Flux calib, no DAR
        # - DAR_CORRECTED_CUBE_COADD_FLUXCAL, DAR_CORRECTED_CUBE_COADD_FLUXCAL_MEAN, SPECTRUM_FLUXCAL -> Flux calib, DAR
        # - DAR_CORRECTED_CUBE_COADD,         DAR_CORRECTED_CUBE_MEAN,               SPECTRUM         -> No Flux calib, DAR
        for filename, catg in self.inputs:
            if catg == "OBJECT_CUBE_COADD_FLUXCAL":
                coadd = filename
            elif catg == "OBJECT_CUBE_COADD_FLUXCAL_MEAN":
                mean = filename
            elif catg == "OBJECT_CUBE_COADD":
                coadd_nofluxcal = filename
            elif catg == "OBJECT_CUBE_MEAN":
                mean_nofluxcal = filename
            elif catg == "DAR_CORRECTED_CUBE_COADD_FLUXCAL":
                coadd = filename
            elif catg == "DAR_CORRECTED_CUBE_COADD_FLUXCAL_MEAN":
                mean = filename
            elif catg == "DAR_CORRECTED_CUBE_COADD":
                coadd_nofluxcal = filename
            elif catg == "DAR_CORRECTED_CUBE_MEAN":
                mean_nofluxcal = filename
            elif catg == "SPECTRUM_FLUXCAL":
                spectrum = filename
            elif catg == "SPECTRUM":
                spectrum_nofluxcal = filename
            elif catg == "EXPOSURE_MAP":
                exposure = filename
        if coadd is None:
            coadd    = coadd_nofluxcal
            mean     = mean_nofluxcal
            spectrum = spectrum_nofluxcal
            self.fluxcal = False

        file_lists = []

        file_lists.append(
            {
                "coadd": coadd,
                "spectrum": spectrum,
                "mean": mean,
                "exposure": exposure,
            }
        )
        return file_lists

    def generate_panels(self, **kwargs):
        panels = {}

        science = self.hdus[0]["spectrum"]
        sci_data = science[1].data
        if self.fluxcal:
            wave_col = "wavelength"
            flux_col = "flux"
            err_col = "flux_error"
        else :
            wave_col = "WAVE"
            flux_col = "FLUX"
            err_col = "ERR"
        spec_wave = sci_data[wave_col].flatten()
        if self.fluxcal:
            spec_flux = sci_data[flux_col] * 1.e16   # TODO: No tfor fnoo flux
            spec_flux_unit =  "$F\\ /\\ \\mathrm{10^{-16}~erg/s/cm^2/\\AA}$"
        else :
            spec_flux = sci_data[flux_col].flatten()
            spec_flux_unit =  "$F$"
        snr = spec_flux / sci_data[err_col].flatten()

        p = Panel(6, 5, height_ratios=[1, 5, 1, 5, 1], y_stretch=0.84)

        # Create the main spectrum plot
        low_wave = low_flux = high_wave = high_flux = None
        binning = 0 
        specplot, lowhighplot = self.generate_spec_plot(
            spec_wave,
            spec_flux,
            spec_flux_unit,
            binning,
            low_wave,
            low_flux,
            high_wave,
            high_flux,
        )
        specplot.x_label = "lambda / microns"
        p.assign_plot(specplot, 0, 1, xext=6)

        # Generate S/N plot
        snrplot = self.generate_snr_plot(spec_wave, snr)
        snrplot.x_label = "lambda / microns"
        p.assign_plot(snrplot, 0, 2, xext=6)

        # White-light image
        scaling = {}
        scaling["v_clip"] = "percentile"
        scaling["v_clip_kwargs"] = {"percentile": 95.}

        image = self.hdus[0]["mean"][1].data
        full_image = ImagePlot(title="White-light image", **scaling)
        full_image.add_data(image)
        p.assign_plot(full_image, 0, 3, xext=2)

        # Exposure map
        exposure = self.hdus[0]["exposure"][0].data
        exp_image = ImagePlot(title="Exposure map")
        exp_image.add_data(exposure)
        p.assign_plot(exp_image, 2, 3, xext=2)

        # Histogram
        flux = self.hdus[0]["coadd"]["DATA"].data
        if self.fluxcal:
            flux *= 1.e16
            flux_label = "$F\\ /\\ \\mathrm{10^{-16}~erg/s/cm^2/\\AA}$"
        else :
            flux_label = "$F$"

        flux_hist = HistogramPlot(
            title="Flux histogram",
            bins=50,
            x_label=flux_label,
            legend=False,
            v_clip_kwargs={"nsigma": 3},
        )
        flux_hist.add_data(flux, label="data counts")
        p.assign_plot(flux_hist, 4, 3, xext=2)

        # Upper Text Plot
        vspace = 0.4
        coadd = self.hdus[0]["coadd"]
        t0 = TextPlot(columns=1, v_space=vspace)
        col0 = (
            str(fetch_kw_or_default(coadd["PRIMARY"], "INSTRUME", default="N/A"))
            + " science product preview",
            "Product: "
            + str(
                fetch_kw_or_default(coadd["PRIMARY"], "ESO PRO CATG", default="N/A")
            ),
            "Raw file: "
            + str(
                fetch_kw_or_default(
                    coadd["PRIMARY"], "ESO PRO REC1 RAW1 NAME", default="N/A"
                )
            ),
            "MJD-OBS: "
            + str(fetch_kw_or_default(coadd["PRIMARY"], "MJD-OBS", default="N/A")),
        )
        t0.add_data(col0, fontsize=13)
        p.assign_plot(t0, 0, 0, xext=1)

        t1 = TextPlot(columns=1, v_space=vspace)
        col1 = (
            "Target: "
            + str(fetch_kw_or_default(coadd["PRIMARY"], "OBJECT", default="N/A")),
            "OB ID: "
            + str(fetch_kw_or_default(coadd["PRIMARY"], "ESO OBS ID", default="N/A")),
            "OB NAME: "
            + str(
                fetch_kw_or_default(coadd["PRIMARY"], "ESO OBS NAME", default="N/A")
            ),
            "TPL ID: "
            + str(fetch_kw_or_default(coadd["PRIMARY"], "ESO TPL ID", default="N/A")),
            "RUN ID: "
            + str(
                fetch_kw_or_default(
                    coadd["PRIMARY"], "ESO OBS PROG ID", default="N/A"
                )
            ),
        )
        t1.add_data(col1, fontsize=13)
        p.assign_plot(t1, 2, 0, xext=1)

        t2 = TextPlot(columns=1, v_space=vspace)
        col2 = (
            "Spectral band: "
            + str(
                fetch_kw_or_default(coadd["PRIMARY"], "ESO INS3 SPGW NAME", default="N/A")
            ),
            "Spaxel size [mas]: "
            + str(
                fetch_kw_or_default(coadd["PRIMARY"], "ESO INS3 SPXW NAME", default="N/A")
            ),
            "AO mode: "
            + str(
                fetch_kw_or_default(coadd["PRIMARY"], "ESO OCS AOMODE", default="N/A")
            ),
        )
        t2.add_data(col2, fontsize=13)
        p.assign_plot(t2, 4, 0, xext=1)

        # Bottom Text Plot
        vspace = 0.4
        t4 = TextPlot(columns=1, v_space=vspace)
        col4 = ()
        if self.fluxcal:
            col4 += (
                "Exp. time [s] on target: "
                + format_kw_or_default(coadd["PRIMARY"], "TEXPTIME", "%.1f"),
                "N exposures on target: "
                + format_kw_or_default(coadd["PRIMARY"], "NCOMBINE", "%i"),
            )
        else :
            col4 += (
                "Exp. time [s] on target: "
                + format_kw_or_default(coadd["PRIMARY"], "EXPTIME", "%.1f"),
                "N exposures on target: "
                + format_kw_or_default(coadd["PRIMARY"], "ESO PRO DATANCOM", "%i"),
            )

        t4.add_data(col4, fontsize=13)
        p.assign_plot(t4, 0, 4, xext=1)

        input_files = [science.filename()]
        input_files.append(self.hdus[0]["mean"].filename())
        input_files.append(self.hdus[0]["exposure"].filename())
        input_files.append(self.hdus[0]["coadd"].filename())

        science_fname = os.path.basename(str(science.filename()))
        
        addme = {
            "report_name": f"ERIS-SPIFFIER_{str(science_fname).removesuffix('.fits').lower()}",
            "report_description": "Science panel",
            "report_tags": [],
            "report_prodcatg": "ANCILLARY.PREVIEW",
            "input_files": input_files,
        }

        panels[p] = addme

        return panels

rep = ErisIfuScienceReport()
