from adari_core.plots.combined import CombinedPlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.images import ImagePlot
from adari_core.plots.panel import Panel
from adari_core.plots.points import ScatterPlot
from adari_core.report import AdariReportBase
import numpy as np


class MasterImgScienceReport(AdariReportBase):
    """Master report class for Science Imaging reports."""

    def __init__(self, name: str):
        super().__init__(name)

    def parse_sof(self):
        raise NotImplementedError(
            "MasterImgScienceReport is a template only, "
            "the child Report is responsible for "
            "defining parse_sof"
        )

    def generate_single_panel(
        self, img_kwargs={}, conf_kwargs={}, combined=False, title="Image", label="Flux"
    ):
        """
        Create a report for a given HDU.
        It requires image product and confidence map provided in self.image and self.conf, respectively.

        Description
        -----------
        This method constructs a report from an input HDU consisting of:
        - ImagePlot of the reduced field.
        - CutPlot along x direction of the 2D image.
        - CutPlot(s) along y direction of the 2D image.
        - ImagePlot of the confidence map.
        - Histograms of the 2D reduced image.

        Parameters
        ----------
        img_kwargs : dict, optional, default={}
            The image clipping mode to use for the image.
        conf_kwargs : dict, optional, default={}
            The confidence map clipping mode to use for the image.
        combined : bool, optional, default=False
            If true, the products come from the combination of the reduced images.
        title : str, optional, default="Image"
            The title for main image plot.
        label : str, optional, default="Flux"
            The flux label.

        Returns
        -------
        - panel: Panel
            A panel containing all the images, and histogram plots.

        """

        p = Panel(
            6, 6, height_ratios=[1, 3, 3, 3, 3, 1], y_stretch=0.67, right_subplot=1.5
        )

        image = self.image
        conf = self.conf

        # 1. Main image plot
        full_image = ImagePlot(title=title, **img_kwargs)
        full_image.add_data(np.nan_to_num(image))
        p.assign_plot(full_image, 0, 1, xext=4, yext=4)

        # 2. Cut plots
        if combined:
            cut_pos_y = full_image.get_data_coord(full_image.data.shape[0] // 4, "y")
            cut_pos_x = full_image.get_data_coord(full_image.data.shape[1] // 4, "x")
        else:
            cut_pos_y = full_image.get_data_coord(full_image.data.shape[0] // 2, "y")
            cut_pos_x = full_image.get_data_coord(full_image.data.shape[1] // 2, "x")
        full_cut_y = CutPlot(
            "y",
            title="Cut @ y=" + str(cut_pos_y),
            x_label="x",
            y_label=label,
            legend=False,
        )
        full_cut_y.add_data(
            full_image,
            cut_pos=cut_pos_y,
        )
        p.assign_plot(full_cut_y, 4, 1)
        full_cut_x = CutPlot(
            "x",
            title="Cut @ x=" + str(cut_pos_x),
            x_label="y",
            y_label=label,
            legend=False,
        )
        full_cut_x.add_data(
            full_image,
            cut_pos=cut_pos_x,
        )
        p.assign_plot(full_cut_x, 4, 2)

        if combined:
            cut_pos_x2 = 3 * full_image.get_data_coord(
                full_image.data.shape[1] // 4, "x"
            )
            full_cut_x2 = CutPlot(
                "x",
                title="Cut @ x=" + str(cut_pos_x2),
                x_label="y",
                y_label=label,
                legend=False,
            )
            full_cut_x2.add_data(
                full_image,
                cut_pos=cut_pos_x2,
            )
            p.assign_plot(full_cut_x2, 5, 2)

        # 3. Confidence map
        conf_image = ImagePlot(title="Confidence", **conf_kwargs)
        conf_image.add_data(np.nan_to_num(conf))
        p.assign_plot(conf_image, 5, 1)

        # 4. Product histograms
        comb_hist = CombinedPlot(
            title="Product histograms",
        )
        # mode: the bin with the highest number of counts
        hist, bin_edges = np.histogram(image, bins=50)
        mode = 0.5 * (bin_edges[np.argmax(hist)] + bin_edges[np.argmax(hist) + 1])

        zoom1 = bin_edges[min(50, np.argmax(hist) + 5)]
        zoom0 = bin_edges[max(0, np.argmax(hist) - 5)]
        data_hist = HistogramPlot(
            title="Product histograms",
            v_min=zoom0,
            v_max=zoom1,
            bins=50,
            x_label=label,
        )
        data_hist.add_data(image, label="data counts")

        # recalculate mode for zoom histogram
        hist, bin_edges = np.histogram(image, bins=50, range=(zoom0, zoom1))
        mode = 0.5 * (bin_edges[np.argmax(hist)] + bin_edges[np.argmax(hist) + 1])

        stat_line0 = ScatterPlot(title="Line zero")
        stat_line0.add_data(
            label="_zero",
            d=[[0], [np.nan]],
            vline=True,
            color="black",
        )
        stat_line = ScatterPlot(title="Line mode")
        stat_line.add_data(
            label="_mode",
            d=[[mode], [np.nan]],
            vline=True,
            color="black",
            linestyle="dotted",
        )
        data_hist.legend = False
        stat_line.legend = False
        stat_line0.legend = False
        comb_hist.add_data(data_hist)
        comb_hist.add_data(stat_line0)
        comb_hist.add_data(stat_line)
        comb_hist.y_scale = "log"
        comb_hist.legend = False

        scaling = {"v_clip": "minmax"}
        data_fullhist = HistogramPlot(
            title="",
            x_label=label,
            bins=50,
            legend=False,
            **scaling,
        )
        data_fullhist.add_data(image, label="data counts")
        data_fullhist.y_scale = "log"
        p.assign_plot(comb_hist, 4, 3, xext=2)
        p.assign_plot(data_fullhist, 4, 4, xext=2)

        return p
