from adari_core.plots.text import TextPlot
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.points import ScatterPlot
from adari_core.plots.panel import Panel

from adari_core.data_libs.echelle_flatfield import MasterEchelleFlatfieldReport

import os

from . import UvesReportMixin


class UvesFibreFlatFieldReport(UvesReportMixin, MasterEchelleFlatfieldReport):
    """Report class for the UVES-Flames fibre flat recipe.

    Description
    -----------
    This class is meant to produce the QC reports for flat UVES Fibre mode.
    The recipes involved in this report are: flames_cal_prep_sff_ofpos, and flames_cal_mkmaster.
    Both have independent reports.
    """

    raw_extension_default = None
    sflat = False
    fibflat = False

    def __init__(self):
        super().__init__("uves_flames_fibre_flatfield")
        self.center_size = 200

    def parse_sof(self):
        categories = [file[1] for file in self.inputs]
        if "SFLAT_RED" in categories:
            self.sflat = True
            return self.parse_sof_sflat()
        else:
            self.fibflat = True
            return self.parse_sof_fibflat()

    def parse_sof_sflat(self):
        master_s_rl = []
        master_s_ru = []

        first_raw_flat_rl = None
        first_raw_flat_ru = None
        for filename, catg in self.inputs:
            # S Master files
            if catg == "MASTER_SFLAT_REDL":
                master_s_rl.append(filename)
            elif catg == "MASTER_SFLAT_REDU":
                master_s_ru.append(filename)
            # Raw files
            elif catg == "SFLAT_RED" and first_raw_flat_rl is None:
                first_raw_flat_rl = filename
                first_raw_flat_ru = filename

        file_lists = []
        self.raw_extension_default = []

        if len(master_s_rl) > 0 and first_raw_flat_rl is not None:
            files_rl = {
                "raw": first_raw_flat_rl,
            }
            for i in range(3):
                files_rl[f"master_sflat_{i+1}"] = master_s_rl[i]
            file_lists.append(files_rl)
            self.raw_extension_default.append(2)

        if len(master_s_ru) > 0 and first_raw_flat_ru is not None:
            files_ru = {
                "raw": first_raw_flat_ru,
            }
            for i in range(3):
                files_ru[f"master_sflat_{i+1}"] = master_s_ru[i]
            file_lists.append(files_ru)
            self.raw_extension_default.append(1)
        return file_lists

    def parse_sof_fibflat(self):
        order_table_rl = None
        order_table_ru = None

        master_rl = None
        master_ru = None

        first_raw_flat_rl = None
        first_raw_flat_ru = None
        for filename, catg in self.inputs:
            if catg == "FIB_ORDEF_TABLE_REDL":
                order_table_rl = filename
            elif catg == "FIB_ORDEF_TABLE_REDU":
                order_table_ru = filename
            # Master
            elif catg == "FIB_FF_DTC_REDL":
                master_rl = filename
            elif catg == "FIB_FF_DTC_REDU":
                master_ru = filename
            # Raw files
            elif catg == "FIB_FF_ALL_RED" and first_raw_flat_rl is None:
                first_raw_flat_rl = filename
                first_raw_flat_ru = filename

        file_lists = []
        self.raw_extension_default = []

        if order_table_rl is not None and first_raw_flat_rl is not None:
            files_rl = {
                "order_table": order_table_rl,
                "master_product": master_rl,
                "raw": first_raw_flat_rl,
            }
            file_lists.append(files_rl)
            self.raw_extension_default.append(2)

        if order_table_ru is not None and first_raw_flat_ru is not None:
            files_ru = {
                "order_table": order_table_ru,
                "master_product": master_ru,
                "raw": first_raw_flat_ru,
            }
            file_lists.append(files_ru)
            self.raw_extension_default.append(1)
        return file_lists

    def echelle_basic_panels(self, master, master_ext, raw, raw_ext):
        """Produce the basic plot requirements for echelle flat fields."""
        master_procatg = master["PRIMARY"].header.get("HIERARCH ESO PRO CATG")

        p = Panel(5, 3, height_ratios=[1, 4, 4])

        # Master Plots
        # TODO overplot order solution
        master_plot = ImagePlot(master[master_ext].data, title=master_procatg)
        p.assign_plot(master_plot, 0, 1)

        # Get central image
        # TODO overplot order solution
        master_center = CentralImagePlot(
            master_plot, title=f"{master_procatg} center", extent=self.center_size
        )
        p.assign_plot(master_center, 0, 2)

        # Plot cut of master
        bunit = master[master_ext].header.get("BUNIT", "Unknown BUNIT")
        raw_exptime = raw["PRIMARY"].header.get("EXPTIME", None)

        cutpos = master_plot.get_data_coord(master_plot.data.shape[1] // 2, "x")
        master_cut = CutPlot(
            "x", title="master row @ X {}".format(cutpos), y_label=bunit
        )
        master_cut.add_data(master_plot, cutpos, color="red", label="master")
        p.assign_plot(master_cut, 1, 1)

        # center cut plot
        cutpos = master_center.get_data_coord(master_center.data.shape[1] // 2, "x")
        master_cen_cut = CutPlot(
            "x", title="Central Region: row @ X {}".format(cutpos), y_label=bunit
        )
        master_cen_cut.add_data(master_center, cutpos, label="master", color="red")
        p.assign_plot(master_cen_cut, 1, 2)

        raw_hist = HistogramPlot(
            raw_data=raw[raw_ext].data, title="raw value counts", x_label="counts"
        )
        raw_hist.v_min, raw_hist.v_max = -5e3, 7e4
        p.assign_plot(raw_hist, 2, 1)

        master_hist = HistogramPlot(
            master_data=master_plot.data,
            title="master value counts",
            bins=50,
            x_label="normalized counts",
        )
        master_hist.v_min, master_hist.v_max = -5e3 / raw_exptime, 7e4 / raw_exptime
        p.assign_plot(master_hist, 2, 2)

        return p

    def generate_panels(self, **kwargs):
        panels = {}
        master_product_ext = "PRIMARY"
        raw_ext = self.raw_extension_default
        ext_kw = {1: "REDU", 2: "REDL"}

        for r_ext, file_dict in zip(raw_ext, self.hdus):
            if self.sflat:
                # Select the first SFLAT
                master_product = file_dict["master_sflat_1"]
                raw = file_dict["raw"]
                hdr = master_product["PRIMARY"].header
                fname = master_product.filename()
            elif self.fibflat:
                master_product = file_dict["master_product"].copy()
                master_second_plane = master_product["PRIMARY"].data[1]
                master_product["PRIMARY"].data = master_product["PRIMARY"].data[0]
                order_profile = file_dict["order_table"]
                hdr = order_profile["PRIMARY"].header
                fname = order_profile.filename()
                raw = file_dict["raw"]

            ## Generic panels
            master_procatg = master_product["PRIMARY"].header.get(
                "HIERARCH ESO PRO CATG"
            )
            p = self.echelle_basic_panels(
                master_product, master_product_ext, raw, r_ext
            )

            if self.fibflat:
                p.x = 4
                fibre_rel_trans = []
                for i in range(1, 10):
                    fibre_rel_trans.append(hdr[f"ESO QC FIB{i} RELTRANS"])
                scatter_p = ScatterPlot()
                scatter_p.add_data(
                    [list(range(1, 10)), fibre_rel_trans],
                    label="transmission",
                    color="black",
                )
                scatter_p.x_label = "Fibre number"
                scatter_p.y_label = "Relative Transmission"
                scatter_p.legend = None
                p.assign_plot(scatter_p, 3, 1, yext=1, xext=1)

                # Include the second ....
                cutplot = p.retrieve(1, 2)
                master_second = ImagePlot(master_second_plane, title="")
                master_center_second = CentralImagePlot(
                    master_second, title="", extent=self.center_size
                )
                cutpos = master_center_second.get_data_coord(
                    master_center_second.data.shape[1] // 2, "x"
                )
                cutplot.add_data(
                    master_center_second, cutpos, label="master_2", color="blue"
                )

                # p.assign_plot(master_plot, 0, 1)

            elif self.sflat:
                cutpos = 2048
                sflat_cut = CutPlot(
                    "x",
                    title="SFLATs row @ X {}".format(cutpos),
                    y_label="normalized counts",
                )
                sflat_cut_center = CutPlot(
                    "x",
                    title="Central Region SFLATs row @ X {}".format(cutpos),
                    y_label="normalized counts",
                )
                colors = ["red", "green", "blue"]
                for s in range(1, 4):
                    sflat_image = ImagePlot(
                        file_dict[f"master_sflat_{s}"]["PRIMARY"].data,
                        title=master_procatg + f"_{s}",
                    )
                    cutpos = sflat_image.get_data_coord(
                        sflat_image.data.shape[1] // 2, "x"
                    )
                    sflat_image_center = CentralImagePlot(
                        sflat_image,
                        title=f"{master_procatg} center",
                        extent=self.center_size,
                    )
                    cutpos_center = sflat_image_center.get_data_coord(
                        sflat_image_center.data.shape[1] // 2, "x"
                    )
                    sflat_cut.add_data(
                        sflat_image, cutpos, color=colors[s - 1], label=f"sflat_{s}"
                    )
                    sflat_cut_center.add_data(
                        sflat_image_center,
                        cutpos_center,
                        color=colors[s - 1],
                        label=f"sflat_{s}",
                    )
                p.assign_plot(sflat_cut, 3, 1, yext=1, xext=2)
                p.assign_plot(sflat_cut_center, 3, 2, yext=1, xext=2)

            # Metadata and setup info
            vspace = 0.3
            t1 = TextPlot(columns=1, v_space=vspace)
            col1 = (
                str(hdr.get("INSTRUME")),
                # "EXTNAME: " + str(self.hdus[i]['master_product']["PRIMARY"]
                #                   .header.get("EXTNAME", "N/A")),
                "PRO CATG: " + str(hdr.get("HIERARCH ESO PRO CATG")),
                "FILE NAME: " + os.path.basename(fname),
                "RAW1 NAME: " + str(hdr.get("HIERARCH ESO PRO REC1 RAW1 NAME")),
            )
            t1.add_data(col1)
            p.assign_plot(t1, 0, 0, xext=1)

            if self.fibflat:
                recipe = "prep_sff_ofpos"
                t2 = TextPlot(columns=1, v_space=vspace, xext=1)
                col2 = [
                    # "INS.MODE: " + str(hdr.get("HIERARCH ESO INS MODE")),
                    "DET.BINX: " + str(hdr.get("HIERARCH ESO DET WIN1 BINX")),
                    "DET.BINY: " + str(hdr.get("HIERARCH ESO DET WIN1 BINY")),
                    "DET.READ.SPEED: " + str(hdr.get("HIERARCH ESO DET READ SPEED")),
                    "INS.SLIT3.PLATE: " + str(hdr.get("HIERARCH ESO INS SLIT3 PLATE")),
                    "INS.SLIT3.MODE: " + str(hdr.get("HIERARCH ESO INS SLIT3 MODE")),
                    "INS.GRAT2.WLEN: " + str(hdr.get("HIERARCH ESO INS GRAT2 WLEN")),
                ]
            elif self.sflat:
                recipe = "mkmaster"
                t2 = TextPlot(columns=1, v_space=vspace, xext=1)
                col2 = [
                    "DET.WIN1.BINX: " + str(hdr.get("HIERARCH ESO DET WIN1 BINX")),
                    "DET.WIN1.BINY: " + str(hdr.get("HIERARCH ESO DET WIN1 BINY")),
                    "DET.READ.SPEED: " + str(hdr.get("HIERARCH ESO DET READ SPEED")),
                    "INS.GRAT2.WLEN: " + str(hdr.get("HIERARCH ESO INS GRAT2 WLEN")),
                ]

            t2.add_data(col2)
            p.assign_plot(t2, 2, 0, xext=1)
            panels[p] = {
                "master_product": master_product.filename(),
                "master_product_ext": master_product_ext,
                "raw": raw.filename(),
                "raw_ext": r_ext,
                "report_name": f"uves_fibre_flatfield_{master_product_ext}_{ext_kw[r_ext]}_{recipe}",
                "report_description": f"UVES-Flames fibre flat field - {master_procatg}, "
                f"{master_product_ext}",
                "report_tags": [],
            }
        return panels


rep = UvesFibreFlatFieldReport()
