import numpy as np
import os
from adari_core.plots.panel import Panel
from adari_core.plots.images import ImagePlot
from adari_core.plots.points import LinePlot, ScatterPlot
from adari_core.plots.text import TextPlot
from adari_core.report import AdariReportBase
from .kmos_utils import KmosReportMixin


class KmosSkyFlatReport(KmosReportMixin, AdariReportBase):
    rot_angles = {}
    files_needed = {"ILLUM_CORR": None}
    n_detectors = 3
    n_ifu = 24

    def __init__(self):
        super().__init__("kmos_sky_flat")

    def parse_sof(self):
        """Returns a list of files selected from a set of frames (sof).

        If more than one file fullfills the criteria, the first file
        in the array will be selected.
        """
        files_path, files_category = (
            [elem[0] for elem in self.inputs],
            [elem[1] for elem in self.inputs],
        )
        for required_file in self.files_needed.keys():
            # Check that category matches the requirement
            if required_file in files_category:
                self.files_needed[required_file] = files_path[
                    files_category.index(required_file)
                ]
            else:
                raise IOError("[WARNING] {} file not found".format(required_file))
        # Include the sky flat raw files
        self.n_flat_sky = 0
        for f_name, f_cat in zip(files_path, files_category):
            if f_cat == "FLAT_SKY":
                self.n_flat_sky += 1
                self.files_needed["RAW{}".format(self.n_flat_sky)] = f_name
        if self.n_flat_sky == 0:
            raise IOError("[WARNING] RAW files not present if SOF")
        return [self.files_needed]

    def compute_median(self, data):
        """Compute the median value from a given 2D image."""
        return np.nanmedian(data)

    def get_rotations(self):
        """Get the rotation angles present in a KMOS HDU list."""
        # By default there's only one rotator angle
        self.number_of_rotations = 1

    def create_rotation_panel(self, hdul):
        """Generate the panel of a given rotation from an HDU list"""
        panel = Panel(6, 5, height_ratios=[1, 4, 4, 4, 4])
        # IFU plots
        for i in range(1, self.n_ifu + 1):
            row = (i - 1) // 6
            col = i - 1 - 6 * ((i - 1) // 6)
            ifu_hdu = hdul["IFU.{}.DATA".format(i)]
            if ifu_hdu.data is not None:
                mplot = ImagePlot(
                    title="IFU.{}".format(i),
                    cbar_kwargs={"pad": 0.1},
                    interpolation="none",
                )
                mplot.add_data(ifu_hdu.data)

            else:
                mplot = LinePlot(
                    legend=False,
                    y_label="y",
                    x_label="x",
                    title="IFU.{} (NO DATA)".format(i),
                )
            panel.assign_plot(mplot, col, row + 1)
        return panel

    def create_detectors_panel(self, hdul):
        """Generate the panel of a given rotation from an HDU list"""
        panel = Panel(6, 2, height_ratios=[1, 5])
        input_files = []
        # Metadata
        text = TextPlot(columns=1, v_space=0.3)
        text_col = (
            hdul["ILLUM_CORR"]["PRIMARY"].header.get("INSTRUME") + ": FLAT_SKY",
            "FIRST RAW FILE: "
            + hdul["ILLUM_CORR"]["PRIMARY"].header.get(
                "HIERARCH ESO PRO REC1 RAW1 NAME"
            ),
            "LAST RAW FILE: "
            + str(
                hdul["ILLUM_CORR"]["PRIMARY"].header.get(
                    "HIERARCH ESO PRO REC1 RAW{} NAME".format(self.n_flat_sky)
                )
            ),
        )
        text.add_data(text_col)
        panel.assign_plot(text, 0, 0, xext=1)
        text = TextPlot(columns=1, v_space=0.3)
        text_col = (
            "INS.GRAT1.NAME: "
            + str(
                hdul["ILLUM_CORR"]["PRIMARY"].header.get("HIERARCH ESO INS GRAT1 NAME")
            ),
            "INS.GRAT1.WLEN: "
            + str(
                hdul["ILLUM_CORR"]["PRIMARY"].header.get("HIERARCH ESO INS GRAT1 WLEN")
            ),
            "DET.SEQ1.DIT: "
            + str(
                hdul["ILLUM_CORR"]["PRIMARY"].header.get("HIERARCH ESO DET SEQ1 DIT")
            ),
            "PRO.ROT.NAANGLE: {:03.1f}".format(
                hdul["ILLUM_CORR"]["IFU.1.DATA"].header.get(
                    "HIERARCH ESO PRO ROT NAANGLE"
                )
            ),
        )
        text.add_data(text_col)
        panel.assign_plot(text, 2, 0, xext=1)
        input_files = [hdul["ILLUM_CORR"].filename()]
        
        # detectors
        detector_plots = [
            ScatterPlot(title="SKY FLAT MEDIAN (Det 1)", legend=False),
            ScatterPlot(title="SKY FLAT MEDIAN (Det 2)", legend=False),
            ScatterPlot(title="SKY FLAT MEDIAN (Det 3)", legend=False),
        ]
        # Iterate over all sky flats
        for exp_number in range(1, self.n_flat_sky + 1):
            hdu = hdul["RAW{}".format(exp_number)]
            input_files.append(hdul["RAW{}".format(exp_number)].filename())
            for detector in range(3):
                median_flux = self.compute_median(
                    hdu["CHIP{}.INT1".format(detector + 1)].data
                )
                detector_plots[detector].add_data(
                    [[int(exp_number)], [median_flux]],
                    label="{}-{}".format(exp_number, detector),
                    color="black",
                )
        for plot_pos, sc_plot in enumerate(detector_plots):
            sc_plot.x_label = "Exposure number"
            sc_plot.y_label = "Median flux / ADU"
            panel.assign_plot(sc_plot, 2 * plot_pos, 1, xext=2)
        return panel, input_files

    def generate_panels(self, **kwargs):
        panels = {}
        self.get_rotations()
        for hdu_list in self.hdus:
            hdul_illum_corr = hdu_list["ILLUM_CORR"]
            setup = (
                hdul_illum_corr["PRIMARY"].header.get("HIERARCH ESO INS GRAT1 NAME")
                + hdul_illum_corr["PRIMARY"].header.get("HIERARCH ESO INS GRAT2 NAME")
                + hdul_illum_corr["PRIMARY"].header.get("HIERARCH ESO INS GRAT3 NAME")
            )

            for i in range(0, self.number_of_rotations):
                rot_angle_extensions = slice(
                    i * self.n_ifu + 1, (i + 1) * 2 * self.n_ifu, 2
                )
                p = self.create_rotation_panel(hdul_illum_corr[rot_angle_extensions])
                # Metadata
                text_col = (
                    str(hdul_illum_corr["PRIMARY"].header.get("INSTRUME")),
                    "PRO CATG: "
                    + str(
                        hdul_illum_corr["PRIMARY"].header.get("HIERARCH ESO PRO CATG")
                    ),
                    "FILE NAME: " + os.path.basename(hdul_illum_corr.filename()),
                    "RAW1 NAME: "
                    + str(
                        hdul_illum_corr["PRIMARY"].header.get(
                            "HIERARCH ESO PRO REC1 RAW1 NAME"
                        )
                    ),
                )
                text = TextPlot(columns=1, v_space=0.3)
                text.add_data(text_col)
                p.assign_plot(text, 0, 0, xext=3)
                # Instrument setup metadata
                text_col = (
                    "INS.GRAT1.NAME: "
                    + str(
                        hdul_illum_corr["PRIMARY"].header.get(
                            "HIERARCH ESO INS GRAT1 NAME"
                        )
                    ),
                    "INS.GRAT1.WLEN: "
                    + str(
                        hdul_illum_corr["PRIMARY"].header.get(
                            "HIERARCH ESO INS GRAT1 WLEN"
                        )
                    ),
                    "DET.SEQ1.DIT: "
                    + str(
                        hdul_illum_corr["PRIMARY"].header.get(
                            "HIERARCH ESO DET SEQ1 DIT"
                        )
                    ),
                    "PRO.ROT.NAANGLE: {:03.1f}".format(
                        hdul_illum_corr[i * self.n_ifu + 1].header.get(
                            "HIERARCH ESO PRO ROT NAANGLE"
                        )
                    ),
                )
                text = TextPlot(columns=1, v_space=0.3)
                text.add_data(text_col)
                p.assign_plot(text, 2, 0, xext=3)
                input_files = [hdul_illum_corr.filename()]
                addme = {
                    "report_name": "kmos_sky_flat_{}".format(setup.lower()),
                    "report_description": "KMOS Illumination correction panel",
                    "report_tags": [],
                    "input_files": input_files,
                }

                panels[p] = addme

            # Create detectors panel
            
            p_det, input_files = self.create_detectors_panel(hdu_list)
            addme_det = {
                "report_name": "kmos_sky_flat_detectors",
                "report_description": "KMOS sky flat detectors panel",
                "report_tags": [],
                "input_files": input_files,
            }

            panels[p_det] = addme_det

        return panels


rep = KmosSkyFlatReport()
