from .hawki_utils import HawkiSetupInfo, HawkiReportMixin
from adari_core.data_libs.master_raw_cuts import MasterRawCutsReport
from adari_core.data_libs.master_sky_flat import MasterSkyFlatReport
from adari_core.plots.points import ScatterPlot
from adari_core.plots.panel import Panel
from adari_core.plots.text import TextPlot

import os

class HawkiSkyFlatReport(HawkiReportMixin, MasterRawCutsReport, MasterSkyFlatReport):
    detectors = {
        "CHIP1.INT1": [1, 1],
        "CHIP2.INT1": [2, 1],
        "CHIP3.INT1": [2, 2],
        "CHIP4.INT1": [1, 2],
    }  # dictionary values indicate typical (x,y) mosaic position
    files_needed = {
        "master_im": "MASTER_TWILIGHT_FLAT",
        # For interactive only:
        #  RATIOIMG_TWILIGHT_FLAT,
        #  RATIOIMG_STATS_TWILIGHT_FLAT
    }
    center_size = 200

    def __init__(self):
        super().__init__("hawki_sky_flat")

    def parse_sof(self):
        """
        Returns a list of files selected from a set of frames (sof).

        If more than one file fullfills the criteria, the first file
        in the array will be selected.
        """
        file_path, files_category = (
            [elem[0] for elem in self.inputs],
            [elem[1] for elem in self.inputs],
        )
        file_lists = {}
        for required_file in self.files_needed.keys():
            # Check that category matches the requirement
            if self.files_needed[required_file] in files_category:
                file_lists[required_file] = file_path[
                    files_category.index(self.files_needed[required_file])
                ]
            else:
                raise IOError(
                    "[WARNING] {} file not found".format(
                        self.files_needed[required_file]
                    )
                )
        return [file_lists]

    def generate_stats_panel(self):
        """Generate a panel with the mean flux stats for the input flats"""
        panel = Panel(4, 3, height_ratios=[1, 5, 5])

        #Detector plots
        detector_plots = {}
        for detector in range(4):
            ext = "CHIP{}.INT1".format(detector+1)
            detector_plots[ext] = ScatterPlot(title=ext, legend=False)
            # Iterate over all input sky flats
            for exp_number in range(1, len(self.hdus[0]["master_im"][ext].header)):
                key = "ESO QC RAW{} DET MED".format(exp_number)
                if key in self.hdus[0]["master_im"][ext].header :
                    median_flux = self.hdus[0]["master_im"][ext].header[key]
                    detector_plots[ext].add_data(
                        [[int(exp_number)], [median_flux]],
                        label="{}".format(exp_number),
                        color="black",
                    )
                else :
                    self.n_flat_sky = exp_number - 1
                    break
        for ext in detector_plots.keys():
            detector_plots[ext].x_label = "Exposure number"
            detector_plots[ext].y_label = "Median flux / ADU"
            px, py = self.detectors[ext]
            px = 2 * px - 2
            py = 3 - py
            panel.assign_plot(detector_plots[ext], px, py, xext=2)

        # Metadata
        text = TextPlot(columns=1, v_space=0.3)
        header = self.hdus[0]["master_im"]["PRIMARY"].header
        text_col = (
            header["INSTRUME"] + ": FLAT_SKY",
            "FIRST RAW FILE: "
            + header["HIERARCH ESO PRO REC1 RAW1 NAME"],
            "LAST RAW FILE: "
            + str(
                header.get(
                    "HIERARCH ESO PRO REC1 RAW{} NAME".format(self.n_flat_sky)
                )
            ),
        )
        text.add_data(text_col)
        panel.assign_plot(text, 0, 0, xext=1)
        text = TextPlot(columns=1, v_space=0.3)
        text_col = (
            "DET.NCORRS.NAME: "
            + str(
                header["HIERARCH ESO DET NCORRS NAME"]
            ),
            "DET.DIT: "
            + str(
                header["HIERARCH ESO DET DIT"]
            ),
            "DET.NDIT: "
            + str(
                header["HIERARCH ESO DET NDIT"]
            ),
            "DET.RSPEED: "
            + str(
                header["HIERARCH ESO DET RSPEED"]
            ),
            "INS.FILT1.NAME: "
            + str(
                header["HIERARCH ESO INS FILT1 NAME"]
            ),
            "INS.FILT2.NAME: "
            + str(
                header["HIERARCH ESO INS FILT2 NAME"]
            ),
        )
        text.add_data(text_col)
        panel.assign_plot(text, 2, 0, xext=1)

        panels = {}
        addme = {
                "report_name": "hawki_raw_flat_stats",
                "report_description": "HAWK-I input sky flat stats",
                "report_tags": [],
                "input_files": [self.hdus[0]["master_im"].filename()],
        }
        panels[panel] = addme
        return panels

    def generate_single_ext_panels(self):
        panels = {}

        exts = self.detectors.keys()
        for ext in exts:
            new_panels = super().generate_panels(
                master_im_ext=ext,
                master_title="Master Sky Flat",
                master_im_clipping="mad",
                master_im_n_clipping=4,
                master_im_zoom_clipping="mad",
                master_im_zoom_n_clipping=3,
                cut_clipping="percentile",
                cut_n_clipping=99.5,
                cut_highest_min=0.9,
                cut_lowest_max=1.1,
                cut_cent_clipping="percentile",
                cut_cent_n_clipping=99.0,
                cut_cent_highest_min=0.9,
                cut_cent_lowest_max=1.1,
            )
            for i, (panel, panel_descr) in enumerate(new_panels.items()):
                panel_descr["report_description"] = (
                    f"HAWK-I flat panel - "
                    f"{os.path.basename(panel_descr['master_im'])}, "
                    f"{panel_descr['master_im_ext']}"
                )
                master_im = self.hdus[i]["master_im"]

                # Text Plot
                px = 0
                py = 0
                # which hdul and ext to use
                vspace = 0.3
                fname = os.path.basename(str(master_im.filename()))
                t1 = TextPlot(columns=1, v_space=vspace)
                col1 = (
                    str(
                        master_im["PRIMARY"].header.get("INSTRUME", "Missing INSTRUME")
                    ),
                    "EXTNAME: " + str(master_im[ext].header.get("EXTNAME", "N/A")),
                    "PRO CATG: "
                    + str(
                        master_im["PRIMARY"].header.get(
                            "HIERARCH ESO PRO CATG", "Missing PRO CATG"
                        )
                    ),
                    "FILE NAME: " + fname,
                    "RAW1 NAME: "
                    + str(
                        master_im["PRIMARY"].header.get(
                            "HIERARCH ESO PRO REC1 RAW1 NAME", "Missing RAW1 NAME"
                        )
                    ),
                )
                t1.add_data(col1)
                panel.assign_plot(t1, px, py, xext=2)

                px = px + 2
                t2 = TextPlot(columns=1, v_space=vspace, xext=1)
                col2 = self.metadata
                t2.add_data(col2)
                panel.assign_plot(t2, px, py, xext=1)

            panels = {**panels, **new_panels}

        return panels

    def generate_multi_ext_panels(self):
        panels = {}
        p = Panel(x=4, y=3, height_ratios=[1, 4, 4])

        ext_list = self.detectors

        # Metadata in Text Plot
        px, py = 0, 0
        vspace = 0.3
        t1 = TextPlot(columns=1, v_space=vspace)
        fname = os.path.basename(str(self.hdus[0]["master_im"].filename()))

        col1 = (
            str(
                self.hdus[0]["master_im"]["PRIMARY"].header.get(
                    "INSTRUME", "Missing INSTRUME"
                )
            ),
            "PRO CATG: "
            + str(
                self.hdus[0]["master_im"]["PRIMARY"].header.get(
                    "HIERARCH ESO PRO CATG", "Missing PRO CATG"
                )
            ),
            "MASTER FILE NAME: " + fname,
            "RAW1 NAME: "
            + str(
                self.hdus[0]["master_im"]["PRIMARY"].header.get(
                    "HIERARCH ESO PRO REC1 RAW1 NAME", "Missing RAW1 NAME"
                )
            ),
        )
        t1.add_data(col1)
        p.assign_plot(t1, px, py, xext=2)

        px = px + 2
        t2 = TextPlot(columns=1, v_space=vspace, xext=1)
        t2.add_data(self.metadata)
        p.assign_plot(t2, px, py, xext=1)

        py = 1
        for extname in ext_list.keys():
            mosaic_x = self.hdus[0]["master_im"][extname].header.get(
                "HIERARCH ESO DET CHIP X", 0
            )
            mosaic_y = self.hdus[0]["master_im"][extname].header.get(
                "HIERARCH ESO DET CHIP Y", 0
            )
            # Apply defaults if missing header keywords
            if mosaic_x == 0 or mosaic_y == 0:
                mosaic_x, mosaic_y = ext_list[extname]
            px = mosaic_x - 1
            py = 3 - mosaic_y

            full_plot, zoom_plot = super().image_plot(
                self.hdus[0]["master_im"][extname],
                zoom_in=True,
                zoom_in_extent=self.center_size,
                img_kwargs={
                    "title": extname,
                    "zoom_title": "{0} Central".format(extname),
                    "v_clip": "mad",
                    "v_clip_kwargs": {"nmad": 4},
                },
                zoom_img_kwargs={"v_clip": "mad", "v_clip_kwargs": {"nmad": 3}},
            )
            p.assign_plot(full_plot, px, py, xext=1)
            p.assign_plot(zoom_plot, px + 2, py, xext=1)

            input_files = [self.hdus[0]["master_im"].filename()]

            addme = {
                "report_name": "hawki_sky_flat_multi",
                "report_description": "HAWK-I sky flat multi panel",
                "report_tags": [],
                "input_files": input_files,
            }

            panels[p] = addme
        return panels

    def generate_panels(self, **kwargs):
        """Create both single and multiple extension panels."""

        self.metadata = HawkiSetupInfo.flat(list(self.hdus[0].values())[0])

        panels = {
            **self.generate_single_ext_panels(),
            **self.generate_multi_ext_panels(),
            **self.generate_stats_panel(),
        }
        return panels


rep = HawkiSkyFlatReport()
