from adari_core.plots.panel import Panel
from adari_core.plots.points import ScatterPlot
from adari_core.plots.images import ImagePlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.text import TextPlot
from adari_core.report import AdariReportBase
from adari_core.utils.utils import fetch_kw_or_default

from .uves_util import UvesSetupInfo

import os

from . import UvesReportMixin


class UvesCdReport(UvesReportMixin, AdariReportBase):
    raw_extensions = {"blue": "PRIMARY", "redu": "CCD-20", "redl": "CCD-44"}

    def __init__(self):
        super().__init__("uves_cd")

    def parse_sof(self):
        # This is not appropriate if there's more than one table
        for filename, catg in self.inputs:
            if catg == "CD_ALIGN_TABLE_BLUE":
                cd_align_table = filename
            if catg == "CD_ALIGN_BLUE":
                cd_align_raw_table = filename
            if catg == "CD_ALIGN_TABLE_RED":
                cd_align_table = filename
            if catg == "CD_ALIGN_RED":
                cd_align_raw_table = filename
        file_list = [
            {"cd_align_table": cd_align_table, "cd_align_raw": cd_align_raw_table}
        ]
        return file_list

    def generate_panels(self, **kwargs):
        # Create panels
        panels = {}
        for i in range(len(self.hdus)):
            # Use the same name convention as the original
            cd_align_table = self.hdus[i]["cd_align_table"]
            fname = os.path.basename(cd_align_table.filename())
            cd_align = cd_align_table[1]
            arm = UvesSetupInfo.get_arm_info(cd_align_table)
            if "RED" in arm:
                chip_1_index = cd_align_table["PRIMARY"].header.get(
                    "HIERARCH ESO DET CHIP1 INDEX"
                )
                ext_key = "red"
                if chip_1_index is None:
                    ext_key = "redl"
                else:
                    ext_key = "redu"
            else:
                ext_key = "blue"

            cd_align_raw_table = self.hdus[i]["cd_align_raw"]
            cd_align_raw = cd_align_raw_table[self.raw_extensions[ext_key]]
            # cd_align = cd_align_table[1]

            # Create the panel
            p = Panel(2, 3, height_ratios=[1, 4, 4])

            # ylabel = fetch_kw_or_default(cd_align_table_blue['SCIDATA'], "BUNIT", default="ADU")
            scaling = {"v_clip": "sigma", "v_clip_kwargs": {"nsigma": 2.5}}
            rplot_size = 2
            # Text plot
            px = 0
            py = 0
            vspace = 0.3
            t1 = TextPlot(columns=1, v_space=vspace)
            col1 = (
                str(cd_align_table["PRIMARY"].header.get("INSTRUME")),
                "EXTNAME: " + str(cd_align_table[0].header.get("EXTNAME", "N/A")),
                "PRO CATG: "
                + str(cd_align_table["PRIMARY"].header.get("HIERARCH ESO PRO CATG")),
                "FILE NAME: " + fname,
                "RAW1 NAME: "
                + str(
                    cd_align_table["PRIMARY"].header.get(
                        "HIERARCH ESO PRO REC1 RAW1 NAME"
                    )
                ),
            )
            t1.add_data(col1)
            p.assign_plot(t1, px, py, xext=2)

            px = px + 1

            t2 = TextPlot(columns=1, v_space=vspace, xext=1)
            grat_key = "INS.GRAT2.NAME" if "RED" in arm else "INS.GRAT1.NAME"
            col2 = (
                "DET.READ.SPEED: "
                + str(
                    cd_align_table["PRIMARY"].header.get(
                        "HIERARCH ESO DET READ SPEED"
                    )
                ),
                grat_key + ": "
                + str(
                    cd_align_table["PRIMARY"].header.get(
                        "HIERARCH ESO " + grat_key.replace('.', ' ' ), "N/A"
                    )
                ),
            )
            t2.add_data(col2)

            p.assign_plot(t2, px, py, xext=1)

            # 1 Cut in y direction through 1st and second raw frame (log of ordinate)

            raw_plot = ImagePlot(
                cd_align_raw.data,
                title="raw plot",
                aspect="auto",
                y_scale="log",
                **scaling,
            )
            cutpos = raw_plot.get_data_coord(raw_plot.data.shape[0] // 4, "x")
            raw_cutY = CutPlot(
                "y",
                title="raw row @ Y {}".format(cutpos),
                y_scale="log",
                y_label=fetch_kw_or_default(
                    cd_align_table["PRIMARY"], "BUNIT", default="ADU"
                ),
            )
            raw_cutY.add_data(raw_plot, cutpos, color="black", label="raw")
            p.assign_plot(raw_cutY, 0, 2)

            # 2 Central intensity of gaussion fit vs x for 1st and 2nd raw frame
            intensity_plot = ScatterPlot(
                title="Cent Intensity",
                x_label="X",
                y_label="BACK",
                markersize=rplot_size,
                legend=True,
            )
            back1 = cd_align.data["BACK1"] / 1000
            intensity_plot.add_data(
                (cd_align.data["X"], back1), label="Raw 1", color="black"
            )
            back2 = cd_align.data["BACK2"] / 1000
            intensity_plot.add_data(
                (cd_align.data["X"], back2), label="Raw 2", color="red"
            )
            p.assign_plot(intensity_plot, 0, 1)

            # 3 Y difference vs x

            diff_plot = ScatterPlot(
                title="Fit residuals",
                x_label="X",
                y_label="YDIFF",
                markersize=rplot_size,
                legend=False,
            )
            diff_plot.add_data(
                (cd_align.data["X"], cd_align.data["YDIFF"]),
                label="Used lines",
                color="black",
            )
            p.assign_plot(diff_plot, 1, 2)

            # 4 Ycent vs Xcent fot 1st and 2nd raw framw

            cent_plot = ScatterPlot(
                title="Cent Position",
                x_label="X",
                y_label="YCEN",
                markersize=rplot_size,
                legend=True,
            )
            cent_plot.add_data(
                (cd_align.data["X"], cd_align.data["YCEN1"]),
                label="Raw 1",
                color="black",
            )
            cent_plot.add_data(
                (cd_align.data["X"], cd_align.data["YCEN2"]), label="Raw 2", color="red"
            )

            p.assign_plot(cent_plot, 1, 1)

            panels[p] = {
                "report_name": "uves_cd_panel",
                "report_description": "UVES CD panel",
                "report_tags": [],
            }
        return panels


rep = UvesCdReport()
