# SPDX-License-Identifier: BSD-3-Clause
from adari_core.plots.panel import Panel
from adari_core.plots.combined import CombinedPlot
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.points import ScatterPlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.cut import CutPlot
from adari_core.report import AdariReportBase
from adari_core.utils.utils import fetch_kw_or_default

import numpy as np


class MasterFormatCheckReport(AdariReportBase):
    def __init__(self, name: str):
        super().__init__(name)
        self.central_region_size = 200

    def parse_sof(self):
        raise NotImplementedError(
            "MasterWaveReport is a template only, "
            "the child Report is responsible for "
            "defining parse_sof"
        )

    def generate_panels(
        self, arc_lamp_ext=0, data_ext=0, ext=0, plot_params={}, **kwargs
    ):
        panels = {}

        for i, filedict in enumerate(self.hdus):
            line_guess_hdul = filedict["line_guess"]
            line_guess_procatg = fetch_kw_or_default(
                line_guess_hdul["PRIMARY"], "HIERARCH ESO PRO CATG"
            )

            arc_lamp_hdul = filedict["arc_lamp"]
            arc_lamp_image_data = None
            data_arrays = []
            for aext in data_ext[i]:
                data_arrays.append(arc_lamp_hdul[aext].data)
            arc_lamp_image_data = np.concatenate(data_arrays, axis=1)

            if plot_params["arc_yoffset"] :
                arc_yoffset = arc_lamp_hdul[data_ext[i][0]].header[
                    "HIERARCH ESO DET OUT1 PRSCX"
                ]
            else:
                arc_yoffset = 0.0

            p = Panel(5, 3, height_ratios=[1, 4, 4])

            rplot_size = 0.2

            mask = (
                line_guess_hdul[plot_params["ext"]].data[plot_params["mask_data"]]
                == plot_params["mask"]
            )
            selplot = line_guess_hdul[plot_params["ext"]].data[mask]

            # 1. Position of the arc lines that are used for the wavelength solution.
            # X axis being dispersion direction (in pixels), y axis being order number.
            pos_plot = ScatterPlot(
                title="Line positions",
                x_label=plot_params["pos_plot_x"],
                y_label=plot_params["pos_plot_y"],
                markersize=rplot_size,
                legend=False,
            )
            pos_plot.add_data(
                (
                    selplot[plot_params["pos_plot_x"]],
                    selplot[plot_params["pos_plot_y"]],
                ),
                label="Used lines",
                color="black",
            )
            # adjust the x/yaxis limits as per required convention
            if plot_params["pos_plot_limits"] == "y":
                ymin = (
                    max(
                        line_guess_hdul[plot_params["ext"]].data[
                            plot_params["pos_plot_y"]
                        ]
                    )
                    + 1
                )
                ymax = (
                    min(
                        line_guess_hdul[plot_params["ext"]].data[
                            plot_params["pos_plot_y"]
                        ]
                    )
                    - 1
                )
                pos_plot.set_ylim(ymin, ymax)
            elif plot_params["pos_plot_limits"] == "x":
                xmin = (
                    max(
                        line_guess_hdul[plot_params["ext"]].data[
                            plot_params["pos_plot_x"]
                        ]
                    )
                    + 1
                )
                xmax = (
                    min(
                        line_guess_hdul[plot_params["ext"]].data[
                            plot_params["pos_plot_x"]
                        ]
                    )
                    - 1
                )
                pos_plot.set_xlim(xmin, xmax)
            p.assign_plot(pos_plot, 0, 1)

            # 2. Fit residuals in x direction of the lines vs. x (used lines only).
            xres_plot = ScatterPlot(
                title="Fit residuals in x direction",
                x_label=plot_params["xres_plot_x"],
                y_label=plot_params["xres_plot_y"],
                markersize=rplot_size,
                legend=False,
            )
            xres_plot.add_data(
                (
                    selplot[plot_params["xres_plot_x"]],
                    selplot[plot_params["xres_plot_y"]],
                ),
                label="Used lines",
                color="black",
            )
            p.assign_plot(xres_plot, 0, 2)
            # 3. Fit residuals in y direction of the lines vs. x (used lines only).
            yres_plot = ScatterPlot(
                title="Fit residuals in y direction",
                x_label=plot_params["yres_plot_x"],
                y_label=plot_params["yres_plot_y"],
                markersize=rplot_size,
                legend=False,
            )
            yres_plot.add_data(
                (
                    selplot[plot_params["yres_plot_x"]],
                    selplot[plot_params["yres_plot_y"]],
                ),
                label="Used lines",
                color="black",
            )
            p.assign_plot(yres_plot, 1, 1)

            # 4. Fit residuals of the lines: x vs. y (used lines only).
            fres_plot = ScatterPlot(
                title="Fit residuals",
                x_label=plot_params["fres_plot_x"],
                y_label=plot_params["fres_plot_y"],
                markersize=rplot_size,
                legend=False,
            )
            fres_plot.add_data(
                (
                    selplot[plot_params["fres_plot_x"]],
                    selplot[plot_params["fres_plot_y"]],
                ),
                label="Used lines",
                color="black",
            )
            p.assign_plot(fres_plot, 1, 2)

            # 5. Cut in cross-dispersion direction (at centre of dispersion direction) through full raw file.
            rotation_kwargs = {}
            rotation_kwargs["rotate"] = plot_params["raw_cutX_rot"]
            rotation_kwargs["flip"] = plot_params["raw_cutX_flip"]

            raw_plot = ImagePlot(
                arc_lamp_image_data,
                title="raw plot",
                **rotation_kwargs,
                v_clip=plot_params["raw_cutX_vclip"],
                v_clip_kwargs=plot_params["raw_cutX_vclip_kwargs"],
            )
            raw_center = CentralImagePlot(
                raw_plot,
                extent=self.central_region_size,
                title="raw center",
                v_clip=plot_params["raw_cutX_vclip"],
                v_clip_kwargs=plot_params["raw_cutX_vclip_kwargs"],
            )
            if plot_params["raw_cutX"] == "x":
                ishape = 1
                x_label = "y"
            else:
                ishape = 0
                x_label = "x"

            cutpos = raw_plot.get_data_coord(
                raw_plot.data.shape[ishape] // 2, plot_params["raw_cutX"]
            )
            raw_cutX = CutPlot(
                plot_params["raw_cutX"],
                title=f"Raw col @ {plot_params['raw_cutX']} {cutpos}",
                x_label=x_label,
                y_label=fetch_kw_or_default(
                    arc_lamp_hdul["PRIMARY"], "BUNIT", default="ADU"
                ),
            )
            raw_cutX.add_data(raw_plot, cutpos, color="black", label="raw")
            p.assign_plot(raw_cutX, 2, 1)

            # 6. Same but for central region (200 pixels).
            cutpos = raw_center.get_data_coord(
                np.floor(raw_center.data.shape[ishape] // 2), plot_params["raw_cutX"]
            )
            raw_cen_cutX = CutPlot(
                plot_params["raw_cutX"],
                title=f"Raw central Region: col @ {plot_params['raw_cutX']} {cutpos}",
                x_label=x_label,
                y_label=fetch_kw_or_default(
                    arc_lamp_hdul["PRIMARY"], "BUNIT", default="ADU"
                ),
            )
            raw_cen_cutX.add_data(raw_center, cutpos, label="raw", color="black")
            p.assign_plot(raw_cen_cutX, 2, 2)

            # 7. Histogram of raw file.
            raw_hist = HistogramPlot(
                raw_data=raw_plot.data,
                title="Raw value counts",
                bins=50,
                v_min=-5000,
                v_max=70000,
            )
            p.assign_plot(raw_hist, 3, 1)
            msize_overlay = 4

            # 8. Position of used lines plotted on detector image.
            # This combination works for now before any image clipping

            overlay_x = selplot[plot_params["overlay_x"]]
            overlay_y = selplot[plot_params["overlay_y"]] + arc_yoffset

            scatter = ScatterPlot()
            scatter.add_data(
                (overlay_x, overlay_y),
                color="c",
                label="Used lines",
                marker="ring",
                markersize=msize_overlay,
            )
            scatter.legend = False

            # CombinedPlot for raw ImagePlot and ScatterPlot
            combined_raw_plot = CombinedPlot(
                title="Raw plot",
            )
            combined_raw_plot.add_data(raw_plot, z_order=1)
            combined_raw_plot.add_data(
                scatter, z_order=2
            )  # Default zorder for CentralImagePlot's box is 10
            p.assign_plot(combined_raw_plot, 3, 2, xext=2)

            # 9. Same but for central region.
            scatter_center = ScatterPlot()
            scatter_center.add_data(
                (overlay_x, overlay_y),
                color="c",
                label="Used lines",
                marker="ring",
                markersize=msize_overlay,
            )
            scatter_center.legend = False

            # CombinedPlot for raw CentralImagePlot and ScatterPlot
            combined_raw_center = CombinedPlot(
                title="Raw center",
            )
            combined_raw_center.add_data(raw_center, z_order=1)
            combined_raw_center.add_data(scatter_center, z_order=2)
            p.assign_plot(combined_raw_center, 4, 1)

            # hdr = mdata_hdul[mdata_ext].header

            input_files = [line_guess_hdul.filename(), arc_lamp_hdul.filename()]

            addme = {
                "ext": ext,
                "line_guess": line_guess_hdul.filename(),
                "arc_lamp": arc_lamp_hdul.filename(),
                "arc_lamp_ext": arc_lamp_ext[i],
                "report_name": f"{line_guess_procatg.lower()}_{str(ext).lower()}",
                "report_description": f"Formatcheck panel" f"{ext})",
                "report_tags": [],
                "input_files": input_files,
            }

            panels[p] = addme
        return panels
