# SPDX-License-Identifier: BSD-3-Clause
from adari_core.plots.cut import CutPlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.points import ScatterPlot
from adari_core.report import AdariReportBase

import numpy as np

class MasterWaveReport(AdariReportBase):
    def __init__(self, name: str):
        super().__init__(name)
        self.central_region_size = 200
        self.central_cut = 2

    def parse_sof(self):
        raise NotImplementedError(
            "MasterWaveReport is a template only, "
            "the child Report is responsible for "
            "defining parse_sof"
        )

    def generate_panels(self, lplot_msize=0.05, ext=0, **kwargs):
        panels = {}
        # Backwards compatibility
        return panels
        
    def generate_raw_plots(self, rawdata, title_cut="Cross-dispersion raw", label="raw", color="black", **kwargs):
        """
        Creates plots from raw data (full and zoom): 
        - raw image 
        - raw cuts X 
        - raw cuts Y
        - raw hist

        Parameters:
        -----------
        rawdata : 2D iterable of numeric values
            The raw array
        title_cut : str, default="Cross-dispersion raw"
            The title to use on cut plots
        label : str, default="raw"
            The label used for the data
        color: str, default="black"
            The line colour

        Returns
        -------
        [img_raw, img_cent, cut_raw_y, cut_raw_y_cent, cut_raw_x, cut_raw_x_cent, hist_raw] : ImagePlot, CutPlot, HistogramPlot
            Plots that can be added to a panel
        """

        img_raw = ImagePlot(title="Raw input", **kwargs)
        img_raw.add_data(rawdata)
        
        img_cent = CentralImagePlot(title="Raw input - central region")
        img_cent.add_data(img_raw, extent=self.central_region_size)
        
        cutpos = img_raw.get_data_coord(img_raw.data.shape[0] // self.central_cut, "y")
        cut_raw_y = CutPlot(
            "y",
            title=title_cut+"cut @ Y={}".format(cutpos),
            x_label="x",
            y_label="ADU",
            legend=False,
        )
        cut_raw_y.add_data(img_raw, cutpos, label=label, color=color)

        cutpos = img_cent.get_data_coord(
            np.floor(img_cent.data.shape[0] // self.central_cut), "y"
        )
        cut_raw_y_cent = CutPlot(
            "y",
            title=title_cut+"cut centre @ Y={}".format(cutpos),
            x_label="x",
            y_label="ADU",
            legend=False,
        )
        cut_raw_y_cent.add_data(img_cent, cutpos, label=label, color=color)

        cutpos = img_raw.get_data_coord(img_raw.data.shape[1] // self.central_cut, "x")
        cut_raw_x = CutPlot(
            "x",
            title=title_cut+"cut @ X={}".format(cutpos),
            x_label="y",
            y_label="ADU",
            legend=False,
        )
        cut_raw_x.add_data(img_raw, cutpos, label=label, color=color)

        cutpos = img_cent.get_data_coord(
            np.floor(img_cent.data.shape[1] // self.central_cut), "x"
        )
        cut_raw_x_cent = CutPlot(
            "x",
            title=title_cut+"cut centre @ X={}".format(cutpos),
            x_label="y",
            y_label="ADU",
            legend=False,
        )
        cut_raw_x_cent.add_data(img_cent, cutpos, label=label, color=color)

        hist_raw = HistogramPlot(
            title="Raw file", bins=50, v_min=-5000, v_max=70000, color=color, legend=False
        )
        hist_raw.add_raw_data(img_raw.data)
        
        return [img_raw, img_cent, cut_raw_y, cut_raw_y_cent, cut_raw_x, cut_raw_x_cent, hist_raw]

    def generate_line_order_plot(self, order, disp, label, color="black", lplot_msize=0.05, legend=False, axis="X"):
        """
        Creates a plot with line positions.
        
        Parameters:
        -----------
        order : 1D iterable of numeric values or a list of 1D iterables
            The array(s) with line orders
        disp : 1D iterable of numeric values or a list of 1D iterables
            The array(s) with dispersion
        label : str or a list of strs
            The label(s) used for the data 
        color: str or a list of strs, default="black"
            The point colour(s)
        lplot_msize : float, default=0.05
            The point size
        legend : bool, default=False
            If true, the legend is plotted
        axis : str, default="X"
            The axis to plot dispersions
        
        Returns
        -------
        line_pos_plot : ScatterPlot
            Plot that can be added to a panel
        """
        line_pos_plot = ScatterPlot(markersize=lplot_msize, legend=legend)
        if not isinstance(order, list):
            if axis == "X":
                line_pos_plot.add_data((disp, order), label=label, color=color,)
            else:
                line_pos_plot.add_data((order, disp), label=label, color=color,)
        else:
            #if not isinstance(color, list):
            #    color = len(order) * color
            for iorder, idisp, ilabel, icolor in zip(order, disp, label, color):
                if axis == "X":
                    line_pos_plot.add_data((idisp, iorder), label=ilabel, color=icolor)
                else:
                    line_pos_plot.add_data((iorder, idisp), label=ilabel, color=icolor)
        if axis == "X":
            line_pos_plot.x_label = "Cross-dispersion [pix]"
            line_pos_plot.y_label = "Order number"
        else:
            line_pos_plot.x_label = "Order number"
            line_pos_plot.y_label = "Cross-dispersion [pix]"
        return line_pos_plot

    def generate_resolution_plot(self, power, disp, label, color="black", rplot_msize=3., legend=False):
        """
        Creates a plot with resolution.
        
        Parameters:
        -----------
        power : 1D iterable of numeric values or a list of 1D iterables
            The array(s) with resolutions 
        disp : 1D iterable of numeric values or a list of 1D iterables
            The array(s) with dispersion
        label : str or a list of strs
            The label(s) used for the data
        color: str or a list of strs, default="black"
            The point colour(s)
        rplot_msize : float, default=3.
            The point size
        legend : bool, default=False
            If true, the legend is plotted
        
        Returns
        -------
        res_plot : ScatterPlot
            Resolution plot that can be added to a panel
        """
        res_plot = ScatterPlot(markersize=rplot_msize, legend=legend)
        if not isinstance(power, list):
            res_plot.add_data((disp, power), label=label, color=color)
        else:
            for iorder, idisp, ilabel, icolor in zip(power, disp, label, color):
                res_plot.add_data((idisp, iorder), label=ilabel, color=icolor)
        res_plot.x_label = "Cross-dispersion [pix]"
        res_plot.y_label = "Resolving power"
        return res_plot
