from adari_core.plots.panel import Panel
from adari_core.plots.points import ScatterPlot
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.text import TextPlot
from adari_core.report import AdariReportBase
from adari_core.utils.utils import fetch_kw_or_default

from .espresso_utils import EspressoReportMixin

import os
import numpy as np
import re

QC_ORDER_MAX_FLUX_REGEX = r"^ESO\sQC\sORDER(?P<order>[0-9]{1,3})\sMAX\sFLUX\s*$"


class EspressoContaminationReport(EspressoReportMixin, AdariReportBase):
    def __init__(self):
        # raise NotImplementedError("This report has not been updated to "
        #                           "use the new structure yet!")
        super().__init__("espresso_contamination")

    # FIXME update to use new report inheritance schema correctly

    def parse_sof(self):
        contam = None
        for filename, catg in self.inputs:
            if catg == "CONTAM_S2D_A":
                contam = filename

        file_lists = []
        if contam is not None:
            file_lists.append(
                {
                    "contam_s2d": contam,
                }
            )

        return file_lists

    def generate_panels(self, **kwargs):
        panels = {}

        for filedict in self.hdus:
            contam_s2d = filedict["contam_s2d"]

            p = Panel(4, 4, height_ratios=[1, 4, 4, 4])  # , title="Contamination")

            ylabel = fetch_kw_or_default(contam_s2d["SCIDATA"], "BUNIT", default="ADU")
            scaling = {}
            scaling["v_clip"] = "sigma"
            scaling["v_clip_kwargs"] = {"nsigma": 2.5}

            # Text plot
            px = 0
            py = 0
            vspace = 0.3
            t1 = TextPlot(columns=1, v_space=vspace)
            fname = os.path.basename(str(contam_s2d.filename()))
            col1 = (
                str(contam_s2d["PRIMARY"].header.get("INSTRUME")),
                "EXTNAME: " + str(contam_s2d[1].header.get("EXTNAME", "N/A")),
                "PRO CATG: "
                + str(contam_s2d["PRIMARY"].header.get("HIERARCH ESO PRO CATG")),
                "FILE NAME: " + fname,
                "RAW1 NAME: "
                + str(
                    contam_s2d["PRIMARY"].header.get("HIERARCH ESO PRO REC1 RAW1 NAME")
                ),
            )
            t1.add_data(col1)
            p.assign_plot(t1, px, py, xext=2)

            px = px + 2
            t2 = TextPlot(columns=1, v_space=vspace, xext=1)
            col2 = (
                "INS.MODE: "
                + str(contam_s2d["PRIMARY"].header.get("HIERARCH ESO INS MODE")),
                "DET.BINX: "
                + str(contam_s2d["PRIMARY"].header.get("HIERARCH ESO DET BINX")),
                "DET.BINY: "
                + str(contam_s2d["PRIMARY"].header.get("HIERARCH ESO DET BINY")),
            )
            t2.add_data(col2)
            p.assign_plot(t2, px, py, xext=1)
            # Full image & cut
            px = 0
            py = 1
            full_img = ImagePlot(title="Full image", y_label="Order No", **scaling)
            full_img.add_data(contam_s2d["SCIDATA"].data)
            full_img.aspect = 50
            p.assign_plot(full_img, px, py, xext=2)

            py = py + 1
            full_cut = CutPlot("y", title="Full image", x_label="x", y_label=ylabel)
            full_cut.add_data(
                full_img,
                cut_pos=full_img.get_data_coord(full_img.data.shape[0] // 2, "y"),
            )
            p.assign_plot(full_cut, px, py, xext=2)

            cent_img = CentralImagePlot(
                full_img,
                title="Central 200px",
                extent=200,
                y_label="Order No",
                **scaling,
            )
            p.assign_plot(cent_img, 2, 1, xext=2)

            cent_cut = CutPlot("y", title="Central region", x_label="x", y_label=ylabel)
            cent_cut.add_data(
                cent_img,
                cut_pos=cent_img.get_data_coord(cent_img.data.shape[0] // 2, "y"),
            )
            p.assign_plot(cent_cut, 2, 2, xext=2)

            # QC ORDERn MAX FLUX v order no.
            order_max_flux = {
                k: v
                for k, v in contam_s2d[0].header.items()
                if re.match(QC_ORDER_MAX_FLUX_REGEX, k)
            }
            orders = [
                np.nan,
            ] * len(order_max_flux)
            max_flux = [
                np.nan,
            ] * len(order_max_flux)
            for k, v in order_max_flux.items():
                order_no = int(re.match(QC_ORDER_MAX_FLUX_REGEX, k)["order"])
                orders[order_no - 1] = order_no
                max_flux[order_no - 1] = v
            order_max_flux_plot = ScatterPlot(
                title="QC ORDER<i> MAX FLUX",
                markersize=2,
                x_label="Order No",
                y_label=ylabel,
            )
            order_max_flux_plot.add_data(
                [orders, max_flux], label="QC ORDER<i> MAX FLUX"
            )
            order_max_flux_plot.x_origin = 1
            order_max_flux_plot.set_xlim(0, len(orders) + 1)
            p.assign_plot(order_max_flux_plot, 1, 3, xext=2)

            panels[p] = {
                "report_name": "espresso_contamination",
                "report_description": f"ESPRESSO Contamination: "
                f"{contam_s2d.filename()}",
                "report_tags": [],
                "input_files": [contam_s2d.filename()],
            }

        return panels


rep = EspressoContaminationReport()
