from adari_core.data_libs.master_rawdisp import MasterRawdispReport
from .eris_utils import ErisSetupInfo, data_reader
import astropy.io.fits as fits
from adari_core.plots.text import TextPlot


import os
import numpy as np

from . import ErisReportMixin


class ErisRawdispReport(ErisReportMixin, MasterRawdispReport):
    def __init__(self):
        super().__init__("eris_rawdisp")
        self.extensions = []
        self.tasks = {
            "PERSISTENCE_IMA": "persistence",
            "PUPIL_LAMP": "pupil",
            "PUPIL_SKY": "pupil_nix",
            "OFF_RAW": "linearity",
            "ON_RAW": "linearity",
            "FLAT_LAMP_OFF": "nix_flat_lamp",
            "FLAT_LAMP_ON": "nix_flat_lamp",
            "FLAT_TWILIGHT": "nix_flat_sky_twi",
            "FLAT_SKY": "nix_flat_sky_twi",
            "DARK": "dark",
            "PERSISTENCE_CUBE": "persistence",
            "LINEARITY_LAMP": "linearity",
            "DARK_NS": "distortion",
            "FIBRE_NS": "distortion",
            "WAVE_NS": "distortion",
            "FLAT_NS": "distortion",
            "FLAT_LAMP": "ifu_flat",
            "WAVE_LAMP": "wavelength",
            "STD_FLUX": "ifu_stdstar",
            "SKY_STD_FLUX": "ifu_stdstar",
            "SKY_PSF_CALIBRATOR": "ifu_stdstar",
            "PSF_CALIBRATOR": "ifu_stdstar",
            "STD": "ifu_stdstar",
            "SKY_STD": "ifu_stdstar",
            "STD_JITTER": "nix_on_sky",
        }
        self.setup_info = ErisSetupInfo
        self.data_readers["filename"] = ErisRawdispReport.data_reader

    def parse_sof(self):
        # we building multiple report sets, so we append multiple reports to file_lists
        # get a list of tags
        ext = "PRIMARY"
        tags = list(self.tasks.keys())
        added = {}
        file_lists = []
        for filename, catg in self.inputs:
            if catg in tags:
                if filename is not None:
                    file_lists.append({"filename": filename})
                    added[catg] = self.tasks[catg]
                    self.sof_tag.append(catg)
                    self.extensions.append(ext)

        # Determine the final lists based on conditions
        if any(tag in self.sof_tag for tag in ["LINEARITY_LAMP", "FLAT_LAMP"]):
            self.extensions = [self.extensions[0], self.extensions[-1]]
            self.sof_tag = [self.sof_tag[0], self.sof_tag[-1]]
            file_lists = [file_lists[0], file_lists[-1]]
        elif not any(
            tag in self.sof_tag
            for tag in [
                "PERSISTENCE_CUBE",
                "WAVE_LAMP",
                "DARK_NS",
                "STD_FLUX",
                "OFF_RAW",
                "PSF_CALIBRATOR",
                "STD",
                "PERSISTENCE_IMA",
                "FLAT_LAMP_OFF",
                "STD_JITTER",
            ]
        ):
            self.extensions = [self.extensions[0]]
            self.sof_tag = [self.sof_tag[0]]
            file_lists = [file_lists[0]]
        elif any(
            tag in self.sof_tag
            for tag in [
                "DARK_NS",
                "STD_FLUX",
                "OFF_RAW",
                "PSF_CALIBRATOR",
                "STD",
                "FLAT_LAMP_OFF",
            ]
        ):
            first_positions = []
            exts = []
            fl = []
            for i, element in enumerate(self.sof_tag):
                if element not in first_positions:
                    first_positions.append(element)
                    exts.append(self.extensions[i])
                    fl.append(file_lists[i])
            self.extensions = exts
            self.sof_tag = first_positions
            file_lists = fl

        return file_lists

    @staticmethod
    def data_reader(filename):
        hdu = fits.open(filename, mode="readonly")
        extensions = [i.name for i in hdu]
        if "DATA" in extensions:
            ext = "DATA"
        else:
            for item in hdu.info(output=False):
                if len(item[5]) == 3:
                    ext = item[1]
                    hdu[ext].data = hdu[ext].data[0]
        return hdu

    def generate_panels(self, **kwargs):
        panels = {}
        if "PUPIL_LAMP_OPEN" in self.sof_tag:
            ext = "DATA"
        else:
            ext = "PRIMARY"

        vspace = 0.3

        new_panels = super().generate_panels(ext=ext, **kwargs)

        for i, (panel, panel_descr) in enumerate(new_panels.items()):
            hdul = self.hdus[i]["filename"]
            raw = hdul[ext]
            # depending on task name
            try:
                task_name = panel_descr["task_name"]
            except KeyError as e:
                raise RuntimeError(
                    "A report has been created by "
                    "MasterRawdispReport that did "
                    "not come back with a task name "
                    "attached!"
                )

            t1 = TextPlot(columns=1, v_space=vspace)
            ext_index = hdul.index_of(ext)
            col1 = (
                str(hdul["PRIMARY"].header.get("INSTRUME")),
                "ORIGFILE: " + str(hdul["PRIMARY"].header.get("ORIGFILE")),
                "ARCFILE: " + str(hdul["PRIMARY"].header.get("ARCFILE")),
                "EXTENSION: " + str(ext_index),
                "EXTNAME: " + str(ext),
            )
            t1.add_data(col1)
            panel.assign_plot(t1, 0, 0, xext=2)

            raw_im = panel.retrieve(0, 1)
            raw_kwargs = {"percentile": 95.0}
            raw_im.set_v_clip_method("percentile", **raw_kwargs)

            panel_descr["report_name"] = "eris_rawdisp_{}_{}_{}_{}".format(
                task_name,
                ext,
                self.sof_tag[i].lower(),
                os.path.basename(panel_descr["filename"]),
            )
            panel_descr["report_description"] = (
                f"ERIS rawdisp panel - "
                f"{panel_descr['task_name']}, "
                f"{panel_descr['tag']}, "
                f"{os.path.basename(panel_descr['filename'])}, "
                f"{panel_descr['ext']}"
            )

        panels = {**panels, **new_panels}

        return panels


rep = ErisRawdispReport()
