# SPDX-License-Identifier: BSD-3-Clause
from adari_core.plots.points import LinePlot
from adari_core.plots.images import ImagePlot

from adari_core.report import AdariReportBase

import numpy as np


def check_empty(plot_function):
    """Check if the input data of a plotting function.

    Description
    -----------
    If the input "data" is None, an empty LinePlot will be returned instead.
    """

    def wrapper(*args, **kwargs):
        if kwargs.get("data", None) is not None:
            return plot_function(*args, **kwargs)
        else:
            # Create an empty plot
            empty_plot = LinePlot()
            # Draw two crossed diagonal lines
            empty_plot.add_data(
                [np.array([0, 1]), np.array([0, 1])], label="1", color="k"
            )
            empty_plot.add_data(
                [np.array([0, 1]), np.array([1, 0])], label="2", color="k"
            )
            empty_plot.set_xlim(0, 1)
            empty_plot.set_ylim(0, 1)
            empty_plot.legend = False
            empty_plot.title = "{} (N/A)".format(kwargs.get("title", ""))
            return empty_plot

    return wrapper


class MasterSpecphotStdReport(AdariReportBase):
    """Generic report for Spectrophotometric Standard Stars reports."""

    files_needed = {}
    wavelength = None
    wavelength_unit = None

    def __init__(self, name: str):
        super().__init__(name)

    def parse_sof(self):
        raise NotImplementedError(
            "MasterSpecphotStdReport is a template only, "
            "the child Report is responsible for "
            "defining parse_sof"
        )

    @check_empty
    def plot_vs_wavelength(self, data=None, label="", title=""):
        """Plot a given quantity versus wavelength."""
        lineplot = LinePlot()
        lineplot.add_data(d=[self.wavelength, data], label="_n", color="black")
        lineplot.x_label = f"Wavelength ({self.wavelength_unit})"
        lineplot.y_label = label
        lineplot.legend = False
        lineplot.title = title
        return lineplot

    @check_empty
    def plot_star_image(self, data=None, title="STD STAR"):
        """Plots the image of a standard star."""
        implot = ImagePlot(title=title, interpolation="none")
        implot.add_data(data)
        return implot

    def get_wavelength(self, hdu, axis=1):
        """Compute the wavelength vector from a HDU header."""
        wl_c = hdu.header[f"CRVAL{axis}"]
        pix_c = hdu.header[f"CRPIX{axis}"]
        n_pix = hdu.header[f"NAXIS{axis}"]
        wl_del = hdu.header[f"CDELT{axis}"]
        pixels = np.arange(1, n_pix + 1)
        self.wavelength = wl_c + (pixels - pix_c) * wl_del
        self.wavelength_unit = hdu.header[f"CUNIT{axis}"]

    def generate_panels(self, **kwargs):
        # TODO: Implement a generic function compatible with e.g., KMOS, MUSE
        panels = {}
        return panels
