from adari_core.plots.points import ScatterPlot
from adari_core.data_libs.echelle_flatfield import MasterEchelleFlatfieldReport
from adari_core.plots.text import TextPlot
import numpy as np
import os
import logging

# from adari_core.report import master_report_logger as logger

from .espresso_utils import EspressoReportMixin

center_size = 200

QC_SNR_KW = "HIERARCH ESO QC ORDER{} SNR"
QC_RMS_KW = "HIERARCH ESO QC ORDER{} FLAT RMS"
MODE_KW = "HIERARCH ESO INS MODE"

logger = logging.getLogger(__name__)


class EspressoEchelleFlatfieldReport(EspressoReportMixin, MasterEchelleFlatfieldReport):
    def __init__(self):
        # raise NotImplementedError("This report has not been updated to "
        #                           "use the new structure yet!")
        super().__init__("espresso_echelle_flatfield")
        self.center_size = center_size
        self._version = "Espresso-Echelle-0.1"

    def parse_sof(self):
        # Need to generate the following sets (master/raw):
        # ORDER_PROFILE_A/ORDER_TABLE_A
        # ORDER_PROFILE_B/ORDER_TABLE_B
        order_profile_a = None
        order_profile_b = None
        first_raw_flat_a = None
        first_raw_flat_b = None
        for filename, catg in self.inputs:
            if catg == "ORDER_PROFILE_A" and order_profile_a is None:
                order_profile_a = filename
            elif catg == "ORDER_PROFILE_B" and order_profile_b is None:
                order_profile_b = filename
            # elif catg == "ORDER_TABLE_A" and order_table_a is None:
            #     order_table_a = filename
            # elif catg == "ORDER_TABLE_B" and order_table_b is None:
            #     order_table_b = filename
            elif catg == "FLAT_A" and first_raw_flat_a is None:
                first_raw_flat_a = filename
            elif catg == "FLAT_B" and first_raw_flat_b is None:
                first_raw_flat_b = filename

        file_lists = []
        if order_profile_a is not None and first_raw_flat_a is not None:
            file_lists.append(
                {
                    "master_product": order_profile_a,
                    "raw": first_raw_flat_a,
                }
            )
        if order_profile_b is not None and first_raw_flat_b is not None:
            file_lists.append(
                {
                    "master_product": order_profile_b,
                    "raw": first_raw_flat_b,
                }
            )

        return file_lists

    def generate_panels(self, **kwargs):
        panels = {}
        exts = ["CCD290red", "CCD290blue"]

        for ext in exts:
            new_panels = super().generate_panels(
                master_product_ext=ext, raw_ext=ext, direction="x"
            )
            for i, (panel, panel_descr) in enumerate(new_panels.items()):
                # Update the panel description
                panel_descr["panel_description"] = (
                    f"ESPRESSO echelle flatfield - "
                    f"{os.path.basename(panel_descr['master_product'])}, "
                    f"{panel_descr['master_product_ext']}"
                )

                # Manually set all 'BUNIT' to 'e-'
                for yl in [(1, 1), (1, 2)]:
                    plt = panel.retrieve(*yl)
                    plt.y_label = "e-"
                for xl in [(2, 1)]:
                    plt = panel.retrieve(*xl)
                    plt.x_label = "counts"
                for xl in [(2, 2)]:
                    plt = panel.retrieve(*xl)
                    plt.x_label = "e-"

                raw_hist = panel.retrieve(2, 1)
                raw_hist.v_min = -5000.0
                raw_hist.v_max = 70000.0
                raw_hist.bins = 25

                # Update the master hist v range
                master_hist = panel.retrieve(2, 2)
                master_hist.x_min = -50000.0
                master_hist.x_max = 700000.0
                master_hist.v_min = -50000
                master_hist.v_max = 850000
                logger.debug(
                    f"Master hist data bounds: "
                    f"{np.min(master_hist._data['master']['data'])}, "
                    f"{np.max(master_hist._data['master']['data'])}"
                )

                # Add the SN/RMS v order no plots to the panel
                # We will not add these plots if we can't work out the
                # a valid INSMODE (SINGLEHR/SINGLEUHR/MULTIMR)
                panel.x = 4
                order_profile = self.hdus[i]["master_product"]
                insmode = order_profile[0].header.get(MODE_KW)
                logger.debug(f"INSMODE = {insmode}")
                if insmode is None:
                    continue

                if ext == "CCD290blue":
                    if "SINGLEHR" in insmode or "SINGLEUHR" in insmode:
                        order_range = range(1, 90 + 1)
                    elif "MULTIMR" in insmode:
                        order_range = range(1, 45 + 1)
                    else:
                        continue
                else:
                    if "SINGLEHR" in insmode or "SINGLEUHR" in insmode:
                        order_range = range(91, 170 + 1)
                    elif "MULTIMR" in insmode:
                        order_range = range(46, 85 + 1)
                    else:
                        continue

                # Scatter plots for statistics
                SNR_plot = ScatterPlot(
                    title="S/N vs order no",
                    x_label="Order No",
                    y_label="SNR",
                    markersize=2,
                )
                SNR_values = [
                    order_profile[0].header.get(kw, np.nan)
                    for kw in [QC_SNR_KW.format(d) for d in order_range]
                ]
                SNR_plot.add_data((order_range, SNR_values), label="SNR")
                panel.assign_plot(SNR_plot, 3, 1)

                RMS_plot = ScatterPlot(
                    title="RMS vs order no",
                    x_label="Order No",
                    y_label="RMS",
                    markersize=2,
                )
                RMS_values = [
                    order_profile[0].header.get(kw, np.nan)
                    for kw in [QC_RMS_KW.format(d) for d in order_range]
                ]
                RMS_plot.add_data((order_range, RMS_values), label="RMS")
                panel.assign_plot(RMS_plot, 3, 2)

                vspace = 0.3
                t1 = TextPlot(columns=1, v_space=vspace)
                fname = os.path.basename(str(order_profile.filename()))
                col1 = (
                    str(order_profile["PRIMARY"].header.get("INSTRUME")),
                    "EXTNAME: " + str(order_profile[ext].header.get("EXTNAME", "N/A")),
                    "PRO CATG: "
                    + str(order_profile["PRIMARY"].header.get("HIERARCH ESO PRO CATG")),
                    "FILE NAME: " + fname,
                    "RAW1 NAME: "
                    + str(
                        order_profile["PRIMARY"].header.get(
                            "HIERARCH ESO PRO REC1 RAW1 NAME"
                        )
                    ),
                )
                t1.add_data(col1)
                panel.assign_plot(t1, 0, 0, xext=2)

                t2 = TextPlot(columns=1, v_space=vspace, xext=1)
                col2 = (
                    "INS.MODE: "
                    + str(order_profile["PRIMARY"].header.get("HIERARCH ESO INS MODE")),
                    "DET.BINX: "
                    + str(order_profile["PRIMARY"].header.get("HIERARCH ESO DET BINX")),
                    "DET.BINY: "
                    + str(order_profile["PRIMARY"].header.get("HIERARCH ESO DET BINY")),
                )
                t2.add_data(col2)
                panel.assign_plot(t2, 2, 0, xext=1)

            panels = {**panels, **new_panels}

        return panels


rep = EspressoEchelleFlatfieldReport()
