# SPDX-License-Identifier: BSD-3-Clause
from adari_core.plots.panel import Panel
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.text import TextPlot
from adari_core.report import AdariReportBase
from adari_core.utils.utils import (
    fetch_kw_or_default,
    expand_ext_argument,
)

import os


class MasterOrderdefReport(AdariReportBase):
    def __init__(self, name: str):
        super().__init__(name)
        self.center_size = 200

    def remove_raw_scan(self, im_hdu):
        """
        Remove the pre/overscan regions from a raw image as required.

        Parameters
        ----------
        im_hdu : ImageHDU
            The input ImageHDU.

        Returns
        -------
        stripped_hdu : ImageHDU
            The image data from im_hdu, with the pre/overscan regions removed.

        """
        # Default option: no-op
        return im_hdu

    def parse_sof(self):
        raise NotImplementedError(
            "MasterOrderdefReport is a template only, "
            "the child Report is responsible for "
            "defining parse_sof"
        )

    def generate_panels(
        self,
        master_im_ext=0,
        raw_im_ext=0,
        master_im_rotation_kwargs={},
        raw_im_rotation_kwargs={},
        im_cut_direction="x",
        main_image="master",
        product_title="Master plot",
        raw_title="Raw image",
        im_clipping="sigma",
        im_n_clipping=2,
        **kwargs,
    ):
        # Input checking
        try:
            assert main_image == "master" or main_image == "raw", (
                f"Invalid value for main_image" f" ({main_image})"
            )
            assert im_cut_direction == "x" or im_cut_direction == "y", (
                f"Invalid value for " f"im_cut_direction " f"({im_cut_direction})"
            )
        except AssertionError as e:
            raise ValueError(str(e))

        panels = {}

        # Raw im ext - allow for single or multiple values
        raw_im_ext = expand_ext_argument(raw_im_ext, len(self.hdus), "raw_im_ext")
        master_im_ext = expand_ext_argument(
            master_im_ext, len(self.hdus), "master_im_ext"
        )

        for i, filedict in enumerate(self.hdus):
            order_table_hdul = filedict["order_table"]
            order_def_hdul = filedict["order_def"]

            order_table_procatg = fetch_kw_or_default(
                order_table_hdul["PRIMARY"],
                "HIERARCH ESO PRO CATG",
                default="MASTER-PRO-CATG-UNKNOWN",
            )
            # o for blue 1,2 for red to to automate
            order_table = order_table_hdul[master_im_ext[i]]
            order_def = order_def_hdul[raw_im_ext[i]]

            scaling = {
                "v_clip": im_clipping,
                "v_clip_kwargs": {"nsigma": im_n_clipping},
            }
            if main_image == "master":  # Only load image data if required
                master_plot = ImagePlot(
                    order_table.data,
                    title="master plot",
                    **scaling,
                    **master_im_rotation_kwargs,
                )
            raw_plot = ImagePlot(
                self.remove_raw_scan(order_def).data,
                title="raw plot file",
                **raw_im_rotation_kwargs,
            )

            p = Panel(4, 3, height_ratios=[1, 4, 4])
            # ext_stats = self.get_orderdef_stats(order_table_hdul, ext)

            # Text Plot
            px = 0
            py = 0
            vspace = 0.3
            t1 = TextPlot(columns=1, v_space=vspace)
            fname = os.path.basename(str(order_table_hdul.filename()))
            col1 = (
                str(order_table_hdul["PRIMARY"].header.get("INSTRUME")),
                "EXTNAME: "
                + str(order_table_hdul[master_im_ext[i]].header.get("EXTNAME", "N/A")),
                "PRO CATG: " + str(order_table_procatg),
                "FILE NAME: " + fname,
                "RAW1 NAME: "
                + str(
                    order_table_hdul["PRIMARY"].header.get(
                        "HIERARCH ESO PRO REC1 RAW1 NAME"
                    )
                ),
            )
            t1.add_data(col1)
            p.assign_plot(t1, px, py, xext=2)
            instru = order_table_hdul["PRIMARY"].header.get(
                "INSTRUME", "Name not found"
            )
            px = px + 2
            t2 = TextPlot(columns=1, v_space=vspace, xext=1)
            col2 = (
                "INS.MODE: "
                + str(order_table_hdul["PRIMARY"].header.get("HIERARCH ESO INS MODE")),
                "DET.BINX: "
                + str(order_table_hdul["PRIMARY"].header.get("HIERARCH ESO DET BINX")),
                "DET.BINY: "
                + str(order_table_hdul["PRIMARY"].header.get("HIERARCH ESO DET BINY")),
            )
            t2.add_data(col2)
            p.assign_plot(t2, px, py, xext=1)

            # Image Plots

            if main_image == "master":
                to_plot = master_plot
            else:
                to_plot = raw_plot
            p.assign_plot(to_plot, 0, 1)

            raw_center = CentralImagePlot(
                raw_plot, extent=self.center_size, title="raw center"
            )
            if main_image == "raw":
                to_plot_center = raw_center
            else:
                main_center = CentralImagePlot(
                    to_plot,
                    title="{} center".format(main_image),
                    extent=self.center_size,
                    **scaling,
                )
                to_plot_center = main_center
            p.assign_plot(to_plot_center, 0, 2)

            # Plot cuts of raw image
            if im_cut_direction == "x":
                cutpos = raw_plot.get_data_coord(raw_plot.data.shape[1] // 2, "x")
                cut_title = "raw column @ X={}".format(cutpos)
            else:
                cutpos = raw_plot.get_data_coord(raw_plot.data.shape[0] // 2, "y")
                cut_title = "raw row @ Y={}".format(cutpos)
            raw_cut = CutPlot(
                im_cut_direction,
                title=cut_title,
                y_label=fetch_kw_or_default(order_table, "BUNIT", default="ADU"),
            )
            raw_cut.add_data(raw_plot, cutpos, color="black", label="raw")
            p.assign_plot(raw_cut, 1, 1)

            raw_center_cut = CutPlot(
                im_cut_direction,
                title="Central region " + cut_title,
                y_label=fetch_kw_or_default(order_table, "BUNIT", default="ADU"),
            )
            raw_center_cut.add_data(raw_center, cutpos, color="black", label="raw")
            p.assign_plot(raw_center_cut, 1, 2)

            raw_hist = HistogramPlot(
                raw_data=raw_plot.data,
                title="raw value counts",
                bins=50,
                v_min=-5000,
                v_max=70000,
            )

            p.assign_plot(raw_hist, 2, 1)

            input_files = [
                filedict["order_table"].filename(),
                filedict["order_def"].filename(),
            ]
            panels[p] = {
                "order_table": order_table_hdul.filename(),
                "order_table_ext": order_table,
                "order_def": order_def_hdul.filename(),
                "hdus_i": i,
                "ext": master_im_ext[i],
                "raw_im_ext": raw_im_ext[i],
                "report_name": f"{instru}_{order_table_procatg.lower()}_{master_im_ext[i]}",
                "report_description": f"Order definition panel - ({order_table_hdul.filename()}, "
                f"{order_def_hdul.filename()}, "
                f"{master_im_ext[i]})",
                "report_tags": [],
                "input_files": input_files,
            }

        return panels
