# SPDX-License-Identifier: BSD-3-Clause
import numpy as np
import os

from adari_core.plots.images import ImagePlot
from adari_core.plots.panel import Panel
from adari_core.plots.points import LinePlot
from adari_core.plots.text import TextPlot
from adari_core.report import AdariReportBase
from adari_core.utils.utils import fetch_kw_or_default, fetch_kw_or_error
from .xshooter_utils import XshooterSetupInfo, XshooterReportMixin

class XshooterTelluricStdReport(XshooterReportMixin, AdariReportBase):
    def __init__(self):
        super().__init__("xshooter_telluric_std")
        self.central_region_size = 200

    def parse_sof(self):
        # building one report set
        tell_uvb = None
        tell_vis = None
        tell_nir = None

        tell_slit_merge2d_uvb = None
        tell_slit_merge2d_vis = None
        tell_slit_merge2d_nir = None

        for filename, catg in self.inputs:
            if catg == "TELL_SLIT_MERGE1D_UVB":
                tell_uvb = filename
            if catg == "TELL_SLIT_MERGE1D_VIS":
                tell_vis = filename
            if catg == "TELL_SLIT_MERGE1D_NIR":
                tell_nir = filename

            if catg == "TELL_SLIT_MERGE2D_UVB":
                tell_slit_merge2d_uvb = filename
            if catg == "TELL_SLIT_MERGE2D_VIS":
                tell_slit_merge2d_vis = filename
            if catg == "TELL_SLIT_MERGE2D_NIR":
                tell_slit_merge2d_nir = filename

        # Build and return the (one) file name list
        file_lists = []
        if tell_uvb is not None:
            file_lists.append(
                {
                    "telluric": tell_uvb,
                    "merge2d": tell_slit_merge2d_uvb,
                }
            )
        if tell_vis is not None:
            file_lists.append(
                {
                    "telluric": tell_vis,
                    "merge2d": tell_slit_merge2d_vis,
                }
            )
        if tell_nir is not None:
            file_lists.append(
                {
                    "telluric": tell_nir,
                    "merge2d": tell_slit_merge2d_nir,
                }
            )
        return file_lists

    def generate_panels_std(self, **kwargs):
        panels = {}

        tell_hdul = self.hdus[0]["telluric"]
        ext = "PRIMARY"
        tell_spec = tell_hdul[ext].data
        hdr = tell_hdul[ext].header

        procatg = fetch_kw_or_default(tell_hdul[ext], "HIERARCH ESO PRO CATG")

        # get wavelength info
        ctype = hdr.get("CTYPE1")
        cunit = hdr.get("CUNIT1")
        crpix = hdr.get("CRPIX1")
        crval = hdr.get("CRVAL1")
        cdelt = hdr.get("CDELT1")
        if "LINEAR" in ctype:
            wave = (np.arange(tell_spec.size) + 1.0 - crpix) * cdelt + crval
            w1 = (wave[-1] - wave[0]) / 4.0  # split wavelength range in 4 panels

        p = Panel(7, 5, height_ratios=[1, 2, 2, 2, 2])

        px = 0
        py = 0
        vspace = 0.2

        t1 = TextPlot(columns=1, v_space=vspace, xext=4)
        fname = os.path.basename(str(tell_hdul.filename()))

        col1 = (
            str(hdr.get("INSTRUME")),
            "EXTNAME: " + ext,
            "PRO CATG: " + str(hdr.get("HIERARCH ESO PRO CATG")),
            "FILE NAME: " + fname,
            "RAW1 NAME: " + str(hdr.get("HIERARCH ESO PRO REC1 RAW1 NAME")),
        )
        t1.add_data(col1)
        p.assign_plot(t1, px, py, xext=4)

        px = px + 3
        t2 = TextPlot(columns=1, v_space=vspace, xext=1)

        col2 = XshooterSetupInfo.telluric_standard_slit_nod(tell_hdul)

        t2.add_data(col2)
        p.assign_plot(t2, px, py, xext=1)

        y_min = -1.0 * np.max(tell_spec) * 0.05

        wave1_plot = LinePlot(
            title="Telluric standard star",
            x_label="",
            y_label="ADU",
            x_min=wave[0] * 0.99,
            x_max=(wave[0] + w1) * 1.01,
            y_min=y_min,
            legend=False,
        )

        wave1_plot.add_data((wave, tell_spec), label=" ", color="black")
        p.assign_plot(wave1_plot, 0, 1, xext=7, yext=1)

        wave2_plot = LinePlot(
            title="",
            x_label="",
            y_label="ADU",
            x_min=(wave[0] + w1) * 0.99,
            x_max=(wave[0] + 2.0 * w1) * 1.01,
            y_min=y_min,
            legend=False,
        )

        wave2_plot.add_data((wave, tell_spec), label=" ", color="black")
        p.assign_plot(wave2_plot, 0, 2, xext=7)

        wave3_plot = LinePlot(
            title="",
            x_label="",
            y_label="ADU",
            x_min=(wave[0] + 2.0 * w1) * 0.99,
            x_max=(wave[0] + 3.0 * w1) * 1.01,
            y_min=y_min,
            legend=False,
        )

        wave3_plot.add_data((wave, tell_spec), label=" ", color="black")
        p.assign_plot(wave3_plot, 0, 3, xext=7)

        wave4_plot = LinePlot(
            title="",
            x_label="Wavelength, " + cunit,
            y_label="ADU",
            x_min=(wave[0] + 3.0 * w1) * 0.99,
            x_max=wave[-1] * 1.01,
            y_min=y_min,
            legend=False,
        )

        wave4_plot.add_data((wave, tell_spec), label=" ", color="black")
        p.assign_plot(wave4_plot, 0, 4, xext=7)

        input_files = [tell_hdul.filename()]
        addme = {
            "report_name": f"XSHOOTER_telluric_std_{procatg.lower()}_{str(ext).lower()}",
            "report_description": f"Telluric std panel" f"{ext})",
            "report_tags": [],
            "input_files": input_files,
        }

        panels[p] = addme

        return panels

    def generate_panels_merge2d(self):
        panels = {}
        vspace = 0.3
        panel = Panel(4, 2, height_ratios=[1, 4])
        m2d = self.hdus[0]["merge2d"]

        # Text Plot
        procatg = fetch_kw_or_error(m2d["PRIMARY"], "HIERARCH ESO PRO CATG")
        fname = os.path.basename(str(m2d.filename()))
        rawname = str(
            m2d["PRIMARY"].header.get(
                "HIERARCH ESO PRO REC1 RAW1 NAME", "Missing RAW1 NAME"
            )
        )
        t1 = TextPlot(columns=1, v_space=vspace)
        col1 = (
            str(m2d["PRIMARY"].header.get("INSTRUME", "Missing INSTRUME")),
            "EXTNAME: PRIMARY",
            "PRO CATG: " + procatg,
            "FILE NAME: " + fname,
            "RAW1 NAME: " + rawname,
        )
        t1.add_data(col1)
        panel.assign_plot(t1, 0, 0, xext=2)

        t2 = TextPlot(columns=1, v_space=vspace, xext=1)
        col2 = XshooterSetupInfo.specphot_star(m2d)
        t2.add_data(col2)
        panel.assign_plot(t2, 2, 0, xext=1)

        m2d_plot = ImagePlot(m2d[0].data, title="Merged 2D spectrum", aspect="auto")
        m2d_plot.tick_visibility = {
            "top": False,
            "labeltop": False,
            "bottom": True,
            "labelbottom": True,
            "right": True,
            "labelright": True,
            "left": True,
            "labelleft": True,
        }
        m2d_plot.cbar_kwargs = {"pad": 0.05}
        panel.assign_plot(m2d_plot, 0, 1, xext=4, yext=1)
        input_files = [m2d.filename()]
        panels[panel] = {
            "report_name": f"XSHOOTER_telluric_std_{procatg.lower()}",
            "report_description": "StandardStar_merge2d",
            "report_tags": [],
            "input_files": input_files,
        }
        return panels

    def generate_panels(self, **kwargs):
        panels = {
            **self.generate_panels_std(),
            **self.generate_panels_merge2d(),
        }
        return panels


rep = XshooterTelluricStdReport()
