from adari_core.data_libs.echelle_flatfield import MasterEchelleFlatfieldReport
from adari_core.plots.text import TextPlot
from adari_core.plots.points import LinePlot
from .xshooter_utils import XshooterSetupInfo
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.combined import CombinedPlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.panel import Panel
from adari_core.utils.utils import fetch_kw_or_default, fetch_kw_or_error
import numpy as np
import matplotlib.pyplot as plt

import os

from . import XshooterReportMixin

center_size = 200


class XshooterEchelleFlatfieldReport(XshooterReportMixin, MasterEchelleFlatfieldReport):
    raw_extension_default = None
    flat_type = "FLAT"

    def __init__(self):
        super().__init__("xshooter_echelle_flatfield")
        self.center_size = center_size

    def set_flat_type(self):
        """Set the FLAT type category to use for producing the report."""
        for _, catg in self.inputs:
            if "SLIT" in catg:
                self.flat_type = "SLIT"
                return
            elif "IFU" in catg:
                self.flat_type = "IFU"
                return

    def parse_sof(self):
        raw_flat_uvb = None
        raw_flat_vis = None
        raw_flat_nir = None
        
        master_flat_uvb = None
        master_flat_vis = None
        master_flat_nir = None
        # First check whether is FLAT, DFLAT or IFLAT
        self.set_flat_type()

        for filename, catg in self.inputs:
            if catg == f"MASTER_FLAT_{self.flat_type}_UVB":
                master_flat_uvb = filename
                self.category_label = "UVB"
            if catg == f"MASTER_FLAT_{self.flat_type}_VIS":
                master_flat_vis = filename
                self.category_label = "VIS"
            if catg == f"MASTER_FLAT_{self.flat_type}_NIR":
                master_flat_nir = filename
                self.category_label = "NIR"
        
            if catg == f"FLAT_D2_{self.flat_type}_UVB" or catg == f"FLAT_{self.flat_type}_UVB" or catg == f"FLAT_{self.flat_type}_UVB_ON":
                raw_flat_uvb = filename
            if catg == f"FLAT_D2_{self.flat_type}_VIS" or catg == f"FLAT_{self.flat_type}_VIS" or catg == f"FLAT_{self.flat_type}_VIS_ON":
                raw_flat_vis = filename
            if catg == f"FLAT_D2_{self.flat_type}_NIR" or catg == f"FLAT_{self.flat_type}_NIR" or catg == f"FLAT_{self.flat_type}_NIR_ON":
                raw_flat_nir = filename
        
        file_lists = []
        self.raw_extension_default = []
        if master_flat_uvb is not None and raw_flat_uvb is not None:
            file_lists.append(
                {
                    "master_product": master_flat_uvb,
                    "raw": raw_flat_uvb,
                }
            )
            self.raw_extension_default.append(0)
        if master_flat_vis is not None and raw_flat_vis is not None:
            file_lists.append(
                {
                    "master_product": master_flat_vis,
                    "raw": raw_flat_vis,
                }
            )
            self.raw_extension_default.append(2)
        if master_flat_nir is not None and raw_flat_nir is not None:
            file_lists.append(
                {
                    "master_product": master_flat_nir,
                    "raw": raw_flat_nir,
                }
            )
            self.raw_extension_default.append(1)
        return file_lists

    def prepare_slit_function(self):
        master_product = self.hdus[0]["master_product"]
        mst_image = master_product["PRIMARY"].data
        ny1, nx1 = mst_image.shape
        self.row = np.arange(nx1)
        col = np.arange(ny1)
        self.cen_ny = int(ny1/2)
        # print cen_ny
        
        self.procenrow = mst_image[self.cen_ny:(self.cen_ny)+1,:]
        
        # subarray: central row  +- 10 rows
        pro_stripe = mst_image[self.cen_ny-10:self.cen_ny+10,:]
        # median over the 21 central rows
        self.promedrow = np.median(pro_stripe, axis=0)
        binx = master_product["PRIMARY"].header.get("HIERARCH ESO DET WIN1 BINX")
        biny = master_product["PRIMARY"].header.get("HIERARCH ESO DET WIN1 BINY")

        if self.category_label == 'UVB' and binx == 1 and biny == 1:
            self.x1 = 1350
            self.x2 = 1430
        if self.category_label == 'UVB' and binx == 1 and biny == 2:
            self.x1 = 1350
            self.x2 = 1430
        if self.category_label == 'UVB' and binx == 2 and biny == 2:
            self.x1 = 1350/2
            self.x2 = 1430/2
        if self.category_label == 'VIS' and binx == 1 and biny == 1:
            self.x1 = 1365
            self.x2 = 1450
        if self.category_label == 'VIS' and binx == 1 and biny == 2:
            self.x1 = 1365
            self.x2 = 1450
        if self.category_label == 'VIS' and binx == 2 and biny == 2:
            self.x1 = 1365/2
            self.x2 = 1450/2
        if self.category_label == 'NIR':
            self.x1 = 384
            self.x2 = 435


    def generate_panels(self, **kwargs):
        panels = {}
        new_panels = super().generate_panels(master_product_ext="PRIMARY", raw_ext=0, direction="y")
        for i, (panel, panel_descr) in enumerate(new_panels.items()):
            # Alter the cut pos, or remove CutPlot(s) completely,
            # depending on task name        
            vspace = 0.3
            master_product = self.hdus[i]["master_product"]
            master_product_ext = "PRIMARY"
            direction="y"
            raw = self.hdus[i]["raw"]
            hdr = master_product["PRIMARY"].header
            master_product_procatg = fetch_kw_or_error(
                master_product[0], "HIERARCH ESO PRO CATG"
            )
                
            t2 = TextPlot(columns=1, v_space=vspace, xext=1)
            col2 = XshooterSetupInfo.flatfield(master_product)            
    
            t2.add_data(col2)
            panel.assign_plot(t2, 2, 0, xext=1)
            
            master_plot = ImagePlot(
                master_product[master_product_ext].data, title=master_product_procatg
            )
            cutpos1 = master_plot.get_data_coord(master_plot.data.shape[0] // 2, direction)
            min_data = np.min(master_plot.data[cutpos1-1:cutpos1,:][0])
            max_data = np.max(master_plot.data[cutpos1-1:cutpos1,:][0])

            plt = panel.retrieve(2, 2)
            plt.v_min = min_data
            plt.v_max = max_data

            self.prepare_slit_function()

            # median over the 21 central rows
        
            nx1 = master_product[master_product_ext].data.shape[1]
            slit_function_plot = ImagePlot(
                master_product[master_product_ext].data[self.cen_ny:(self.cen_ny)+1,:], title="Slit Function"
            )
            cutpos3 = slit_function_plot.get_data_coord(slit_function_plot.data.shape[0] // 2, direction)
            slit_function1 = CutPlot(
                direction,
                title="master row @ {} {}".format(direction, cutpos3),
                y_label="normalised counts"
            )
            slit_function1.add_data(slit_function_plot, cutpos3, color="red", label='central row =' + str(self.cen_ny))
    
            slit_function2 = LinePlot(
                title="Slit Function",
                x_label = "x",
                y_label="normalised counts",
                legend = True
            )
            slit_function2.add_data(d=[self.row, self.promedrow], color="blue", label='median of 21 rows')
    
            slit_function_combined = CombinedPlot(
                title="Slit function combined",
                x_label="x",
                y_label="normalised counts",
                legend = True
            )
            slit_function_combined.add_data(slit_function1, z_order=0)
            slit_function_combined.add_data(slit_function2, z_order=1)
    
            slit_function_combined.set_ylim(ymin=-0.1, ymax=1.3)
            slit_function_combined.set_xlim(xmin=self.x1, xmax=self.x2)
    
            panel.assign_plot(slit_function_combined, 3, 1)

                        # Update the panel description
            panel_descr["report_name"] = "xshooter_echelle_flatfield_{}".format(
                hdr.get("HIERARCH ESO PRO CATG").lower()
            )
            panel_descr["report_description"] = (
                f"XSHOOTER echelle flatfield - "
                f"{os.path.basename(panel_descr['master_product'])}, "
                f"{panel_descr['master_product_ext']}"
            )

        panels = {**panels, **new_panels}
        

        return panels


rep = XshooterEchelleFlatfieldReport()
