# SPDX-License-Identifier: BSD-3-Clause
from adari_core.report import AdariReportBase
from adari_core.plots.panel import Panel
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.points import LinePlot
from adari_core.utils.utils import fetch_kw_or_error, get_wavelength_from_header

import numpy as np
import os


class MasterSpecphotStdReport(AdariReportBase):
    def __init__(self, name: str):
        super().__init__(name)
        self.center_size = 200

    def parse_sof(self):
        raise NotImplementedError(
            "MasterSpecphotStdReport is a template only, "
            "the child Report is responsible for "
            "defining parse_sof"
        )

    def get_wave(self, h):
        # print(h.data.columns.names)
        try:
            if any("wave" in n.lower() for n in h.data.columns.names):
                w = [n for n in h.data.columns.names if "wave" in n.lower()][0]
                unit = str(h.header["TUNIT1"])
                if unit == "[angstrom]":
                    unit = "Angstrom"
                return h.data[w], unit
        except Exception:
            return get_wavelength_from_header(h)

    def generate_panels(
        self,
        raw_im_ext=0,
        plot_eff=True,
        eff_im_ext=0,
        eff_label="Eff",
        flux_ext=1,
        flux_ylim=None,
        raw_cut="y",
        raw_cut_i=1,
        raw_cut_centre=True,
        n_raw_ext=1,
        **kwargs,
    ):
        panels = {}

        # Raw im ext - allow for single or multiple values
        if isinstance(raw_im_ext, str) or isinstance(raw_im_ext, int):
            raw_im_ext = [(raw_im_ext,)] * len(self.hdus)

        if isinstance(raw_cut_i, int):
            raw_cut = [
                raw_cut,
            ] * n_raw_ext
            raw_cut_i = [
                raw_cut_i,
            ] * n_raw_ext

        try:
            assert sum([len(t) for t in raw_im_ext]) == len(self.hdus) * n_raw_ext, (
                "Not enough raw_im_exts " "values to match to all " "file lists"
            )
        except AssertionError as e:
            raise ValueError(str(e))

        for i, filedict in enumerate(self.hdus):

            std = filedict["std"]
            raw = filedict["raw"]

            p = Panel(2 + n_raw_ext, 4, height_ratios=[1, 4, 4, 4])

            # Eff vs. wavelength
            if plot_eff:
                eff = filedict["eff"]

                w, unit_w = self.get_wave(eff[1])
                eff_v_wavl = LinePlot(
                    title="Efficiency vs. wavelength",
                )
                eff_v_wavl.add_data(
                    [
                        w,
                        eff[1].data[eff_label],
                    ],
                    label="Efficiency",
                )
                eff_v_wavl.x_label = f"Wavelength ({unit_w})"
                eff_v_wavl.y_label = "Efficiency"
                p.assign_plot(eff_v_wavl, 0, 1, xext=2)

            # Extracted flux vs wavelength
            try:
                flux = std[flux_ext].data["flux"]
                w, unit_w = self.get_wave(std[flux_ext])
                unit_flux = str(std[flux_ext].header["TUNIT3"])
            except Exception:
                flux = std[flux_ext].data
                w, unit_w = self.get_wave(std[0])
                unit_flux = str(std[flux_ext].header["BUNIT"])

            if "e-" in unit_flux:
                unit_flux = "electrons"

            eflux_v_wavl = LinePlot(
                title="Extracted flux vs. wavelength",
            )

            if flux_ylim == "percentile":
                ymax = 1.1 * np.percentile(flux, 99)
            elif flux_ylim == "median":
                ymax = 2 * np.median(flux)

            eflux_v_wavl.add_data(
                [
                    w,
                    flux,
                ],
                label="Extracted flux",
            )
            eflux_v_wavl.x_label = f"Wavelength ({unit_w})"
            eflux_v_wavl.y_label = f"Extracted flux ({unit_flux})"
            if flux_ylim is not None:
                eflux_v_wavl.y_min = 0.0
                eflux_v_wavl.y_max = ymax
            p.assign_plot(eflux_v_wavl, 0, 2, xext=2)

            # Raw cuts and histograms
            for j, jext in enumerate(raw_im_ext[i]):
                exttitle = ""
                if n_raw_ext > 1:
                    exttitle = str(jext) + " "

                # Cross-dispersion cut of full raw file
                raw_data = raw[jext]
                raw_full = ImagePlot(
                    raw_data.data,
                    title=f"Raw file {exttitle}(full)",
                )

                cutpos = raw_full.get_data_coord(
                    raw_full.data.shape[raw_cut_i[j]] // 2, raw_cut[j]
                )
                cut_full = CutPlot(
                    raw_cut[j], title=f"Raw file {exttitle}(full)", y_label="ADU"
                )
                cut_full.add_data(
                    raw_full,
                    cut_pos=cutpos,
                    color="black",
                    label=f"row @ {raw_cut[j]}={cutpos}",
                )
                p.assign_plot(cut_full, 2 + j, 1)

                # Cross-dispersion cut of central region raw file
                raw_cent = CentralImagePlot(title=f"Raw file {exttitle}(cent.)")
                raw_cent.add_data(raw_full, extent=200)

                if raw_cut_centre:
                    cutpos = raw_cent.get_data_coord(
                        np.floor(raw_cent.data.shape[raw_cut_i[j]] // 2), raw_cut[j]
                    )
                else:
                    cutpos = raw_cent.get_data_coord(
                        np.floor(raw_cent.data.shape[raw_cut_i[j]] * 0.25), raw_cut[j]
                    )

                cut_cent = CutPlot(
                    raw_cut[j], title=f"Raw file {exttitle}(cent.)", y_label="ADU"
                )
                cut_cent.add_data(
                    raw_cent,
                    cut_pos=cutpos,
                    color="black",
                    label=f"row @ {raw_cut[j]}={cutpos}",
                )
                p.assign_plot(cut_cent, 2 + j, 2)

                # Histogram of full raw file
                hist_full = HistogramPlot(
                    title=f"Raw data {exttitle}histogram",
                    bins=50,
                    v_min=-5000,
                    v_max=70000,
                )
                hist_full.add_data(raw_data.data, label="Raw data counts")
                p.assign_plot(hist_full, 2 + j, 3)

            procatg = fetch_kw_or_error(std[0], "HIERARCH ESO PRO CATG")
            stdfile = os.path.basename(str(std.filename()))
            instru = std[0].header.get("INSTRUME")
            rawfile = os.path.basename(str(raw.filename()))
            rawfile = rawfile[rawfile.find(".") + 1 :].removesuffix(".fits")
            raw_im_ext_str = "_".join(map(str, raw_im_ext[i]))

            input_files = [std.filename(), raw.filename()]
            if plot_eff:
                input_files.append(eff.filename())

            panels[p] = {
                "raw": raw.filename(),
                "std": std.filename(),
                "raw_im_ext": raw_im_ext_str,
                "report_name": f"{instru}_{procatg.lower()}_{rawfile}_{raw_im_ext_str}",
                "report_description": f"Specphotometric Std panel - ( "
                f"{rawfile}, {stdfile})",
                "report_tags": [],
                "input_files": input_files,
            }
            if plot_eff:
                panels[p].update(
                    {
                        "eff": eff.filename(),
                    }
                )
        return panels
