# SPDX-License-Identifier: BSD-3-Clause
from adari_core.report import AdariReportBase
from adari_core.plots.panel import Panel
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.utils.utils import (
    fetch_kw_or_default,
)

import os


class MasterSpecFlatReport(AdariReportBase):
    def __init__(self, name: str):
        super().__init__(name)
        self.center_size = 200

    def parse_sof(self):
        raise NotImplementedError(
            "MasterSpecFlatReport is a template only, "
            "the child Report is responsible for "
            "defining parse_sof"
        )

    def generate_panels(
        self,
        raw_im_ext=0,
        raw_map=False,
        master_im_ext=0,
        master_im_scale=False,
        master_im_scale_val=None,
        raw_hist_scale=False,
        raw_hist_scale_val=None,
        master_hist_scale=False,
        master_hist_scale_val=None,
        **kwargs,
    ):
        panels = {}

        for i, filedict in enumerate(self.hdus):

            master_im = self.hdus[i]["master_im"][master_im_ext]
            raw_im = self.hdus[i]["raw_im"][raw_im_ext]

            if raw_map:
                p = Panel(4, 3, height_ratios=[1, 4, 4])
                shift = 1
            else:
                p = Panel(3, 3, height_ratios=[1, 4, 4])
                shift = 0

            master_procatg = fetch_kw_or_default(
                master_im, "HIERARCH ESO PRO CATG", "Missing PRO CATG"
            )

            # Master Plots
            scaling = {}
            if master_im_scale:
                scaling["v_clip"] = "val"
                scaling["v_clip_kwargs"] = {
                    "low": master_im_scale_val[0],
                    "high": master_im_scale_val[1],
                }

            master_plot = ImagePlot(
                master_im.data,
                title=master_procatg,
                **scaling,
            )
            p.assign_plot(master_plot, 0, 1)

            # Get central image
            master_center = CentralImagePlot(
                master_plot,
                title=f"{master_procatg} " f"center",
                extent=self.center_size,
                **scaling,
            )
            p.assign_plot(master_center, 0, 2)

            # Full raw flat
            if raw_map:

                full_raw = ImagePlot(
                    title="Raw Flat",
                    v_clip="minmax",
                )
                full_raw.add_data(raw_im.data)
                p.assign_plot(full_raw, 1, 1)
                # Get central image
                raw_center = CentralImagePlot(
                    full_raw,
                    title="Raw Flat center",
                    extent=self.center_size,
                )
                p.assign_plot(raw_center, 1, 2)

            # Plot cut of master
            cutpos = master_plot.get_data_coord(master_plot.data.shape[0] // 2, "y")
            master_cut = CutPlot(
                "y",
                title="master column @ Y {}".format(cutpos),
                y_label=fetch_kw_or_default(master_im, "BUNIT", default="ADU"),
            )
            master_cut.add_data(master_plot, cutpos, color="red", label="master")
            p.assign_plot(master_cut, 1 + shift, 1)

            cutpos = master_center.get_data_coord(master_center.data.shape[0] // 2, "y")
            master_cen_cut = CutPlot(
                "y",
                title="Central Region: column @ Y {}".format(cutpos),
                y_label=fetch_kw_or_default(master_im, "BUNIT", default="ADU"),
            )
            master_cen_cut.add_data(master_center, cutpos, label="master", color="red")
            p.assign_plot(master_cen_cut, 1 + shift, 2)

            scaling = {}
            if not raw_hist_scale:
                scaling = {}
            else:
                scaling["v_clip"] = "val"
                scaling["v_clip_kwargs"] = {
                    "low": raw_hist_scale_val[0],
                    "high": raw_hist_scale_val[1],
                }

            raw_hist = HistogramPlot(
                raw_data=raw_im.data,
                title="raw value counts",
                bins=50,
                x_label=fetch_kw_or_default(raw_im, "BUNIT", "counts"),
                **scaling,
            )
            p.assign_plot(raw_hist, 2 + shift, 1)

            if not master_hist_scale:
                scaling = {}
            else:
                scaling["v_clip"] = "val"
                scaling["v_clip_kwargs"] = {
                    "low": master_hist_scale_val[0],
                    "high": master_hist_scale_val[1],
                }

            master_hist = HistogramPlot(
                master_data=master_plot.data,
                title="master value counts",
                bins=50,
                x_label=fetch_kw_or_default(master_im, "BUNIT", "counts"),
                **scaling,
            )
            p.assign_plot(master_hist, 2 + shift, 2)

            instru = raw_im.header.get("INSTRUME")
            rawfile = os.path.basename(str(self.hdus[i]["raw_im"].filename()))
            rawfile = rawfile[rawfile.find(".") + 1 :].removesuffix(".fits")
            input_files = [
                self.hdus[i]["raw_im"].filename(),
                self.hdus[i]["master_im"].filename(),
            ]
            panels[p] = {
                "raw": rawfile,
                "raw_im_ext": raw_im_ext,
                "report_name": f"{instru}_{master_procatg.lower()}_{rawfile}",
                "report_description": f"Spec Flat panel - ({rawfile})",
                "report_tags": [],
                "input_files": input_files,
            }

        return panels
