# SPDX-License-Identifier: BSD-3-Clause
from adari_core.report import AdariReportBase
from adari_core.plots.text import TextPlot
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.panel import Panel
from adari_core.utils.utils import fetch_kw_or_error
import os


class MasterEchelleFlatfieldReport(AdariReportBase):
    """Master ADARI Echelle Flatfield report.

    Description
    -----------
    Master report for building the echelle flatfield report for VLT echelle spectrograph instruments.
    The report consists of a 2D image of the master product, a 2D image of the central 200 pixels of the master product,
    cut plots of the master product image and the central 200 pixels,
    a histogram of the raw pixel data, and a detailed histogram of the master product.
    """

    def __init__(self, name: str):
        super().__init__(name)
        self.center_size = 200
        self.panel_kwargs = dict(x=4, y=3, height_ratios=[1, 4, 4])

    def parse_sof(self):
        raise NotImplementedError(
            "MasterEchelleFlatFieldReport is a template "
            "only, the child Report is responsible for "
            "defining parse_sof"
        )

    def generate_panels(
        self,
        raw_ext=0,
        master_product_ext="PRIMARY",
        direction="y",
        interpolation=None,
        **kwargs,
    ):
        """Create a panel of images from a given list of HDU.

        Description
        -----------
        Given a list of HDU, this function will create a panel that
        contains a mosaic of all required 2D images and plots.

        Parameters
        ----------
        - raw_ext: int
        Extension of the raw image, required to make the 2D image plots and cut plots of the raw files.
        - master_product_ext: str
        Extension of the master product, required to make the 2D image plots and cut plots of the master product files.
        - direction: str
        Direction of the 2D image in which the cut plots are made. Can be "x" or "y".

        Returns
        -------
        - panel: Panel
        A panel containing all the required plots.
        """

        panels = {}
        if type(raw_ext) is not list:
            raw_ext = [raw_ext] * len(self.hdus)

        if direction == "y":
            k = 0
        else:
            k = 1

        for r_ext, filedict in zip(raw_ext, self.hdus):
            master_product = filedict["master_product"]
            raw = filedict["raw"]
            hdr = master_product[master_product_ext].header

            master_product_procatg = fetch_kw_or_error(
                master_product[0], "HIERARCH ESO PRO CATG"
            )

            # p = Panel(3, 3, height_ratios=[1, 4, 4])
            p = Panel(**self.panel_kwargs)
            # Get the position value of the extension
            # Reset to start of row for next panels
            py = 0
            px = 0

            vspace = 0.3

            t1 = TextPlot(columns=1, v_space=vspace)
            fname = master_product.filename()
            col1 = (
                str(hdr.get("INSTRUME")),
                # "EXTNAME: " + str(self.hdus[i]['master_product']["PRIMARY"]
                #                   .header.get("EXTNAME", "N/A")),
                "PRO CATG: " + str(hdr.get("HIERARCH ESO PRO CATG")),
                "FILE NAME: " + os.path.basename(fname),
                "RAW1 NAME: " + str(hdr.get("HIERARCH ESO PRO REC1 RAW1 NAME")),
            )
            t1.add_data(col1)
            p.assign_plot(t1, px, py, xext=2)

            py = py + 1
            # Master Plots
            # TODO overplot order solution
            master_plot = ImagePlot(
                master_product[master_product_ext].data, title=master_product_procatg
            )
            master_plot.interp = interpolation
            p.assign_plot(master_plot, px, py)

            # Get central image
            # TODO overplot order solution
            master_center = CentralImagePlot(
                master_plot,
                title=f"{master_product_procatg} " f"center",
                extent=self.center_size,
            )
            p.assign_plot(master_center, px, py + 1)

            # Plot cut of master
            cutpos1 = master_plot.get_data_coord(
                master_plot.data.shape[k] // 2, direction
            )
            master_cut = CutPlot(
                direction,
                title="master row @ {} {}".format(direction, cutpos1),
                y_label="normalised counts",
            )
            master_cut.add_data(master_plot, cutpos1, color="red", label="master")
            p.assign_plot(master_cut, px + 1, py)

            # center cut plot
            # Get the max port
            # N = max([value for keyword, value in
            #          master_product[master_product_ext].header["*DET OUT*INDEX"].items()])
            # central_port_number = N//2

            cutpos2 = master_center.get_data_coord(
                master_center.data.shape[k] // 2, direction
            )
            master_cen_cut = CutPlot(
                direction,
                title="Central Region: row @ {} {}".format(direction, cutpos1),
                y_label="normalised counts",
            )
            master_cen_cut.add_data(master_center, cutpos2, label="master", color="red")
            p.assign_plot(master_cen_cut, px + 1, py + 1)

            raw_hist = HistogramPlot(
                raw_data=raw[r_ext].data,
                title="raw value counts",
                x_label="counts",
                v_min=-5000,
                v_max=70000,
            )
            p.assign_plot(raw_hist, px + 2, py)

            master_hist = HistogramPlot(
                master_data=master_plot.data,
                title="master value counts",
                bins=50,
                x_label="normalised counts",
            )
            p.assign_plot(master_hist, px + 2, py + 1)

            input_files = [
                filedict["master_product"].filename(),
                filedict["raw"].filename(),
            ]
            panels[p] = {
                "master_product": master_product.filename(),
                "master_product_ext": master_product_ext,
                "raw": raw.filename(),
                "raw_ext": r_ext,
                "report_name": f"{self.name}_{master_product_procatg}_{master_product_ext}",
                "report_description": f"{self.name} - {master_product_procatg}, "
                f"{master_product_ext}",
                "report_tags": [],
                "input_files": input_files,
            }

        return panels
