import os
from adari_core.plots.panel import Panel
from adari_core.plots.text import TextPlot
from adari_core.utils.clipping import clipping_percentile
from adari_core.utils.utils import fetch_kw_or_default
from adari_core.utils.utils import format_kw_or_default
from adari_core.data_libs.master_ifu_spec_science import MasterIfuSpecScienceReport
from .kmos_utils import KmosReportMixin


class KmosScienceReport(KmosReportMixin, MasterIfuSpecScienceReport):
    def __init__(self):
        super().__init__("kmos_science")
        self.spec_lw = 1.2

    def parse_sof(self):
        coadd = []
        spectrum = []
        mean = []
        exposure = []

        # Collect all files of each type
        for filename, catg in self.inputs:
            if catg == "IDP_COMBINED_CUBE" and filename not in coadd:
                coadd.append(filename)
            elif catg == "COMBINED_IMAGE" and filename not in mean:
                mean.append(filename)
            elif catg == "IDP_QC_SPEC" and filename not in spectrum:
                spectrum.append(filename)
            elif catg == "EXP_MASK" and filename not in exposure:
                exposure.append(filename)

        file_lists = []

        for coadd_count, item in enumerate(coadd):
            file_lists.append(
                {
                "coadd%03i" % coadd_count: item,
                }
            )

        for spectrum_count, item in enumerate(spectrum):
            file_lists.append(
                {
                "spectrum%03i" % spectrum_count: item,
                }
            )

        for mean_count, item in enumerate(mean):
            file_lists.append(
                {
                "mean%03i" % mean_count: item,
                }
            )

        for exposure_count, item in enumerate(exposure):
            file_lists.append(
                {
                "exposure%03i" % exposure_count: item,
                }
            )

        return file_lists

    def generate_panels(self, **kwargs):
        panels = {}

        # Create a map per object.
        # The key is the science object name. It is known to be unique for each tuple of coadd, exposure, mean, spectrum
        # The value is yet another map with the type of file: "coadd", "exposure", "mean", "spectrum"
        object_map = {}
        for item in self.hdus:
            hdu = item[list(item.keys())[0]]
            sci_object = hdu[0].header["OBJECT"]
            file_type = list(item.keys())[0][:-3] # Trim the number in format %03i from the key name
            if not sci_object in object_map:
                object_map[sci_object] = {}
            object_map[sci_object][file_type] = hdu

        # Generate a panel for each science object
        for item in object_map.values():
            m, p =  self.generate_single_science_panel(item["spectrum"], item["exposure"], item["coadd"], item["mean"])
            panels[p] = m

        return panels

    def generate_single_science_panel(self, science, exposure, coadd, mean):

        sci_data = science[1].data
        spec_wave = sci_data["wavelength"]
        spec_flux = sci_data["flux"] * 1.0e20
        snr = sci_data["flux"] / sci_data["flux_error"]

        spec_flux_unit = "$F\\ /\\ \\mathrm{10^{-20}~erg/s/cm^2/\\AA}$"

        p = Panel(6, 5, height_ratios=[1, 5, 1, 5, 1], y_stretch=0.84)

        # Create the main spectrum plot
        low_wave = low_flux = high_wave = high_flux = None
        binning = 0

        specplot, lowhighplot = self.generate_spec_plot(
            spec_wave,
            spec_flux,
            spec_flux_unit,
            binning,
            low_wave,
            low_flux,
            high_wave,
            high_flux,
        )
        specplot.x_label = "lambda / microns"
        # Rescale y-axis
        n_points = spec_flux.size
        spec_flux1 = spec_flux[int(0.03 * n_points) : -int(0.03 * n_points)]
        y_max = clipping_percentile(spec_flux1, 98)[1] * 1.2
        y_min = min(clipping_percentile(spec_flux1, 98)[0] * 1.2, 0)
        specplot.set_ylim(y_min, y_max)

        p.assign_plot(specplot, 0, 1, xext=6)

        # Generate S/N plot
        snrplot = self.generate_snr_plot(spec_wave, snr)
        snrplot.x_label = "lambda / microns"
        p.assign_plot(snrplot, 0, 2, xext=6)

        # White-light image
        image = mean[1].data
        full_image = self.generate_image_plot(image)
        p.assign_plot(full_image, 0, 3, xext=2)

        # Exposure map
        exposure_data = exposure[1].data
        exp_image = self.generate_exposure_plot(exposure_data)
        p.assign_plot(exp_image, 2, 3, xext=2)

        # Histogram
        flux = coadd["DATA"].data
        flux *= 1.0e20
        flux_label = "$F\\ /\\ \\mathrm{10^{-20}~erg/s/cm^2/\\AA}$"
        flux_hist = self.generate_flux_histogram(flux, flux_label=flux_label)
        p.assign_plot(flux_hist, 4, 3, xext=2)

        # Upper Text Plot
        vspace = 0.4
        t0 = TextPlot(columns=1, v_space=vspace)
        col0 = (
            str(fetch_kw_or_default(coadd["PRIMARY"], "INSTRUME", default="N/A"))
            + " science product preview",
            "Product: "
            + str(fetch_kw_or_default(coadd["PRIMARY"], "ESO PRO CATG", default="N/A")),
            "Raw file: "
            + str(
                fetch_kw_or_default(
                    coadd["PRIMARY"], "ESO PRO REC1 RAW1 NAME", default="N/A"
                )
            ),
            "MJD-OBS: "
            + str(fetch_kw_or_default(coadd["PRIMARY"], "MJD-OBS", default="N/A")),
        )
        t0.add_data(col0, fontsize=13)
        p.assign_plot(t0, 0, 0, xext=1)

        t1 = TextPlot(columns=1, v_space=vspace)
        col1 = (
            "Target: "
            + str(fetch_kw_or_default(coadd["PRIMARY"], "OBJECT", default="N/A")),
            "OB ID: "
            + str(fetch_kw_or_default(coadd["PRIMARY"], "ESO OBS ID", default="N/A")),
            "OB NAME: "
            + str(fetch_kw_or_default(coadd["PRIMARY"], "ESO OBS NAME", default="N/A")),
            "TPL ID: "
            + str(fetch_kw_or_default(coadd["PRIMARY"], "ESO TPL ID", default="N/A")),
            "RUN ID: "
            + str(
                fetch_kw_or_default(coadd["PRIMARY"], "ESO OBS PROG ID", default="N/A")
            ),
        )
        t1.add_data(col1, fontsize=13)
        p.assign_plot(t1, 2, 0, xext=1)

        t2 = TextPlot(columns=1, v_space=vspace)
        col2 = (
            "Spectral band: "
            + str(
                fetch_kw_or_default(
                    coadd["PRIMARY"], "ESO INS GRAT1 NAME", default="N/A"
                )
            ),
        )
        t2.add_data(col2, fontsize=13)
        p.assign_plot(t2, 4, 0, xext=1)

        # Bottom Text Plot
        vspace = 0.4
        t4 = TextPlot(columns=1, v_space=vspace)
        col4 = (
            "Exp. time [s] on target: "
            + format_kw_or_default(coadd["PRIMARY"], "TEXPTIME", "%.1f"),
            "N exposures on target: "
            + format_kw_or_default(coadd["PRIMARY"], "NCOMBINE", "%i"),
            "ABMAG limit: "
            + format_kw_or_default(coadd["PRIMARY"], "ABMAGLIM", "%.2f"),
            "Resolving power: "
            + format_kw_or_default(coadd["PRIMARY"], "SPEC_RES", "%.2f"),
        )

        t4.add_data(col4, fontsize=13)
        p.assign_plot(t4, 0, 4, xext=1)

        t5 = TextPlot(columns=1, v_space=vspace)
        col5 = (
            "Lambda start [nm]: "
            + format_kw_or_default(coadd["PRIMARY"], "WAVELMIN", "%.2f"),
            "Lambda end [nm]: "
            + format_kw_or_default(coadd["PRIMARY"], "WAVELMAX", "%.2f"),
            "PWV tell [mm]: "
            + format_kw_or_default(coadd["PRIMARY"], "ESO QC H2O AVG", "%.2f"),
            "PWV header [mm]: "
            + format_kw_or_default(coadd["PRIMARY"], "ESO QC MEAS IWV", "%.2f"),
        )
        t5.add_data(col5, fontsize=13)
        p.assign_plot(t5, 2, 4, xext=1)

        t6 = TextPlot(columns=1, v_space=vspace)
        col6 = (
            "Delta t telluric [d]: "
            + format_kw_or_default(coadd["PRIMARY"], "ESO QC DELTA TIME TELL", "%.2f"),
            "Delta airmass tell: "
            + format_kw_or_default(coadd["PRIMARY"], "ESO QC AIRM DIFF", "%.2f"),
            "Delta t flat-field [d]: "
            + format_kw_or_default(coadd["PRIMARY"], "ESO QC DELTA TIME FLAT", "%.2f"),
        )
        t6.add_data(col6, fontsize=13)
        p.assign_plot(t6, 4, 4, xext=1)

        input_files = [science.filename()]
        input_files.append(mean.filename())
        input_files.append(exposure.filename())
        input_files.append(coadd.filename())

        science_fname = os.path.basename(str(science.filename()))

        metadata = {
            "report_name": f"KMOS_{str(science_fname).removesuffix('.fits').lower()}",
            "report_description": "Science panel",
            "report_tags": [],
            "report_prodcatg": "ANCILLARY.PREVIEW",
            "input_files": input_files,
        }

        return metadata, p

rep = KmosScienceReport()
