from .giraffe_utils import GiraffeSetupInfo

from adari_core.report import AdariReportBase
from adari_core.plots.combined import CombinedPlot
from adari_core.plots.points import LinePlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.images import ImagePlot
from adari_core.plots.panel import Panel
from adari_core.plots.points import ScatterPlot
from adari_core.plots.text import TextPlot
from adari_core.utils.utils import fetch_kw_or_default, get_wavelength_from_header

import os
import logging
import numpy as np

logger = logging.getLogger(__name__)


class GiraffeWaveCalReport(AdariReportBase):
    master_im_ext = "PRIMARY"
    # CentralImagePlot largest allowed center_size
    center_size = 200
    # histogram parameters
    histmin = 50000
    histmax = 66000
    hist_bins = 20
    fwhm_min = 0
    fwhm_max = 10

    def __init__(self):
        super().__init__("giraffe_wave_cal")

    def parse_sof(self):
        master_im = None
        raw_im = None
        line_data = None
        line_catalog = None

        for filename, catg in self.inputs:
            if catg == "ARC_RBNSPECTRA" and master_im is None:
                master_im = filename
            elif catg == "ARC_SPECTRUM" and raw_im is None:
                raw_im = filename
            elif catg == "LINE_DATA" and line_data is None:
                line_data = filename
            elif catg == "LINE_CATALOG" and line_data is None:
                line_catalog = filename

        file_lists = []
        if master_im is not None and raw_im is not None and line_data is not None:
            file_lists.append(
                {
                    "master_im": master_im,
                    "raw_im": raw_im,
                    "linedata": line_data,
                    "linecat": line_catalog,
                }
            )
        return file_lists

    def generate_panels(self, **kwargs):
        master_im = self.hdus[0]["master_im"]
        raw_im = self.hdus[0]["raw_im"]

        fwhm = self.hdus[0]["linedata"]["FWHM"].data
        linecat = self.hdus[0]["linecat"][1].data["WLEN"]

        instru = fetch_kw_or_default(
            master_im["PRIMARY"], "INSTRUME", "Missing INSTRUME"
        )
        master_procatg = fetch_kw_or_default(
            master_im["PRIMARY"], "HIERARCH ESO PRO CATG", "Missing PRO CATG"
        )
        fname = os.path.basename(str(master_im.filename()))
        rawname = os.path.basename(str(raw_im.filename())).removesuffix(".fits")
        cut_pos = 2  # fiber=3

        wave, wave_unit = get_wavelength_from_header(
            master_im[self.master_im_ext], axis=2
        )

        input_files = [
            self.hdus[0]["master_im"].filename(),
            self.hdus[0]["raw_im"].filename(),
            self.hdus[0]["linedata"].filename(),
            self.hdus[0]["linecat"].filename(),
        ]

        panels = {}

        panel = Panel(5, 3, height_ratios=[1, 2, 2], width_ratios=[3, 3, 2, 2, 2])

        panels[panel] = {
            "report_name": "{}_{}_{}".format(instru, master_procatg.lower(), rawname),
            "report_description": "GIRAFFE wave cal panel - {}".format(fname),
            "report_tags": [],
            "input_files": input_files,
        }

        # Text Plot
        vspace = 0.2
        t1 = TextPlot(columns=1, v_space=vspace)
        col1 = (
            instru,
            "EXTNAME: "
            + str(fetch_kw_or_default(master_im["PRIMARY"], "EXTNAME", "N/A")),
            "PRO CATG: " + str(master_procatg),
            "FILE NAME: " + fname,
            "RAW1 NAME: "
            + str(
                fetch_kw_or_default(
                    master_im["PRIMARY"],
                    "HIERARCH ESO PRO REC1 RAW1 NAME",
                    "Missing RAW1 NAME",
                )
            ),
        )
        t1.add_data(col1)
        panel.assign_plot(t1, 0, 0, xext=2)

        t2 = TextPlot(columns=1, v_space=vspace)
        col2 = GiraffeSetupInfo.wavecalibration(self.hdus[0]["master_im"])
        t2.add_data(col2)
        panel.assign_plot(t2, 3, 0, xext=2)

        # Full reduced image
        vspace = 0.1
        fullfield = ImagePlot(
            title="REDUCED_LAMP",
            aspect=0.1,
            v_clip="val",
            v_clip_kwargs={
                "low": 0,
                "high": np.percentile(master_im["PRIMARY"].data, 98),
            },
        )
        fullfield.add_data(master_im[self.master_im_ext].data)
        panel.assign_plot(fullfield, 0, 1, xext=2, yext=2)

        # Dispersion-direction cut
        cutX = LinePlot(
            title="REDUCED_LAMP cut", x_label="Wavelength, " + str(wave_unit)
        )
        cutX.y_label = fetch_kw_or_default(master_im["PRIMARY"], "BUNIT", default="ADU")
        cutX.add_data(
            (wave, master_im[self.master_im_ext].data[:, cut_pos]),
            label=f"Fiber = {str(cut_pos + 1)}",
        )
        cutX.legend = True
        cutX.y_scale = "log"
        panel.assign_plot(cutX, 2, 1, xext=2, yext=1)

        # Dispersion-direction cut - cental region
        cutXzoom = LinePlot(
            title="REDUCED_LAMP central 200 pixels",
            x_label="Wavelength, " + str(wave_unit),
        )
        cutXzoom.y_label = fetch_kw_or_default(
            master_im["PRIMARY"], "BUNIT", default="ADU"
        )
        cutXzoom.add_data(
            (wave, master_im[self.master_im_ext].data[:, cut_pos]),
            label=f"Fiber = {str(cut_pos + 1)}",
        )
        cutXzoom.legend = False
        cutXzoom.set_xlim(
            wave[fullfield.data.shape[0] // 2 - self.center_size // 2],
            wave[fullfield.data.shape[0] // 2 + self.center_size // 2],
        )
        centercutMark = ScatterPlot(title="Line labels")
        centercutMark.add_data(
            label="Lines",
            d=[linecat, np.nan * np.ones_like(linecat)],  # dummy y-values
            marker="circle",
            vline=True,
            color="gainsboro",
        )
        centercutMark.legend = False
        centercutMark.set_xlim(
            wave[fullfield.data.shape[0] // 2 - self.center_size // 2],
            wave[fullfield.data.shape[0] // 2 + self.center_size // 2],
        )

        # Plot combined cut plot with lines
        combinedcutX = CombinedPlot(title="REDUCED_LAMP central 200 pixels")
        combinedcutX.add_data(cutXzoom, z_order=2)
        combinedcutX.add_data(
            centercutMark, z_order=1
        )  # place vertical lines behind data
        combinedcutX.y_scale = "log"
        combinedcutX.set_xlim(
            wave[fullfield.data.shape[0] // 2 - self.center_size // 2],
            wave[fullfield.data.shape[0] // 2 + self.center_size // 2],
        )
        combinedcutX.legend = False
        panel.assign_plot(combinedcutX, 2, 2, xext=1, yext=1)

        # Raw histogram
        rawhist = HistogramPlot(
            title="Raw counts histogram",
            v_clip="val",
            v_clip_kwargs={"low": self.histmin, "high": self.histmax},
        )
        rawhist.add_data(
            raw_im[0].data, color="black", label="Raw", bins=self.hist_bins
        )
        rawhist.legend = False

        panel.assign_plot(rawhist, 3, 2, xext=1, yext=1)

        # Histogram of the FWHM values of the found arc lines
        fwhmhist = HistogramPlot(
            title="FWHM values histogram",
            v_clip="val",
            v_clip_kwargs={"low": self.fwhm_min, "high": self.fwhm_max},
        )
        fwhmhist.add_data(
            fwhm,
            color="black",
            label="FWHM",
        )
        fwhmhist.legend = False
        fwhmhist.x_label = "FWHM values"

        panel.assign_plot(fwhmhist, 4, 1, xext=1, yext=1)

        return panels


rep = GiraffeWaveCalReport()
