# SPDX-License-Identifier: BSD-3-Clause
from astropy.nddata import block_reduce

from adari_core.plots.points import LinePlot
from adari_core.report import AdariReportBase
from adari_core.utils.clipping import clipping_percentile

import numpy as np


class MasterSpecScienceReport(AdariReportBase):
    """Master report class for display of single science spectrum

    Description
    -----------
    Generate 1 panel with a single science spectrum
    """

    def __init__(self, name: str):
        super().__init__(name)
        self.spectral_lines = {
            "Hα": 656.280,
            "Hβ": 486.132,
            "Hγ": 434.046,
            "Hδ": 410.173,
            "Hε": 397.007,
            "H10": 379.790,
            "H17": 369.715,
            "Na D": 589.29,
            "P7": 1004.94,
            "P8": 954.598,
            "P9": 922.902,
            "P10": 901.491,
            "P17": 846.72,
        }
        self.spec_lw = 0.4
        self.binned_lw = 1.2
        self.zoom_range = None
        self.zoom_centre = [0.1, 0.4, 0.9]
        self.zoom_show = False

    def parse_sof(self):
        raise NotImplementedError(
            "MasterSpecScienceReport is a template only, "
            "the child Report is responsible for "
            "defining parse_sof"
        )

    def calculate_zoom_centers(self, wave):
        """
        Calculates where the centres of zoom plots are.

        Parameters:
        -----------
        wave : 1D iterable of numeric values
            The wavelength array (X axis)

        Returns
        -------
        plot_centers : list
            A list of zoom centres
        """
        wspan = wave[-1] - wave[0]
        plot_centers = [wspan * i for i in self.zoom_centre]
        plot_centers += wave[0]
        return plot_centers

    def generate_snr_plot(self, wave, snr, ylim=False):
        """
        Creates a plot with S/N.

        Parameters:
        -----------
        wave : 1D iterable of numeric values
            The wavelength array (X axis)
        snr : 1D iterable of numeric values
            The S/N

        Returns
        -------
        snrplot : LinePlot
            A plot that can be added to a panel
        """
        snrplot = LinePlot(
            title="",
            legend=False,
        )
        snrplot.add_data(
            d=[wave, snr], color="cyan", label="ignore", linewidth=self.spec_lw
        )
        snrplot.x_minor_ticks = 0
        snrplot.x_label = "Wavelength (nm)"
        snrplot.y_label = "S/N"
        if ylim:
            snrplot.set_ylim(0, clipping_percentile(snr, 98)[1] * 1.2)
        snrplot.set_xlim(wave[0], wave[-1])
        return snrplot

    def generate_spec_plot(
        self,
        wave,
        flux,
        flux_unit,
        binning=-1,
        low_wave=None,
        low_flux=None,
        high_wave=None,
        high_flux=None,
    ):
        """
        Creates a plot with the main spectrum.

        It will also overplot a binned version of the spectrum and will
        display vertical lines on selected spectral lines.

        Optionally, an additional plot can be created with the lowest and highest
        flux spectrum (usually coming from a combined stack). This plot will be
        created only if low_wave is not None.

        Parameters:
        -----------
        wave : 1D iterable of numeric values
            The wavelength array (X axis)
        flux : 1D iterable of numeric values
            The flux (Y axis)
        flux_unit: str
            The units to be display for flux
        binning : int
            The binning factor to be used to bin the spectrum
        low_wave : 1D iterable of numeric values
            The wave for the lowest flux spectrum
        low_flux : 1D iterable of numeric values
            The flux for the lowest flux spectrum
        high_wave : 1D iterable of numeric values
            The wave for the highest flux spectrum
        high_flux : 1D iterable of numeric values
            The flux for the highest flux spectrum

        Returns
        -------
        [specplot, lowhighplot] : LinePlot
            A specplot that can be added to a panel.
            If low_wave is not None also lowhighplot can be added to the
            panel, otherwise it is None.
        """
        if binning > 0:
            binning_pix = int(len(wave) / ((wave[-1] - wave[0]) / binning))
            wave_binned = block_reduce(wave, binning_pix, func=np.mean)
            flux_binned = block_reduce(flux, binning_pix, func=np.mean)
        specplot = LinePlot(
            title="",
        )
        specplot.add_data(
            d=[wave, flux],
            color="royalblue",
            label="Flux",
            linewidth=self.spec_lw,
            zorder=100,
        )
        if binning > 0:
            specplot.add_data(
                d=[wave_binned, flux_binned],
                color="darkorange",
                label="Averaged flux",
                linewidth=self.binned_lw,
                zorder=200,
            )
        specplot.x_minor_ticks = 0
        specplot.x_label = "Wavelength (nm)"
        specplot.y_label = f"{flux_unit}"
        specplot.set_xlim(wave[0], wave[-1])
        specplot.set_ylim(0, clipping_percentile(flux, 98)[1] * 1.2)

        # Only if zoom plots are plotted
        if self.zoom_range is not None and self.zoom_show:
            plot_centers = self.calculate_zoom_centers(wave)
            for wave_plot in plot_centers:
                specplot.add_data(
                    d=[
                        [
                            wave_plot - self.zoom_range / 2.0,
                            wave_plot + self.zoom_range / 2.0,
                        ],
                        [
                            clipping_percentile(flux, 98)[1] * 1.1,
                            clipping_percentile(flux, 98)[1] * 1.1,
                        ],
                    ],
                    color="green",
                    label="_zoom-" + str(wave_plot),
                )

        # Low and high flux plot
        lowhighplot = None
        if low_wave is not None:
            if low_flux is None or high_wave is None or high_flux is None:
                raise ValueError("Some low/high spectra arrays is None")

            lowhighplot = LinePlot(title="")
            if binning > 0:
                low_wave_lh = block_reduce(low_wave, binning_pix, func=np.mean)
                low_flux_lh = block_reduce(low_flux, binning_pix, func=np.mean)
                high_wave_lh = block_reduce(high_wave, binning_pix, func=np.mean)
                high_flux_lh = block_reduce(high_flux, binning_pix, func=np.mean)
                wave_lh = wave_binned
                flux_lh = flux_binned
            else:
                low_wave_lh = low_wave
                low_flux_lh = low_flux
                high_wave_lh = high_wave
                high_flux_lh = high_flux
                wave_lh = wave
                flux_lh = flux

            lowhighplot.add_data(
                d=[low_wave_lh, low_flux_lh],
                color="darkblue",
                label="Lowest flux",
                linewidth=self.binned_lw / 1.2,
            )
            lowhighplot.add_data(
                d=[high_wave_lh, high_flux_lh],
                color="darkgreen",
                label="Highest flux",
                linewidth=self.binned_lw / 1.2,
            )
            lowhighplot.add_data(
                d=[wave_lh, flux_lh],
                color="darkorange",
                label="Averaged flux",
                linewidth=self.binned_lw,
            )
            lowhighplot.x_minor_ticks = 0
            lowhighplot.set_xlim(wave_lh[0], wave_lh[-1])
            lowhighplot.set_ylim(0, clipping_percentile(high_flux_lh, 98)[1] * 1.2)
            lowhighplot.x_label = "Wavelength (nm)"
            lowhighplot.y_label = f"{flux_unit}"

        return [specplot, lowhighplot]

    def generate_qual_plot(self, wave, qual):
        """
        Creates a plot with quality.

        Parameters:
        -----------
        wave : 1D iterable of numeric values
            The wavelength array (X axis)
        qual : 1D iterable of numeric values
            The quality

        Returns
        -------
        qualplot : LinePlot
            A plot that can be added to a panel
        """

        qualplot = LinePlot(
            title="",
            legend=False,
        )
        qualplot.add_data(d=[wave, qual], color="lightgreen", label="Q")
        qualplot.x_label = "Wavelength (nm)"
        qualplot.y_label = "Quality"
        qualplot.x_minor_ticks = 0
        qualplot.set_xlim(wave[0], wave[-1])
        qualplot.set_ylim(0, clipping_percentile(qual, 98)[1] * 1.2)
        return qualplot

    def generate_spec_zooms(self, wave, flux, flux_unit):
        if self.zoom_range is None:
            raise ValueError("Value of zoom_range is not set")
        plots = []
        plot_centers = self.calculate_zoom_centers(wave)
        for wave_plot in plot_centers:
            linezoomplot = LinePlot(
                legend=False,
            )
            linezoomplot.add_data(
                d=[wave, flux],
                color="blue",
                label="flux",
                linewidth=self.spec_lw,
            )
            linezoomplot.x_label = "Wavelength (nm)"
            linezoomplot.y_label = f"{flux_unit}"
            linezoomplot.set_xlim(
                wave_plot - self.zoom_range / 2.0, wave_plot + self.zoom_range / 2.0
            )
            flux_selected = flux[abs(wave - wave_plot) < 1]
            linezoomplot.set_ylim(0, 1.2 * clipping_percentile(flux_selected, 99)[1])
            linezoomplot.x_major_ticks = 0.5

            plots.append(linezoomplot)

        return plots

    def generate_panels(self, **kwargs):
        panels = {}
        return panels
