# SPDX-License-Identifier: BSD-3-Clause
from adari_core.plots.panel import Panel
from adari_core.plots.points import LinePlot
from adari_core.report import AdariReportBase
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.text import TextPlot
from adari_core.plots.collapse import CollapsePlot
from adari_core.utils.utils import fetch_kw_or_default

import numpy as np


class MasterWaveCalReport(AdariReportBase):
    """Master report class for single-extension wavelength calibration reports of 3D instruments.
    Instruments:
     - KMOS
     - MUSE

     Description
     -----------
     Generate n reports to each detector (e.g., for MUSE n=24).
    """

    def __init__(self, name: str):
        super().__init__(name)

    def parse_sof(self):
        raise NotImplementedError(
            "MasterFlatMultiReport is a template only, "
            "the child Report is responsible for "
            "defining parse_sof"
        )

    def generate_single_panel(
        self,
        hdu,
        hist_clipping="sigma",
        hist_n_clipping=4,
        cut_pos=None,
    ):
        """Create a report from a given HDU."""

        panel = Panel(3, 3, height_ratios=[1, 4, 4])
        scaling = {
            "v_clip": "percentile",
            "v_clip_kwargs": {"percentile": 96},
        }
        master_procatg = fetch_kw_or_default(
            hdu, "HIERARCH ESO PRO CATG", "MASTER_IMAGE"
        )

        data = hdu.data
        # data = np.nan_to_num(data)
        master_plot = ImagePlot(
            data,
            title=f"{master_procatg}",
            interpolation="none",
            v_clip=scaling["v_clip"],
            v_clip_kwargs=scaling["v_clip_kwargs"],
        )
        panel.assign_plot(master_plot, 0, 1)

        master_center = CentralImagePlot(
            master_plot,
            extent=self.center_size,
            title=f"{master_procatg} \n" "central {}-px".format(self.center_size),
            interpolation="none",
            v_clip=scaling["v_clip"],
            v_clip_kwargs=scaling["v_clip_kwargs"],
        )
        panel.assign_plot(master_center, 0, 2)
        # Plot cuts,
        # plot X
        cutX = CutPlot("x", title="Central column", x_label="y")
        cutX.y_label = fetch_kw_or_default(hdu, "BUNIT", default="ADU")
        if cut_pos is None:
            cut_pos = master_plot.data.shape[1] // 2
        cutX.add_data(
            master_plot,
            cut_pos=master_plot.get_data_coord(cut_pos, "x"),
            label="master",
            color="red",
        )
        panel.assign_plot(cutX, 1, 1)

        # Cut plot Y
        # cutY = CutPlot('y', title="Central row", x_label="x")
        # cutY.y_label = fetch_kw_or_default(hdu, "BUNIT", default="ADU")
        # cutY.add_data(master_plot, master_plot.get_data_coord(
        #     master_plot.data.shape[0] // 2, "y"),
        #                 label="master",
        #                 color='red')
        # panel.assign_plot(cutY, 1, 1)
        # Do same for the central region
        # Cut plot X
        if cut_pos is None:
            cut_pos_cen = master_center.data.shape[1] // 2
        else:
            cut_shift = master_plot.data.shape[1] // 2 - cut_pos
            cut_pos_cen = master_center.data.shape[1] // 2 - cut_shift

        cutX_cent = CutPlot("x", title="Central region - central column", x_label="y")
        cutX_cent.y_label = fetch_kw_or_default(hdu, "BUNIT", default="ADU")
        cutX_cent.add_data(
            master_center,
            master_center.get_data_coord(cut_pos_cen, "x"),
            label="master",
            color="red",
        )
        panel.assign_plot(cutX_cent, 1, 2)

        # Cut plot Y
        # cutY_cent = CutPlot('y', title="Central region - central row",
        #                     x_label="x")
        # cutY_cent.y_label = fetch_kw_or_default(hdu, "BUNIT",
        #                                         default="ADU")
        # cutY_cent.add_data(master_center, master_center.get_data_coord(
        #     master_center.data.shape[0] // 2, "y"),
        #                 label="master",
        #                 color='red')
        # panel.assign_plot(cutY_cent, 1, 2)

        # Plot histogram
        scaling["v_clip"] = hist_clipping
        scaling["v_clip_kwargs"] = {"nsigma": hist_n_clipping}

        histogram = HistogramPlot(
            master_data=data,
            raw_data=None,
            bins=None,
            title="Counts histogram",
            **scaling,
        )
        panel.assign_plot(histogram, 2, 1)
        # Text Plot
        text = TextPlot(columns=2, v_space=0.3)
        panel.assign_plot(text, 0, 0, xext=2)

        # Collapse plot
        collapse_x = CollapsePlot(
            data,
            "x",
            title=f"{master_procatg} - collapse",
            color="black",
            x_label="pixels",
        )
        collapse_x.y_label = fetch_kw_or_default(hdu, "BUNIT", default="ADU")
        collapse_x.add_data(data, "y", color="red")

        panel.assign_plot(collapse_x, 2, 2)

        return panel

    def multipanel_plot(self, data, **kwargs):
        """Individual plot to be included in the multipanel.

        This function need to be overriden by the child classes.
        """
        if data is not None:
            mplot = ImagePlot(
                title=kwargs.get("title", ""),
                cbar_kwargs=kwargs.get("cbar_kwargs", {"pad": 0.1}),
                interpolation="none",
            )
            data = np.nan_to_num(data)
            mplot.add_data(data)
        else:
            # if no image data, setup a placeholder plot
            # TODO: implement empty plot
            mplot = LinePlot(
                legend=False,
                y_label="y",
                x_label="x",
                title=kwargs.get("title", "") + " (NO DATA)",
            )
        return mplot

    def generate_multipanel(self, panel, hdu_info, plot_kwargs):
        """This function generates a panel containing identical multiple plots.

        Parameters
        ----------
        hdu_info: List containing the HDUL position index, SOF keyword,
                  HDU index or extension and x, y position within the plot
                  to select from self.hdul
        plot_kwargs: List of arguments to be passed to multipanel_plot for
                     each individual panel.
        """
        for panel_data, panel_kwargs in zip(hdu_info, plot_kwargs):
            idx, hdul_name, hdu_idx, xpos, ypos = panel_data
            plot = self.multipanel_plot(
                self.hdus[idx][hdul_name][hdu_idx].data, **panel_kwargs
            )
            panel.assign_plot(plot, xpos, ypos)
        return panel

    def generate_panels(self, **kwargs):
        panels = {}
        return panels
