import os
from adari_core.data_libs.master_interfer_calib import MasterInterferCalibReport
from adari_core.utils.utils import fetch_kw_or_default
from adari_core.plots.text import TextPlot

from .matisse_utils import MatisseSetupInfo, MatisseReportMixin


class MatisseCalibratorReport(MatisseReportMixin, MasterInterferCalibReport):
    def __init__(self):
        super().__init__("matisse_calibrator")

    def parse_sof(self):
        file_lists = []

        for filename, catg in sorted(self.inputs):
            if catg == "CALIB_RAW_INT":
                file_lists.append(
                    {
                        "oi_data": filename,
                    }
                )

        return file_lists

    def find_ext_index(self, hdu, ext):
        ind = []
        list_info = hdu.info(output=False)
        for item in list_info:
            if item[1] == ext:
                ind.append(item[0])
        return ind

    def generate_panels(self, **kwargs):

        extensions = ["OI_VIS2", "OI_T3", "OI_VIS", "OI_FLUX", "TF2"]
        vspace = 0.4

        # meta text
        meta0 = []
        meta1 = []

        for ext in extensions:
            cols0 = []
            cols1 = []
            for hdu in self.hdus:
                mdata_hdul = hdu["oi_data"]
                mdata_ext = 0
                col0 = (
                    "INSTRUME: " + str(mdata_hdul[mdata_ext].header.get("INSTRUME")),
                    "EXTNAME: " + ext,
                    "PRO.CATG: "
                    + str(
                        fetch_kw_or_default(
                            mdata_hdul[mdata_ext],
                            "HIERARCH ESO PRO CATG",
                            default="N/A",
                        )
                    ),
                    "FILE NAME: " + os.path.basename(mdata_hdul.filename()),
                    "RAW1 NAME: "
                    + str(
                        fetch_kw_or_default(
                            mdata_hdul[mdata_ext],
                            "HIERARCH ESO PRO REC1 RAW1 NAME",
                            default="N/A",
                        )
                    ),
                )
                cols0.append(col0)
                cols1.append(MatisseSetupInfo.calib(mdata_hdul))
            meta0.append(cols0)
            meta1.append(cols1)

        # 1. Visibility / Correlated flux
        oi_data_iext = self.find_ext_index(self.hdus[0]["oi_data"], extensions[0])
        if (
            fetch_kw_or_default(
                self.hdus[0]["oi_data"][0], "HIERARCH ESO PRO REC3 PARAM1 NAME"
            )
            == "corrFlux"
            and fetch_kw_or_default(
                self.hdus[0]["oi_data"][0], "HIERARCH ESO PRO REC3 PARAM1 VALUE"
            )
            == "true"
        ):
            label = "Correlated flux"
            lim = "clip0"
            visibility = False
        else:
            label = "Visibility"
            lim = "lim2"
            visibility = True
        panel_vis2 = super().generate_panels(
            oi_ext_list=oi_data_iext,
            y_label=label,
            y_column="VIS2DATA",
            y_err="VIS2ERR",
            y_lim=lim,
            legend_from=True,
            **kwargs
        )
        for i, (panel, panel_descr) in enumerate(panel_vis2.items()):
            t0 = TextPlot(columns=1, v_space=vspace)
            t0.add_data(meta0[0][i])
            panel.assign_plot(t0, 0, 0, xext=1)

            t1 = TextPlot(columns=1, v_space=vspace, xext=1)
            t1.add_data(meta1[0][i])
            panel.assign_plot(t1, 2, 0, xext=1)

        # 2. Closure phase
        oi_data_iext = self.find_ext_index(self.hdus[0]["oi_data"], extensions[1])

        panel_cp = super().generate_panels(
            oi_ext_list=oi_data_iext,
            n_panels=4,
            y_label="T3PHI",
            y_column="T3PHI",
            y_err="T3PHIERR",
            y_lim="xclip",
            legend_from=True,
            **kwargs
        )
        for i, (panel, panel_descr) in enumerate(panel_cp.items()):
            t0 = TextPlot(columns=1, v_space=vspace)
            t0.add_data(meta0[1][i])
            panel.assign_plot(t0, 0, 0, xext=1)

            t1 = TextPlot(columns=1, v_space=vspace, xext=1)
            t1.add_data(meta1[1][i])
            panel.assign_plot(t1, 2, 0, xext=1)

        # 3. Differential phase
        oi_data_iext = self.find_ext_index(self.hdus[0]["oi_data"], extensions[2])

        panel_vis = super().generate_panels(
            oi_ext_list=oi_data_iext,
            y_label="VISPHI",
            y_column="VISPHI",
            y_err="VISPHIERR",
            y_lim="xclip",
            legend_from=True,
            **kwargs
        )
        for i, (panel, panel_descr) in enumerate(panel_vis.items()):
            t0 = TextPlot(columns=1, v_space=vspace)
            t0.add_data(meta0[2][i])
            panel.assign_plot(t0, 0, 0, xext=1)

            t1 = TextPlot(columns=1, v_space=vspace, xext=1)
            t1.add_data(meta1[2][i])
            panel.assign_plot(t1, 2, 0, xext=1)

        # 4. Flux (if OI_FLUX)
        oi_data_iext = self.find_ext_index(self.hdus[0]["oi_data"], extensions[3])

        # check if OI_FLUX exists
        panel_oiflux = {}
        if len(oi_data_iext) != 0:
            panel_oiflux = super().generate_panels(
                oi_ext_list=oi_data_iext,
                n_panels=4,
                y_label="Flux",
                y_column="FLUXDATA",
                y_err="FLUXERR",
                y_lim="clip0",
                legend_from=True,
                **kwargs
            )
            for i, (panel, panel_descr) in enumerate(panel_oiflux.items()):
                t0 = TextPlot(columns=1, v_space=vspace)
                t0.add_data(meta0[3][i])
                panel.assign_plot(t0, 0, 0, xext=1)

                t1 = TextPlot(columns=1, v_space=vspace, xext=1)
                t1.add_data(meta1[3][i])
                panel.assign_plot(t1, 2, 0, xext=1)

        # 5. Transfer function (if TF2)
        oi_data_iext = self.find_ext_index(self.hdus[0]["oi_data"], extensions[4])

        # check if TF2 exists
        panel_tf2 = {}
        if len(oi_data_iext) != 0:
            if visibility:
                lim = "lim2"
            else:
                lim = "clip0"
            panel_tf2 = super().generate_panels(
                oi_ext_list=oi_data_iext,
                y_label="TF2",
                y_column="TF2",
                y_err="TF2ERR",
                y_lim=lim,
                legend_from=True,
                **kwargs
            )
            for i, (panel, panel_descr) in enumerate(panel_tf2.items()):
                t0 = TextPlot(columns=1, v_space=vspace)
                t0.add_data(meta0[4][i])
                panel.assign_plot(t0, 0, 0, xext=1)

                t1 = TextPlot(columns=1, v_space=vspace, xext=1)
                t1.add_data(meta1[4][i])
                panel.assign_plot(t1, 2, 0, xext=1)

        panels = {**panel_vis2, **panel_cp, **panel_vis}
        if panel_oiflux:
            panels.update(panel_oiflux)
        if panel_tf2:
            panels.update(panel_tf2)

        return panels


rep = MatisseCalibratorReport()
