from adari_core.data_libs.master_wave_echelle import MasterWaveReport
from adari_core.plots.panel import Panel
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.points import ScatterPlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.text import TextPlot
from adari_core.utils.utils import fetch_kw_or_default

from .uves_util import UvesSetupInfo

import numpy as np
import os

from . import UvesReportMixin


class UvesWavelengthCalibReport(UvesReportMixin, MasterWaveReport):
    raw_extensions = {"blue": "PRIMARY", "redu": "CCD-20", "redu_legacy": "CCID-20","redl": "CCD-44"}
    is_flames = False

    def __init__(self):
        super().__init__("uves_cal_wavecal")

    def check_uves_mode(self):
        """This function will check the instrumental setup of UVES based on in the input SOF."""
        for _, catg in self.inputs:
            if "FIB" in catg:
                self.is_flames = True
                break

    def parse_sof(self):
        # Check which instrumental setup is used
        self.check_uves_mode()
        if self.is_flames:
            prefix = "FIB_"
        else:
            prefix = ""

        line_table_b = None
        line_table_ru = None
        line_table_rl = None

        arc_lamp_b = None
        arc_lamp_r = None

        for filename, catg in self.inputs:
            if catg == f"{prefix}LINE_TABLE_BLUE":
                line_table_b = filename
            if catg == f"{prefix}LINE_TABLE_REDL":
                line_table_rl = filename
            if catg == f"{prefix}LINE_TABLE_REDU":
                line_table_ru = filename

            if catg == f"{prefix}ARC_LAMP_BLUE":
                arc_lamp_b = filename
            if catg == f"{prefix}ARC_LAMP_RED":
                arc_lamp_r = filename

        file_lists = []
        if line_table_b is not None and arc_lamp_b is not None:
            # for ext in ['WIN1','WIN2']:
            file_lists.append(
                {
                    "line_table": line_table_b,
                    "arc_lamp": arc_lamp_b,
                }
            )
        if line_table_rl is not None and arc_lamp_r is not None:
            file_lists.append(
                {
                    "line_table": line_table_rl,
                    "arc_lamp": arc_lamp_r,
                }
            )
        if line_table_ru is not None and arc_lamp_r is not None:
            file_lists.append(
                {
                    "line_table": line_table_ru,
                    "arc_lamp": arc_lamp_r,
                }
            )
        return file_lists

    def get_line_tables(self, line_table_hdul):
        """TODO"""
        if not self.is_flames:
            line_table = line_table_hdul["LINE_T0_W2_X4"]
            line_table_res = [
                line_table_hdul["LINE_T0_W1_X1"],
                line_table_hdul["LINE_T0_W3_X7"],
            ]
        else:
            # line_table = line_table_hdul["LINE_T0_W1_X1"]
            line_table = line_table_hdul[10]
            line_table_res = []
        return line_table, line_table_res

    def generate_panels(self, lplot_msize=0.05, ext=0, **kwargs):
        panels = {}
        ext = "LINE_TABLE"

        for hdus in self.hdus:
            line_table_hdul = hdus["line_table"]
            table_fname = line_table_hdul.filename()
            table_fname = os.path.basename(table_fname)

            # Find out which arm is used
            arm = UvesSetupInfo.get_arm_info(line_table_hdul)
            if arm is None:
                raise ValueError(
                    "UVES spectrograph arm info not found"
                )
            elif "RED" in arm:
                chip_1_index = line_table_hdul["PRIMARY"].header.get(
                    "HIERARCH ESO DET CHIP1 INDEX"
                )
                uves_arm = "red"
                if chip_1_index is None:
                    uves_arm = "redl"
                else:
                    all_hdu_names = [hdu.name for hdu in hdus["arc_lamp"]]
                    if "CCID-20" in all_hdu_names:
                        uves_arm = "redu_legacy"
                    else:
                        uves_arm = "redu"
            else:
                uves_arm = "blue"

            procatg = line_table_hdul["PRIMARY"].header["HIERARCH ESO PRO CATG"]

            # Table containing the the lines used for computing the wavelength solution
            # and resolving power (R)
            line_table, line_table_res = self.get_line_tables(line_table_hdul)
            # Arc lamp data
            arc_lamp_hdul = hdus["arc_lamp"]
            arc_lamp_ext = self.raw_extensions[uves_arm]
            arc_lamp_image_data = None
            if "BLUE" in arm and "625kHz" in arc_lamp_hdul["PRIMARY"].header.get(
                "HIERARCH ESO DET READ SPEED"
            ):
                data_arrays = []
                for aext in ["WIN1", "WIN2"]:
                    data_arrays.append(arc_lamp_hdul[aext].data)
                arc_lamp_image_data = np.concatenate(data_arrays, axis=1)
            else:
                arc_lamp_image_data = arc_lamp_hdul[arc_lamp_ext].data

            p = Panel(4, 3, height_ratios=[1, 4, 4])

            # Text Plot
            # Column 1
            # Used to keep track of position in Panel
            px = 0
            py = 0
            # which hdul and ext to use
            mdata_hdul = line_table_hdul
            mdata_ext = "PRIMARY"
            vspace = 0.3

            t1 = TextPlot(columns=1, v_space=vspace)
            fname = os.path.basename(str(mdata_hdul.filename()))

            col1 = (
                str(mdata_hdul[mdata_ext].header.get("INSTRUME")),
                "EXTNAME: " + ext,
                "PRO CATG: "
                + str(mdata_hdul[mdata_ext].header.get("HIERARCH ESO PRO CATG")),
                "FILE NAME: " + fname,
                "RAW1 NAME: "
                + str(
                    mdata_hdul["PRIMARY"].header.get("HIERARCH ESO PRO REC1 RAW1 NAME")
                ),
            )
            t1.add_data(col1)

            p.assign_plot(t1, px, py, xext=2)

            px = px + 2
            t2 = TextPlot(columns=1, v_space=vspace, xext=1)
            hdr = mdata_hdul[mdata_ext].header

            if self.is_flames:
                col2 = (
                    "DET.WIN1.BINX: " + str(hdr.get("HIERARCH ESO DET WIN1 BINX")),
                    "DET.WIN1.BINY: " + str(hdr.get("HIERARCH ESO DET WIN1 BINY")),
                    "DET.READ.SPEED: " + str(hdr.get("HIERARCH ESO DET READ SPEED")),
                    "INS.SLIT3.PLATE: " + str(hdr.get("HIERARCH ESO INS SLIT3 PLATE")),
                    "INS.SLIT3.MODE: " + str(hdr.get("HIERARCH ESO INS SLIT3 MODE")),
                    "INS.GRAT2.WLEN: " + str(hdr.get("HIERARCH ESO INS GRAT2 WLEN")),
                )
            else:
                g1 = "HIERARCH ESO INS GRAT1 NAME"
                g2 = "HIERARCH ESO INS GRAT2 NAME"
                w1 = "HIERARCH ESO INS GRAT1 WLEN"
                w2 = "HIERARCH ESO INS GRAT2 WLEN"
                s2 = "HIERARCH ESO INS SLIT2 WID"
                col2 = (
                    "DET.WIN1.BINX: " + str(hdr.get("HIERARCH ESO DET WIN1 BINX")),
                    "DET.WIN1.BINY: " + str(hdr.get("HIERARCH ESO DET WIN1 BINY")),
                    "DET.READ.SPEED: " + str(hdr.get("HIERARCH ESO DET READ SPEED")),
                    "INS.MODE: " + str(hdr.get("HIERARCH ESO INS MODE")),
                    "INS.GRAT1.NAME: " + str(hdr.get(g1))
                    if g1 in hdr
                    else "INS.GRAT2.NAME: " + str(hdr.get(g2)),
                    "INS.GRAT1.WLEN: " + str(hdr.get(w1))
                    if g1 in hdr
                    else "INS.GRAT2.WLEN: " + str(hdr.get(w2)),
                    "INS.SLIT1.NAME: " + str(hdr.get("HIERARCH ESO INS SLIT1 NAME")),
                    "INS.SLIT2.WID: " + str(hdr.get(s2))
                    if s2 in hdr
                    else "INS.SLIT3.WID: " + str(hdr.get(s2.replace("SLIT2", "SLIT3"))),
                )

            t2.add_data(col2)
            p.assign_plot(t2, px, py, xext=1)

            rplot_size = 0.2

            mask = line_table.data["Select"] == 1
            selplot = line_table.data[mask]

            # 1. Position of the arc lines that are used for the wavelength solution.
            # X axis being dispersion direction (in pixels), y axis being order number.
            ymin = max(line_table.data["Order"]) + 1
            ymax = min(line_table.data["Order"]) - 1

            pos_plot = ScatterPlot(
                title="Line positions",
                x_label="X",
                y_label="Order",
                markersize=rplot_size,
                legend=False,
            )
            pos_plot.add_data(
                (selplot["X"], selplot["Order"]), label="Used lines", color="black"
            )
            # pos_plot.add_data((line_table.data['X'], line_table.data['Order']), label='Used lines', color='black')
            # flip the yaxis limits as per required convention
            pos_plot.set_ylim(ymin, ymax)
            p.assign_plot(pos_plot, 0, 1)

            # 2 Resolving power R.  a.Resolving power R of used lines vs pixel position in dispersion direction (X)

            pos_plot = ScatterPlot(
                title="Resolution",
                x_label="X",
                y_label="Resolving Power",
                markersize=rplot_size,
                legend=False,
            )
            pos_plot.add_data(
                (line_table.data["X"], line_table.data["Resol"]),
                label="Used lines",
                color="black",
            )
            p.assign_plot(pos_plot, 1, 1)

            # 3 Resolving power R of the lines used vs pixel positon in cross dispersion direction
            pos_plot = ScatterPlot(
                title="Resolution",
                x_label="Y",
                y_label="Resolving Power",
                markersize=rplot_size,
                legend=False,
            )
            pos_plot.add_data(
                (line_table.data["Ynew"], line_table.data["Resol"]),
                label="Used lines",
                color="black",
            )
            p.assign_plot(pos_plot, 1, 2)

            # Cut in cross dispersion direction through raw file
            rotation_kwargs = {}
            rotation_kwargs["rotate"] = -90
            if "RED" in arm:
                rotation_kwargs["rotate"] = 270
                rotation_kwargs["flip"] = "y"

            raw_plot = ImagePlot(
                arc_lamp_image_data, title="raw plot", **rotation_kwargs
            )
            raw_center = CentralImagePlot(
                raw_plot, extent=self.central_region_size, title="raw center"
            )

            cutpos = raw_plot.get_data_coord(raw_plot.data.shape[1] // 2, "x")
            raw_cutX = CutPlot(
                "x",
                title="raw col @ X {}".format(cutpos),
                x_label="Y",
                y_label=fetch_kw_or_default(
                    arc_lamp_hdul["PRIMARY"], "BUNIT", default="ADU"
                ),
            )
            raw_cutX.add_data(raw_plot, cutpos, color="black", label="raw")
            p.assign_plot(raw_cutX, 2, 1)

            # p.assign_plot(raw_plot, 4, 1)

            # Same for central region
            cutpos = raw_center.get_data_coord(
                np.floor(raw_center.data.shape[1] // 2), "x"
            )
            raw_cen_cutX = CutPlot(
                "x",
                title="Raw central Region: col @ X {}".format(cutpos),
                x_label="Y",
                y_label=fetch_kw_or_default(
                    arc_lamp_hdul["PRIMARY"], "BUNIT", default="ADU"
                ),
            )
            raw_cen_cutX.add_data(raw_center, cutpos, label="raw", color="black")
            p.assign_plot(raw_cen_cutX, 2, 2)

            # Histogram Plot
            raw_hist = HistogramPlot(
                raw_data=raw_plot.data,
                title="raw value counts",
                bins=50,
                v_min=-5000,
                v_max=70000,
            )
            p.assign_plot(raw_hist, 0, 2)

            # Fit residuals of line vs order number
            resorder_plot = ScatterPlot(
                title="Fit residuals",
                y_label="Fit res Lines",
                x_label="ORDER",
                markersize=rplot_size,
                legend=True,
            )
            res_colours = ["green", "blue"]
            for i, table in enumerate(line_table_res):
                resorder_plot.add_data(
                    (table.data["Order"], table.data["Residual"]),
                    label=f"sky{i+1}",
                    color=res_colours[i],
                )
            resorder_plot.add_data(
                (selplot["Order"], selplot["Residual"]), label="obj", color="red"
            )
            p.assign_plot(resorder_plot, 3, 1)

            # Fit residuals of line vs position on detector
            respos_plot = ScatterPlot(
                title="Fit residuals",
                y_label="Fit Res Lines",
                x_label="X pos",
                markersize=rplot_size,
                legend=True,
            )
            for i, table in enumerate(line_table_res):
                respos_plot.add_data(
                    (table.data["X"], table.data["Residual"]),
                    label=f"sky{i+1}",
                    color=res_colours[i],
                )
            respos_plot.add_data(
                (selplot["X"], selplot["Residual"]), label="obj", color="red"
            )
            p.assign_plot(respos_plot, 3, 2)

            # return panels
            hdr = mdata_hdul[mdata_ext].header
            # grating = 'none'
            wlen = "none"
            # if("HIERARCH ESO INS GRAT1 NAME" in hdr):
            #    grating = str(order_table_hdul["PRIMARY"].header.get("HIERARCH ESO INS GRAT1 NAME"))
            # if("HIERARCH ESO INS GRAT2 NAME" in hdr):
            #    grating = str(order_table_hdul["PRIMARY"].header.get("HIERARCH ESO INS GRAT2 NAME"))
            if "HIERARCH ESO INS GRAT1 WLEN" in hdr:
                wlen = str(
                    int(mdata_hdul[mdata_ext].header.get("HIERARCH ESO INS GRAT1 WLEN"))
                )
            if "HIERARCH ESO INS GRAT2 WLEN" in hdr:
                wlen = str(
                    int(mdata_hdul[mdata_ext].header.get("HIERARCH ESO INS GRAT2 WLEN"))
                )

            binning = (
                str(mdata_hdul[mdata_ext].header.get("HIERARCH ESO DET WIN1 BINX"))
                + "x"
                + str(mdata_hdul[mdata_ext].header.get("HIERARCH ESO DET WIN1 BINY"))
            )
            setup = "bin" + binning + "_" + wlen
            addme = {
                "ext": ext,
                "report_name": f"uves_{procatg.lower()}_{setup}_{str(ext).lower()}",
                "report_description": f"UVES wavelength calibration panel - " f"{ext}",
                "report_tags": [],
            }

            panels[p] = addme
        return panels


rep = UvesWavelengthCalibReport()
