from adari_core.plots.panel import Panel
from adari_core.plots.points import LinePlot
from adari_core.plots.text import TextPlot
from adari_core.plots.histogram import HistogramPlot
from .kmos_utils import KmosSetupInfo

from adari_core.data_libs.master_std_star_ifu import MasterSpecphotStdReport

import os
import numpy as np

from . import KmosReportMixin


class KmosSpecphotStdReport(KmosReportMixin, MasterSpecphotStdReport):
    """KMOS kmos_std_star recipe report class.

    Attributes
    ----------
    - files_needed: dictionary contaning the categories needed and the file names."""

    files_needed = {"STD_IMAGE": None, "STAR_SPEC": None}
    n_extensions = 24

    def __init__(self):
        super().__init__("KMOS_specphot_std")

    def parse_sof(self):
        """Returns a list of files selected from a set of frames (sof).

        If more than one file fullfills the criteria, the first file
        in the array will be selected.
        """
        file_path, files_category = (
            [elem[0] for elem in self.inputs],
            [elem[1] for elem in self.inputs],
        )
        for required_file in self.files_needed.keys():
            # Check that category matches the requirement
            if required_file in files_category:
                # In case of multiple files sharing the same category (e.g, STD)
                # only select the first
                if self.files_needed[required_file] is None:
                    self.files_needed[required_file] = file_path[
                        files_category.index(required_file)
                    ]
            else:
                raise IOError("ERROR: {} file not found".format(required_file))
        # Include raw files
        raw_counter = 0
        for fn, ctg in zip(file_path, files_category):
            if ctg == "STD":
                self.files_needed["STD_RAW{}".format(raw_counter)] = fn
                raw_counter += 1
        return [self.files_needed]

    def get_wavelength(self, files):
        """Get the wavelength array from KMOS data."""
        for i in range(1, self.n_extensions + 1):
            ifu_data_extension = "IFU.{}.DATA".format(i)
            if files["STAR_SPEC"][ifu_data_extension].data is not None:
                super().get_wavelength(files["STAR_SPEC"][ifu_data_extension])
                break

    def get_std_raw_files(self, std_image_file):
        """Return the location of the RAW files for each detector.

        Description
        -----------
        The standard stars provided in STD_IMAGE result from a combination
        of raw files (STD) from which it is necessary to determine which
        chanel is each raw file contributing to.

        Parameters
        ----------
        std_image_file: Reference STD_IMAGE HDU list. Files will be searched
        on the Primary extension variable HIERARCH ESO PRO REC1 RAW*** NAME.

        Returns
        -------
        raw_chip: dictionary containing the STD keyword for each detector (chip).
        """
        hdr = std_image_file["PRIMARY"].header
        # Number of combined frames
        n_frames = hdr.get("HIERARCH ESO PRO DATANCOM")
        chip = 1
        raw_chip = {}
        filenames = list(map(os.path.basename, self.files_needed.values()))
        filekeys = list(self.files_needed.keys())
        for frame in range(1, n_frames + 1):
            fname = hdr.get("HIERARCH ESO PRO REC1 RAW{} NAME".format(frame))
            if fname is not None and fname in filenames:
                raw_chip["STD_CHIP{}".format(chip)] = filekeys[filenames.index(fname)]
                chip += 1
        return raw_chip

    def generate_panels(self, **kwargs):
        panels = {}
        # Get files
        files = self.hdus[0]
        # Get the wavelength array
        self.get_wavelength(files)
        raw_per_chip = self.get_std_raw_files(files["STD_IMAGE"])
        # Generate a panel for each IFU (up to 24)
        for i in range(1, self.n_extensions + 1):
            ifu_data_extension = "IFU.{}.DATA".format(i)
            # Check PRIMARY HDU and set variables
            if (
                files["STD_IMAGE"]["PRIMARY"].header.get(
                    "HIERARCH ESO PRO STDSTAR{}".format(i)
                )
                is None
            ):
                continue
            if (
                files["STAR_SPEC"]["PRIMARY"].header["HIERARCH ESO TPL ID"]
                == "KMOS_spec_cal_stdstar"
            ):
                raw = True
                raw_channel = i // 8 + (i % 8 > 0)
                raw_extension = "CHIP{}.INT1".format(raw_channel)
                raw_file = raw_per_chip["STD_CHIP{}".format(raw_channel)]
            else:
                raw = False
                raw_channel = 0
                raw_extension = ""
                raw_file = ""
            # Create report panel
            p = Panel(3, 3, width_ratios=[1, 1, 1], height_ratios=[1, 4, 4])

            t1 = TextPlot(columns=1, v_space=0.3)
            col1 = ("EXTENSION: IFU.{}.DATA".format(i),)
            for key, val in files.items():
                if "STD_RAW" in key:
                    if key == raw_file and raw:
                        col1 = (
                            *col1,
                            "{}: {}".format("STD", val.filename().split("/")[-1]),
                        )
                else:
                    col1 = (*col1, "{}: {}".format(key, val.filename().split("/")[-1]))
            col1 = (
                *col1,
                "{} = {}".format(
                    "HIERARCH ESO TPL ID",
                    files["STAR_SPEC"][0].header["HIERARCH ESO TPL ID"],
                ),
            )

            t1.add_data(col1)
            p.assign_plot(t1, 0, 0, xext=2)

            col2 = KmosSetupInfo.standard(files["STD_IMAGE"])
            col2.append(
                "OCS.ROT.NAANGLE: "
                + str(
                    files["STD_IMAGE"]["PRIMARY"].header.get(
                        "HIERARCH ESO OCS ROT NAANGLE", "N/A"
                    )
                )
            )
            t2 = TextPlot(columns=1, v_space=0.3)
            t2.add_data(col2)
            p.assign_plot(t2, 2, 0, xext=2)

            # Star map
            star_map = self.plot_star_image(
                data=files["STD_IMAGE"][ifu_data_extension].data
            )
            star_map.title = "STD IMAGE"
            p.assign_plot(star_map, 0, 1)

            # Star spectra
            lineplot = self.plot_vs_wavelength(
                data=files["STAR_SPEC"][ifu_data_extension].data,
                label="Counts (ADUs)",
                title="STAR SPEC",
            )
            p.assign_plot(lineplot, 0, 2, xext=3)

            # Raw
            if not raw:
                data = None
                raw_hist = LinePlot()
                raw_hist.add_data(
                    [np.array([0, 1]), np.array([0, 1])], label="1", color="k"
                )
                raw_hist.add_data(
                    [np.array([0, 1]), np.array([1, 0])], label="2", color="k"
                )
                raw_hist.set_xlim(0, 1)
                raw_hist.set_ylim(0, 1)
                raw_hist.legend = False
                raw_hist.title = "RAW Histogram (N/A)"
            else:
                data = files[raw_file][raw_extension].data

                raw_hist = HistogramPlot(
                    raw_data=data.flatten(), bins=50, v_min=-5000, v_max=70000
                )
                raw_hist.title = "RAW Histogram"
                raw_hist.legend = False

            raw_map = self.plot_star_image(
                data=data, title="RAW: {}".format(raw_extension)
            )
            p.assign_plot(raw_map, 1, 1)
            p.assign_plot(raw_hist, 2, 1)

            # Metadata
            panels[p] = {
                "report_name": "KMOS_IFU_ext_{}".format(i),
                "report_description": "Spectrophotometric Standard Stars_{}".format(i),
                "report_tags": [],
            }
        return panels


rep = KmosSpecphotStdReport()
