# SPDX-License-Identifier: BSD-3-Clause
from adari_core.plots.panel import Panel
from adari_core.plots.images import ImagePlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.text import TextPlot

from adari_core.report import AdariReportBase
from adari_core.utils.utils import fetch_kw_or_default

import logging

logger = logging.getLogger(__name__)


class MasterRawdispReport(AdariReportBase):
    """Master ADARI Rawdisp report.

    Description
    -----------
    Master report for building the Rawdisp report for all ESO instruments. The report consists of
    a 2D image of the pixel data, a histogram of the pixel data, detailed histogram around the distribution mode,
    and cut plots along x and y directions, respectively.

    Attributes
    ----------
    - tasks: (dict) Dictionary that maps the SOF categories with EDPS recipes. Each instrument shall define its own
    set of recipes and SOF categories.
    - sof_tag: (list) List containing the SOF categorie of each raw file that will be used in the report.
    - task_scaling: (dict) Dictionary containing the data scaling requirements (in the form of a dictionary) for plotting the
    data using any given EDPS recipe.
    - scaling_default_image: (dict) Default scaling configuration for image plots. This is used when a given recipe has no counterpart at task_scaling.
    - scaling_default_hist: (dict) Default scaling configuration for histogram plots. This is used when a given recipe has no counterpart at task_scaling.
    - x_n_plots: (int) Number of plots along the horizontal direction of the panel.
    - y_n_plots: (int) Number of plots along the vertical direction of the panel.
    """

    tasks = {}
    sof_tag = []
    task_scaling = {}
    scaling_default_image = {"v_clip": "sigma", "v_clip_kwargs": {"nsigma": 1}}
    scaling_default_hist = {"bins": 50, "v_clip": "mad", "v_clip_kwargs": {"nmad": 4}}
    x_n_plots = 9
    y_n_plots = 5

    def __init__(self, name: str):
        super().__init__(name)
        self.center_size = 200

    def parse_sof(self):
        raise NotImplementedError(
            "MasterRawdispReport is a template only, "
            "the child Report is responsible for "
            "defining parse_sof"
        )

    def get_metadata_plots(self, hdul, extension_name, category_tag):
        """Get the required metadata for the generic Rawdisp report.

        Parameters
        ----------
        - hdul: astropy.io.HDUList to be used for generating the report.
        - extension_name: (str) Name of the extension containing the data.
        - extension_index: (corresponding
        Returns
        -------
        - metadata_plots: (list) List containing the metadata TextPlots.
        """
        # Get the position value of the extension
        if type(extension_name) is str:
            ext_index = hdul.index_of(extension_name)
        else:
            ext_index = extension_name

        col1 = (
            str(hdul["PRIMARY"].header.get("INSTRUME")),
            "ORIGFILE: " + str(hdul["PRIMARY"].header.get("ORIGFILE")),
            "ARCFILE: " + str(hdul["PRIMARY"].header.get("ARCFILE")),
            "EXTENSION: " + str(ext_index),
            "EXTNAME: " + str(hdul[extension_name].header.get("EXTNAME", "N/A")),
        )
        col2 = (
            "DPR.CATG: " + str(hdul["PRIMARY"].header.get("HIERARCH ESO DPR CATG")),
            "DPR.TYPE: " + str(hdul["PRIMARY"].header.get("HIERARCH ESO DPR TYPE")),
            "DPR.TECH: " + str(hdul["PRIMARY"].header.get("HIERARCH ESO DPR TECH")),
            "TPL.ID: " + str(hdul["PRIMARY"].header.get("HIERARCH ESO TPL ID")),
            "CATEGORY: " + category_tag,
        )
        # Instrument set up (this should be overriden by each instrument report)
        try:
            setup_metadata = getattr(self.setup_info, self.tasks[category_tag])(hdul)
        except Exception:
            raise NameError(
                "No setup method found for task: {}".format(self.tasks[category_tag])
            )
        t1 = TextPlot(columns=1, v_space=0.3)
        t1.add_data(col1)

        t2 = TextPlot(columns=1, v_space=0.3)
        t2.add_data(col2)

        t3 = TextPlot(columns=1, v_space=0.3)
        t3.add_data(setup_metadata)
        return [t1, t2, t3]

    def generate_panels(self, ext=0, bin_size=1.0, **kwargs):
        panels = {}
        if isinstance(ext, str) or isinstance(ext, int):
            ext = [
                ext,
            ] * len(self.hdus)

        for tag, file_ext, filedict in zip(self.sof_tag, ext, self.hdus):
            hdul = filedict["filename"]
            filename = hdul.filename()
            raw = hdul[file_ext]
            task_name = self.tasks[tag]

            px, py, ix, iy = 0, 0, self.x_n_plots, self.y_n_plots
            p = Panel(ix + 1, iy, height_ratios=[1, 2, 2, 2, 2])

            # Get the position value of the extension
            text_plots = self.get_metadata_plots(hdul, file_ext, tag)
            for t in text_plots:
                p.assign_plot(t, px, py, xext=2)
                px = px + 2
            # Reset to start of row for next panels
            py = py + 1
            px = 0

            # Image Plot
            scaling = self.scaling_default_image
            if task_name in self.task_scaling.keys():
                if "image" in self.task_scaling[task_name].keys():
                    scaling = self.task_scaling[task_name]

            raw_plot = ImagePlot(raw.data, title="raw plot", aspect="auto", **scaling)
            p.assign_plot(raw_plot, px, py, xext=ix, yext=iy - 1)  # 0,1

            # Histogram (Full)
            # The histogram range is -5000,70000 for all full plots
            full_hist = HistogramPlot(
                raw_data=raw_plot.data,
                title="raw value counts (full)",
                bins=50,
                v_min=-5000,
                v_max=70000,
            )
            p.assign_plot(full_hist, px + ix - 1, py, xext=2)  # 1,1

            py = py + 1
            # Histogram (Detail)
            # Default scaling parameters
            scaling = self.scaling_default_hist
            if task_name in self.task_scaling.keys():
                if "hist" in self.task_scaling[task_name].keys():
                    scaling = self.task_scaling[task_name]["hist"]
            det_hist = HistogramPlot(
                raw_data=raw_plot.data, title="raw value counts (detailed)", **scaling
            )
            # Set the number of bins according to a minimum bin size
            det_hist.set_int_bins(bin_size=bin_size)
            logger.debug(
                f"Have set the number of bins for the detail "
                f"histogram to {det_hist.bins}"
            )

            p.assign_plot(det_hist, px + ix - 1, py, xext=2)  # 1,2

            py = py + 1
            # plot cut in x-direction (i.e. at y=...)
            cutpos = raw_plot.get_data_coord(raw_plot.data.shape[0] // 2, "x")
            # For these tasks use cutpos NAXIS2/2
            raw_cutY = CutPlot(
                "y",
                title="raw row @ Y {}".format(cutpos),
                y_label=fetch_kw_or_default(raw, "BUNIT", default="ADU"),
            )
            raw_cutY.add_data(raw_plot, cutpos, color="black", label="raw")

            p.assign_plot(raw_cutY, px + ix - 1, py, xext=2)  # 1,3

            py = py + 1

            # plot cut in Y-direction (i.e. at X=...)

            # For these tasks use cutpos NAXIS1/2-NAXIS1/16
            cutpos = raw_plot.get_data_coord(raw_plot.data.shape[1] // 2, "y")
            raw_cutX = CutPlot(
                "x",
                title="raw column @ X {}".format(cutpos),
                y_label=fetch_kw_or_default(raw, "BUNIT", default="ADU"),
            )
            raw_cutX.add_data(raw_plot, cutpos, color="black", label="raw")
            p.assign_plot(raw_cutX, px + ix - 1, py, xext=2)  # 1,4

            panels[p] = {
                "task_name": task_name,
                "filename": filename,
                "tag": tag,
                "raw": raw,
                "ext": file_ext,
                "report_name": f"{self.name}_{task_name.lower()}_{tag.lower()}_{file_ext}",
                "report_description": f"Rawdisp panel - ({filename}, "
                f"{task_name}, "
                f"{tag}, "
                f"{file_ext})",
                "report_tags": [],
                "input_files": [filename],
            }
        return panels
