# SPDX-License-Identifier: BSD-3-Clause
from adari_core.plots.panel import Panel
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.points import ScatterPlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.text import TextPlot
from adari_core.report import AdariReportBase

import numpy as np
import os


class MasterWaveReport(AdariReportBase):
    def __init__(self, name: str):
        super().__init__(name)
        self.central_region_size = 200
        # blue orders and red orders for espresso
        self.b_orders = [1, 90, 1, 45]
        self.r_orders = [91, 170, 46, 85]

    def remove_raw_scan(self, im_hdu):
        """
        Remove the pre/overscan regions from a raw image as required.

        Parameters
        ----------
        im_hdu : ImageHDU
            The input ImageHDU.

        Returns
        -------
        stripped_hdu : ImageHDU
            The image data from im_hdu, with the pre/overscan regions removed.

        """
        # Default option: no-op
        return im_hdu

    def parse_sof(self):
        raise NotImplementedError(
            "MasterWaveReport is a template only, "
            "the child Report is responsible for "
            "defining parse_sof"
        )

    def get_line_table_subset(self, plot, tab_data, hdr, is_red, set_ylim):
        # filter table
        subset = tab_data
        orders = self.r_orders if is_red else self.b_orders
        nmin = None
        nmax = None
        if "ESO INS MODE" in hdr:
            ins_mode = hdr["ESO INS MODE"]
            if ins_mode == "SINGLEHR" or ins_mode == "SINGLEUHR":
                mask = (
                    (subset["QC"] == 1)
                    & (subset["ORDER"] >= orders[0])
                    & (subset["ORDER"] <= orders[1])
                )
                nmin = orders[0]
                nmax = orders[1]
                # buffer of +-1 here to give some visual space
                if set_ylim:
                    plot.set_ylim(nmin - 1, nmax + 1)
                subset = subset[mask]
            if ins_mode == "MULTIMR":
                mask = (
                    (subset["QC"] == 1)
                    & (subset["ORDER"] >= orders[2])
                    & (subset["ORDER"] <= orders[3])
                )
                nmin = orders[2]
                nmax = orders[3]
                # buffer of +-1 here to give some visual space
                if set_ylim:
                    plot.set_ylim(nmin - 1, nmax + 1)
                subset = subset[mask]
        return (subset, nmin, nmax)

    def generate_panels(
        self, lplot_msize=0.05, splot_msize=4, rplot_msize=3, ext=0, **kwargs
    ):
        panels = {}

        for filedict in self.hdus:
            line_table_a_hdul = None
            line_table_a = None
            line_table_b_hdul = None
            line_table_b = None
            if "line_table_a" in filedict:
                line_table_a_hdul = filedict["line_table_a"]
                line_table_a = line_table_a_hdul[1]
            if "line_table_b" in filedict:
                line_table_b_hdul = filedict["line_table_b"]
                line_table_b = line_table_b_hdul[1]

            res_table_a_hdul = None
            res_table_a = None
            res_table_b_hdul = None
            res_table_b = None
            if "res_table_a" in filedict:
                res_table_a_hdul = filedict["res_table_a"]
                res_table_a = res_table_a_hdul[1]
            if "res_table_b" in filedict:
                res_table_b_hdul = filedict["res_table_b"]
                res_table_b = res_table_b_hdul[1]

            raw_hdul = None
            raw = None
            if "raw" in filedict:
                raw_hdul = filedict["raw"]
                raw = raw_hdul[ext]

            ncols = 5
            # in the FP_FP case, we have both line tables specified and less to plot
            if line_table_a and line_table_b:
                ncols = 4
            # line_table_a_procatg = fetch_kw_or_default(line_table_a_hdul[0],
            # "HIERARCH ESO PRO CATG",
            # None)
            p = Panel(ncols, 3, height_ratios=[1, 4, 4])

            # Text Plot
            px = 0
            py = 0
            # which hdul and ext to use
            mdata_hdul = line_table_a_hdul if line_table_a_hdul else line_table_b_hdul
            mdata_ext = 0
            vspace = 0.3

            t1 = TextPlot(columns=1, v_space=vspace)
            fname = os.path.basename(str(mdata_hdul.filename()))

            col1 = (
                str(mdata_hdul[mdata_ext].header.get("INSTRUME")),
                "EXTNAME: " + str(mdata_hdul[mdata_ext].header.get("EXTNAME", "N/A")),
                "PRO CATG: "
                + str(mdata_hdul[mdata_ext].header.get("HIERARCH ESO PRO CATG")),
                "FILE NAME: " + fname,
                "RAW1 NAME: "
                + str(
                    mdata_hdul["PRIMARY"].header.get("HIERARCH ESO PRO REC1 RAW1 NAME")
                ),
            )
            t1.add_data(col1)
            # give the raw for the textplot, since the ext does not exist for the line table?
            # ERROR in 2nd Table of Sect 1.4.4 of qc refs
            # if(raw_hdul):
            #    text.add_data(text.metadata(raw_hdul,ext))

            p.assign_plot(t1, px, py, xext=2)

            px = px + 2
            t2 = TextPlot(columns=1, v_space=vspace, xext=1)
            col2 = (
                "INS.MODE: "
                + str(mdata_hdul[mdata_ext].header.get("HIERARCH ESO INS MODE")),
                "DET.BINX: "
                + str(mdata_hdul[mdata_ext].header.get("HIERARCH ESO DET BINX")),
                "DET.BINY: "
                + str(mdata_hdul[mdata_ext].header.get("HIERARCH ESO DET BINY")),
            )
            t2.add_data(col2)
            p.assign_plot(t2, px, py, xext=1)

            # first work out some basics
            is_red = True if "red" in ext else False
            nmin_a = None
            nmin_b = None
            nmax_a = None
            nmax_b = None

            # Column 1
            # Used to keep track of position in Panel
            px = 0
            py = 1
            # First plot - scatter plot of the line tab (A)
            if line_table_a:
                fibre = "A"
                line_tab_hdr = line_table_a_hdul[0].header
                line_pos_plot_a = ScatterPlot(markersize=lplot_msize, legend=False)
                (subset, nmin_a, nmax_a) = self.get_line_table_subset(
                    line_pos_plot_a, line_table_a.data, line_tab_hdr, is_red, True
                )
                line_pos_plot_a.add_data(
                    np.asarray([subset["X0"], subset["ORDER"]]),
                    label="Line positions",
                    color="black",
                )
                line_pos_plot_a.x_label = "Cross-dispersion [pix]"
                line_pos_plot_a.y_label = "Order number"
                line_pos_plot_a.title = "Line positions fibre {}".format(fibre)
                p.assign_plot(line_pos_plot_a, px, py)

            if line_table_b:
                # if we have the case where line_table_a is also available (both a and b are plotted)
                if line_table_a:
                    py = py + 1
                fibre = "B"
                line_tab_hdr = line_table_b_hdul[0].header
                line_pos_plot_b = ScatterPlot(markersize=lplot_msize, legend=False)
                (subset, nmin_b, nmax_b) = self.get_line_table_subset(
                    line_pos_plot_b, line_table_b.data, line_tab_hdr, is_red, True
                )
                line_pos_plot_b.add_data(
                    np.asarray([subset["X0"], subset["ORDER"]]),
                    label="Line positions",
                    color="black",
                )
                line_pos_plot_b.x_label = "Cross-dispersion [pix]"
                line_pos_plot_b.y_label = "Order number"
                line_pos_plot_b.title = "Line positions fibre {}".format(fibre)
                p.assign_plot(line_pos_plot_b, px, py)

            if res_table_a or res_table_b:
                if line_table_a or line_table_b:
                    py = py + 1
                fibre = "A" if res_table_a else "B"
                res_tab_hdr = (
                    res_table_a_hdul[0].header
                    if res_table_a
                    else res_table_b_hdul[0].header
                )
                res_tab = res_table_a if res_table_a else res_table_b
                res_power_plot = ScatterPlot(markersize=rplot_msize, legend=False)
                (subset, nmin_r, nmax_r) = self.get_line_table_subset(
                    res_power_plot, res_tab.data, res_tab_hdr, is_red, False
                )
                res_power_plot.add_data(
                    np.asarray([subset["X0"], subset["RESOLUTION"]]),
                    label="Res. power",
                    color="black",
                )
                res_power_plot.x_label = "Cross-dispersion [pix]"
                res_power_plot.y_label = "Resolving power"
                res_power_plot.title = "Resolving power fibre " + fibre
                p.assign_plot(res_power_plot, px, py)

            # Column 2
            px = px + 1
            py = 1

            img_raw = ImagePlot(title="Raw input")
            img_raw.add_data(self.remove_raw_scan(raw).data)
            p.assign_plot(img_raw, px, py)

            py = py + 1
            img_cent = CentralImagePlot(title="Raw input - central region")
            img_cent.add_data(img_raw, extent=self.central_region_size)
            p.assign_plot(img_cent, px, py)

            # Column 3
            px = px + 1
            py = 1
            cutpos = img_raw.get_data_coord(img_raw.data.shape[0] // 4, "x")
            cut_raw = CutPlot(
                "y",
                title="Cross-dispersion raw cut @ Y={}".format(cutpos),
                x_label="x",
                y_label="ADU",
            )
            cut_raw.add_data(img_raw, cut_pos=cutpos, label="raw", color="black")
            p.assign_plot(cut_raw, px, py)

            py = py + 1
            cutpos = img_cent.get_data_coord(
                np.floor(img_cent.data.shape[1] * 0.25), "y"
            )
            cut_raw_cent = CutPlot(
                "y",
                title="Cross-dispersion raw cut centre @ Y={}".format(cutpos),
                x_label="x",
                y_label="ADU",
            )
            cut_raw_cent.add_data(img_cent, cutpos, label="raw", color="black")
            p.assign_plot(cut_raw_cent, px, py)

            # Column 4
            px = px + 1
            py = 1
            hist_raw = HistogramPlot(
                title="Raw file", bins=50, v_min=-5000, v_max=70000
            )
            hist_raw.add_raw_data(img_raw.data)
            p.assign_plot(hist_raw, px, py)

            py = py + 1
            # Plot some qc statistics if relevant
            # is_thar = False
            is_lfc = False
            if not (line_table_a and line_table_b):
                linesnb_x = []
                linesnb_y = []
                chi2_x = []
                chi2_y = []
                rms_x = []
                rms_y = []
                nmin = nmin_a if nmin_a else nmin_b
                nmax = nmax_a if nmax_a else nmax_b
                line_tab_hdr = (
                    line_table_a_hdul[0].header
                    if line_table_a
                    else line_table_b_hdul[0].header
                )
                # populate the data arrays to plot from the header
                for n in np.arange(nmin, nmax + 1):
                    fitkw = "ESO QC ORDER{} LFC FIT WAVE".format(n)
                    if fitkw in line_tab_hdr:
                        is_lfc = True
                    lnbkw = "ESO QC ORDER{} LFC LINES NB".format(n)
                    chi2kw = "ESO QC ORDER{} CHI2".format(n)
                    rmskw = "ESO QC ORDER{} RMS".format(n)
                    if fitkw in line_tab_hdr and line_tab_hdr[fitkw] == 1:
                        # LFC case
                        if lnbkw in line_tab_hdr:
                            linesnb_x.append(n)
                            linesnb_y.append(line_tab_hdr[lnbkw])
                        if chi2kw in line_tab_hdr:
                            chi2_x.append(n)
                            chi2_y.append(line_tab_hdr[chi2kw])
                        if rmskw in line_tab_hdr:
                            rms_x.append(n)
                            rms_y.append(line_tab_hdr[rmskw])
                    elif fitkw not in line_tab_hdr:
                        # THAR case (will be missing LFC keywords)
                        lnbkw = "ESO QC ORDER{} FP LINES NB".format(n)
                        if lnbkw in line_tab_hdr:
                            linesnb_x.append(n)
                            linesnb_y.append(line_tab_hdr[lnbkw])
                        if chi2kw in line_tab_hdr:
                            chi2_x.append(n)
                            chi2_y.append(line_tab_hdr[chi2kw])
                        if rmskw in line_tab_hdr:
                            rms_x.append(n)
                            rms_y.append(line_tab_hdr[rmskw])
                # create the plots
                lines_nb_plot = ScatterPlot(markersize=splot_msize, legend=False)
                nb_label = "LFC LINES NB" if is_lfc else "FP LINES NB"
                lines_nb_plot.add_data(
                    np.asarray([linesnb_x, linesnb_y]), label=nb_label, color="black"
                )
                lines_nb_plot.x_label = "Order number"
                lines_nb_plot.y_label = nb_label
                lines_nb_plot.title = lines_nb_plot.y_label
                lines_nb_plot.set_xlim(nmin - 1, nmax + 1)
                p.assign_plot(lines_nb_plot, px, py)

                # Column 5
                px = px + 1
                py = 1
                lines_chi2_plot = ScatterPlot(markersize=splot_msize, legend=False)
                lines_chi2_plot.add_data(
                    np.asarray([chi2_x, chi2_y]), label="CHI2", color="black"
                )
                lines_chi2_plot.x_label = "Order number"
                lines_chi2_plot.y_label = "CHI2"
                lines_chi2_plot.title = lines_chi2_plot.y_label
                lines_chi2_plot.set_xlim(nmin - 1, nmax + 1)
                p.assign_plot(lines_chi2_plot, px, py)

                py = py + 1

                lines_rms_plot = ScatterPlot(markersize=splot_msize, legend=False)
                lines_rms_plot.add_data(
                    np.asarray([rms_x, rms_y]), label="RMS", color="black"
                )
                lines_rms_plot.x_label = "Order number"
                lines_rms_plot.y_label = "RMS"
                lines_rms_plot.title = lines_rms_plot.y_label
                lines_rms_plot.set_xlim(nmin - 1, nmax + 1)
                p.assign_plot(lines_rms_plot, px, py)

            rname = "FP"
            if not (line_table_a and line_table_b):
                if is_lfc:
                    rname = "LFC"
                else:
                    rname = "THAR"

            input_files = []
            if line_table_a_hdul is not None:
                input_files.append(line_table_a_hdul.filename())
            if line_table_b_hdul is not None:
                input_files.append(line_table_b_hdul.filename())
            if res_table_a_hdul is not None:
                input_files.append(res_table_a_hdul.filename())
            if res_table_b_hdul is not None:
                input_files.append(res_table_b_hdul.filename())
            if raw_hdul is not None:
                input_files.append(raw_hdul.filename())

            addme = {
                "ext": ext,
                "report_name": f"espresso_{rname.lower()}_{ext}",
                "report_description": f"Wavelength calibration panel" f"{ext})",
                "report_tags": [],
                "input_files": input_files,
            }
            if line_table_a:
                addme["line_table_a"] = line_table_a_hdul.filename()
            if line_table_b:
                addme["line_table_b"] = line_table_b_hdul.filename()
            if res_table_a:
                addme["res_table_a"] = res_table_a_hdul.filename()
            if res_table_b:
                addme["res_table_b"] = res_table_b_hdul.filename()

            panels[p] = addme
        return panels
