from adari_core.plots.combined import CombinedPlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.panel import Panel
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.points import ScatterPlot
from adari_core.plots.text import TextPlot
from adari_core.report import AdariReportBase
from adari_core.utils.utils import fetch_kw_or_default

import os
import numpy as np


class MasterStdStarImg(AdariReportBase):
    """Master ADARI Imaging Standard Star report.

    Description
    -----------
    Master report for building the Imaging Standard Star report for all ESO instruments.
    The report consists of a full image of reduced field, found standard stars marked,
    a central region image of reduced field, found standard stars marked,
    cut plots along x and y directions, respectively,
    cut plots along x and y directions in the central region, respectively,
    a histogram of the raw and/or reduced data.
    """

    center_size = 200
    hist_bins_max = 50

    def __init__(self, name: str):
        super().__init__(name)

    def parse_sof(self):
        """
        Convert the input SOF into a list of filenames to open.

        See the documentation of :any:`AdariReportBase.parse_sof` for full
        details.

        Returns
        -------
        list of dicts
            Each dict in the list specifies the inputs required to create
            a single :any:`Panel`. Children reports of MasterStdStarImg
            require the following inputs to be defined for each panel:

            - ``"calib_std"`` - a processed input image
            - ``"raw_im"`` *(optional)* - the raw image that "calib_std" was
              derived from
            - ``"matched_phot"`` *(optional)* - a table with found standard stars
        """

        raise NotImplementedError(
            "MasterStdStarImgReport is a template only, "
            "the child Report is responsible for "
            "defining parse_sof"
        )

    def get_standards(self):
        """
        Convert the table with found standard stars.

        Returns
        -------
        list with (x,y) coordinates of stardard stars,
        eg. [std_x, std_y]
        """

        raise NotImplementedError(
            "MasterStdStarImgReport is a template only, "
            "the child Report is responsible for "
            "defining get_standards"
        )

    def generate_panels(
        self,
        im_ext=0,
        raw_ext=None,
        marking=False,
        marking_centre=False,
        im_clipping="percentile",
        im_n_clipping={"percentile": 95},
        hist_raw=True,
        hist_red=True,
        hist_clipping="sigma",
        hist_n_clipping={"nsigma": 4},
        **kwargs
    ):
        """
        Create a report from a given HDU.

        Description
        -----------
        This method constructs a report from an input HDU consisting of:
        - ImagePlot of the reduced field with standard stars marked.
        - CentralImagePlot of the central region image of reduced field with standard stars marked.
        - CutPlot along x direction of the 2D image.
        - CutPlot along y direction of the 2D image.
        - CutPlot along x direction of the cutout image.
        - CutPlot along y direction of the cutout image.
        - Histogram of the 2D raw and/or reduced image.

        Parameters
        ----------
        im_ext : int or str or list, optional
            The extension of the ``"master_im"`` file to examine.
        raw_ext : int or str or list, optional
            The extension of the ``"raw_im"`` file to examine.
        marking : bool, default=False
            If true, standard stars are marked in the full image of reduced field.
        marking_centre : bool, default=False
            If true, standard stars are marked in the cantral region of reduced field.
        im_clipping : str, or None, optional, default="percentile"
            The image clipping mode to use for the image.
        im_n_clipping : numeric, optional, default={"percentile":95}
            The argument(s) to be supplied to the clipping function selected
            by ``im_clipping``.
        hist_raw : bool, default=True
            If true, a histogram of the raw data is plotted.
        hist_red : bool, default=True
            If true, a histogram of the reduced data is plotted.
        hist_clipping, hist_n_clipping : optional, default="sigma", {"nsigma":4}
            As for ``im_clipping`` and ``im_n_clipping``, but applied to the
            histogram display of the image.

        Returns
        -------
        - panel: Panel
            A panel containing all the images, zoom-in, and histogram plots.

        """

        panels = {}
        vspace = 0.3

        # Ext - allow for single or multiple values
        if isinstance(im_ext, str) or isinstance(im_ext, int):
            im_ext = [
                im_ext,
            ] * len(self.hdus)
        if raw_ext is not None and (isinstance(raw_ext, str) or isinstance(raw_ext, int)):
            raw_ext = [
                raw_ext,
            ] * len(self.hdus)

        for i, filedict in enumerate(self.hdus):

            for iext, ext in enumerate(im_ext):

                panel = Panel(4, 3, height_ratios=[1, 4, 4])
                calib_std = self.hdus[i]["calib_std"]
                input_files = [calib_std.filename()]
                raw_im = None
                if hist_raw:
                    raw_im = self.hdus[i]["raw_im"]
                    input_files.append(raw_im.filename())
                if marking or marking_centre:
                    phot = self.hdus[i]["matched_phot"]
                    input_files.append(phot.filename())

                # Text Plot
                fname = os.path.basename(str(calib_std.filename()))
                t1 = TextPlot(columns=1, v_space=vspace)
                col1 = (
                    str(
                        calib_std["PRIMARY"].header.get("INSTRUME", "Missing INSTRUME")
                    ),
                    "EXTNAME: " + str(ext),
                    "PRO CATG: "
                    + str(
                        calib_std["PRIMARY"].header.get(
                            "HIERARCH ESO PRO CATG", "Missing PRO CATG"
                        )
                    ),
                    "FILE NAME: " + fname,
                    "RAW1 NAME: "
                    + str(
                        calib_std["PRIMARY"].header.get(
                            "HIERARCH ESO PRO REC1 RAW1 NAME", "Missing RAW1 NAME"
                        )
                    ),
                )
                t1.add_data(col1)
                panel.assign_plot(t1, 0, 0, xext=2)

                # Mark matched standards
                if (marking or marking_centre) and phot[ext].size != 0:
                    std_xvals, std_yvals = self.get_standards(phot[ext].data)
                    scatter = ScatterPlot()
                    scatter.add_data(
                        (std_xvals, std_yvals),
                        color="g",
                        label="STD",
                        marker="ring",
                        markersize=5,
                    )
                    scatter.legend = False

                # Full reduced image
                scaling = {}
                scaling["v_clip"] = im_clipping
                scaling["v_clip_kwargs"] = im_n_clipping

                fullfield = ImagePlot(
                    title="Reduced field (full)",
                    **scaling,
                )
                fullfield.add_data(np.array(calib_std[ext].data, dtype=np.float64))
                # CombinedPlot for ImagePlot and ScatterPlot
                combined = CombinedPlot(title="Reduced field (full)")
                combined.add_data(fullfield, z_order=1)
                if marking and phot[ext].size != 0:
                    combined.add_data(scatter, z_order=2)
                panel.assign_plot(combined, 0, 1, xext=1, yext=1)

                # Central region of reduced image, with found standards marked
                centerfield = CentralImagePlot(
                    fullfield,
                    extent=self.center_size,
                    title="Reduced field (central {} pixels)".format(self.center_size),
                    **scaling,
                )
                # CombinedPlot for CentralImagePlot and ScatterPlot
                combined = CombinedPlot(
                    title="Reduced field (central {} pixels)".format(self.center_size)
                )
                combined.add_data(centerfield, z_order=1)
                if marking_centre and phot[ext].size != 0:
                    combined.add_data(scatter, z_order=2)
                panel.assign_plot(combined, 0, 2)

                # Cut plot X
                cutX = CutPlot("x", title="Central column", x_label="y")
                cutX.y_label = fetch_kw_or_default(
                    calib_std[ext], "BUNIT", default="ADU"
                )
                cutX.add_data(
                    fullfield,
                    fullfield.get_data_coord(fullfield.data.shape[1] // 2, "x"),
                    label="Full",
                    color="red",
                )
                panel.assign_plot(cutX, 1, 1, xext=1, yext=1)

                # Cut plot Y
                cutY = CutPlot("y", title="Central row", x_label="x")
                cutY.y_label = fetch_kw_or_default(
                    calib_std[ext], "BUNIT", default="ADU"
                )
                cutY.add_data(
                    fullfield,
                    fullfield.get_data_coord(fullfield.data.shape[0] // 2, "y"),
                    label="Full",
                    color="red",
                )
                panel.assign_plot(cutY, 2, 1, xext=1, yext=1)

                # Do same for the central region
                # Cut plot X
                cutX_cent = CutPlot("x", title="Central region (Y)", x_label="y")
                cutX_cent.y_label = fetch_kw_or_default(
                    calib_std[ext], "BUNIT", default="ADU"
                )
                cutX_cent.add_data(
                    centerfield,
                    centerfield.get_data_coord(centerfield.data.shape[1] // 2, "x"),
                    label="Central",
                    color="red",
                )
                panel.assign_plot(cutX_cent, 1, 2, xext=1, yext=1)

                # Cut plot Y
                cutY_cent = CutPlot("y", title="Central region (X)", x_label="x")
                cutY_cent.y_label = fetch_kw_or_default(
                    calib_std[ext], "BUNIT", default="ADU"
                )
                cutY_cent.add_data(
                    centerfield,
                    centerfield.get_data_coord(centerfield.data.shape[0] // 2, "y"),
                    label="Central",
                    color="red",
                )
                panel.assign_plot(cutY_cent, 2, 2, xext=1, yext=1)

                # Plot histogram
                scaling = {}
                scaling["v_clip"] = hist_clipping
                scaling["v_clip_kwargs"] = hist_n_clipping

                if raw_im is not None:
                    r_ext = raw_ext[iext]
                histogram = HistogramPlot(
                    master_data=calib_std[ext].data,
                    raw_data=raw_im[r_ext].data if raw_im is not None else None,
                    bins=self.hist_bins_max,
                    title="Counts histogram",
                    master_label="reduced",
                    **scaling,
                )
                if raw_im is not None:
                    bin_per_val = self.hist_bins_max + 1  # Dummy val to start while
                    vals_per_bin = 1
                    rmin, rmax = histogram.get_vlim()
                    while bin_per_val > self.hist_bins_max:
                        bin_per_val = (rmax - rmin + 1) // vals_per_bin
                        vals_per_bin += 1
                    histogram.bins = bin_per_val
                panel.assign_plot(histogram, 3, 1, xext=1, yext=1)


                # Metadata
                instru = calib_std["PRIMARY"].header.get("INSTRUME")
                panels[panel] = {
                    "report_name": "{}_std_star_ext_{}".format(instru, ext),
                    "report_description": "StandardStar_{}".format(ext),
                    "report_tags": [],
                    "input_files": input_files,
                }

        return panels
