# SPDX-License-Identifier: BSD-3-Clause
from adari_core.plots.points import LinePlot
from adari_core.plots.text import TextPlot
from adari_core.report import AdariReportBase
from adari_core.plots.panel import Panel
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.collapse import CollapsePlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.utils.utils import fetch_kw_or_default
import numpy as np


class MasterIfuFlatReport(AdariReportBase):
    """Master report class for Lamp Flat reports of 3D instruments.
    Instruments:
     - KMOS
     - MUSE

     Description
     -----------
     Generate multi-extension reports containing `n` individual panels corresponding to each detector
     (e.g., for MUSE n=24) and single-extension reports containing information about a specific extention.
    """

    files_needed = {"MASTER_FLAT": None}
    detectors = None
    n_panels = None

    def __init__(self, name: str):
        super().__init__(name)

    def parse_sof(self):
        raise NotImplementedError(
            "MasterFlatMultiReport is a template only, "
            "the child Report is responsible for "
            "defining parse_sof"
        )

    def single_panel(
        self,
        hdu,
        master_im_clipping=None,
        master_im_clipping_kwargs=None,
        master_im_limits=(0.5, 1.5),
        cutout_clipping=None,
        cutout_clipping_kwargs=None,
        cutout_limits=(0.5, 1.5),
        hist_clipping="sigma",
        hist_clipping_kwargs=dict(nsigma=4),
        panel_args=dict(x=4, y=3, height_ratios=[1, 4, 4]),
    ):
        """Create a report from a given HDU.

        Description
        -----------
        This method constructs a report from an input HDU consisting of:
        - Empty text plot (expected to be filled by the instrument report).
        - ImagePlot for the whole extent of the 2D image.
        - CentralImagePlot for the central regions of the image.
        - CutPlot along x direction of the 2D image.
        - CutPlot along y direction of the 2D image.
        - CutPlot along x direction of the cutout image.
        - CutPlot along y direction of the cutout image.
        - Histogram of the 2D image.
        - CollapsedPlot of the 2D image.

        Parameters
        ----------
        - master_im_clipping: (str, default=None) Clipping method of the 2D image plot.
        - master_im_clipping_kwargs: (dict, defaul=None) Dictionary containing the clipping parameters.
        - master_im_limits: (tuple, default=(0.5, 1.5)) minimum and maximum values of the ImagePlot.
        - cutout_clipping: (str, default=None) Same as master_im_clipping for the cutout plot.
        - cutout_clipping_kwargs: (dict, defaul=None) Same as master_im_clipping_kwargs for the cutout plot.
        - cutout_limits: (tuple, default=(0.5, 1.5)) Minimum and maximum values of the ImagePlot colorbar.
        - hist_clipping: (str, defaul='sigma') Clipping method for the Histogram plot.
        - hist_clipping_kwargs: (dict, defalt=dict(nsigma=4)) Clipping argments for Histogram plot.
        - panel_kwargs: (dict, default=dict(x=4, y=3, height_ratios=[1,4,4])) Dictionary containin
        the panel arguments.
        """
        master_procatg = fetch_kw_or_default(
            hdu, "HIERARCH ESO PRO CATG", "MASTER_IMAGE"
        )
        data = hdu.data
        panel = Panel(**panel_args)
        # ----- Empty text plot -----
        text = TextPlot(columns=2, v_space=0.3)
        panel.assign_plot(text, 0, 0, xext=2)
        # ----- 2D image  -----
        if master_im_clipping is not None and master_im_clipping_kwargs is not None:
            image_scaling = {
                "v_clip": master_im_clipping,
                "v_clip_kwargs": master_im_clipping_kwargs,
            }
        else:
            image_scaling = {}

        master_plot = ImagePlot(
            data, title=f"{master_procatg}", interpolation="none", **image_scaling
        )
        if master_im_limits is not None:
            master_plot.set_vlim(*master_im_limits)

        panel.assign_plot(master_plot, 0, 1)

        # ----- 2D image cutout -------
        if cutout_clipping is not None and cutout_clipping_kwargs is not None:
            cutout_image_scaling = {
                "v_clip": cutout_clipping,
                "v_clip_kwargs": cutout_clipping_kwargs,
            }
        else:
            cutout_image_scaling = {}
        cutout_plot = CentralImagePlot(
            master_plot,
            extent=self.center_size,
            title=f"{master_procatg} \n" "central {}-px".format(self.center_size),
            interpolation="none",
            **cutout_image_scaling,
        )
        if cutout_limits is not None:
            cutout_plot.set_vlim(*cutout_limits)
        panel.assign_plot(cutout_plot, 0, 2)

        # ----- 2D image cut plot Y -----
        cutX = CutPlot("x", title="Central column", x_label="y")
        cutX.y_label = fetch_kw_or_default(hdu, "BUNIT", default="ADU")
        cutX.add_data(
            master_plot,
            cut_pos=master_plot.get_data_coord(master_plot.data.shape[1] // 2, "x"),
            label="master",
            color="red",
        )
        panel.assign_plot(cutX, 1, 1)
        # ----- 2D image cut plot X -----
        cutY = CutPlot("y", title="Central row", x_label="x")
        cutY.y_label = fetch_kw_or_default(hdu, "BUNIT", default="ADU")
        cutY.add_data(
            master_plot,
            master_plot.get_data_coord(master_plot.data.shape[0] // 2, "y"),
            label="master",
            color="red",
        )
        panel.assign_plot(cutY, 2, 1)
        # ----- 2D image cutout cut plot Y -----
        cutX_cent = CutPlot("x", title="Central region - central column", x_label="y")
        cutX_cent.y_label = fetch_kw_or_default(hdu, "BUNIT", default="ADU")
        cutX_cent.add_data(
            cutout_plot,
            cutout_plot.get_data_coord(cutout_plot.data.shape[1] // 2, "y"),
            label="master",
            color="red",
        )
        panel.assign_plot(cutX_cent, 1, 2)
        # ----- 2D image cutout cut plot X -----
        cutY_cent = CutPlot("y", title="Central region - central row", x_label="x")
        cutY_cent.y_label = fetch_kw_or_default(hdu, "BUNIT", default="ADU")
        cutY_cent.add_data(
            cutout_plot,
            cutout_plot.get_data_coord(cutout_plot.data.shape[0] // 2, "x"),
            label="master",
            color="red",
        )
        panel.assign_plot(cutY_cent, 2, 2)

        # ----- Image histogram -----
        if hist_clipping is not None:
            hist_scaling = {
                "v_clip": hist_clipping,
                "v_clip_kwargs": hist_clipping_kwargs,
            }
        else:
            hist_scaling = {}

        histogram = HistogramPlot(
            master_data=data,
            raw_data=None,
            bins=None,
            title="Counts histogram",
            **hist_scaling,
        )
        panel.assign_plot(histogram, 3, 1)
        # ----- Collapse plot -----
        collapse_x = CollapsePlot(
            data,
            "x",
            title=f"{master_procatg} - collapse",
            color="black",
            x_label="pixels",
        )
        collapse_x.y_label = fetch_kw_or_default(hdu, "BUNIT", default="ADU")
        collapse_x.add_data(data, "y", color="red")

        panel.assign_plot(collapse_x, 3, 2)

        return panel

    def multipanel_plot(self, data, **kwargs):
        """Create a plot to be included in the multipanel report."""
        if data is not None:
            mplot = ImagePlot(
                title=kwargs.get("title", ""),
                cbar_kwargs=kwargs.get("cbar_kwargs", {"pad": 0.1}),
                interpolation="none",
            )
            data = np.nan_to_num(data)
            mplot.add_data(data)
        else:
            # if no image data, setup a placeholder plot
            # TODO: implement empty plot
            mplot = LinePlot(
                legend=False,
                y_label="y",
                x_label="x",
                title=kwargs.get("title", "") + " (NO DATA)",
            )
        return mplot

    def multi_panel(self, panel, hdu_info, plot_kwargs):
        """This function generates a panel containing identical multiple plots.

        Parameters
        ----------
        panel: (Panel) Panel that contains all the report plots.
        hdu_info: (list) List containing the HDUL position index, SOF keyword,
                HDU index or extension and x, y position within the plot
                to select from self.hdul
        plot_kwargs: (list) List of arguments to be passed to multipanel_plot for
                     each individual panel.
        """
        for panel_data, panel_kwargs in zip(hdu_info, plot_kwargs):
            idx, hdul_name, hdu_idx, xpos, ypos = panel_data
            plot = self.multipanel_plot(
                self.hdus[idx][hdul_name][hdu_idx].data, **panel_kwargs
            )
            panel.assign_plot(plot, xpos, ypos)
        return panel

    def generate_single_ext_panels(
        self, default_category="FLAT", default_extension="DATA"
    ):
        panels = {}
        for hdul_dict in self.hdus:
            data_idx = 0
            for key, val in hdul_dict.items():
                if default_category in key:
                    master_flat_hdul = val
                else:
                    continue
                for i in range(len(master_flat_hdul)):
                    if (
                        master_flat_hdul[i].data is None
                        or default_extension not in master_flat_hdul[i].name
                    ):
                        continue
                    channel = master_flat_hdul[i].name
                    panel = self.single_panel(master_flat_hdul[i])
                    setup = str(
                        master_flat_hdul[0].header.get("HIERARCH ESO INS MODE", "N/A")
                    )
                    panels[panel] = {
                        "report_name": "{}_single_{}_{}_{}".format(
                            self.name, channel, setup.lower(), data_idx
                        ),
                        "report_description": "Flat single panel",
                        "report_tags": [],
                    }
                    data_idx += 1
        return panels

    def generate_multi_ext_panels(
        self, ncols=6, nrows=5, default_category="FLAT", default_extension="DATA"
    ):
        panels = {}

        for idx, hdu_dict in enumerate(self.hdus):
            n_flats = len(hdu_dict)
            if n_flats == 1:
                # Multi-extension file
                n_flats = 0
                for hdul in hdu_dict.values():
                    for extension in hdul:
                        if default_extension in extension.name:
                            n_flats += 1
            nrows = int(n_flats / ncols) + (n_flats % ncols > 0) + 1
            h = [nrows] * nrows
            h[0] = 1
            p = Panel(ncols, nrows, height_ratios=h)
            metadata_plot = False
            # Text Plot
            px, py = 0, 0
            # which hdul and ext to use
            vspace = 0.3
            # Data plot index
            data_plot_idx = 0
            hdu_info = []
            hdu_plot_args = []
            input_files = []

            for key, val in hdu_dict.items():
                if default_category in key:
                    master_flat_hdul = val
                else:
                    continue

                if not metadata_plot:
                    t1 = TextPlot(columns=1, v_space=vspace)

                    col1 = (
                        str(master_flat_hdul[0].header.get("INSTRUME")),
                        "PRO CATG: "
                        + str(master_flat_hdul[0].header.get("HIERARCH ESO PRO CATG")),
                        "RAW1 NAME: "
                        + str(
                            master_flat_hdul[0].header.get(
                                "HIERARCH ESO PRO REC1 RAW1 NAME"
                            )
                        ),
                    )
                    t1.add_data(col1)
                    p.assign_plot(t1, 0, 0, xext=2)

                    t2 = TextPlot(columns=1, v_space=vspace, xext=1)
                    col2 = "INS.MODE: " + str(
                        master_flat_hdul[0].header.get("HIERARCH ESO INS MODE", "N/A")
                    )
                    t2.add_data(col2)
                    p.assign_plot(t2, 2, 0, xext=1)

                    setup = str(
                        master_flat_hdul[0].header.get("HIERARCH ESO INS MODE", "N/A")
                    )

                # Loop over all extensions
                for i in range(len(master_flat_hdul)):
                    if (
                        master_flat_hdul[i].data is None
                        or default_extension not in master_flat_hdul[i].name
                    ):
                        continue
                    channel = master_flat_hdul[i].name
                    py, px = np.unravel_index(data_plot_idx, (nrows, ncols))
                    # include the extra raw containing metadata
                    py += 1
                    data_plot_idx += 1

                    hdu_info.append((idx, key, channel, px, py))
                    # Plot metadata
                    hdu_plot_args.append(dict(title=channel))

                input_files.append(master_flat_hdul.filename())

            p = self.multi_panel(panel=p, hdu_info=hdu_info, plot_kwargs=hdu_plot_args)

            addme = {
                "report_name": f"{self.name}_multi_{setup.lower()}",
                "report_description": f"{self.name} multi panel",
                "report_tags": [],
                "input_files": input_files,
            }
            panels[p] = addme
        return panels

    def generate_panels(self, **kwargs):
        panels = {
            **self.generate_single_ext_panels(),
            **self.generate_multi_ext_panels(),
        }
        return panels
