from .fors_utils import ForsSetupInfo

from adari_core.report import AdariReportBase
from adari_core.plots.combined import CombinedPlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.panel import Panel
from adari_core.plots.points import ScatterPlot
from adari_core.plots.text import TextPlot
from adari_core.utils.utils import fetch_kw_or_default, round_arbitrary

import os
import logging
import numpy as np

logger = logging.getLogger(__name__)


class ForsWaveCalReport(AdariReportBase):
    files_needed = {
        "master_im": None,
        "raw_im": None,
        "resolution_im": None,
        "residuals_im": None,
    }  # Further elaborated in parse_sof
    master_im_ext = "PRIMARY"
    category_label = ""
    raw_label = ""
    # CentralImagePlot largest allowed center_size
    center_size = 200
    # range for cut plots with logarithmic y-axis
    cutymin = 10
    cutymax = 1e5
    # histogram parameters
    histmin = 50000
    histmax = 66000
    hist_bins = 20

    def __init__(self):
        super().__init__("fors_wave_cal")

    def parse_sof(self):
        """Returns a list of files selected from a set of frames (sof).

        If more than one file fulfils the criteria, the first file
        in the array will be selected.

        SOF should contain one of: REDUCED_LAMP_{LSS/MXU/MOS/LONG_MOS/PMOS}
        """

        files_path, files_category = (
            [elem[0] for elem in self.inputs],
            [elem[1] for elem in self.inputs],
        )
        if "REDUCED_LAMP_LSS" in files_category:
            self.category_label = "LSS"
        if "REDUCED_LAMP_MXU" in files_category:
            self.category_label = "MXU"
        elif "REDUCED_LAMP_MOS" in files_category:
            self.category_label = "MOS"
        elif "REDUCED_LAMP_LONG_MOS" in files_category:
            self.category_label = "LONG_MOS"
        elif "REDUCED_LAMP_PMOS" in files_category:
            self.category_label = "PMOS"
            self.master_im_ext = 1

        categories_needed = {}
        categories_needed["master_im"] = "REDUCED_LAMP_{}".format(self.category_label)
        categories_needed["resolution_im"] = "SPECTRAL_RESOLUTION_{}".format(
            self.category_label
        )
        categories_needed["residuals_im"] = "DISP_RESIDUALS_TABLE_{}".format(
            self.category_label
        )
        if self.category_label == "LONG_MOS":
            self.raw_label = "LAMP_MOS"
        else:
            self.raw_label = "LAMP_{}".format(self.category_label)
        categories_needed["raw_im"] = self.raw_label

        for category_label, sof_procatg in categories_needed.items():
            # Check that category matches the requirement
            if sof_procatg in files_category:
                self.files_needed[category_label] = files_path[
                    files_category.index(sof_procatg)
                ]
            else:
                raise IOError("{} file not found".format(sof_procatg))

        return [self.files_needed]

    def generate_panels(self, **kwargs):
        """Create panels."""

        logger.info("Working on category {} report".format(self.category_label))

        # Retrieve appropriate metadata
        if self.category_label == "LSS":
            self.metadata = ForsSetupInfo.wave_cal_lss(list(self.hdus[0].values())[0])
        elif self.category_label == "PMOS":
            self.metadata = ForsSetupInfo.wave_cal_pmos(list(self.hdus[0].values())[0])
        else:
            self.metadata = ForsSetupInfo.wave_cal_mos(list(self.hdus[0].values())[0])

        master_im = self.hdus[0]["master_im"]
        raw_im = self.hdus[0]["raw_im"]
        resolution_im = self.hdus[0]["resolution_im"]
        residuals_im = self.hdus[0]["residuals_im"]

        instru = fetch_kw_or_default(
            master_im["PRIMARY"], "INSTRUME", "Missing INSTRUME"
        )
        master_procatg = fetch_kw_or_default(
            master_im["PRIMARY"], "HIERARCH ESO PRO CATG", "Missing PRO CATG"
        )
        fname = os.path.basename(str(master_im.filename()))
        crpix1 = fetch_kw_or_default(master_im[self.master_im_ext], "CRPIX1", None)
        crval1 = fetch_kw_or_default(master_im[self.master_im_ext], "CRVAL1", None)
        cd1_1 = fetch_kw_or_default(master_im[self.master_im_ext], "CD1_1", None)
        wavelength = resolution_im[1].data["wavelength"]
        resolution = resolution_im[1].data["resolution"]
        resolution_rms = resolution_im[1].data["resolution_rms"]
        nlines = resolution_im[1].data["nlines"]

        panels = {}

        panel = Panel(5, 3, height_ratios=[1, 2, 2], width_ratios=[3, 3, 2, 2, 2])

        panels[panel] = {
            "report_name": "{}_{}".format(instru, master_procatg.lower()),
            "report_description": "FORS wave cal {} panel - {}".format(
                self.category_label, fname
            ),
            "report_tags": [],
        }

        # Text Plot
        vspace = 0.2
        t1 = TextPlot(columns=1, v_space=vspace)
        col1 = (
            instru,
            "EXTNAME: "
            + str(fetch_kw_or_default(master_im["PRIMARY"], "EXTNAME", "N/A")),
            "PRO CATG: " + str(master_procatg),
            "FILE NAME: " + fname,
            "RAW1 NAME: "
            + str(
                fetch_kw_or_default(
                    master_im["PRIMARY"],
                    "HIERARCH ESO PRO REC1 RAW1 NAME",
                    "Missing RAW1 NAME",
                )
            ),
        )
        t1.add_data(col1)
        panel.assign_plot(t1, 0, 0, xext=2)

        t2 = TextPlot(columns=1, v_space=vspace)
        col2 = self.metadata
        t2.add_data(col2)
        panel.assign_plot(t2, 3, 0, xext=2)

        # Full reduced image
        vspace = 0.1
        fullfield = ImagePlot(
            title="REDUCED_LAMP",
            v_clip="val",
            v_clip_kwargs={"low": 0, "high": 16000},
        )
        fullfield.add_data(master_im[self.master_im_ext].data)
        fullfield.add_horiz_line(ypos=fullfield.data.shape[0] // 2)
        panel.assign_plot(fullfield, 0, 1, xext=2, yext=2)

        # Dispersion-direction cut
        cutX = CutPlot("y", title="REDUCED_LAMP cut", x_label="x")
        cutX.y_label = fetch_kw_or_default(master_im["PRIMARY"], "BUNIT", default="ADU")
        cutX.add_data(
            fullfield,
            fullfield.get_data_coord(fullfield.data.shape[0] // 2, "y"),
            label="Dispersion",
            color="red",
        )
        cutX.set_ylim(self.cutymin, self.cutymax)
        cutX.legend = False

        # Mark identified lines
        cutMark = ScatterPlot(title="Line labels")
        if crpix1 and crval1 and cd1_1 and cd1_1 != 0:
            # w = (x-crpix1)*cd1_1 + crval1; x = (w - crval1)/cd1_1 + crpix1
            # subtract 1 to put into 0-indexed array mode
            xloc = (wavelength[nlines > 0] - crval1) / cd1_1 + crpix1 - 1
            cutMark.add_data(
                label="Lines",
                d=[xloc, np.nan * np.ones_like(xloc)],  # dummy y-values
                marker="circle",
                vline=True,
                color="gainsboro",
            )
            cutMark.legend = False

        # Plot combined cut plot with text-scatter labels
        combinedcutX = CombinedPlot(title="REDUCED_LAMP cut")
        combinedcutX.add_data(cutX, z_order=2)
        combinedcutX.add_data(cutMark, z_order=1)  # place vertical lines behind data
        combinedcutX.y_scale = "log"
        combinedcutX.set_xlim(0, fullfield.data.shape[1])
        panel.assign_plot(combinedcutX, 2, 1, xext=2, yext=1)

        # Create CentralImagePlot for CutPlot creation,
        # but CentralImagePlot not shown

        # Ensure the extent has not been rounded down to 0
        center_size = min(
            self.center_size,
            max(
                min(50, master_im[self.master_im_ext].data.shape[1]),
                round_arbitrary(
                    master_im[self.master_im_ext].data.shape[1] // 5, base=50
                ),
            ),
        )

        centerfield = CentralImagePlot(
            fullfield,
            title="Reduced Lamp (central {} pixels)".format(center_size),
            extent=center_size,
        )
        (
            this_xmin,
            this_xmax,
            this_ymin,
            this_ymax,
        ) = centerfield.get_ref_image_window_coords()
        # Indicate the region of this cut on the plot
        fullfield.add_rect(
            anchor=(this_xmin, this_ymin),
            width=(this_xmax - this_xmin),
            height=(this_ymax - this_ymin),
            label="Center",
        )

        # Central image cut plot
        centercutX = CutPlot(
            "y",
            title="REDUCED_LAMP central {} pixels".format(center_size),
            x_label="x",
        )
        centercutX.y_label = fetch_kw_or_default(
            master_im["PRIMARY"], "BUNIT", default="ADU"
        )
        centercutX.add_data(
            centerfield,
            centerfield.get_data_coord(centerfield.data.shape[0] // 2, "y"),
            label="Central",
            color="red",
        )
        centercutX.set_ylim(self.cutymin, self.cutymax)
        centercutX.legend = False
        cutxmin, cutxmax = centerfield.get_xlim()

        # Mark identified lines
        centercutMark = ScatterPlot(title="Line labels")
        if crpix1 and crval1 and cd1_1 and cd1_1 != 0:
            centercutMark.add_data(
                label="Lines",
                d=[xloc, np.nan * np.ones_like(xloc)],  # dummy y-values
                marker="circle",
                vline=True,
                color="gainsboro",
            )
            centercutMark.legend = False
            # Don't override the central image limits
            centercutMark.set_xlim(cutxmin, cutxmax)

        # Plot combined central cut plot with text-scatter labels
        combinedcentercutX = CombinedPlot(
            title="REDUCED_LAMP central {} pixels".format(center_size)
        )
        combinedcentercutX.add_data(centercutX, z_order=2)
        combinedcentercutX.add_data(
            centercutMark, z_order=1
        )  # place vertical lines behind data
        combinedcentercutX.y_scale = "log"
        panel.assign_plot(combinedcentercutX, 2, 2, xext=1, yext=1)

        # Raw histogram
        nsat = np.count_nonzero(raw_im[0].data > self.histmin)
        if nsat > 0:
            rawhist = HistogramPlot(
                title="Raw counts histogram",
                v_clip="val",
                v_clip_kwargs={"low": self.histmin, "high": self.histmax},
            )
            rawhist.add_data(
                raw_im[0].data, color="black", label="Raw", bins=self.hist_bins
            )
            # rawhist.set_xlim(self.histmin, self.histmax)
            rawhist.set_ylim(
                0.1,
                max(1.0, 2.0 * nsat),
            )
            rawhist.legend = False
        else:
            rawhist = TextPlot(title="Raw counts histogram", columns=1, v_space=0.4)
            rawhist.add_data(
                f"[No pixels above {self.histmin} ADU]",
                halign="center",
                xref=0.5,
                yref=0.5,
            )
        panel.assign_plot(rawhist, 3, 2, xext=1, yext=1)

        # Resolving power vs wavelength, with errorbars
        if np.count_nonzero(nlines) > 0:
            respower = ScatterPlot(title="Resolving Power")
            respower.add_data(
                [wavelength[nlines > 0], resolution[nlines > 0]],
                label="ResPow",
                yerror=resolution_rms[nlines > 0],
            )
            respower.x_label = r"Wavelength (${\rm \AA}$)"
            respower.y_label = "Resolving Power"
            respower.legend = False
        else:
            respower = TextPlot(title="Resolving Power", columns=1, v_space=0.4)
            respower.add_data(
                "[No data to plot]",
                halign="center",
                xref=0.5,
                yref=0.5,
            )
        panel.assign_plot(respower, 4, 1, xext=1, yext=1)

        # Fit residuals vs wavelength
        wavelength_resid = residuals_im[1].data["wavelength"]
        r10 = residuals_im[1].data["r10"]
        if np.count_nonzero(~np.isnan(r10)) > 0:
            resid = ScatterPlot(title="Fit Residuals")
            resid.add_data(
                [wavelength_resid[~np.isnan(r10)], r10[~np.isnan(r10)]],
                label="Resid",
            )
            resid.x_label = r"Wavelength (${\rm \AA}$)"
            resid.y_label = "Residuals  (pix)"
            resid.legend = False
        else:
            resid = TextPlot(title="Fit Residuals", columns=1, v_space=0.4)
            resid.add_data(
                "[No data to plot]",
                halign="center",
                xref=0.5,
                yref=0.5,
            )

        panel.assign_plot(resid, 4, 2, xext=1, yext=1)

        return panels


rep = ForsWaveCalReport()
