from .fors_utils import ForsSetupInfo
from adari_core.data_libs.master_spec_flat import MasterSpecFlatReport
from adari_core.plots.cut import CutPlot
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.text import TextPlot
from adari_core.utils.utils import fetch_kw_or_default, round_arbitrary
import adari_core.utils.clipping as clipping
import os
import logging

logger = logging.getLogger(__name__)


class ForsFlatSpecMosReport(MasterSpecFlatReport):
    files_needed = {
        "master_im": None,
        "raw_im": None,
        "wavecal_master_im": None,
    }  # Further elaborated in parse_sof
    category_label = ""
    center_size = 200

    def __init__(self):
        super().__init__("fors_flat_spec_mos")

    def parse_sof(self):
        """Returns a list of files selected from a set of frames (sof).

        If more than one file fulfils the criteria, the first file
        in the array will be selected.

        SOF should contain one of: MASTER_NORM_FLAT_{MXU/MOS/LONG_MOS/PMOS}
        """

        files_path, files_category = (
            [elem[0] for elem in self.inputs],
            [elem[1] for elem in self.inputs],
        )

        if "MASTER_NORM_FLAT_MXU" in files_category:
            self.category_label = "MXU"
        elif "MASTER_NORM_FLAT_MOS" in files_category:
            self.category_label = "MOS"
        elif "MASTER_NORM_FLAT_LONG_MOS" in files_category:
            self.category_label = "LONG_MOS"
        elif "MASTER_NORM_FLAT_PMOS" in files_category:
            self.category_label = "PMOS"

        categories_needed = {}
        categories_needed["master_im"] = "MASTER_NORM_FLAT_{}".format(
            self.category_label
        )
        categories_needed["wavecal_master_im"] = "MAPPED_NORM_FLAT_{}".format(
            self.category_label
        )
        if self.category_label == "LONG_MOS":
            categories_needed["raw_im"] = "SCREEN_FLAT_MOS"
        else:
            categories_needed["raw_im"] = "SCREEN_FLAT_{}".format(self.category_label)
            categories_needed["spatial_map_im"] = "SPATIAL_MAP_{}".format(
                self.category_label
            )

        for category_label, sof_procatg in categories_needed.items():
            # Check that category matches the requirement
            if sof_procatg in files_category:
                self.files_needed[category_label] = files_path[
                    files_category.index(sof_procatg)
                ]
            else:
                raise IOError("{} file not found".format(sof_procatg))

        return [self.files_needed]

    def generate_panels(self, **kwargs):
        vspace = 0.3

        logger.info("Working on category {} report".format(self.category_label))
        panels = {}
        new_panels = super().generate_panels(
                raw_map=True,
                master_im_scale=True,
                master_im_scale_val=(0.5,1.5),
                raw_hist_scale=True,
                raw_hist_scale_val=(-5000,70000),
                master_hist_scale=True,
                master_hist_scale_val=(0.5,1.5),
            )
            
        for i, (panel, panel_descr) in enumerate(new_panels.items()):
        
            # Retrieve appropriate metadata
            if self.category_label == "PMOS":
                self.metadata = ForsSetupInfo.flat_spec_pmos(list(self.hdus[i].values())[0])
            else:
                self.metadata = ForsSetupInfo.flat_spec(list(self.hdus[i].values())[0])

            master_im = self.hdus[i]["master_im"]
            instru = fetch_kw_or_default(
                 master_im["PRIMARY"], "INSTRUME", "Missing INSTRUME"
            )
            master_procatg = fetch_kw_or_default(
                 master_im["PRIMARY"], "HIERARCH ESO PRO CATG", "Missing PRO CATG"
            )
            fname = os.path.basename(str(master_im.filename()))
            t1 = TextPlot(columns=1, v_space=vspace)
            col1 = (
                instru,
                "EXTNAME: "
                + str(fetch_kw_or_default(master_im["PRIMARY"], "EXTNAME", "N/A")),
                "PRO CATG: " + str(master_procatg),
                "FILE NAME: " + fname,
                "RAW1 NAME: "
                + str(
                  fetch_kw_or_default(
                    master_im["PRIMARY"],
                    "HIERARCH ESO PRO REC1 RAW1 NAME",
                    "Missing RAW1 NAME",
                  )
                ),
            )
            t1.add_data(col1)
            panel.assign_plot(t1, 0, 0, xext=2)
            t2 = TextPlot(columns=1, v_space=vspace)
            col2 = self.metadata
            t2.add_data(col2)
            panel.assign_plot(t2, 2, 0, xext=2)
            
            # MOS specific plots
            wavecal_master_im = self.hdus[i]["wavecal_master_im"]
            spatial_map_im = (
                self.hdus[i]["spatial_map_im"]
                if self.category_label != "LONG_MOS"
                else None
            )
            
            # Full image of spatial map (not available for LONG_MOS)
            if self.category_label != "LONG_MOS":
                fullspatial = ImagePlot(
                    spatial_map_im["PRIMARY"].data,
                    title="Spatial Map",
                    v_clip="val",
                    v_clip_kwargs={"low": 0, "high": 150},
                )
                panel.assign_plot(fullspatial, 1, 2)
            else:
                panel.pop(1, 2)
            
            # Cross-dispersion cuts
            # Create ImagePlot for CutPlot creation, but ImagePlot not shown
            fullmapped = ImagePlot(title="Mapped Norm Flat")
            fullmapped.add_data(wavecal_master_im["PRIMARY"].data)

            # Full image cut plot
            cutY = CutPlot("x", title="MAPPED_NORM cut", x_label="y")
            cutY.y_label = fetch_kw_or_default(
                 wavecal_master_im["PRIMARY"], "BUNIT", default="ADU"
            )
            cutY.add_data(
                 fullmapped,
                 fullmapped.get_data_coord(fullmapped.data.shape[1] // 2, "x"),
                 label="Cross-dispersion",
                 color="red",
            )
            this_cut = fullmapped.get_axis_cut(
                 cutY.cut_ax, cutY._data["Cross-dispersion"]["cut_pos"]
            )
            cutY.y_min, cutY.y_max = clipping.clipping_mad(this_cut, nmad=7)
            panel.assign_plot(cutY, 2, 1, xext=1, yext=1)

            # Create CentralImagePlot for CutPlot creation,
            # but CentralImagePlot not shown

            # Ensure the extent has not been rounded down to 0
            center_cut_extent = min(
                self.center_size,
                max(
                    min(50, wavecal_master_im["PRIMARY"].data.shape[0]),
                    round_arbitrary(
                        wavecal_master_im["PRIMARY"].data.shape[0] // 5, base=50
                   ),
                ),
            )
            centermapped = CentralImagePlot(
                fullmapped,
                title="Master Flat (central {} pixels)".format(center_cut_extent),
                extent=center_cut_extent,
            )

            # Central image cut plot
            centercutY = CutPlot(
                "x",
                title="MAPPED_NORM central {} pixels".format(center_cut_extent),
                x_label="y",
            )
            centercutY.y_label = fetch_kw_or_default(
                wavecal_master_im["PRIMARY"], "BUNIT", default="ADU"
            )
            centercutY.add_data(
                centermapped,
                centermapped.get_data_coord(centermapped.data.shape[1] // 2, "x"),
                label="Central",
                color="red",
            )
            this_centercut = centermapped.get_axis_cut(
                centercutY.cut_ax,
                center_cut_extent // 2,
            )
            centercutY.y_min, centercutY.y_max = clipping.clipping_mad(
                this_centercut, nmad=7
            )
            panel.assign_plot(centercutY, 2, 2, xext=1, yext=1)

        panels = {**panels, **new_panels}

        return panels

rep = ForsFlatSpecMosReport()
