from adari_core.plots.combined import CombinedPlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.images import ImagePlot
from adari_core.plots.panel import Panel
from adari_core.plots.points import ScatterPlot
from adari_core.plots.text import TextPlot
from adari_core.utils.utils import fetch_kw_or_default
from adari_core.report import AdariReportBase
import numpy as np
import os
from .hawki_utils import HawkiReportMixin

class HawkiScienceReport(HawkiReportMixin, AdariReportBase):
    def __init__(self):
        super().__init__("hawki_science")
        self.combined = False

    def parse_sof(self):
        image_jit = None
        cat_jit = None
        conf_jit =  None
        image_tiled = None
        cat_tiled = None
        conf_tiled = None
       
        for filename, catg in self.inputs:
            if catg == "JITTERED_IMAGE_SCI":
                image_jit = filename
            elif catg == "OBJECT_CATALOGUE_JITTERED":
                cat_jit = filename
            elif catg == "CONFIDENCE_MAP_JITTERED":
                conf_jit = filename
            elif catg == "TILED_IMAGE":
                image_tiled = filename
            elif catg == "TILED_OBJECT_CATALOGUE":
                cat_tiled = filename
            elif catg == "TILED_CONFIDENCE_MAP":
                conf_tiled = filename

        file_lists = []
        if image_tiled is not None:
            file_lists.append(
                {
                    "image": image_tiled,
                    "cat": cat_tiled,
                    "conf": conf_tiled,
                }
            )
            self.combined = True
        elif image_jit is not None:
            file_lists.append(
                {
                    "image": image_jit,
                    "cat": cat_jit,
                    "conf": conf_jit,
                }
            )
            
        return file_lists

    def fetch_and_print_format_or_default(self, hdu, key, format, default="N/A"):
        val = fetch_kw_or_default(hdu, key, default=default)
        if val == default:
            return str(default)
        else:
            return format % val

    def generate_single_panel(self, ext, **kwargs):
        panels = {}
        vspace = 0.3
        science = self.hdus[0]["image"]
        image = self.hdus[0]["image"][ext].data
        cat = self.hdus[0]["cat"]
        conf = self.hdus[0]["conf"][ext].data


        p = Panel(6, 6, height_ratios=[1, 3, 3, 3, 3, 1], y_stretch=0.7, right_subplot=1.5)

        # 1. Main image plot
        scaling = {}
        scaling["v_clip"] = "percentile"
        scaling["v_clip_kwargs"] = {"percentile": 95.}

        full_image = ImagePlot(title="Image "+str(ext), colormap="hot", **scaling)
        full_image.add_data(np.nan_to_num(image))
        p.assign_plot(full_image, 0, 1, xext=4, yext=4)
        
        label = str(
            fetch_kw_or_default(self.hdus[0]["image"][ext], "BUNIT", default="Flux")
        )
        
        # 2. Cut plots
        if self.combined:
            cut_pos_y = full_image.get_data_coord(full_image.data.shape[0] // 4, "y")
            cut_pos_x = full_image.get_data_coord(full_image.data.shape[1] // 4, "x")
        else:
            cut_pos_y = full_image.get_data_coord(full_image.data.shape[0] // 2, "y")
            cut_pos_x = full_image.get_data_coord(full_image.data.shape[1] // 2, "x")
        full_cut_y = CutPlot("y", title="Cut @ y="+str(cut_pos_y), x_label="x", y_label=label, legend=False)
        full_cut_y.add_data(
            full_image,
            cut_pos=cut_pos_y,
        )
        p.assign_plot(full_cut_y, 4, 1)
        full_cut_x = CutPlot("x", title="Cut @ x="+str(cut_pos_x), x_label="y", y_label=label, legend=False)
        full_cut_x.add_data(
            full_image,
            cut_pos=cut_pos_x,
        )
        p.assign_plot(full_cut_x, 4, 2)
        
        if self.combined:
            cut_pos_x2 = 3 * full_image.get_data_coord(full_image.data.shape[1] // 4, "x")
            full_cut_x2 = CutPlot("x", title="Cut @ x="+str(cut_pos_x2), x_label="y", y_label=label, legend=False)
            full_cut_x2.add_data(
                full_image,
                cut_pos=cut_pos_x2,
            )
            p.assign_plot(full_cut_x2, 5, 2)
        
        # 3. Confidence map
        conf_image = ImagePlot(title="Confidence", colormap="hot")
        conf_image.add_data(np.nan_to_num(conf))
        p.assign_plot(conf_image, 5, 1)

        # 4. Product histograms
        comb_hist = CombinedPlot(
            title="Product histograms",
        )
        # mode: the bin with the highest number of counts
        hist, bin_edges = np.histogram(image, bins=50)
        mode = 0.5 * (bin_edges[np.argmax(hist)] + bin_edges[np.argmax(hist) + 1])
        
        zoom1 = bin_edges[min(50,np.argmax(hist)+5)]
        zoom0 = bin_edges[max(0,np.argmax(hist)-5)]
        data_hist = HistogramPlot(
            title="Product histograms",
            v_min=zoom0,
            v_max=zoom1,
            bins=50,
            x_label=label,
        )
        data_hist.add_data(image, label="data counts")
        
        # recalculate mode for zoom histogram
        hist, bin_edges = np.histogram(image, bins=50, range=(zoom0,zoom1))
        mode = 0.5 * (bin_edges[np.argmax(hist)] + bin_edges[np.argmax(hist) + 1])
       
        stat_line0 = ScatterPlot(title="Line zero")
        stat_line0.add_data(
            label="_zero",
            d=[[0], [np.nan]],
            vline=True,
            color="black",
        )
        stat_line = ScatterPlot(title="Line mode")
        stat_line.add_data(
            label="_mode",
            d=[[mode], [np.nan]],
            vline=True,
            color="black",
            linestyle="dotted",
        )
        data_hist.legend = False
        stat_line.legend = False
        stat_line0.legend = False
        comb_hist.add_data(data_hist)
        comb_hist.add_data(stat_line0)
        comb_hist.add_data(stat_line)
        comb_hist.y_scale = "log"
        comb_hist.legend = False
            
        scaling = {"v_clip": "minmax"}
        data_fullhist = HistogramPlot(
            title="",
            x_label=label,
            bins=50,
            legend=False,
            **scaling,
        )
        data_fullhist.add_data(image, label="data counts")
        data_fullhist.y_scale = "log"
        p.assign_plot(comb_hist, 4, 3, xext=2)
        p.assign_plot(data_fullhist, 4, 4, xext=2)
    
        vspace = 0.4
        t0 = TextPlot(columns=1, v_space=vspace)
        col0 = (
            str(fetch_kw_or_default(science["PRIMARY"], "INSTRUME", default="N/A"))
            + " science product preview",
            "Product: "
            + str(
                fetch_kw_or_default(science["PRIMARY"], "ESO PRO CATG", default="N/A")
            ),
            "Raw file: "
            + str(
                fetch_kw_or_default(
                    science["PRIMARY"], "ESO PRO REC1 RAW1 NAME", default="N/A"
                )
            ),
            "MJD-OBS: "
            + str(fetch_kw_or_default(science["PRIMARY"], "MJD-OBS", default="N/A")),
        )
        t0.add_data(col0, fontsize=13)
        p.assign_plot(t0, 0, 0, xext=1)

        t1 = TextPlot(columns=1, v_space=vspace)
        col1 = (
            "Target: "
            + str(fetch_kw_or_default(science["PRIMARY"], "OBJECT", default="N/A")),
            "OB ID: "
            + str(fetch_kw_or_default(science["PRIMARY"], "ESO OBS ID", default="N/A")),
            "OB NAME: "
            + str(
                fetch_kw_or_default(science["PRIMARY"], "ESO OBS NAME", default="N/A")
            ),
            "TPL ID: "
            + str(fetch_kw_or_default(science["PRIMARY"], "ESO TPL ID", default="N/A")),
            "RUN ID: "
            + str(
                fetch_kw_or_default(
                    science["PRIMARY"], "ESO OBS PROG ID", default="N/A"
                )
            ),
        )
        t1.add_data(col1, fontsize=13)
        p.assign_plot(t1, 2, 0, xext=1)

        t2 = TextPlot(columns=1, v_space=vspace)
        col2 = (
            "Filter 1: "
            + str(
                fetch_kw_or_default(science["PRIMARY"], "ESO INS FILT1 NAME", default="N/A")
            ),
            "Filter 2: "
            + str(
                fetch_kw_or_default(science["PRIMARY"], "ESO INS FILT2 NAME", default="N/A")
            ),
            "INS MODE: "
            + str(
                fetch_kw_or_default(science["PRIMARY"], "ESO INS MODE", default="N/A")
            ),
        )
        t2.add_data(col2, fontsize=13)
        p.assign_plot(t2, 4, 0, xext=1)

        # Bottom Text Plot
        if self.combined:
            extcat = 1
        else:
            extcat = ext
        
        
        vspace = 0.4
        t4 = TextPlot(columns=1, v_space=vspace)
        col4 = ()
        
        col4 += (
            "Exp. time [s]: "
            + self.fetch_and_print_format_or_default(science["PRIMARY"], "EXPTIME", "%.1f"),
            "N exposures: "
            + self.fetch_and_print_format_or_default(science["PRIMARY"], "NCOMBINE", "%i"),
            "Image quality [arcsec]: "
            + self.fetch_and_print_format_or_default(cat[extcat], "ESO QC IMAGE_SIZE", "%.2f"),
            "Average ellipticity: "
            + self.fetch_and_print_format_or_default(cat[extcat], "ESO QC ELLIPTICITY", "%.2f"),
            "N sources: "
            + self.fetch_and_print_format_or_default(cat[extcat], "NAXIS2", "%i"),
        )
        t4.add_data(col4, fontsize=13)
        p.assign_plot(t4, 0, 5, xext=1)
        
        t5 = TextPlot(columns=1, v_space=vspace)
        col5 = ()
        
        col5 += (
            "Limiting magnitude: "
            + self.fetch_and_print_format_or_default(cat[extcat], "ESO QC LIMITING_MAG", "%.2f"),
            "Zeropoint: "
            + self.fetch_and_print_format_or_default(cat[extcat], "ESO QC MAGZPT", "%.2f"),
            "N stars ZP: "
            + self.fetch_and_print_format_or_default(cat[extcat], "ESO QC MAGNZPT", "%i"),
            "WCS error [arcsec]: "
            + self.fetch_and_print_format_or_default(cat[extcat], "ESO DRS STDCRMS", "%.3f"),
            "N stars WCS: "
            + self.fetch_and_print_format_or_default(cat[extcat], "ESO DRS NUMBRMS", "%i"),
        )
        t5.add_data(col5, fontsize=13)
        p.assign_plot(t5, 2, 5, xext=1)

        t6 = TextPlot(columns=1, v_space=vspace)
        col6 = (
            "Median sky [ADU]: "
            + self.fetch_and_print_format_or_default(cat[extcat], "ESO QC MEAN_SKY", "%.2f"),
        )
        
        if self.combined:
            col6 += (
                "Airmass: "
                + self.fetch_and_print_format_or_default(science["PRIMARY"], "ESO QC AIRM MEAN", "%.2f"),
                "Delta t flat field [d]: "
                + self.fetch_and_print_format_or_default(science["PRIMARY"], "ESO QC DELTA TIME TWFLAT", "%.2f"),
                "N pix non-linearity: "
                + self.fetch_and_print_format_or_default(science["PRIMARY"], "ESO QC SAT NB", "%i"),
            )
        t6.add_data(col6, fontsize=13)
        p.assign_plot(t6, 4, 5, xext=1)
        
        input_files = []
        input_files.append(self.hdus[0]["image"].filename())
        input_files.append(self.hdus[0]["cat"].filename())
        input_files.append(self.hdus[0]["conf"].filename())

        data_fname = os.path.basename(str(science.filename()))

        addme = {
            "report_name": f"HAWKI_{str(data_fname).removesuffix('.fits').lower()}_{str(ext)}",
            "report_description": "Science panel",
            "report_tags": [],
            "report_prodcatg": "ANCILLARY.PREVIEW",
            "input_files": input_files,
        }

        panels[p] = addme

        return panels
        
    def generate_panels(self, **kwargs):
    
        if self.combined:
            panels = {
                **self.generate_single_panel(ext="PRIMARY"),
            }
        else:
            panels = {
                **self.generate_single_panel(ext="CHIP1.INT1"),
                **self.generate_single_panel(ext="CHIP2.INT1"),
                **self.generate_single_panel(ext="CHIP3.INT1"),
                **self.generate_single_panel(ext="CHIP4.INT1"),
            }
        return panels

rep = HawkiScienceReport()
