# SPDX-License-Identifier: BSD-3-Clause
import os
from adari_core.plots.panel import Panel
from adari_core.plots.points import LinePlot, ScatterPlot
from adari_core.plots.images import ImagePlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.combined import CombinedPlot
from adari_core.plots.text import TextPlot

from adari_core.report import AdariReportBase
from adari_core.utils.utils import fetch_kw_or_default

import numpy as np
import warnings


class MasterDetmonReport(AdariReportBase):

    METADATA_V_SPACING = 0.4

    def __init__(self, name: str, files_needed):
        super().__init__(name)
        self.files_needed = files_needed

    def access_port(self, im_hdu, port):
        raise NotImplementedError(
            "Each instrument child Report must " "implement access_port"
        )

    def generate_file_key(self, prefix, n):
        if not (prefix == "off" or prefix == "on"):
            raise ValueError(f"Invalid file key prefix '{prefix}'")
        try:
            if int(n) < 0:
                raise ValueError("n must be >= 0")
        except TypeError:
            raise ValueError(f"Unable to cast {n} to int")

        return f"{prefix}{n:03d}"

    def validate_info(self, data_info):
        expected1 = set(["exptime_key", "gain_nominal_key", "lintime_key"])
        expected2 = set(["exptime_key", "gain_nominal_value", "lintime_key"])
        expected3 = set(
            [
                "exptime_key",
                "gain_nominal_key",
                "lintime_key",
                "gain_qc_key",
                "lin_qc_key",
            ]
        )

        keys = set(data_info.keys())

        return (expected1 == keys) or (expected2 == keys) or (expected3 == keys)

    def parse_sof(self):
        file_lists = {}
        count_on = 0
        count_off = 0

        for filename, catg in self.inputs:
            if catg in self.files_needed.values():
                if catg == self.files_needed["on"]:
                    file_lists[self.generate_file_key("on", count_on)] = filename
                    count_on += 1
                elif catg == self.files_needed["off"]:
                    file_lists[self.generate_file_key("off", count_off)] = filename
                    count_off += 1
                else:
                    gen = (
                        elem
                        for elem in self.files_needed
                        if self.files_needed[elem] == catg
                    )
                    file_lists[next(gen)] = filename

        return [
            file_lists,
        ]

    def generate_panels(
        self, data_info, ext=0, port=1, raw_im_rotation_kwargs={}, **kwargs
    ):
        """
        Generate panels.

        Example for data_info (nominal gain in header):
            data_info = {"exptime_key" : "EXPTIME", "gain_nominal_key" : "HIERARCH ESO DET OUT1 CONAD",
            "lintime_key" : "EXPTIME"}
        Example for data_info (value of nominal gain provided directly):
            data_info = {"exptime_key" : "EXPTIME", "gain_nominal_value" : 1.5,
            "lintime_key" : "DIT"}
        Example for data_info (nominal gain in header, non-trivial qc gain keyword):
            data_info = {"exptime_key" : "EXPTIME", "gain_nominal_key" : "HIERARCH ESO DET OUT1 CONAD",
            "lintime_key" : "EXPTIME", "gain_qc_key": "HIERARCH ESO QC GAIN AVG"}
        Example for data_info (nominal gain in header, non-trivial qc lin keyword):
            data_info = {"exptime_key" : "EXPTIME", "gain_nominal_key" : "HIERARCH ESO DET OUT1 CONAD",
            "lintime_key" : "EXPTIME", "lin_qc_key": "HIERARCH ESO QC REG1 LIN"}

        Parameters
        ----------
            data_info : dictionary with additional info on instrument:
                        keywords for exp.time, nominal gain or nominal gain value
            ext : fits file extension
            port : readout port

        Returns
        -------
            panels

        """

        if not self.validate_info(data_info):
            raise ValueError("Incomplete data info")

        exptime_key = data_info["exptime_key"]
        lin_exptime_key = data_info["lintime_key"]
        panels = {}

        for j, filedict in enumerate(self.hdus):

            p = Panel(2, 4, height_ratios=[1, 4, 4, 4], x_stretch=1.4)

            # Locate the longest-exptime ON frame
            longest_on_frame = sorted(
                [v for k, v in filedict.items() if "on" in k],  # Get the ON frames
                key=lambda x: x[0].header.get(exptime_key, -999.0),  # Order by EXPTIME
                reverse=True,  # But order with largest value first
            )[
                0
            ]  # Get the first frame HDU

            # DETMON outputs file with extensions matching the input
            # data extension. The exception is when the incoming data
            # is from the "PRIMARY" or 0 extension - in that case, the
            # returned files have a header-only PHU, and the results
            # table is in extension 1 (with no extname)

            if ext == "PRIMARY" or ext == 0:
                gain = filedict["gain_info"][1]
                lin = filedict["detlin_info"][1]
                gain_hdr = filedict["gain_info"][0]
            else:
                gain = filedict["gain_info"][ext]
                lin = filedict["detlin_info"][ext]
                gain_hdr = gain

            coeffs = None
            if filedict.get("coeffs_cube") is not None:
                coeffs = filedict["coeffs_cube"][ext]  # coeffs doesn't have this issue

            # Metadata
            text1 = TextPlot(v_space=self.METADATA_V_SPACING)
            text1.add_data(
                [
                    f"INSTRUME: {fetch_kw_or_default(filedict['detlin_info'][0], 'INSTRUME', 'UNKNOWN')}",
                    f"EXTNAME: {fetch_kw_or_default(lin, 'EXTNAME', 'UNKNOWN')}",
                    f"PRO.CATG: {fetch_kw_or_default(lin, 'HIERARCH ESO PRO CATG', 'UNKNOWN')}",
                    f"Filename: {os.path.basename(filedict['detlin_info'].filename())}",
                    f"PRO.REC1.RAW1.NAME: {fetch_kw_or_default(lin, 'HIERARCH ESO PRO REC1 RAW1 NAME', 'UNKNOWN')}",
                ]
            )
            p.assign_plot(text1, 0, 0)

            # Form cuts from the top ON frame at the required port
            # Need to get the pixel range of this port
            port_raw_data = self.access_port(longest_on_frame[ext], port).data

            raw_im = ImagePlot(port_raw_data, **raw_im_rotation_kwargs)
            raw_cutX = CutPlot("x", title="Cut in X")
            raw_x_cut = raw_im.get_data_coord(raw_im.data.shape[1] // 2, axis="x")
            raw_cutX.add_data(raw_im, raw_x_cut, color="blue")

            raw_cutY = CutPlot("y", title="Cut in Y")
            raw_y_cut = raw_im.get_data_coord(raw_im.data.shape[0] // 2, axis="y")
            raw_cutY.add_data(raw_im, raw_y_cut, color="black")

            # Form a combined cuts plot
            raw_cut_combined = CombinedPlot(
                title="Cuts in X & Y",
                x_label="x, y [pixel]",
                y_label=fetch_kw_or_default(longest_on_frame[ext], "BUNIT", "ADU"),
            )
            raw_cut_combined.add_data(raw_cutY, z_order=0)
            raw_cut_combined.add_data(raw_cutX, z_order=1)
            p.assign_plot(raw_cut_combined, 0, 1)

            raw_hist = HistogramPlot()
            raw_hist.add_raw_data(raw_im.data)
            raw_hist.x_label = "Counts"
            raw_hist.y_label = "Frequency"
            raw_hist.set_xlim(
                -5000, 70000
            )  # fixed x-range, can be added to data_info for more flexibility
            raw_hist.title = (
                f"Histogram (lamp-on frame, "
                f"{fetch_kw_or_default(longest_on_frame[0], exptime_key, 'unknown')} s)"
            )
            p.assign_plot(raw_hist, 1, 1)

            # Gain measurement
            try:
                if "gain_nominal_key" in data_info.keys():
                    gain_nominal_key = data_info["gain_nominal_key"]
                    gain_nominal = gain_hdr.header[gain_nominal_key]
                else:
                    gain_nominal = data_info["gain_nominal_value"]
                if "gain_qc_key" in data_info.keys():
                    gain_qc_key = data_info["gain_qc_key"]
                    gain_qc = gain_hdr.header[gain_qc_key]
                else:
                    gain_qc = gain_hdr.header["HIERARCH ESO QC GAIN"]
            except KeyError as e:
                raise RuntimeError(
                    f"Unable to extract bulk gain values "
                    f"from {filedict['gain_info'].filename()} "
                    f"({str(e)})"
                )

            try:
                gain_meas_x = gain.data["X_FIT_CORR"]
                gain_meas_y = gain.data["Y_FIT"]
                gain_meas_labels = [
                    "X_FIT_CORR = <ON1> + <ON2> - 2*<OFF>",
                    "Y_FIT = NDIT*sigma(ON)^2",
                ]
                gain_dev_x = gain.data["ADU"]

                if "GAIN" in gain.columns.names:
                    gain_dev_y = gain.data["GAIN"]
                    gain_dev_labels = ["ADU", "gain [e-/ADU]"]
                else:
                    gain_dev_y = gain.data["X_FIT"] / gain.data["Y_FIT"]
                    gain_dev_labels = ["ADU", "gain [e-/ADU] = X_FIT/Y_FIT"]

            except KeyError as e:
                raise RuntimeError(
                    f"Unable to extract gain measurement values "
                    f"from {filedict['gain_info'].filename()} "
                    f"({str(e)})"
                )
            gain_meas = ScatterPlot(
                markersize=6,
                x_label=gain_meas_labels[0],
                y_label=gain_meas_labels[1],
            )
            gain_meas.add_data(
                (gain_meas_x, gain_meas_y), color="red", label="Gain measurement"
            )
            gain_meas_lines = LinePlot()
            gain_meas_lines.add_data(
                ([0, np.max(gain_meas_x)], [0, (1.0 / gain_qc) * np.max(gain_meas_x)]),
                color="blue",
                label="dy/dx = 1/gain(QC)",
            )
            gain_meas_lines.add_data(
                (
                    [0, np.max(gain_meas_x)],
                    [0, (1.0 / gain_nominal) * np.max(gain_meas_x)],
                ),
                color="green",
                label="dy/dx = 1/gain(nominal)",
            )
            gain_meas_combined = CombinedPlot(title="Gain measurement")
            gain_meas_combined.add_data(gain_meas, z_order=0)
            gain_meas_combined.add_data(gain_meas_lines, z_order=1)
            p.assign_plot(gain_meas_combined, 0, 2)

            # Gain deviation

            gain_dev = ScatterPlot(
                markersize=6,
                x_label=gain_dev_labels[0],
                y_label=gain_dev_labels[1],
            )
            gain_dev.add_data(
                (gain_dev_x, gain_dev_y), color="red", label="Gain computation"
            )
            gain_dev_lines = LinePlot()
            gain_dev_lines.add_data(
                (
                    [0, np.max(gain_dev_x)],
                    [gain_qc, gain_qc],
                ),
                label=f"QC gain = {gain_qc:.3f}",
                color="blue",
            )
            gain_dev_lines.add_data(
                (
                    [0, np.max(gain_dev_x)],
                    [gain_nominal, gain_nominal],
                ),
                label=f"Nominal gain = {gain_nominal:.3f}",
                color="green",
            )
            gain_dev_combined = CombinedPlot(title="Gain deviation")
            gain_dev_combined.add_data(gain_dev, z_order=0)
            gain_dev_combined.add_data(gain_dev_lines, z_order=1)
            p.assign_plot(gain_dev_combined, 1, 2)

            # Linearity
            try:
                lin_x = lin.data[lin_exptime_key]
                lin_y = lin.data["MED"]
            except KeyError as e:
                raise RuntimeError(
                    f"Unable to extract linearity values "
                    f"from {filedict['detlin_info'].filename()} "
                    f"({str(e)})"
                )
            lin_pts = ScatterPlot(
                markersize=6, x_label="Exp. time [s]", y_label="Median ON [ADU]"
            )
            lin_pts.add_data((lin_x, lin_y), color="red", label="Linearity")
            # Fitting coeffs
            lin_coeffs = []
            lin_coeffs_size = 0
            if "lin_qc_key" in data_info.keys():
                lin_qc_key = data_info["lin_qc_key"]
            else:
                lin_qc_key = "HIERARCH ESO QC LIN"
            for i in range(0, 4):
                try:
                    lin_coeffs.append(lin.header[f"{lin_qc_key} COEF{i}"])
                    lin_coeffs_size += 1
                except KeyError:
                    try:
                        lin_coeffs.append(coeffs.header[f"{lin_qc_key} COEF{i}"])
                        lin_coeffs_size += 1
                    except (
                        KeyError,
                        AttributeError,
                    ):  # Covers both missing LIN COEF, and no coeffs
                        warnings.warn(f"Unable to find LIN COEF{i}, set to NaN")
                        lin_coeffs.append(np.nan)
            lin_pts_fits = LinePlot()
            if not np.any(np.isnan(lin_coeffs[:2])):
                poly1d = np.polynomial.polynomial.Polynomial(lin_coeffs[:2])
                lin_pts_fits.add_data(
                    (lin_x, poly1d(lin_x)),
                    color="blue",
                    label=f"f(t) = {lin_coeffs[0]:.2f} + " f"{lin_coeffs[1]:.2f} * t",
                )
            if not np.any(np.isnan(lin_coeffs[:lin_coeffs_size])):
                poly3d = np.polynomial.polynomial.Polynomial(
                    lin_coeffs[:lin_coeffs_size]
                )
                lab = "f(t) = "
                lab_it = [
                    f"{lin_coeffs[0]:.2f}",
                    f" + {lin_coeffs[1]:.2f} * t\n",
                    f" + {lin_coeffs[2]:.2f} * t^2",
                    f" + {lin_coeffs[3]:.2f} * t^3",
                ]
                for i in range(0, lin_coeffs_size):
                    lab += lab_it[i]
                lin_pts_fits.add_data(
                    (lin_x, poly3d(lin_x)),
                    color="green",
                    label=lab,
                )
            lin_pts_combined = CombinedPlot(title="Linearity")
            lin_pts_combined.add_data(lin_pts, z_order=0)
            lin_pts_combined.add_data(lin_pts_fits, z_order=1)
            p.assign_plot(lin_pts_combined, 0, 3)

            # Linearity deviation (note data already loaded)
            lin_dev = ScatterPlot(
                markersize=6,
                x_label="Median ON [ADU]",
                y_label="Median ON / Exp. time [ADU/s]",
            )
            lin_dev.title = (
                f"Non-linearity corr.:\n"
                f"{lin.header.get(lin_qc_key+' EFF', np.nan):.4f} @ "
                f"{lin.header.get(lin_qc_key+' EFF FLUX', np.nan):.0f} ADU"
            )
            lin_dev.add_data(
                (lin_y, lin_y / lin_x),  # median ON  # median ON / EXPTIME
                color="red",
                label="Linearity deviations",
            )
            p.assign_plot(lin_dev, 1, 3)

            input_files = [
                filedict["detlin_info"].filename(),
                filedict["gain_info"].filename(),
            ]
            if filedict.get("coeffs_cube") is not None:
                input_files.append(filedict["coeffs_cube"].filename())
            input_files.append(longest_on_frame.filename())

            panels[p] = {
                "longest_on_frame": longest_on_frame,
                "gain": gain,
                "lin": lin,
                "ext": ext,
                "port": port,
                "hdus_i": j,
                "report_name": f"{filedict['detlin_info'][0].header.get('INSTRUME', 'INSTRUMENT')}"
                f"_{filedict['detlin_info'][0].header.get('HIERARCH ESO TPL START', 'TPL_START_TIME')}"
                f"_ext{ext}_port{port}",
                "report_description": f"Detector monitoring - "
                f"{filedict['detlin_info'][0].header.get('INSTRUME', 'INSTRUMENT')}"
                f", template "
                f"{filedict['detlin_info'][0].header.get('HIERARCH ESO TPL START', 'TPL_START_TIME')}",
                "report_tags": {},
                "input_files": input_files,
            }

        return panels
