# SPDX-License-Identifier: BSD-3-Clause
from adari_core.plots.panel import Panel
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.points import LinePlot

from adari_core.report import AdariReportBase

import numpy as np


class MasterSkyFlatReport(AdariReportBase):
    panel_kwargs = {"x": 1, "y": 1, "height_ratios": [1, 5]}

    def __init__(self, name: str):
        super().__init__(name)

    def parse_sof(self):
        raise NotImplementedError(
            "MasterSkyFlatReport is a template only, "
            "the child Report is responsible for "
            "defining parse_sof"
        )

    def image_plot(
        self, hdu, img_kwargs={}, zoom_img_kwargs={}, zoom_in=False, zoom_in_extent=100
    ):
        """Create a plot from a given HDU image.

        Description
        -----------
        Given an image HDU, this function will create an ImagePlot and a zoom-in plot
        of the central region (if applicable, default is None).
        """
        zoom_img_plot = None
        if hdu.data is not None:
            img_plot = ImagePlot(**img_kwargs)
            img_plot.add_data(hdu.data)
            if zoom_in:
                if "zoom_title" in img_kwargs.keys():
                    img_kwargs["title"] = img_kwargs["zoom_title"]
                else:
                    if "title" not in img_kwargs.keys():
                        img_kwargs["title"] = (
                            f"Central ({zoom_in_extent}x{zoom_in_extent})"
                        )
                    else:
                        img_kwargs[
                            "title"
                        ] += f"Central ({zoom_in_extent}x{zoom_in_extent})"
                for key, value in zoom_img_kwargs.items():
                    img_kwargs[key] = value
                zoom_img_plot = CentralImagePlot(
                    img_plot, **img_kwargs, extent=zoom_in_extent
                )
        else:
            img_plot = LinePlot(
                legend=False,
                y_label="y",
                x_label="x",
                title="(NO DATA)",
            )
            if zoom_in:
                zoom_img_plot = LinePlot(
                    legend=False,
                    y_label="y",
                    x_label="x",
                    title=f"(NO DATA) Central ({zoom_in_extent}x{zoom_in_extent})",
                )
        return img_plot, zoom_img_plot

    def img_grid_panel(
        self,
        hdu_list,
        textrow=True,
        zoom_in=False,
        zoom_in_extent=100,
        img_kw_list=None,
        img_ext=(1, 2),
    ):
        """Create a panel of images from a given list of HDU.

        Description
        -----------
        Given a list of HDU, this function will create a panel that
        contains a mosaic of all images (including zoom-in plots if applicable).

        Parameters
        ----------
        - hdu_list: list
            List of HDU images that will be used to compute the plots. If the HDU
            does not contain data, an empty placeholder LinePlot will be created
            instead. The order of HDU in the list will be mapped to a rectangular
            grid row by row.
        - textrow: bool, default=True
            If true, the panel will include an extra row at the top of the panel to
            include metadata information.
        - zoom_in: bool, defaul=False
            If true, an extra panel for every HDU image will be created on the right
            of every image in the final panel. This will be considered when accounting
            for the number of row and columns required.
        - zoom_in_extent: int, default=100
            Size of the zoom-in cutout image in pixels.
        - img_kw_list: dict, default=None
            A list of dicitonaries that will be used as parameters for the image (and
            zoom image) plots.
        - img_ext: 2D tuple, default=(1,2)
            Plot image extension in the Panel grid. This might be useful to enlarge the
            size of the images. The number of rows and columns is increased accordingly
            by the function.

        Returns
        -------
        - panel: Panel
            A panel containing all the HDU images plots and zoom-in plots.

        Notes
        -----
        This function will make use of the report attribute panel_kwargs, which
        corresponds to a dictionary containing the basic information to build the
        panel.

        """
        # Get the number of rows and columns containing data in the panel
        nrow, ncol = self.panel_kwargs["y"], self.panel_kwargs["x"]

        # Add extra plots to include text and/or zoom-in plots
        if textrow:
            self.panel_kwargs["y"] += 1
            # TODO: It would be good if this is corrected automatically
            # self.panel_kwargs['heigh'] = ...
        if zoom_in:
            self.panel_kwargs["x"] *= 2

        # Account for image extensions
        self.panel_kwargs["x"] *= img_ext[0]
        self.panel_kwargs["y"] *= img_ext[1]

        # Build the empty panel
        panel = Panel(**self.panel_kwargs)
        # Get the number of images to plot
        n_images = len(hdu_list)
        # Check image metadata
        if img_kw_list is None:
            img_kw_list = [{}] * n_images

        if n_images != len(img_kw_list):
            raise ValueError(
                "The length of the input HDU list must be the same as the extensions "
                "and image keywords lists"
            )
        # Fill the panel
        for idx in range(n_images):
            py, px = np.unravel_index(idx, (nrow, ncol))
            px = px * (1 + zoom_in) * img_ext[0]
            py = py * img_ext[1] + 1 * int(textrow)
            hdu = hdu_list[idx]
            img, zoom_img = self.image_plot(
                hdu,
                img_kwargs=img_kw_list[idx],
                zoom_in=zoom_in,
                zoom_in_extent=zoom_in_extent,
            )
            panel.assign_plot(img, px, py, xext=img_ext[0], yext=img_ext[1])
            if zoom_in:
                panel.assign_plot(
                    zoom_img, px + img_ext[0], py, xext=img_ext[0], yext=img_ext[1]
                )
        return panel

    def generate_panels(self, **kwargs):
        """Dummy version of the panel report.

        Notes
        -----
        At the instrument specific level, the user must specify which are the
        HDUL to be used from the HDUs dictionary (typically the SOF tag) and
        also the extensions within the HDUL that are required.

        By default the number of plots in the panel is set to 1 + a zoom-in
        plot.
        """
        panels = {}
        for hdul_dict in self.hdus:
            # Prepare the list of data
            list_of_hdu = []
            for key in self.hdus.keys():
                hdul = hdul_dict[key]
                for ext in hdul:
                    list_of_hdu.append(hdul[ext])
            # Create the panel
            panel = self.img_grid_panel(list_of_hdu)
            panels[panel] = {}

        return panels
