from adari_core.utils.esodet import strip_prescan_overscan_espresso
import numpy as np
from adari_core.plots.points import ScatterPlot
from adari_core.plots.cut import CutPlot
from adari_core.utils.utils import fetch_kw_or_default

from adari_core.data_libs.master_orderdef import MasterOrderdefReport
import os

from .espresso_utils import EspressoReportMixin


class EspressoOrderdefFlatFieldReport(EspressoReportMixin, MasterOrderdefReport):
    def __init__(self):
        super().__init__("espresso_orderdef")

    def remove_raw_scan(self, im_hdu, **kwargs):
        return strip_prescan_overscan_espresso(im_hdu, **kwargs)

    def parse_sof(self):
        # building one report set
        order_table_A = None
        orderdef_raw_A = None
        order_table_B = None
        orderdef_raw_B = None

        for filename, catg in self.inputs:
            if catg == "ORDER_TABLE_A":
                order_table_A = filename
            if catg == "ORDERDEF_A":
                orderdef_raw_A = filename
            if catg == "ORDER_TABLE_B":
                order_table_B = filename
            if catg == "ORDERDEF_B":
                orderdef_raw_B = filename
        # Build and return the (one) file name list
        file_lists = []
        if order_table_A is not None and orderdef_raw_A is not None:
            file_lists.append(
                {
                    "order_table": order_table_A,
                    "order_def": orderdef_raw_A,
                }
            )
        if order_table_B is not None and orderdef_raw_B is not None:
            file_lists.append(
                {
                    "order_table": order_table_B,
                    "order_def": orderdef_raw_B,
                }
            )
        return file_lists

    def get_orderdef_stats(self, hdulist, ext):
        header = hdulist[0].header

        stats = ["STDEV", "MIN", "MAX"]
        ext_to_use = "EXT0" if "blue" in ext else "EXT1"
        ext_dict = {}
        for s in stats:
            ext_dict[s] = [header[key] for key in header[f"*{ext_to_use} RES {s}*"]]
        return ext_dict

    def generate_panels(self, **kwargs):
        panels = {}
        exts = ["CCD290red", "CCD290blue"]

        # ext0_stats, ext1_stats = get_orderdef_stats(order_table_A)

        for ext in exts:
            new_panels = super().generate_panels(
                master_im_ext=ext, raw_im_ext=ext, im_cut_direction="y"
            )
            for panel, panel_descr in new_panels.items():
                i = panel_descr["hdus_i"]
                panel_descr["report_description"] = (
                    f"ESPRESSO orderdef panel - "
                    f"{os.path.basename(panel_descr['order_table'])}, "
                    f"{os.path.basename(panel_descr['order_def'])}, "
                    f"{panel_descr['ext']}"
                )

                # Shift the raw histogram to make room for the master cuts
                raw_hist = panel.pop(2, 1)
                panel.assign_plot(raw_hist, 3, 1)

                # Change the cut position in the raws to 1/4 of the way across,
                # not the default 1/2
                raw_cut = panel.retrieve(1, 1)
                raw_cut_cent = panel.retrieve(1, 2)
                for cut in [raw_cut, raw_cut_cent]:
                    for label, datadict in cut.data.items():
                        datadict["cut_pos"] = datadict["data"].get_data_coord(
                            datadict["data"].data.shape[0] // 4, "y"
                        )
                raw_cut.title = "raw row @ Y={}".format(
                    raw_cut.data[next(iter(raw_cut.data))]["cut_pos"]
                )
                raw_cut_cent.title = "Central region raw row @ Y={}".format(
                    raw_cut_cent.data[next(iter(raw_cut_cent.data))]["cut_pos"]
                )

                # Generate & display the master cuts
                master_plot = panel.retrieve(0, 1)
                master_plot_cent = panel.retrieve(0, 2)

                # Calculate cut position (x cut for ESPRESSO)
                cut_pos = master_plot.get_data_coord(
                    master_plot.data.shape[1] // 2, "x"
                )

                master_cut = CutPlot(
                    "x",
                    title=f"Master column @ X={cut_pos}",
                    y_label=fetch_kw_or_default(
                        self.hdus[i]["order_table"][ext], "BUNIT", "ADU"
                    ),
                )
                master_cut.add_data(master_plot, cut_pos, color="red", label="master")
                panel.assign_plot(master_cut, 2, 1)

                master_cent_cut = CutPlot(
                    "x",
                    title=f"Central region master column @ X={cut_pos}",
                    y_label=fetch_kw_or_default(
                        self.hdus[i]["order_table"][ext], "BUNIT", "ADU"
                    ),
                )
                master_cent_cut.add_data(
                    master_plot_cent, cut_pos, color="red", label="master"
                )
                panel.assign_plot(master_cent_cut, 2, 2)

                # Display the EXTn statistics
                order_table = self.hdus[i]["order_table"]
                ext_to_use = "EXT0" if "blue" in ext else "EXT1"
                ext_stats_plot = ScatterPlot(
                    title="QC ORDER {} Statistics".format(ext_to_use),
                    x_label="Order No",
                    y_label="pixel",
                    markersize=3,
                )

                statcolors = ["black", "red", "blue"]
                scindex = 0
                ext_stats = self.get_orderdef_stats(order_table, ext)
                for key, value in ext_stats.items():
                    ext_stats_plot.add_data(
                        (np.arange(1, len(value) + 1), value),
                        label=key,
                        color=statcolors[scindex],
                    )
                    scindex = scindex + 1
                panel.assign_plot(ext_stats_plot, 3, 2)

            panels = {**panels, **new_panels}

        return panels


rep = EspressoOrderdefFlatFieldReport()
