import os
import math
from adari_core.plots.panel import Panel
from adari_core.plots.text import TextPlot
from adari_core.utils.utils import fetch_kw_or_default
from adari_core.utils.utils import read_idp_spectrum
from adari_core.data_libs.master_spec_science import MasterSpecScienceReport

from .espresso_utils import EspressoReportMixin


class EspressoScienceReport(EspressoReportMixin, MasterSpecScienceReport):
    def __init__(self):
        super().__init__("espresso_science")
        self.stacked = False

    def parse_sof(self):
        sci_spec = []
        stacked_spec = None

        for filename, catg in self.inputs:
            if catg == "S1D_FINAL_A":
                sci_spec.append(filename)
            if catg == "ESOTK_SPECTRUM_IDP_FORMAT":
                stacked_spec = filename

        file_lists = []
        if stacked_spec is not None:
            file_lists.append(
                {
                    "science": stacked_spec,
                }
            )
            single_count = 0
            for single in sci_spec:
                file_lists.append(
                    {
                        "single%03i" % single_count: single,
                    }
                )
                single_count += 1
            self.stacked = True
        elif len(sci_spec) > 0:
            file_lists.append(
                {
                    "science": sci_spec[0],
                }
            )
        return file_lists

    def generate_panels(self, **kwargs):
        panels = {}

        science = self.hdus[0]["science"]
        first_single = self.hdus[0]["science"]
        if self.stacked:
            sorted_raw_name = sorted(
                [
                    v for v in self.hdus if any("single" in s for s in v.keys())
                ],  # Get the frames with single
                key=lambda x: next(iter(x.values()), None)[0].header.get(
                    "ESO PRO REC1 RAW1 NAME", ""
                ),  # Order by raw file name
            )
            first_single = list(sorted_raw_name[0].values())[
                0
            ]  # Input with first raw file name in alphabetical order
        science_fname = os.path.basename(str(science.filename()))

        # Locate the science with the lowest and highest average
        if self.stacked:
            sorted_avg_flux = sorted(
                [
                    v for v in self.hdus if any("single" in s for s in v.keys())
                ],  # Get the frames with single
                key=lambda x: next(iter(x.values()), None)[0].header.get(
                    "ESO QC SPEC AVG", -999.0
                ),  # Order by average flux
            )
            lowest_flux_science = list(sorted_avg_flux[0].values())[
                0
            ]  # Get the science with the lowest average flux
            highest_flux_science = list(sorted_avg_flux[-1].values())[
                0
            ]  # Get the science with the highest average flux

        spec_wave, spec_flux, spec_flux_unit = read_idp_spectrum(science["SPECTRUM"])
        self.snr = science["SPECTRUM"].data["SNR"][0]
        self.qual = science["SPECTRUM"].data["QUAL"][0]

        if self.stacked:
            p = Panel(6, 6, height_ratios=[1.5, 6, 4, 1, 3, 1.5], y_stretch=0.7)
        else:
            p = Panel(6, 5, height_ratios=[1, 6, 1, 3, 1], y_stretch=0.84)

        ins_mode = fetch_kw_or_default(
            science["PRIMARY"], "ESO INS MODE", default="N/A"
        )
        if ins_mode == "MULTIMR":
            binning = 0.4  # in nm
        else:
            binning = 0.2  # in nm

        # Create the main spectrum plot
        low_wave = low_flux = high_wave = high_flux = None
        if self.stacked:
            low_wave, low_flux, _ = read_idp_spectrum(lowest_flux_science["SPECTRUM"])
            high_wave, high_flux, _ = read_idp_spectrum(
                highest_flux_science["SPECTRUM"]
            )

        specplot, lowhighplot = self.generate_spec_plot(
            spec_wave,
            spec_flux,
            spec_flux_unit,
            binning,
            low_wave,
            low_flux,
            high_wave,
            high_flux,
        )
        p.assign_plot(specplot, 0, 1, xext=6)
        if self.stacked:
            p.assign_plot(lowhighplot, 0, 2, xext=6)

        # Generate S/N plot
        snrplot = self.generate_snr_plot(spec_wave, self.snr)
        p.assign_plot(snrplot, 0, 2 + self.stacked, xext=6)

        # Add spectral lines plots
        self.zoom_range = 2.0  # 1 nm to each side of the line
        zoomplots = self.generate_spec_zooms(
            spec_wave, spec_flux, spec_flux_unit 
        )
        xpos = 0
        for pl in zoomplots:
            p.assign_plot(pl, xpos, 3 + self.stacked, xext=2)
            xpos += 2

        # Upper Text Plot
        vspace = 0.4
        t0 = TextPlot(columns=1, v_space=vspace)
        col0 = (
            str(fetch_kw_or_default(science["PRIMARY"], "INSTRUME", default="N/A"))
            + " science product preview",
            "Product: "
            + str(
                fetch_kw_or_default(science["PRIMARY"], "ESO PRO CATG", default="N/A")
            ),
            "Raw file: "
            + str(
                fetch_kw_or_default(
                    first_single["PRIMARY"], "ESO PRO REC1 RAW1 NAME", default="N/A"
                )
            ),
            "MJD-OBS: "
            + str(fetch_kw_or_default(science["PRIMARY"], "MJD-OBS", default="N/A")),
        )
        t0.add_data(col0, fontsize=13)
        p.assign_plot(t0, 0, 0, xext=1)

        t1 = TextPlot(columns=1, v_space=vspace)
        col1 = (
            "Target: "
            + str(fetch_kw_or_default(science["PRIMARY"], "OBJECT", default="N/A")),
            "OB ID: "
            + str(fetch_kw_or_default(science["PRIMARY"], "ESO OBS ID", default="N/A")),
            "OB NAME: "
            + str(
                fetch_kw_or_default(science["PRIMARY"], "ESO OBS NAME", default="N/A")
            ),
            "TPL ID: "
            + str(fetch_kw_or_default(science["PRIMARY"], "ESO TPL ID", default="N/A")),
            "RUN ID: "
            + str(
                fetch_kw_or_default(
                    science["PRIMARY"], "ESO OBS PROG ID", default="N/A"
                )
            ),
        )
        t1.add_data(col1, fontsize=13)
        p.assign_plot(t1, 2, 0, xext=1)

        t2 = TextPlot(columns=1, v_space=vspace)
        spec_type = fetch_kw_or_default(
            science["PRIMARY"], "ESO OCS OBJ SP TYPE", default="N/A"
        )
        if self.stacked:
            fibre_b = str(
                fetch_kw_or_default(
                    science["PRIMARY"], "ESO STACK QC FIB B1", default="N/A"
                )
            )
        else:
            fibre_b = str(
                fetch_kw_or_default(science["PRIMARY"], "ESO QC FIB B", default="N/A")
            )
        col2 = (
            "INS MODE: "
            + str(
                fetch_kw_or_default(science["PRIMARY"], "ESO INS MODE", default="N/A")
            ),
            "Binning: "
            + str(
                fetch_kw_or_default(science["PRIMARY"], "ESO DET BINX", default="N/A")
            )
            + " x "
            + str(
                fetch_kw_or_default(science["PRIMARY"], "ESO DET BINY", default="N/A")
            ),
            "Telescope: "
            + str(fetch_kw_or_default(science["PRIMARY"], "TELESCOP", default="N/A"))
            .replace("ESO-VLT-", "")
            .replace("U", "UT"),
            "Fibre B: " + fibre_b,
            "Spec type: " + spec_type,
        )
        t2.add_data(col2, fontsize=13)
        p.assign_plot(t2, 4, 0, xext=1)

        # Bottom Text Plot
        vspace = 0.4
        t4 = TextPlot(columns=1, v_space=vspace)
        col4 = ()
        if self.stacked:
            col4 += (
                "Total exp. time [s]: "
                + "%.1f"
                % fetch_kw_or_default(science["PRIMARY"], "TEXPTIME", default="N/A"),
                "N exposures (sqrt N): "
                + "%i (%0.1f)"
                % (
                    fetch_kw_or_default(science["PRIMARY"], "NCOMBINE", default="N/A"),
                    math.sqrt(
                        fetch_kw_or_default(
                            science["PRIMARY"], "NCOMBINE", default="N/A"
                        )
                    ),
                ),
                "Avg. contribution: "
                + "%.1f"
                % fetch_kw_or_default(
                    science["PRIMARY"], "ESO QC CONTRIB AVG", default="N/A"
                ),
                "Avg. S/N stack: "
                + "%.1f"
                % fetch_kw_or_default(
                    science["PRIMARY"], "ESO QC SNR AVG", default="N/A"
                ),
            )
        else:
            col4 += (
                "Exp. time [s]: "
                + "%.1f"
                % fetch_kw_or_default(science["PRIMARY"], "TEXPTIME", default="N/A"),
                "N exposures: "
                + "%i"
                % fetch_kw_or_default(science["PRIMARY"], "NCOMBINE", default="N/A"),
                "S/N average: "
                + "%.1f"
                % fetch_kw_or_default(
                    science["PRIMARY"], "ESO QC SNR AVG", default="N/A"
                ),
                "Lambda bin [nm]: "
                + "%.5f"
                % fetch_kw_or_default(science["PRIMARY"], "SPEC_BIN", default="N/A"),
            )
        t4.add_data(col4, fontsize=13)
        p.assign_plot(t4, 0, 4 + self.stacked, xext=1)

        t5 = TextPlot(columns=1, v_space=vspace)
        col5 = ()
        if self.stacked:
            col5 += (
                "Seeing: "
                + "%.2f"
                % fetch_kw_or_default(
                    science["PRIMARY"], "ESO QC IA FWHM AVG", default="N/A"
                ),
                "First N pix sat: "
                + "%i"
                % fetch_kw_or_default(
                    first_single["PRIMARY"], "ESO QC SAT NB", default="N/A"
                ),
                "Avg. S/N all single: "
                + "%.1f"
                % (
                    fetch_kw_or_default(science["PRIMARY"], "ESO QC SNR AVG", default=0)
                    / fetch_kw_or_default(science["PRIMARY"], "ESO QC SNR RATIO", 1)
                ),
                "Avg. S/N improvement: "
                + "%.1f"
                % fetch_kw_or_default(science["PRIMARY"], "ESO QC SNR RATIO", "N/A"),
            )
        else:
            col5 += (
                "Seeing: "
                + "%.2f"
                % fetch_kw_or_default(
                    science["PRIMARY"], "ESO QC IA FWHM AVG", default="N/A"
                ),
                "N pix sat: "
                + "%i"
                % fetch_kw_or_default(
                    first_single["PRIMARY"], "ESO QC SAT NB", default="N/A"
                ),
                "Delta time wave A [d]: "
                + "%.2f"
                % fetch_kw_or_default(
                    first_single["PRIMARY"], "ESO QC DELTA TIME WAVE A", default="N/A"
                ),
            )
        t5.add_data(col5, fontsize=13)
        p.assign_plot(t5, 2, 4 + self.stacked, xext=1)

        t6 = TextPlot(columns=1, v_space=vspace)
        if spec_type[0] in ["F", "G", "K", "M"]:
            ccf_rv = "%.6f" % fetch_kw_or_default(
                first_single["PRIMARY"], "ESO QC CCF RV", default="N/A"
            )
            ccf_rv_err = "%.6f" % fetch_kw_or_default(
                first_single["PRIMARY"], "ESO QC CCF RV ERROR", default="N/A"
            )
            ccf_mask = fetch_kw_or_default(
                first_single["PRIMARY"], "ESO QC CCF MASk", default="N/A"
            )
        else:
            ccf_rv = "N/A"
            ccf_rv_err = "N/A"
            ccf_mask = "N/A"
        if self.stacked:
            col6 = (
                "First Delta time wave A [d]: "
                + "%.2f"
                % fetch_kw_or_default(
                    first_single["PRIMARY"], "ESO QC DELTA TIME WAVE A", default="N/A"
                ),
                "First CCF RV [km/s]: " + ccf_rv,
                "First CCF RV ERROR [km/s]: " + ccf_rv_err,
                "First CCF MASK: " + ccf_mask,
            )
        else:
            col6 = (
                "CCF RV [km/s]: " + ccf_rv,
                "CCF RV ERROR [km/s]: " + ccf_rv_err,
                "CCF MASK: " + ccf_mask,
            )
        t6.add_data(col6, fontsize=13)
        p.assign_plot(t6, 4, 4 + self.stacked, xext=1)
        input_files = [science.filename()]
        if self.stacked:
            input_files.append(first_single.filename())
            if lowest_flux_science.filename() not in input_files:
                input_files.append(lowest_flux_science.filename())
            if highest_flux_science.filename() not in input_files:
                input_files.append(highest_flux_science.filename())

        addme = {
            "report_name": f"ESPRESSO_{str(science_fname).removesuffix('.fits').lower()}",
            "report_description": "Science panel",
            "report_tags": [],
            "report_prodcatg": "ANCILLARY.PREVIEW",
            "input_files": input_files,
        }

        panels[p] = addme

        return panels


rep = EspressoScienceReport()
