import os
from adari_core.data_libs.master_orderdef import MasterOrderdefReport
from adari_core.plots.cut import CutPlot
from adari_core.plots.points import ScatterPlot
from adari_core.plots.text import TextPlot
from adari_core.utils.utils import fetch_kw_or_default
from .xshooter_utils import strip_prescan_overscan
from .xshooter_utils import XshooterSetupInfo, XshooterReportMixin


class XshooterOrderdefFlatFieldReport(XshooterReportMixin, MasterOrderdefReport):
    def __init__(self):
        super().__init__("xshooter_orderdef")

    def remove_raw_scan(self, im_hdu, **kwargs):
        return strip_prescan_overscan(im_hdu, **kwargs)

    def parse_sof(self):
        orderdef_uvb = None
        orderdef_vis = None
        orderdef_nir_on = None
        orderdefon_uvb = None
        orderdefon_vis = None
        orderdefon_nir = None
        orderresid_uvb = None
        orderresid_vis = None
        orderresid_nir = None

        for filename, catg in self.inputs:
            if catg == "ORDERDEF_D2_UVB":
                orderdef_uvb = filename
            if catg == "ORDERDEF_ON_UVB":
                orderdefon_uvb = filename
            if catg == "ORDERPOS_RESID_TAB_UVB":
                orderresid_uvb = filename
            if catg == "ORDERDEF_VIS":
                orderdef_vis = filename
            if catg == "ORDERDEF_ON_VIS":
                orderdefon_vis = filename
            if catg == "ORDERPOS_RESID_TAB_VIS":
                orderresid_vis = filename
            if catg == "ORDERDEF_NIR_ON":
                orderdef_nir_on = filename
            if catg == "ORDERDEF_ON_NIR":
                orderdefon_nir = filename
            if catg == "ORDERPOS_RESID_TAB_NIR":
                orderresid_nir = filename

        file_lists = []
        if orderdef_uvb is not None:
            file_lists.append(
                {
                    "order_def": orderdef_uvb,
                    "order_table": orderdefon_uvb,
                    "order_resid": orderresid_uvb,
                }
            )
        if orderdef_vis is not None:
            file_lists.append(
                {
                    "order_def": orderdef_vis,
                    "order_table": orderdefon_vis,
                    "order_resid": orderresid_vis,
                }
            )
        if orderdef_nir_on is not None:
            file_lists.append(
                {
                    "order_def": orderdef_nir_on,
                    "order_table": orderdefon_nir,
                    "order_resid": orderresid_nir,
                }
            )
        return file_lists

    def generate_panels(self, **kwargs):

        ext = "PRIMARY"

        new_panels = super().generate_panels(
            master_im_ext=ext, raw_im_ext=ext, im_cut_direction="x"
        )

        for panel, panel_descr in new_panels.items():
            i = panel_descr["hdus_i"]
            panel_descr["report_description"] = (
                f"XSHOOTER orderdef panel - "
                f"{os.path.basename(panel_descr['order_table'])}, "
                f"{os.path.basename(panel_descr['order_def'])}, "
                f"{panel_descr['ext']}"
            )
            order_def_hdul = self.hdus[i]["order_table"]
            order_resid_hdul = self.hdus[i]["order_resid"]

            # Generate & display the master cuts
            master_plot = panel.retrieve(0, 1)
            master_plot_cent = panel.retrieve(0, 2)

            # Calculate cut position (y cut)
            cut_pos = master_plot.get_data_coord(master_plot.data.shape[0] // 2, "y")

            master_cut = CutPlot(
                "y",
                title=f"Master column @ Y={cut_pos}",
                y_label=fetch_kw_or_default(order_def_hdul[ext], "BUNIT", "ADU"),
            )
            master_cut.add_data(master_plot, cut_pos, color="red", label="master")
            panel.assign_plot(master_cut, 1, 1)

            master_cent_cut = CutPlot(
                "y",
                title=f"Central region master " f"column @ Y={cut_pos}",
                y_label=fetch_kw_or_default(order_def_hdul[ext], "BUNIT", "ADU"),
            )
            master_cent_cut.add_data(
                master_plot_cent, cut_pos, color="red", label="master"
            )
            panel.assign_plot(master_cent_cut, 1, 2)

            rplot_size = 0.2

            # Additional plots for Xshooter
            # Residuals of fit against X pos
            res_x_plot = ScatterPlot(
                title="Residuals of fit (X)",
                x_label="X",
                y_label="Residual",
                markersize=rplot_size,
                legend=False,
            )
            res_x_plot.add_data(
                (
                    order_resid_hdul[1].data["X"],
                    order_resid_hdul[1].data["RESX"],
                ),
                label="X",
                color="black",
            )
            panel.assign_plot(res_x_plot, 2, 2)

            # Residuals of fit against Y pos
            res_y_plot = ScatterPlot(
                title="Residuals of fit (Y)",
                x_label="Y",
                y_label="Residual",
                markersize=rplot_size,
                legend=False,
            )
            res_y_plot.add_data(
                (
                    order_resid_hdul[1].data["Y"],
                    order_resid_hdul[1].data["RESX"],
                ),
                label="Y",
                color="black",
            )
            panel.assign_plot(res_y_plot, 3, 2)

            # Y vs X
            y_x_plot = ScatterPlot(
                title="Y vs. X",
                x_label="X",
                y_label="Y",
                markersize=rplot_size,
                legend=False,
            )
            y_x_plot.add_data(
                (
                    order_resid_hdul[1].data["X"],
                    order_resid_hdul[1].data["Y"],
                ),
                label="Y",
                color="black",
            )
            panel.assign_plot(y_x_plot, 3, 1)

            px = 0
            py = 0
            vspace = 0.3
            px = px + 2

            t2 = TextPlot(columns=1, v_space=vspace, xext=1)
            col2 = XshooterSetupInfo.order_definition(self.hdus[0]["order_def"])
            t2.add_data(col2)
            panel.assign_plot(t2, px, py, xext=1)

        return new_panels


rep = XshooterOrderdefFlatFieldReport()
