from .fors_utils import ForsSetupInfo
from adari_core.data_libs.master_std_star_img import MasterStdStarImg
from adari_core.plots.panel import Panel
from adari_core.plots.points import ScatterPlot
from adari_core.plots.text import TextPlot
from adari_core.utils.utils import fetch_kw_or_default

from . import ForsReportMixin


class ForsImgStdReport(ForsReportMixin, MasterStdStarImg):
    """Fors fors_std_star recipe report class.

    Attributes
    ----------
    - files_needed: dictionary contaning the categories needed and the file names."""
    raw_extensions = {
    "CHIP1": "CHIP1",
    "CHIP2": "CHIP2"
    }
    extensions = []

    files_needed = {
        "raw_im": "STANDARD_IMG",
        "calib_std": "STANDARD_REDUCED_IMG", 
        "matched_phot": "ALIGNED_PHOT"
    }
    center_size = 400

    def __init__(self):
        super().__init__("fors_img_std")

    def parse_sof(self):
        """
        Returns a list of files selected from a set of frames (sof).

        If more than one file fullfills the criteria, the first file
        in the array will be selected.
        """
        file_path, files_category = (
            [elem[0] for elem in self.inputs],
            [elem[1] for elem in self.inputs],
        )
        file_lists = {}
        for required_file in self.files_needed.keys():
            # Check that category matches the requirement
            if self.files_needed[required_file] in files_category:
                file_lists[required_file] = file_path[
                    files_category.index(self.files_needed[required_file])
                ]
            else:
                raise IOError(
                    "[WARNING] {} file not found".format(
                        self.files_needed[required_file]
                    )
                )
        return [file_lists]
    
    def get_extensions(self):
        """Find the data extensions required for each FORS files.

        Description
        -----------
        After the SOF has been parsed, this method iterates over the different
        HDUS files to find which extension(s) contains the data.
        """
        all_hdu_names = [hdu.name for hdu in self.hdus[0]["raw_im"]]
        for chip_name in self.raw_extensions.keys():
            if chip_name in all_hdu_names:
                self.extensions.append(chip_name)
       
    
    def get_standards(self, data):
        n = len(data)
        std_xvals = [data[i][0] for i in range(n)]
        std_yvals = [data[i][1] for i in range(n)]
        weight = [data[i][-1] for i in range(n)]
        std_xvals_s = [std_xvals[weight.index(x)] for x in weight if abs(x)>0.05]
        std_yvals_s = [std_yvals[weight.index(x)] for x in weight if abs(x)>0.05]
        return [std_xvals_s, std_yvals_s]
        
    def generate_panels(self, **kwargs):
        panels = {}
        vspace = 0.3
        self.get_extensions()
        
        new_panels = super().generate_panels(
            im_ext=1,
            raw_ext=self.extensions[0],
            marking=True,
            marking_centre=True,
            im_clipping="mad",
            im_n_clipping={"nmad":3},
            hist_clipping="val",
            hist_n_clipping={"low":-5000.,"high":70000.},
        )
        
        for i, (panel, panel_descr) in enumerate(new_panels.items()):
            calib_std = self.hdus[0]["calib_std"]
            aligned_phot = self.hdus[0]["matched_phot"]
            t2 = TextPlot(columns=1, v_space=vspace, xext=1)
            col2 = ForsSetupInfo.std_img(calib_std)
            t2.add_data(col2)
            panel.assign_plot(t2, 2, 0, xext=1)

            
            weight = [aligned_phot[1].data[i][-1] for i in range(len(aligned_phot[1].data))]
            zeropoint = [aligned_phot[1].data[i][-3] for i in range(len(aligned_phot[1].data))]
            color = [aligned_phot[1].data[i][19] for i in range(len(aligned_phot[1].data))]
            zp_s = [zeropoint[weight.index(x)] for x in weight if abs(x)>0.05]
            color_s = [color[weight.index(x)] for x in weight if abs(x)>0.05]
            zp_color = [color_s, zp_s]
       
            zp_c = ScatterPlot(title="Color vs. Zeropoint")
            zp_c.add_data(
                zp_color,
                label="zpcolor",
            )
            zp_c.x_label = "Color"
            zp_c.y_label = "Zeropoint"
            zp_c.legend = False
            panel.assign_plot(zp_c, 3, 2, xext=1, yext=1)
            
        panels = {**panels, **new_panels}

        return panels


rep = ForsImgStdReport()
