from adari_core.data_libs.master_std_star_ifu import MasterSpecphotStdReport
from adari_core.plots.panel import Panel
from adari_core.plots.images import ImagePlot
from adari_core.plots.points import LinePlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.text import TextPlot
from .giraffe_utils import GiraffeSetupInfo

import os
import numpy as np

from .giraffe_utils import GiraffeReportMixin


class GiraffeStdStarReport(GiraffeReportMixin, MasterSpecphotStdReport):
    def __init__(self):
        super().__init__("giraffe_std_star")

    def parse_sof(self):
        rbnspec = None
        rcspec = None
        eff = None
        raw_std = None

        for filename, catg in self.inputs:
            if catg == "STD_RBNSPECTRA" and rbnspec is None:
                rbnspec = filename
            elif catg == "STD_RCSPECTRA" and rcspec is None:
                rcspec = filename
            elif catg == "EFFICIENCY_CURVE" and eff is None:
                eff = filename
            elif catg == "STD" and eff is None:
                raw_std = filename

        # Build and return the file name list
        file_lists = []
        if (
            rbnspec is not None
            and rcspec is not None
            and eff is not None
            and raw_std is not None
        ):
            file_lists.append(
                {
                    "std_rbnspectra": rbnspec,
                    "std_rcspectra": rcspec,
                    "eff_curve": eff,
                    "raw": raw_std,
                }
            )
        return file_lists

    def get_wavelength(self, hdu, axis=2):
        """Compute the wavelength vector from a HDU header."""
        wl_c = hdu.header[f"CRVAL{axis}"]
        pix_c = hdu.header[f"CRPIX{axis}"]
        n_pix = hdu.header[f"NAXIS{axis}"]
        wl_del = hdu.header[f"CDELT{axis}"]
        pixels = np.arange(1, n_pix + 1)
        self.wavelength = wl_c + (pixels - pix_c) * wl_del
        self.wavelength_unit = hdu.header[f"CUNIT{axis}"]

    def generate_panels(self, **kwargs):
        raw = self.hdus[0]["raw"]
        rbnspec = self.hdus[0]["std_rbnspectra"]
        rcspec = self.hdus[0]["std_rcspectra"]
        eff = self.hdus[0]["eff_curve"]
        wave = eff["EFFICIENCY_CURVE"].data["WLEN"]
        eff_data = eff["EFFICIENCY_CURVE"].data["EFFICIENCY"]
        self.get_wavelength(rbnspec[0])
        ext = "PRIMARY"

        panels = {}
        p = Panel(3, 4, height_ratios=[1, 4, 4, 4])

        hdr = rbnspec[ext].header
        procatg = hdr["HIERARCH ESO PRO CATG"]

        # Text Plot
        px = 0
        py = 0
        vspace = 0.5
        t1 = TextPlot(columns=1, v_space=vspace)
        fname = os.path.basename(str(rbnspec.filename()))
        instru = rbnspec[ext].header.get("INSTRUME")
        col1 = (
            str(hdr.get("INSTRUME")),
            "EXTNAME: " + str(hdr.get("EXTNAME", "N/A")),
            "PRO CATG: " + str(hdr.get("HIERARCH ESO PRO CATG")),
            "FILE NAME: " + fname,
            "RAW1 NAME: " + str(hdr.get("HIERARCH ESO PRO REC1 RAW1 NAME")),
        )
        t1.add_data(col1)
        p.assign_plot(t1, px, py, xext=2)

        px = px + 2
        t2 = TextPlot(columns=1, v_space=vspace, xext=1)
        self.metadata = GiraffeSetupInfo.response(rbnspec)
        col2 = self.metadata
        t2.add_data(col2)
        p.assign_plot(t2, px, py, xext=1)
        rc_hdr = rcspec[0].header

        rbnspec_mean = np.mean(rbnspec[0].data, axis=1)

        if rc_hdr.get("HIERARCH ESO INS CALS ID") == "Argus":
            ymax1 = 5 * np.median(rbnspec_mean)
        else:
            ymax1 = 2 * np.median(rbnspec_mean)
        ylabel = hdr.get("BUNIT")
        lineplot = LinePlot(y_min=0, y_max=ymax1)
        lineplot.add_data(
            d=[self.wavelength, rbnspec_mean],
            label="Averaged spectrum",
            v_clip="sigma",
            v_clip_kwargs={"nsigma": 2},
        )
        lineplot.x_label = f"Wavelength ({self.wavelength_unit})"
        lineplot.y_label = ylabel
        lineplot.legend = False
        lineplot.title = "Averaged spectrum"
        p.assign_plot(lineplot, 0, 1, xext=2)

        # Cross-dispersion cut of full raw file
        raw_data = raw[ext].data
        raw_full = ImagePlot(
            raw_data,
            title="Raw file (full)",
        )
        cutpos = raw_full.get_data_coord(raw_full.data.shape[1] // 2, "y")
        cut_full = CutPlot(
            "y",
            title="Raw file (full)",
            y_label="ADU",
        )

        cut_full.add_data(
            raw_full,
            cut_pos=cutpos,
            color="black",
            label="row @ y={}".format(cutpos),
        )
        p.assign_plot(cut_full, 2, 1)

        # Eff vs. wavelength
        lineplot = LinePlot(
            title="Efficiency vs. wavelength",
            legend=True,
        )
        lineplot.add_data(
            d=[wave, eff_data], color="red", label="Efficiency vs. wavelength"
        )
        lineplot.x_label = f"Wavelength ({self.wavelength_unit})"
        lineplot.y_label = "Efficiency"
        p.assign_plot(lineplot, 0, 2, xext=2)

        # Histogram of full raw file
        hist_full = HistogramPlot(
            title="Raw data histogram", bins=50, v_min=-5000, v_max=20000
        )
        hist_full.add_data(raw_data, label="Raw data counts")
        p.assign_plot(hist_full, 2, 2)

        if rc_hdr.get("HIERARCH ESO INS CALS ID") == "Argus":
            # Star map
            star_map = self.plot_star_image(data=rcspec[0].data)
            star_map.title = "Reconstructed FOV"
            p.assign_plot(star_map, 0, 3, xext=2)

        input_files = [
            self.hdus[0]["std_rbnspectra"].filename(),
            self.hdus[0]["std_rcspectra"].filename(),
            self.hdus[0]["eff_curve"].filename(),
            self.hdus[0]["raw"].filename(),
        ]
        panels[p] = {
            "raw": raw.filename(),
            "flux_slit": rbnspec.filename(),
            "ext": ext,
            "report_name": f"{instru}_{procatg.lower()}_{ext}",
            "report_description": f"Specphotometric Std panel - ({rbnspec.filename()}, "
            f"{raw.filename()}, "
            f"{ext})",
            "report_tags": [],
            "input_files": input_files,
        }
        return panels


rep = GiraffeStdStarReport()
