import os
import numpy as np
from adari_core.plots.text import TextPlot
from adari_core.report import AdariReportBase
from adari_core.plots.points import ScatterPlot, LinePlot
from adari_core.plots.combined import CombinedPlot
from adari_core.plots.panel import Panel
from adari_core.plots.images import ImagePlot
from adari_core.utils.utils import fetch_kw_or_error

from .pionier_utils import PionierSetupInfo, PionierReportMixin

class PionierKappaReport(PionierReportMixin, AdariReportBase):
    def __init__(self):
        super().__init__("pionier_kappa")

    def parse_sof(self):
        master_kappa = None

        for filename, catg in self.inputs:
            if catg == "KAPPA_MATRIX" and master_kappa is None:
                master_kappa = filename

        # Build and return the file name list
        file_lists = []
        if master_kappa is not None:
            file_lists.append(
                {
                    "master_kappa": master_kappa,
                }
            )

        return file_lists

    def plot_panel_1(self, data, ncols, nrows, hr):

        p = Panel(ncols, nrows, height_ratios=hr, y_stretch=2)

        for i in range(nrows - 1):
            for j in range(ncols):
                if i == 0 and j == 0:
                    ylabel = "output"
                else:
                    ylabel = ""
                if i == nrows - 2 and j == ncols - 1:
                    xlabel = "spectral channel"
                else:
                    xlabel = ""
                title = "telescope " + str(j + 1)
                if nrows > 2:
                    title += ", pol" + str(i + 1)
                master_plot1 = ImagePlot(
                    data[j + i * ncols],
                    x_label=xlabel,
                    y_label=ylabel,
                    title=title,
                    show_colorbar=True,
                )
                if data[j + i * ncols].shape[1] == 1:
                    master_plot1.x_major_ticks = (0.0,)
                master_plot1.set_vlim(0, 1.6)
                p.assign_plot(master_plot1, j, i + 1)

        return p

    def plot_panel_2(self, p, data, xlabel, rplot_size):

        plot_1s = ScatterPlot(
            x_label=xlabel,
            y_label="",
            legend=True,
            markersize=rplot_size,
        )
        for j in range(np.shape(data)[0]):
            data_a = data[j, :, 0][data[j, :, 0] > 0]
            order = np.array(range(np.shape(data)[1]))[data[j, :, 0].flatten() > 0]
            plot_1s.add_data(d=[order, data_a.flatten()], label="NAXIS1: " + str(j + 1))

        plot_1l = LinePlot(legend=False)
        plot_1l.linestyle = "dashed"
        for j in range(np.shape(data)[0]):
            data_a = data[j, :, 0]
            order = range(np.shape(data)[1])
            plot_1l.add_data(d=[order, data_a.flatten()], label="NAXIS1: " + str(j + 1))

        plot_1_combined = CombinedPlot(title="channel 0")
        plot_1_combined.add_data(plot_1s, z_order=0)
        plot_1_combined.add_data(plot_1l, z_order=1)

        p.assign_plot(plot_1_combined, 0, 1)

        if np.shape(data)[2] > 2:
            for i in range(1, int(np.shape(data)[2] / 2)):
                plot_combined = self.plot_kappa_vs_wave(data, i, j, "", rplot_size)

                p.assign_plot(plot_combined, i, 1)

            for i in range(3, (np.shape(data)[2])):
                plot_combined = self.plot_kappa_vs_wave(
                    data, i, j, "output", rplot_size
                )

                p.assign_plot(plot_combined, i - 3, 2)

        return p

    def plot_kappa_vs_wave(self, data, i, j, x_label, rplot_size):
        plot_1s = ScatterPlot(
            x_label=x_label,
            y_label="",
            legend=False,
            markersize=rplot_size,
        )
        for j in range(np.shape(data)[0]):
            data_a = data[j, :, i][data[j, :, i] > 0]
            order = np.array(range(np.shape(data)[1]))[data[j, :, i].flatten() > 0]

            plot_1s.add_data(d=[order, data_a.flatten()], label="NAXIS1: " + str(j + 1))

        plot_1l = LinePlot(legend=False)
        plot_1l.linestyle = "dashed"
        for j in range(np.shape(data)[0]):
            data_a = data[j, :, i]
            order = range(np.shape(data)[1])

            plot_1l.add_data(d=[order, data_a.flatten()], label="NAXIS1: " + str(j + 1))

        plot_1_combined = CombinedPlot(title="channel" + str(i + 1))
        plot_1_combined.add_data(plot_1s, z_order=0)
        plot_1_combined.add_data(plot_1l, z_order=1)
        return plot_1_combined

    def first_panel(self, hdr, data):
        ncols = 4
        if str(hdr.get("HIERARCH ESO INS OPTI2 NAME")) == "GRI+WOL":
            hr = [1, 4, 4]
            nrows = 3
        else:
            hr = [1, 4]
            nrows = 2
        p = self.plot_panel_1(data, ncols, nrows, hr)
        return p

    def second_panel(self, hdr, data):
        rplot_size = 2

        p = Panel(3, 3, height_ratios=[1, 4, 4])

        if str(hdr.get("HIERARCH ESO INS OPTI2 NAME")) == "FREE":
            xlabel = "output"
            p = self.plot_panel_2(p, data, xlabel, rplot_size)
        else:
            xlabel = ""
            p = self.plot_panel_2(p, data, xlabel, rplot_size)

        return p

    def generate_panels(self, **kwargs):

        panels = {}
        master_kappa = self.hdus[0]["master_kappa"]
        hdr = master_kappa[0].header
        ext = "PNDRS_MATRIX"
        fname = master_kappa.filename()

        master_kappa_procatg = fetch_kw_or_error(
            master_kappa[0], "HIERARCH ESO PRO CATG"
        )

        vspace = 0.2

        col1 = (
            "INSTRUME: " + str(hdr.get("INSTRUME")),
            "EXTNAME: " + ext,
            "PRO.CATG: " + str(master_kappa_procatg),
            "FILE NAME: " + os.path.basename(fname),
            "RAW1.NAME: " + str(hdr.get("HIERARCH ESO PRO REC1 RAW1 NAME")),
        )

        col2 = PionierSetupInfo.kappa(master_kappa[0])

        t1 = TextPlot(columns=1, v_space=vspace, xext=1)
        t1.add_data(col1)

        t2 = TextPlot(columns=1, v_space=vspace, xext=1)
        t2.add_data(col2)

        p1 = self.first_panel(hdr, master_kappa[1].data)
        p1.assign_plot(t1, 0, 0)

        p1.assign_plot(t2, 1, 0)

        p2 = self.second_panel(hdr, master_kappa[1].data)
        p2.assign_plot(t1, 0, 0)

        p2.assign_plot(t2, 1, 0)

        addme1 = {
            "panel": "first panel",
            "ext": ext,
            "report_name": f"pionier_{master_kappa_procatg}_{str(ext)}",
            "report_description": f"PIONIER kappa report panel - " f"{ext}",
            "report_tags": [],
        }

        addme2 = {
            "panel": "second panel",
            "ext": ext,
            "report_name": f"pionier_{master_kappa_procatg}_{str(ext)}"
            + "_kappa_per_channel",
            "report_description": f"PIONIER kappa report panel - "
            f"{ext}" + "_kappa_per_channel",
            "report_tags": [],
        }

        panels[p1] = addme1

        panels[p2] = addme2

        return panels


rep = PionierKappaReport()
