# SPDX-License-Identifier: BSD-3-Clause
from adari_core.plots.panel import Panel
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.points import LinePlot
import numpy as np
import logging

logger = logging.getLogger(__name__)


class TemplatePanelMixin(object):
    def __init__(self, *args, **kwargs):
        self.panel_kwargs = kwargs.get(kwargs)

    @property
    def panel_kwargs(self):
        """Panel keyword arguments."""
        return self._panel_kwargs

    @panel_kwargs.setter
    def panel_kwargs(self, panel_dict):
        if type(panel_dict) is not dict:
            raise TypeError("Input panel kwargs must be a dictionary")
        self._panel_kwargs = panel_dict

    def build_panel(self):
        panel = Panel(**self.panel_kwargs)
        return panel


class ImageGridPanelMixin(TemplatePanelMixin):
    """TODO"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def fill_panel(
        self,
        hdul_list,
        extensions,
        textrow=True,
        zoom_in=False,
        zoom_in_extent=100,
        img_kw_list=None,
        img_ext=(1, 2),
    ):
        """TODO"""
        # Get the number of rows and columns containing data in the panel
        nrow, ncol = self.panel_kwargs["y"], self.panel_kwargs["x"]

        # Add extra plots to include text and/or zoom-in plots
        if textrow:
            self.panel_kwargs["y"] += 1
            # TODO
            # self.panel_kwargs['heigh']
        if zoom_in:
            self.panel_kwargs["x"] *= 2

        # Account for image extensions
        self.panel_kwargs["x"] *= img_ext[0]
        self.panel_kwargs["y"] *= img_ext[1]

        # Build the empty panel
        panel = self.build_panel()
        # Get the number of images to plot
        n_images = len(hdul_list)
        # Check image metadata
        if img_kw_list is None:
            img_kw_list = [{}] * n_images

        if n_images != len(extensions) or n_images != len(img_kw_list):
            raise ValueError(
                "The length of the input HDU list must be the same as the extensions and image keywords lists"
            )
        # Fill the panel
        for idx in range(n_images):
            py, px = np.unravel_index(idx, (nrow, ncol))
            px = px * (1 + zoom_in) * img_ext[0]
            py = py * img_ext[1] + 1 * textrow
            logger.debug(px, py)
            hdu = hdul_list[idx][extensions[idx]]
            # THIS SHOULD BE ENCAPSULATED IN A FUNCTION...
            if hdu.data is not None:
                mplot = ImagePlot(**img_kw_list[idx])
                mplot.add_data(hdu.data)
                if zoom_in:
                    if "title" not in img_kw_list[idx].keys():
                        img_kw_list[idx][
                            "title"
                        ] = f"Central ({zoom_in_extent}x{zoom_in_extent})"
                    else:
                        img_kw_list[idx][
                            "title"
                        ] += f"Central ({zoom_in_extent}x{zoom_in_extent})"
                    zoom_img_plot = CentralImagePlot(
                        mplot, **img_kw_list[idx], extent=zoom_in_extent
                    )
            else:
                mplot = LinePlot(
                    legend=False,
                    y_label="y",
                    x_label="x",
                    title="(NO DATA)",
                )
                if zoom_in:
                    zoom_img_plot = LinePlot(
                        legend=False,
                        y_label="y",
                        x_label="x",
                        title=f"(NO DATA) Central ({zoom_in_extent}x{zoom_in_extent})",
                    )
            # .......
            panel.assign_plot(mplot, px, py, xext=img_ext[0], yext=img_ext[1])
            if zoom_in:
                px = px + img_ext[0]
                panel.assign_plot(
                    zoom_img_plot, px, py, xext=img_ext[0], yext=img_ext[1]
                )

        return panel
