import numpy as np
import os
from adari_core.plots.panel import Panel
from adari_core.plots.images import ImagePlot
from adari_core.plots.points import LinePlot
from adari_core.plots.text import TextPlot
from adari_core.data_libs.master_wave_cal import MasterWaveCalReport
from . import KmosReportMixin


class KmosWaveCalReport(KmosReportMixin, MasterWaveCalReport):
    detectors = {"DET.1.DATA": None, "DET.2.DATA": None, "DET.3.DATA": None}
    files_needed = {"DET_IMG_WAVE": None}
    default_angle_entry = 0
    center_size = 500

    def __init__(self, name="kmos_lamp_flat_multi"):
        super().__init__(name)

    def parse_sof(self):
        """Returns a list of files selected from a set of frames (sof).

        If more than one file fullfills the criteria, the first file
        in the array will be selected. By construction, self.hdul
        will only contain one single HDUL.
        """
        file_path, files_category = (
            [elem[0] for elem in self.inputs],
            [elem[1] for elem in self.inputs],
        )
        for required_file in self.files_needed.keys():
            # Check that category matches the requirement
            if required_file in files_category:
                self.files_needed[required_file] = file_path[
                    files_category.index(required_file)
                ]
            else:
                raise IOError("[WARNING] {} file not found".format(required_file))
        return [self.files_needed]

    def get_rotations(self):
        """Get the rotation angles present in a KMOS HDU list.

        Description
        -----------
        There are three detectors labelled as DET.1, DET.2, DET.3
        The expentions containing data (with different rotation angles)
        correspond to multiple DET.***.DATA.
        This function creates a dictionary (detector) that contains the
        three KMOS detectors. Within each dictionary, two additional nested
        dictionaries contain the information of the detector rotation angle
        (angle) and the corresponding index within the HDU list (entry).
        """

        for i, hdu in enumerate(self.hdus[0]["DET_IMG_WAVE"]):
            if hdu.name in self.detectors.keys():
                # detector = hdu.name.replace(".DATA", "")
                angle = hdu.header.get("HIERARCH ESO PRO ROT NAANGLE")
                if self.detectors[hdu.name] is None:
                    self.detectors[hdu.name] = dict(angle=[angle], entry=[i], nangle=1)
                else:
                    self.detectors[hdu.name]["angle"].append(angle)
                    self.detectors[hdu.name]["entry"].append(i)
                    self.detectors[hdu.name]["nangle"] += 1
            else:
                continue

    def add_text_plot(self, panel, master_img, extension_idx):
        """Add to the report panel the text plots required for each extension."""
        px, py = 0, 0
        vspace = 0.3
        t1 = TextPlot(columns=1, v_space=vspace)
        fname = os.path.basename(master_img.filename())

        extname = master_img[extension_idx].header.get("EXTNAME", "N/A")
        n_combined_images = int(
            master_img["PRIMARY"].header.get("HIERARCH ESO PRO DATANCOM", "0")
        )
        rotangle = master_img[extension_idx].header.get("HIERARCH ESO PRO ROT NAANGLE")
        channel_number = int(extname.replace(".DATA", "").replace("DET.", ""))

        if channel_number > n_combined_images:
            # if there are less RAW files than channels
            # select the last raw file
            channel_number = n_combined_images
        col1 = (
            master_img["PRIMARY"].header.get("INSTRUME"),
            "EXTNAME: " + extname,
            "ROT. ANGLE: {:.1f}".format(rotangle),
            "PRO CATG: " + master_img["PRIMARY"].header.get("HIERARCH ESO PRO CATG"),
            "FILE NAME: " + fname,
            "RAW NAME: "
            + master_img["PRIMARY"].header.get(
                f"HIERARCH ESO PRO REC1 RAW{str(channel_number)} NAME"
            ),
        )
        t1.add_data(col1)

        panel.assign_plot(t1, px, py, xext=2)

        px += 2
        t2 = TextPlot(columns=1, v_space=vspace, xext=1)
        col2 = (
            "INS.GRAT1.NAME: "
            + str(master_img["PRIMARY"].header.get("HIERARCH ESO INS GRAT1 NAME")),
            "INS.GRAT1.WLEN: "
            + str(master_img["PRIMARY"].header.get("HIERARCH ESO INS GRAT1 WLEN")),
            "DET.SEQ1.DIT: "
            + str(master_img["PRIMARY"].header.get("HIERARCH ESO DET SEQ1 DIT")),
        )
        t2.add_data(col2)
        panel.assign_plot(t2, px, py, xext=1)

    def single_report_panels(self):
        """Generate the single panel report panels."""
        panels = {}
        for hdul_dict in self.hdus:
            master_img_hdul = hdul_dict["DET_IMG_WAVE"]
            # Only select the first rotation angle for the
            # three detectors
            extension_numbers = []
            rotation_angle = []
            for i in range(1, 4):
                det_name = f"DET.{i}.DATA"
                if self.detectors[det_name] is not None:
                    extension_numbers.append(
                        self.detectors[det_name]["entry"][self.default_angle_entry]
                    )
                    rotation_angle.append(
                        self.detectors[det_name]["angle"][self.default_angle_entry]
                    )
            # Ensure that the rotator angle is the same
            # for the three detectors (otherwise remove detector)
            n_angles = np.unique(rotation_angle).size
            if n_angles > 1:
                sort_pos = np.argsort(rotation_angle)
                extension_numbers = list(
                    np.delete(extension_numbers, sort_pos[-(n_angles - 1) :])
                )
                rotation_angle = list(
                    np.delete(rotation_angle, sort_pos[-(n_angles - 1) :])
                )

            # TODO: In the interactive case, the user should specify
            # a given rotation angle or instead use all the entries avaialble
            for i in extension_numbers:
                hdr = master_img_hdul[i].header
                if "EXTNAME" in hdr:
                    extname = hdr.get("EXTNAME")
                    if "DATA" in extname:
                        if master_img_hdul[i].data is None:
                            continue
                        elif (~np.isfinite(master_img_hdul[i].data)).all():
                            continue
                        panel = self.generate_single_panel(master_img_hdul[i])
                        self.add_text_plot(
                            panel=panel, master_img=master_img_hdul, extension_idx=i
                        )
                        setup = (
                            master_img_hdul["PRIMARY"].header.get(
                                "HIERARCH ESO INS GRAT1 NAME"
                            )
                            + master_img_hdul["PRIMARY"].header.get(
                                "HIERARCH ESO INS GRAT2 NAME"
                            )
                            + master_img_hdul["PRIMARY"].header.get(
                                "HIERARCH ESO INS GRAT3 NAME"
                            )
                        )

                        panels[panel] = {
                            # "master_im": master_im.filename(),
                            # "master_im_ext": master_im_ext,
                            "report_name": "kmos_master_wave_cal_single_{}_{:03.0f}_{}".format(
                                extname.replace(".DATA", ""),
                                master_img_hdul[i].header.get(
                                    "HIERARCH ESO PRO ROT NAANGLE"
                                ),
                                setup.lower(),
                            ),
                            "report_description": f"KMOS wavelength calibration single panel",
                            "report_tags": [],
                        }
        return panels

    def multipanel_plot(self, data, title=""):
        if data is not None:
            mplot = ImagePlot(
                title=title, cbar_kwargs={"pad": 0.1}, interpolation="none"
            )
            data = np.nan_to_num(data)
            mplot.add_data(data)
        else:
            # if no image data, setup a placeholder plot
            # TODO: implement empty plot class
            mplot = LinePlot(
                legend=False, y_label="y", x_label="x", title=title + " (NO DATA)"
            )
        return mplot

    def multi_report_panels(self):
        panels = {}

        for hdu_list in self.hdus:
            number_of_rotations = []
            for detector in self.detectors.values():
                if detector is not None:
                    number_of_rotations.append(detector["nangle"])
            number_of_rotations = max(number_of_rotations)
            # Clip the minimum number of columns to three
            p = Panel(max(number_of_rotations, 3), 4, height_ratios=[1, 4, 4, 4])
            # Text Plot
            px, py = 0, 0
            # which hdul and ext to use
            vspace = 0.3
            t1 = TextPlot(columns=1, v_space=vspace)
            fname = os.path.basename(str(hdu_list["DET_IMG_WAVE"].filename()))

            col1 = (
                str(hdu_list["DET_IMG_WAVE"]["PRIMARY"].header.get("INSTRUME")),
                "EXTNAME: "
                + str(hdu_list["DET_IMG_WAVE"]["PRIMARY"].header.get("EXTNAME", "N/A")),
                "PRO CATG: "
                + str(
                    hdu_list["DET_IMG_WAVE"]["PRIMARY"].header.get(
                        "HIERARCH ESO PRO CATG"
                    )
                ),
                "FILE NAME: " + fname,
                "RAW1 NAME: "
                + str(
                    hdu_list["DET_IMG_WAVE"]["PRIMARY"].header.get(
                        "HIERARCH ESO PRO REC1 RAW1 NAME"
                    )
                ),
            )
            t1.add_data(col1)
            p.assign_plot(t1, px, py, xext=2)

            px = px + 2
            t2 = TextPlot(columns=1, v_space=vspace, xext=1)
            col2 = (
                "INS.GRAT1.NAME: "
                + str(
                    hdu_list["DET_IMG_WAVE"]["PRIMARY"].header.get(
                        "HIERARCH ESO INS GRAT1 NAME"
                    )
                ),
                "INS.GRAT1.WLEN: "
                + str(
                    hdu_list["DET_IMG_WAVE"]["PRIMARY"].header.get(
                        "HIERARCH ESO INS GRAT1 WLEN"
                    )
                ),
                "DET.SEQ1.DIT: "
                + str(
                    hdu_list["DET_IMG_WAVE"]["PRIMARY"].header.get(
                        "HIERARCH ESO DET SEQ1 DIT"
                    )
                ),
            )
            t2.add_data(col2)
            p.assign_plot(t2, px, py, xext=1)

            setup = (
                str(
                    hdu_list["DET_IMG_WAVE"]["PRIMARY"].header.get(
                        "HIERARCH ESO INS GRAT1 NAME"
                    )
                )
                + str(
                    hdu_list["DET_IMG_WAVE"]["PRIMARY"].header.get(
                        "HIERARCH ESO INS GRAT2 NAME"
                    )
                )
                + str(
                    hdu_list["DET_IMG_WAVE"]["PRIMARY"].header.get(
                        "HIERARCH ESO INS GRAT3 NAME"
                    )
                )
            )
            # row 0 is for metadata
            py = 1
            hdu_info = []
            hdu_plot_args = []
            for extname in self.detectors.keys():
                px = 0
                detector_info = self.detectors[extname]
                # Missing data
                if detector_info is None:
                    py = py + 1
                    continue
                for rotation_angle, entry in zip(
                    detector_info["angle"], detector_info["entry"]
                ):
                    hdu_info.append((0, "DET_IMG_WAVE", entry, px, py))
                    hdu_plot_args.append(dict(title=f"{extname} {rotation_angle}"))
                    px = px + 1
                py = py + 1

            p = self.generate_multipanel(
                panel=p, hdu_info=hdu_info, plot_kwargs=hdu_plot_args
            )
            addme = {  # "ext": ext,
                "report_name": f"kmos_wave_cal_multi_{setup.lower()}",
                "report_description": f"KMOS wavelength calibration multi panel",
                #                      f"{ext})",
                "report_tags": [],
            }

            panels[p] = addme
        return panels

    def generate_panels(self, **kwargs):
        self.get_rotations()
        panels = {**self.single_report_panels(), **self.multi_report_panels()}
        return panels


rep = KmosWaveCalReport()
