# SPDX-License-Identifier: BSD-3-Clause
import os
import numpy as np
from adari_core.plots.panel import Panel
from adari_core.plots.points import ScatterPlot
from adari_core.plots.points import LinePlot
from adari_core.plots.combined import CombinedPlot
from adari_core.report import AdariReportBase
from adari_core.utils.utils import fetch_kw_or_default


class MasterInterferCalibReport(AdariReportBase):
    def __init__(self, name: str):
        super().__init__(name)

    def parse_sof(self):
        raise NotImplementedError(
            "MasterInterferCalibReport is a template only, "
            "the child Report is responsible for "
            "defining parse_sof"
        )

    def get_labels(self, sta, ib, hdul=None):
        leg = ""
        if hdul is not None:
            for ii in ib:
                leg += str(
                    fetch_kw_or_default(
                        hdul[0],
                        "HIERARCH ESO ISS CONF T" + str(ii) + "NAME",
                        default=str(ii),
                    )
                )
                leg += " "
            for ii in ib:
                leg += str(
                    fetch_kw_or_default(
                        hdul[0], "HIERARCH ESO ISS CONF STATION" + str(ii), default=""
                    )
                )
                leg += " "
        else:
            if not isinstance(ib, np.ndarray):
                leg += str(sta[sta["STA_INDEX"] == ib]["TEL_NAME"][0])
                leg += " "
                leg += str(sta[sta["STA_INDEX"] == ib]["STA_NAME"][0])
                leg += " "
            else:
                for ii in ib:
                    leg += str(sta[sta["STA_INDEX"] == ii]["TEL_NAME"][0])
                    leg += " "
                for ii in ib:
                    leg += str(sta[sta["STA_INDEX"] == ii]["STA_NAME"][0])
                    leg += " "
        leg = leg[:-1]
        return leg

    def generate_panels(
        self,
        oi_ext_list=None,
        wave_ext_list=["OI_WAVELENGTH"],
        n_panels=6,
        y_label="Calibration",
        y_column="VIS2DATA",
        y_err=None,
        y_lim=None,
        legend_from="header",
        stretch=0.7,
        **kwargs,
    ):

        if oi_ext_list is not None and type(oi_ext_list) is not list:
            oi_ext_list = [oi_ext_list] * len(self.hdus)
        try:
            assert oi_ext_list is not None, "List of oi data extensions not provided"
        except AssertionError as e:
            raise ValueError(str(e))

        if len(wave_ext_list) != len(oi_ext_list):
            wave_ext_list = wave_ext_list * len(oi_ext_list)
        try:
            assert len(wave_ext_list) == len(
                oi_ext_list
            ), "List of oi wavelength extensions not provided"
        except AssertionError as e:
            raise ValueError(str(e))

        panels = {}

        for i, filedict in enumerate(self.hdus):
            hdul = filedict["oi_data"]
            sta = hdul["OI_ARRAY"].data

            for i, i_ext in enumerate(oi_ext_list):
                wave_ext = wave_ext_list[i]
                wave = hdul[wave_ext].data["EFF_WAVE"]
                wave = wave * 1.0e6
                data = hdul[i_ext].data

                p = Panel(
                    6,
                    n_panels + 1,
                    height_ratios=[
                        1,
                    ]
                    + [
                        4,
                    ]
                    * n_panels,
                    x_stretch=stretch,
                    y_stretch=stretch,
                )

                ind = data["STA_INDEX"]

                for b in np.arange(n_panels):
                    dat1 = data[b][y_column]
                    ib = ind[b]
                    if legend_from == "header":
                        leg = self.get_labels(sta, ib, hdul=hdul)
                    else:  # legend from STA_INDEX table
                        leg = self.get_labels(sta, ib)

                    if not isinstance(dat1, np.ndarray):
                        dat1 = np.array([dat1])

                    if dat1.size > 100:
                        single_point_size = 1.0
                    else:
                        single_point_size = 5.0

                    if y_err is not None:
                        yerror = data[b][y_err]
                    else:
                        yerror = None

                    vis_plot = ScatterPlot(
                        title="",
                        x_label="Wavelength, um",
                        y_label=y_label,
                        markersize=single_point_size,
                        legend=True,
                    )
                    vis_plot.add_data(
                        (wave, dat1),
                        label=leg,
                        color="black",
                        yerror=yerror,
                    )
                    visline_plot = LinePlot(
                        title="",
                        legend=False,
                    )
                    visline_plot.add_data((wave, dat1), label="", color="black")
                    vis_combined = CombinedPlot(
                        title="",
                        x_label="Wavelength, um",
                        y_label=y_label,
                    )

                    if y_lim == "lim":
                        vis_combined.set_ylim(0.0, 1.2)
                    elif y_lim == "lim2":
                        vis_combined.set_ylim(-0.2, 1.5)
                    elif y_lim == "clip":
                        ymarg = (np.max(dat1) - np.min(dat1)) * 0.05
                        vis_combined.set_ylim(
                            np.min(dat1) - ymarg, np.max(dat1) + ymarg
                        )
                    elif y_lim == "clip0":
                        ymarg = np.max(dat1) * 0.05
                        vis_combined.set_ylim(-ymarg, np.max(dat1) + ymarg)
                    elif y_lim == "xclip":
                        dat11 = dat1[1:-1]
                        ymarg = (np.max(dat11) - np.min(dat11)) * 0.05
                        vis_combined.set_ylim(
                            np.min(dat11) - ymarg, np.max(dat11) + ymarg
                        )

                    vis_combined.add_data(vis_plot)
                    vis_combined.add_data(visline_plot)
                    p.assign_plot(vis_combined, 0, b + 1, xext=6)

                arcfile = str(fetch_kw_or_default(hdul[0], "ARCFILE", default="N/A"))
                arcfile = arcfile[arcfile.find(".") + 1 :].removesuffix(".fits")
                instrument = str(
                    fetch_kw_or_default(hdul[0], "INSTRUME", default="N/A")
                )
                fname_prod = os.path.basename(str(hdul.filename())).removesuffix(
                    ".fits"
                )
                extname = str(
                    fetch_kw_or_default(hdul[i_ext], "EXTNAME", default="N/A")
                ).lower()

                if arcfile in fname_prod:
                    fname_prod = fname_prod.split(arcfile + "_")[1]
                report_name = f"{instrument}_{arcfile}_{fname_prod}_{extname.lower()}"

                pol = str(fetch_kw_or_default(hdul[i_ext], "INSNAME", default="N/A"))

                if "(" in pol and "_" in pol:
                    pol = pol.split("(")[0].split("_")[1]
                    report_name = f"{report_name}_{pol}"
                elif "_" in pol:
                    pol = "_".join(pol.split("_")[1:])
                    report_name = f"{report_name}_{pol}"

                input_files = [hdul.filename()]

                addme = {
                    "ext": i_ext,
                    "report_name": report_name,
                    "report_description": "Interferometric panel",
                    "report_tags": [],
                    "input_files": input_files
                }

                panels[p] = addme

        return panels
