import os
import math
from adari_core.plots.combined import CombinedPlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.images import ImagePlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.panel import Panel
from adari_core.plots.points import LinePlot, ScatterPlot
from adari_core.plots.text import TextPlot
from adari_core.utils.utils import fetch_kw_or_default
from adari_core.utils.utils import read_idp_spectrum
from adari_core.utils.clipping import clipping_percentile
from adari_core.data_libs.master_spec_science import MasterSpecScienceReport

from .giraffe_utils import GiraffeReportMixin


class GiraffeScienceReport(GiraffeReportMixin, MasterSpecScienceReport):
    def __init__(self):
        super().__init__("giraffe_science")
        self.stacked = False
        self.spec_lw = 1.2

    def parse_sof(self):
        sci_spec = []
        stacked_spec = []
        sky = None
        raw = None
        sci = None

        for filename, catg in self.inputs:
            if catg == "SCIENCE_RBNSPEC_IDP" and filename not in sci_spec:
                sci_spec.append(filename)
            if catg == "ESOTK_SPECTRUM_IDP_FORMAT" and filename not in stacked_spec:
                stacked_spec.append(filename)
            if catg == "ANCILLARY.MOSSKY":
                sky = filename
            if catg == "SCIENCE":
                raw = filename
            if catg == "SCIENCE_RBNSPECTRA":
                sci = filename

        single_count = 0
        stacked_count = 0
        file_lists = []
        if len(stacked_spec) > 0:
            for item in stacked_spec:
                file_lists.append(
                    {
                        "combined%03i" % stacked_count: item,
                    }
                )
                stacked_count += 1
            for single in sci_spec:
                file_lists.append(
                    {
                        "science%03i" % single_count: single,
                    }
                )
                single_count += 1
            self.stacked = True
        else:
            for single in sci_spec:
                file_lists.append(
                    {
                        "single%03i" % single_count: single,
                        "sky": sky,
                        "raw": raw,
                        "sci": sci,
                    }
                )
                single_count += 1

        return file_lists

    def generate_panels(self, **kwargs):
        panels = {}

        # Collect data for SNR vs. counts plot
        if not self.stacked:
            snr_all = []
            mean_all = []
            for i, item in enumerate(self.hdus):
                single = self.hdus[i]["single%03i" % i]
                snr = fetch_kw_or_default(
                    single["PRIMARY"], "ESO QC SNR", default="N/A"
                )
                mean = fetch_kw_or_default(
                    single["PRIMARY"], "ESO QC MEAN RED", default="N/A"
                )
                snr_all.append(snr)
                mean_all.append(mean)

        if self.stacked:
            hdus = [v for v in self.hdus if any("combined" in s for s in v.keys())]
            hdus_single = [v for v in self.hdus if any("science" in s for s in v.keys())]
        else:
            hdus = self.hdus

        for i, item in enumerate(hdus):
            if self.stacked:
                science = hdus[i]["combined%03i" % i]
            else:
                science = hdus[i]["single%03i" % i]
            science_fname = os.path.basename(str(science.filename()))
       
            # Locate the science with the lowest and highest average
            if self.stacked:
                # 0. number of combined frames HIERARCH ESO PRO DATANCOM
                ncom = science[0].header.get("ESO PRO DATANCOM", 0)
                if ncom == 0:
                    raise ValueError("Number of combined frames not provided.")
                # 1. search header for "ESO PRO REC2 RAW? NAME"
                list_single_inputs = []
                for icombined in range(ncom):
                    input_name = science[0].header.get("ESO PRO REC2 RAW%s NAME"%str(icombined+1))
                    if input_name is None:
                        raise ValueError("Input name not provided in ESO PRO REC2 RAW? NAME.")
                    list_single_inputs.append(input_name)
                # 2. search sof files for this fiber data
                hdu_inputs = []
                # Use only unique names from list_single_inputs, hence the set
                for single_name in set(list_single_inputs) :
                    for isingle,ihdu in enumerate(hdus_single):
                        if single_name in ihdu["science%03i" % isingle].filename():
                            hdu_inputs.append(ihdu)
                # 2a. compare numbers of files in sof and in header
                if len(hdu_inputs) != ncom:
                    raise ValueError("Number of combined frames and number of files not equal.")
                # 3. Find the lowest and highest average 
                sorted_avg_flux = sorted(
                    [
                        v for v in hdu_inputs
                    ],  # Get the frames with single
                    key=lambda x: next(iter(x.values()), None)[0].header.get(
                        "ESO QC MEAN RED", -999.0
                    ),  # Order by average flux
                )
                lowest_flux_science = list(sorted_avg_flux[0].values())[
                    0
                ]  # Get the science with the lowest average flux
                highest_flux_science = list(sorted_avg_flux[-1].values())[
                    0
                ]  # Get the science with the highest average flux
                # 4. Find first raw file
                sorted_raw_name = sorted(
                    [
                        v for v in hdu_inputs
                    ],  
                    key=lambda x: next(iter(x.values()), None)[0].header.get(
                        "ESO PRO REC1 RAW1 NAME", ""
                    ),  # Order by raw file name
                )
                first_raw = list(sorted_raw_name[0].values())[0]

 
            spec_wave, spec_flux, spec_flux_unit = read_idp_spectrum(
                science["SPECTRUM"]
            )

            if self.stacked:
                self.snr = science["SPECTRUM"].data["SNR_REDUCED"][0]
                self.qual = science["SPECTRUM"].data["CONTRIB_REDUCED"][0]
            else:
                self.snr = science["SPECTRUM"].data["SNR"][0]
                sky = hdus[i]["sky"][0]
                skydata = sky.data.T
                n_sky = sky.data[0].size

            if self.stacked:
                p = Panel(6, 6, height_ratios=[1.5, 5, 4, 1, 1, 1.5], y_stretch=0.6)
            else:
                p = Panel(6, 6, height_ratios=[1, 5, 1, 1, 3, 1], y_stretch=0.7)

            binning = -1.0  # no average spectrum

            # Create the main spectrum plot
            low_wave = low_flux = high_wave = high_flux = None
            if self.stacked:
                low_wave, low_flux, _ = read_idp_spectrum(
                    lowest_flux_science["SPECTRUM"]
                )
                high_wave, high_flux, _ = read_idp_spectrum(
                    highest_flux_science["SPECTRUM"]
                )

            specplot, lowhighplot = self.generate_spec_plot(
                spec_wave,
                spec_flux,
                spec_flux_unit,
                binning,
                low_wave,
                low_flux,
                high_wave,
                high_flux,
            )

            if self.stacked:
                p.assign_plot(specplot, 0, 1, xext=6)
                p.assign_plot(lowhighplot, 0, 2, xext=6)
            else:
                skyplot = LinePlot(
                    title="",
                    legend=False,
                )
                skyplot.add_data(
                    d=[spec_wave, skydata[n_sky - 1]],
                    color="lightgreen",
                    linewidth=self.spec_lw,
                    label="Sky",
                )
                combined = CombinedPlot(
                    title="",
                    legend=True,
                )
                combined.add_data(skyplot, zorder=0)
                combined.add_data(specplot, zorder=1)

                p.assign_plot(combined, 0, 1, xext=6)

            # Generate S/N plot
            snrplot = self.generate_snr_plot(spec_wave, self.snr)
            p.assign_plot(snrplot, 0, 2 + self.stacked, xext=6)

            # Generate quality plot
            if self.stacked:
                qualplot = self.generate_qual_plot(spec_wave, self.qual)
                qualplot.y_label = "Contrib"
                p.assign_plot(qualplot, 0, 3 + self.stacked, xext=6)

            # Generate sky spectrum
            else:
                skyplot = LinePlot(
                    title="",
                    legend=False,
                    linewidth=self.spec_lw,
                )
                for i_sky in range(n_sky):
                    skyplot.add_data(
                        d=[spec_wave, skydata[i_sky]],
                        color="lightgreen",
                        label="sky" + str(i_sky),
                    )
                skyplot.x_label = "Wavelength (nm)"
                skyplot.y_label = "Sky"
                skyplot.set_xlim(spec_wave[0], spec_wave[-1])
                skyplot.set_ylim(
                    clipping_percentile(skydata, 98)[0],
                    clipping_percentile(skydata, 98)[1] * 2.5,
                )
                p.assign_plot(skyplot, 0, 3 + self.stacked, xext=6)

            # Add small plots

            # Cut in x direction from raw
            if not self.stacked:
                rawdata = hdus[i]["raw"][0].data
                fullfield = ImagePlot(
                    title="Raw",
                )
                fullfield.add_data(rawdata)
                cutX = CutPlot(
                    "y",
                    title="Central column",
                    x_label="x",
                    y_label="ADU",
                )
                cutX.add_data(
                    fullfield,
                    fullfield.get_data_coord(fullfield.data.shape[0] // 2, "y"),
                    label="Raw",
                    color="red",
                )
                p.assign_plot(cutX, 0, 4 + self.stacked, xext=2)

                # SNR vs. counts
                snr = fetch_kw_or_default(
                    science["PRIMARY"], "ESO QC SNR", default="N/A"
                )
                mean = fetch_kw_or_default(
                    science["PRIMARY"], "ESO QC MEAN RED", default="N/A"
                )
                scatter_snr = ScatterPlot(
                    title="SNR vs. counts",
                    x_label="MEAN RED",
                    y_label="SNR",
                    legend=False,
                )
                scatter_snr.add_data(
                    (mean_all, snr_all),
                    label="snr_all",
                    color="red",
                    markersize=1.0,
                )
                scatter_snr.add_data(
                    ([mean], [snr]),
                    label="snr",
                    color="black",
                    markersize=5.0,
                )
                scatter_snr.x_min = 0
                scatter_snr.y_min = 0
                p.assign_plot(scatter_snr, 2, 4 + self.stacked, xext=2)

                # Histogram SCIENCE_RBNSPECTRA
                scidata = hdus[i]["sci"][0].data
                hist = HistogramPlot(
                    master_data=scidata,
                    title="Product science RBNSPECTRA",
                    master_label="",
                    v_min=-50,
                    v_max=200,
                    legend=False,
                )
                p.assign_plot(hist, 4, 4 + self.stacked, xext=1)

                # Histogram raw
                hist = HistogramPlot(
                    master_data=rawdata,
                    title="Saturation limit",
                    master_label="raw",
                    v_min=55000,
                    v_max=70000,
                )
                p.assign_plot(hist, 5, 4 + self.stacked, xext=1)

            # Upper Text Plot
            vspace = 0.4

            if self.stacked:
                raw_name = fetch_kw_or_default(
                   first_raw["PRIMARY"], "ESO PRO REC1 RAW1 NAME", default="N/A"
                ) # Input with first raw file name in alphabetical order
                raw_label = "First raw file"
            else:
                raw_name = fetch_kw_or_default(
                   science["PRIMARY"], "ESO PRO REC1 RAW1 NAME", default="N/A"
                )
                raw_label = "Raw file"

            t0 = TextPlot(columns=1, v_space=vspace)
            col0 = (
                str(fetch_kw_or_default(science["PRIMARY"], "INSTRUME", default="N/A"))
                + " science product preview",
                "Product: "
                + str(
                    fetch_kw_or_default(
                        science["PRIMARY"], "ESO PRO CATG", default="N/A"
                    )
                ),
                raw_label +": " + str(raw_name),
                "MJD-OBS: "
                + str(
                    fetch_kw_or_default(science["PRIMARY"], "MJD-OBS", default="N/A")
                ),
            )
            t0.add_data(col0, fontsize=13)
            p.assign_plot(t0, 0, 0, xext=1)

            t1 = TextPlot(columns=1, v_space=vspace)
            col1 = (
                "Target: "
                + str(fetch_kw_or_default(science["PRIMARY"], "OBJECT", default="N/A")),
                "OB ID: "
                + str(
                    fetch_kw_or_default(science["PRIMARY"], "ESO OBS ID", default="N/A")
                ),
                "OB NAME: "
                + str(
                    fetch_kw_or_default(
                        science["PRIMARY"], "ESO OBS NAME", default="N/A"
                    )
                ),
                "TPL ID: "
                + str(
                    fetch_kw_or_default(science["PRIMARY"], "ESO TPL ID", default="N/A")
                ),
                "RUN ID: "
                + str(
                    fetch_kw_or_default(
                        science["PRIMARY"], "ESO OBS PROG ID", default="N/A"
                    )
                ),
            )
            t1.add_data(col1, fontsize=13)
            p.assign_plot(t1, 2, 0, xext=1)

            if self.stacked:
                fibre = str(
                    fetch_kw_or_default(science["PRIMARY"], "ESO STACK FPS1", default="N/A")
                )
            else:
                fibre = str(
                    fetch_kw_or_default(science["PRIMARY"], "FPS", default="N/A")
                )
            t2 = TextPlot(columns=1, v_space=vspace)
            col2 = (
                "Fibre number (FPS): " + fibre,
                "Slit name: "
                + str(
                    fetch_kw_or_default(
                        science["PRIMARY"], "ESO INS SLIT NAME", default="N/A"
                    )
                ),
                "Exposure mode: "
                + str(
                    fetch_kw_or_default(
                        science["PRIMARY"], "ESO INS EXP MODE", default="N/A"
                    )
                ),
                "Detector speed: "
                + str(
                    fetch_kw_or_default(
                        science["PRIMARY"], "ESO DET READ SPEED", default="N/A"
                    )
                ),
            )
            t2.add_data(col2, fontsize=13)
            p.assign_plot(t2, 4, 0, xext=1)

            # Bottom Text Plot
            vspace = 0.4
            t4 = TextPlot(columns=1, v_space=vspace)
            col4 = ()
            if self.stacked:
                col4 += (
                    "Total exp. time [s]: "
                    + "%.1f"
                    % fetch_kw_or_default(
                        science["PRIMARY"], "TEXPTIME", default="N/A"
                    ),
                    "N exposures (sqrt N): "
                    + "%i (%0.1f)"
                    % (
                        fetch_kw_or_default(
                            science["PRIMARY"], "NCOMBINE", default="N/A"
                        ),
                        math.sqrt(
                            fetch_kw_or_default(
                                science["PRIMARY"], "NCOMBINE", default="N/A"
                            )
                        ),
                    ),
                    "Avg. contribution: "
                    + "%.1f"
                    % fetch_kw_or_default(
                        science["PRIMARY"], "ESO QC CONTRIB AVG", default="N/A"
                    ),
                    "Avg. S/N: "
                    + "%.1f"
                    % fetch_kw_or_default(
                        science["PRIMARY"], "ESO QC SNR AVG", default="N/A"
                    ),
                )
            else:
                col4 += (
                    "Exp. time [s]: "
                    + "%.1f"
                    % fetch_kw_or_default(
                        science["PRIMARY"], "TEXPTIME", default="N/A"
                    ),
                    "N exposures: "
                    + "%i"
                    % fetch_kw_or_default(
                        science["PRIMARY"], "NCOMBINE", default="N/A"
                    ),
                    "S/N average: "
                    + "%.1f"
                    % fetch_kw_or_default(
                        science["PRIMARY"], "ESO QC SNR", default="N/A"
                    ),
                    "Flux average: "
                    + "%.1f"
                    % fetch_kw_or_default(
                        science["PRIMARY"], "ESO QC MEAN RED", default="N/A"
                    ),
                )
            t4.add_data(col4, fontsize=13)
            p.assign_plot(t4, 0, 5, xext=1)

            t5 = TextPlot(columns=1, v_space=vspace)
            col5 = ()
            if self.stacked:
                col5 += (
                    "Avg. S/N all single: "
                    + "%.1f"
                    % (
                        fetch_kw_or_default(
                            science["PRIMARY"], "ESO QC SNR AVG", default=0
                        )
                        / fetch_kw_or_default(science["PRIMARY"], "ESO QC SNR RATIO", 1)
                    ),
                    "Avg. S/N improvement: "
                    + "%.1f"
                    % fetch_kw_or_default(
                        science["PRIMARY"], "ESO QC SNR RATIO", "N/A"
                    ),
                    "Spectral resolution: "
                    + "%i"
                    % fetch_kw_or_default(
                        science["PRIMARY"], "SPEC_RES", default="N/A"
                    ),
                    "Seeing: "
                    + str(
                        fetch_kw_or_default(
                            science["PRIMARY"], "ESO QC FWHM", default="N/A"
                        )
                    ),
                )
            else:
                col5 += (
                    "Mag (user defined): "
                    + "%.1f"
                    % fetch_kw_or_default(
                        science["PRIMARY"], "ESO QC MAG", default="N/A"
                    ),
                    "Lambda bin [nm]: "
                    + "%.5f"
                    % fetch_kw_or_default(
                        science["PRIMARY"], "SPEC_BIN", default="N/A"
                    ),
                    "Spectral resolution: "
                    + "%i"
                    % fetch_kw_or_default(
                        science["PRIMARY"], "SPEC_RES", default="N/A"
                    ),
                    "Seeing: "
                    + "%.2f"
                    % fetch_kw_or_default(
                        science["PRIMARY"], "ESO QC FWHM", default="N/A"
                    ),
                )
            t5.add_data(col5, fontsize=13)
            p.assign_plot(t5, 2, 5, xext=1)

            t6 = TextPlot(columns=1, v_space=vspace)
            if self.stacked:
                col6 = (
                    "Lambda start [nm]: "
                    + "%.5f"
                    % fetch_kw_or_default(
                        science["PRIMARY"], "WAVELMIN", default="N/A"
                    ),
                    "Lambda end [nm]: "
                    + "%.5f"
                    % fetch_kw_or_default(
                        science["PRIMARY"], "WAVELMAX", default="N/A"
                    ),
                    "Lambda bin [nm]: "
                    + "%.5f"
                    % fetch_kw_or_default(
                        science["PRIMARY"], "SPEC_BIN", default="N/A"
                    ),
                )
            else:
                col6 = (
                    "Airmass: "
                    + "%.1f"
                    % fetch_kw_or_default(
                        science["PRIMARY"], "ESO QC AIRM", default="N/A"
                    ),
                    "Delta t calib. [d]: "
                    + "%.1f"
                    % fetch_kw_or_default(
                        science["PRIMARY"], "ESO QC DELTA TIME", default="N/A"
                    ),
                    "Delta T calib [C]: "
                    + "%.1f"
                    % fetch_kw_or_default(
                        science["PRIMARY"], "ESO QC DELTA TEMP", default="N/A"
                    ),
                )
            t6.add_data(col6, fontsize=13)
            p.assign_plot(t6, 4, 5, xext=1)

            if self.stacked:
                input_files = [
                    science.filename(),
                    lowest_flux_science.filename(),
                    highest_flux_science.filename(),
                ]
                if first_raw.filename() not in input_files:
                    input_files.append(first_raw.filename())
            else:
                input_files = [
                    science.filename(),
                    hdus[i]["sci"].filename(),
                    hdus[i]["sky"].filename(),
                    hdus[i]["raw"].filename(),
                ]
            addme = {
                "report_name": f"GIRAFFE_{str(science_fname).removesuffix('.fits').lower()}_{str(i+1)}",
                "report_description": "Science panel, fibre "+fibre,
                "report_tags": [],
                "report_prodcatg": "ANCILLARY.PREVIEW",
                "input_files": input_files,
            }

            panels[p] = addme

        return panels


rep = GiraffeScienceReport()
