# SPDX-License-Identifier: BSD-3-Clause
"""This module renders an ADARI Core panel using matplotlib."""

# Note - docstrings only need to be included for abstract functions if there
# are special notes to add for this backend, or the input/output changes
# between backends (e.g., in render_panel). Otherwise, valid docstrings
# will be inherited from the parent BaseRenderer class.

from ..plots.combined import CombinedPlot
from .base import BaseRenderer
from adari_core.plots.panel import Panel
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.points import LinePlot, ScatterPlot
from adari_core.plots.collapse import CollapsePlot
from adari_core.plots.histogram import HistogramPlot, HISTOGRAM_DEFAULT_BINS
from adari_core.plots.cut import CutPlot
from adari_core.plots.text import TextPlot
from adari_core.plots.colormaps import ColorbarMixin
from adari_core.plots.colormaps import __colormaps__ as cmaps

import numpy as np
import os
import json


# This encoder is necessary to convert numpy datatypes into values that can be
# converted to JSON format using json.dumps
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        try:
            return super(NpEncoder, self).default(obj)
        except TypeError as e:
            raise TypeError(
                f"Unable to super encode {obj} "
                f"of type {obj.__class__} "
                f"(original error: {str(e)})"
            )


class JSONRenderer(BaseRenderer):
    """
    Render Panels using a dummy JSON backend for testing purposes.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._backend = "json"
        self._plot_render_functions = {
            ImagePlot.__name__: self.plot_image,
            CentralImagePlot.__name__: self.plot_image,
            LinePlot.__name__: self.plot_line,
            HistogramPlot.__name__: self.plot_histogram,
            ScatterPlot.__name__: self.plot_scatter,
            CollapsePlot.__name__: self.plot_collapse,
            CombinedPlot.__name__: self.plot_combined,
            CutPlot.__name__: self.plot_cut,
            TextPlot.__name__: self.plot_text,
            ColorbarMixin.__name__: self.plot_colorbarMixin,
        }
        self.interactive = False
        self.artifacts = []

    def _aspect_fnc(self, ax, plotobj, im):
        """
        Function to sent the aspect ration in different types of plots

        Parameters
        ----------
        ax : int matplotlib.axes.Axis
            The Matplotlib axis to be manipulated.
        plotobj : child of adari_core.plots.plot.Plot
            The Plot object.
        im : The plotting object to be referenced for colorbar generation
            This will usually be the output of an ax.imshow call. Set to
            None to no-op.

        Returns
        -------
        JSON

        """
        data = {}
        data["aspect"] = None
        if plotobj.aspect is not None:
            data["aspect"] = plotobj.aspect
        data["plot_colorbar"] = False
        if im is not None:
            if plotobj.colorbar_visible:
                data["plot_colorbar"] = True

        return json.loads(json.dumps(data, cls=NpEncoder))

    def _set_axes(self, ax, plotobj):
        """
        Helper function to set axes parameters from a Plot object.

        Parameters
        ----------
        ax : matplotlib.axes.Axis
            The Matplotlib axis to be manipulated.
        plotobj : child of adari_core.plots.plot.Plot
            The Plot object.

        Returns
        -------
        JSON
        """
        data = {}
        data["x_label"] = plotobj.x_label
        data["y_label"] = plotobj.y_label
        data["x_scale"] = plotobj.x_scale
        data["y_scale"] = plotobj.y_scale

        return json.loads(json.dumps(data, cls=NpEncoder))
        
    def _set_ticks(self, ax, plotobj):
        """
        Helper function to set ticks parameters from a Plot object.

        Parameters
        ----------
        ax : matplotlib.axes.Axis
            The Matplotlib axis to be manipulated.
        plotobj : child of adari_core.plots.plot.Plot
            The Plot object.

        Returns
        -------
        JSON
        """
        data = {}
        data["tick_visibility"] = plotobj.tick_visibility
        data["x_major_ticks"] = plotobj.x_major_ticks
        data["x_minor_ticks"] = plotobj.x_minor_ticks
        data["y_major_ticks"] = plotobj.y_major_ticks
        data["y_minor_ticks"] = plotobj.y_minor_ticks
        return json.loads(json.dumps(data, cls=NpEncoder))

    def _get_rectangles_from_coords(self, coords):
        """
        Generate a list of rectangular areas from a list of coords.

        Used to select ranges for multi-cell plot.

        Paramters
        ---------
        other_coords : list of two-tuples
            list of coords to check [(x,y), ...]

        Returns
        --------
        list of two-tuples of two-tuples

        """
        # Make sure its all sorted to make extension a lot easier
        sorted_coords = sorted(coords)
        rects = []
        while len(sorted_coords) > 0:
            # MCW: cols are 0, rows are 1
            # Get the first index as the starting point of the rectangle
            # Expand horizontally first

            first_pos = sorted_coords.pop(0)
            last_col = first_pos[0]

            # Expand horizontally first along the top row until
            # the numbers are no longer consecutive

            # Use sorted list of the top row of the rectangle
            same_row_positions = [c for c in sorted_coords if c[1] == first_pos[1]]

            for pos in same_row_positions:
                if last_col + 1 == pos[0]:
                    # If the position is the next column across
                    sorted_coords.remove(pos)
                    last_col = pos[0]
                else:
                    break
                    # This is the furthest right we can go with this rectangle

            # Now find the last row expanding vertically
            last_row = first_pos[1]

            # Loop from the row of the start of the rect to the end of the panel
            # grid. Can be stopped early
            for row_pos in range(first_pos[1] + 1, coords[-1][1] + 1):
                # Rect must occupy the column range over the next row
                required_seq = set(
                    [
                        (col_pos, row_pos)
                        for col_pos in range(first_pos[0], last_col + 1)
                    ]
                )
                # if all coords of the range is present
                if all(item in sorted_coords for item in required_seq):
                    last_row = row_pos  # rect can be expanded down 1
                    for pos in required_seq:
                        sorted_coords.remove(pos)
                else:
                    break  # If range is not present then the rectangle cannot
                    # be expanded any further
            # Add rect range to the rect list
            rects.append(((first_pos[0], last_row), (first_pos[1], last_col)))

        return json.loads(json.dumps(rects, cls=NpEncoder))
        # return rects

    @staticmethod
    def check_datadict(datadict, x_check=True, y_check=True, x_pos=False):
        """Check if the input dictionary contains data.

        Parameters
        ----------
        datadict: dict
            Input dictionary that contains the plot information.
        x_check: bool
            If True, it will configure the plot limits along the x
            axis.
        x_pos: bool
            If True, the x range is based on the array size rather
            than the values.
        y_check: bool
            Idem.
        Returns
        -------
        metadata: dict
            Dictionary containing the plot metadata
        """

        if "data" not in datadict.keys():
            raise ValueError("Input Plot does not contain data")
        # Initialise the dictionary containing the plot metadata
        metadata = {"color": datadict["color"]}
        # Check the content of datadict
        data = np.asarray(datadict["data"])
        if data.ndim == 1:
            metadata["xmin"] = 0
            metadata["xmax"] = data.size
            if data.size > 0:
                metadata["ymin"] = data.min()
                metadata["ymax"] = data.max()
        elif data.ndim == 2:
            if x_check:
                x_data = data[0]
                if x_pos:
                    metadata["xmin"] = 0
                    metadata["xmax"] = x_data.size
                else:
                    if x_data.size > 0:
                        metadata["xmin"] = x_data.min()
                        metadata["xmax"] = x_data.max()
                    else:
                        metadata["xmin"] = None
                        metadata["xmax"] = None
            if y_check:
                y_data = data[1]
                if y_data.size > 0:
                    metadata["ymin"] = y_data.min()
                    metadata["ymax"] = y_data.max()
                else:
                    metadata["ymin"] = None
                    metadata["ymax"] = None
            metadata["npoints"] = data.size / 2
        else:
            if x_check:
                metadata["xmin"] = None
                metadata["xmax"] = None
            if y_check:
                metadata["ymin"] = None
                metadata["ymax"] = None
            metadata["npoints"] = 0
        return metadata

    def render_panels(self, panels_dict):
        results = []
        self.artifacts = []

        # Render panels
        for p, meta in panels_dict.items():
            prodcatg = meta.get("report_prodcatg", "")
            input_files = meta.get("input_files", None)
            res = self.render_panel(
                p,
                meta["report_name"],
                meta["report_description"],
                meta["report_tags"],
                prodcatg,
                input_files,
            )
            results.append(res)

        return results

    def render_panel(
        self,
        pnl,
        pnl_name,
        pnl_description,
        pnl_tags,
        pnl_prodcatg,
        pnl_input_files=None,
    ):
        """Render the given Panel.

        Parameters
        ----------
        pnl : :any:`adari_core.plots.panel.Panel`
            The Panel to be rendered.
        pnl_name : str
            The Panel name
        pnl_description: str
            The Panel description that will be added to the metadata
        pnl_tags : iterable of str
            A list of tags that are added to the metadata
        pnl_prodcatg : str
            Product category as defined by ESO Phase III

        Returns
        -------
        JSON
        -------
        """

        data = {}
        # TODO: Return a JSON with some numbers from a "rendered" JSON panel

        self.check_plot_obj_class(pnl, Panel)
        data["metadata"] = {
            "title": pnl.title,
            "task_name": pnl_name,
            "description": pnl_description,
            "tags": pnl_tags,
            "width_ratios": pnl.width_ratios,
            "height_ratios": pnl.height_ratios,
            "x_stretch": pnl.x_stretch,
            "y_stretch": pnl.y_stretch,
            "right_subplot": pnl.right_subplot,
            "grid_x": pnl._x,
            "grid_y": pnl._y,
        }

        # Do a pre-plotting loop where we make any connections required between
        # Plot objects, rather than axis-specific connections
        for loc in pnl.get_assigned_locations():
            toPlot = pnl._plots[loc]
            if isinstance(toPlot, CutPlot):
                # Need to add a vertical line to the ImagePlot at the cut
                # location
                for label, data_array in toPlot.data.items():
                    if toPlot.cut_ax == "x":
                        # Compute the correct cut pos
                        data_array["data"].add_vert_line(
                            data_array["cut_pos"], label=None
                        )
                    elif toPlot.cut_ax == "y":
                        data_array["data"].add_horiz_line(
                            data_array["cut_pos"], label=None
                        )
            elif isinstance(toPlot, CentralImagePlot):
                # Need to add a rect to the parent ImagePlot
                # object in the cutout position
                # We need to calculate the LL corner of the rectangle,
                # and the width and height, from the CentralImagePlot.
                # However, this comes back in *array* coordinates - we
                # need to convert to *data* coordinates.

                # Get the cutout array coords
                x_min, x_max, y_min, y_max = toPlot.get_ref_image_window_coords()

                # Add the rect
                toPlot.ref_image.add_rect(
                    (x_min, y_min), x_max - x_min, y_max - y_min  # Ref point  # Width
                )  # Height

        data["plots"] = []
        # Generate a map of figures and their
        for loc in pnl.get_assigned_locations():
            toPlot = pnl._plots[loc]
            ax = None
            pltfig = self.get_render_function(toPlot)
            json_res = json.loads(pltfig(toPlot, ax))
            plot_type = pltfig.__name__
            json_res[plot_type]["x"] = loc[1]
            json_res[plot_type]["y"] = loc[0]
            json_res[plot_type]["title"] = toPlot.title
            data["plots"].append(json_res)

        # save the json to disk....
        filename = os.path.join(os.getcwd(), pnl_name + ".json")
        mime_type = "application/json"
        with open(filename, "w") as outfile:
            json.dump({"panel": data}, outfile, indent=4, cls=NpEncoder)
        artifact_metadata = {
            "filename": filename,
            "mime_type": mime_type,
            "description": pnl_description,
            "tags": pnl_tags,
            "prodcatg": pnl_prodcatg,
        }
        if pnl_input_files is not None:
            artifact_metadata["input_files"] = pnl_input_files

        self.artifacts.append(artifact_metadata)

    def plot_image(self, plotobj, ax):
        """
        Render a representation of a ImagePlot object.

        Parameters
        ----------
        plotobj : :any:`adari_core.plots.images.ImagePlot`
            The ImagePlot object to be rendered.
        ax : matplotlib.axes.Axis
            The matplotlib.axes.Axis to render the Plot with/onto.

        Returns
        -------
        JSON
        """
        data = {}
        data["v_min"], data["v_max"] = plotobj.get_vlim()
        data["v_clip"] = plotobj.v_clip
        data["v_clip_kwargs"] = plotobj.v_clip_kwargs
        data["plot_left"] = plotobj.x_origin - plotobj.x_step // 2 - 0.5
        data["plot_right"] = data["plot_left"] + plotobj.data.shape[1] * plotobj.x_step
        data["plot_bottom"] = plotobj.y_origin - plotobj.y_step // 2 - 0.5
        data["plot_top"] = data["plot_bottom"] + plotobj.data.shape[0] * plotobj.y_step
        data["origin"] = "lower"
        data["vert_lines"] = []
        data["horiz_lines"] = []
        # Add any vertical lines
        for xpos, label in plotobj.vert_lines.items():
            data["vert_lines"].append({"x": xpos, "label": label, "color": "blue"})
        for ypos, label in plotobj.horiz_lines.items():
            data["horiz_lines"].append({"y": ypos, "label": label, "color": "blue"})

        data["aspect"] = self._aspect_fnc(
            ax, plotobj, "dummy_image"
        )  # calling aspect ratio function
        data["rectangles"] = []

        # Add any rectangles
        for rect, label in plotobj.rects.items():
            data["rectangles"].append(
                {
                    "xy": rect[0],
                    "width": rect[1],
                    "height": rect[2],
                    "ec": "green",
                    "fc": "none",
                    "ls": "-",
                    "lw": 1.5,
                }
            )

        data["colormap"] = self._get_colormap_name(plotobj.colormap)

        data["axes"] = self._set_axes(ax, plotobj)
        ticks = self._set_ticks(ax, plotobj)
        if any(ticks.values()):
            data["ticks"] = ticks
        return json.dumps({"plot_image": data}, cls=NpEncoder)

    def plot_line(self, plotobj, ax):
        """
        Render a representation of a LinePlot object.

        Parameters
        ----------
        plotobj : :any:`adari_core.plots.points.LinePlot`
            The LinePlot object to be rendered.
        ax : matplotlib.axes.Axis
            The matplotlib.axes.Axis to render the Plot with/onto.

        Returns
        -------
        JSON
        """
        self.check_plot_obj_class(plotobj, LinePlot)
        data = {}
        data["data"] = []
        for label, datadict in plotobj.data.items():
            plot_meta = self.check_datadict(datadict=datadict)
            plot_meta["label"] = label
            data["data"].append(plot_meta)
        data["axes"] = self._set_axes(ax, plotobj)
        ticks = self._set_ticks(ax, plotobj)
        if any(ticks.values()):
            data["ticks"] = ticks
        data["legend"] = True if plotobj.legend else False
        data["aspect"] = self._aspect_fnc(ax, plotobj, None)

        return json.dumps({"plot_line": data}, cls=NpEncoder)

    def plot_scatter(self, plotobj, ax):
        """
        Render a representation of a ScatterPlot object.

        Parameters
        ----------
        plotobj : :any:`adari_core.plots.points.ScatterPlot`
            The ScatterPlot object to be rendered.
        ax : matplotlib.axes.Axis
            The matplotlib.axes.Axis to render the Plot with/onto.

        Returns
        -------
        JSON
        """
        data = {}
        self.check_plot_obj_class(plotobj, ScatterPlot)
        data["data"] = []
        for label, datadict in plotobj.data.items():
            markersize, marker = self._compute_marker_properties(plotobj, datadict)
            plot_meta = self.check_datadict(datadict=datadict)
            plot_meta["label"] = label
            if plotobj.markersize is not None:
                plot_meta["s"] = plotobj.markersize**2
            plot_meta["marker"] = marker
            plot_meta["markersize"] = markersize
            if "xerror" in datadict:
                plot_meta["x_errorbars"] = True
            if "yerror" in datadict:
                plot_meta["y_errorbars"] = True
            if "vline" in datadict:
                plot_meta["vlines"] = True
            if "annotate" in datadict:
                plot_meta["annotate"] = True
            data["data"].append(plot_meta)
        data["axes"] = self._set_axes(ax, plotobj)
        ticks = self._set_ticks(ax, plotobj)
        if any(ticks.values()):
            data["ticks"] = ticks
        data["legend"] = True if plotobj.legend else False
        data["aspect"] = self._aspect_fnc(ax, plotobj, None)
        return json.dumps({"plot_scatter": data}, cls=NpEncoder)

    def plot_histogram(self, plotobj, ax):
        """
        Render a representation of a HistogramPlot object.

        Parameters
        ----------
        plotobj : :any:`adari_core.plots.histogram.HistogramPlot`
            The HistogramPlot object to be rendered.
        ax : matplotlib.axes.Axis
            The matplotlib.axes.Axis to render the Plot with/onto.

        Returns
        -------
        JSON
        """
        data = {}
        self.check_plot_obj_class(plotobj, HistogramPlot)
        # Is there a preferred precedence (v_min/v_max vs scaling algorithm?)
        # Choosing v_min,v_max over automatic algorithms
        rmin = plotobj.v_min
        rmax = plotobj.v_max
        data["v_clip"] = None
        data["v_clip_kwargs"] = None
        if rmin is None or rmax is None:
            data["v_clip"] = plotobj.v_clip
            data["v_clip_kwargs"] = plotobj.v_clip_kwargs
            rmin, rmax = plotobj.get_vlim()
        data["v_min"], data["v_max"] = plotobj.get_vlim()

        data["data"] = []

        bins_set = plotobj.bins if plotobj.bins is not None else HISTOGRAM_DEFAULT_BINS

        for label, datadict in plotobj.data.items():
            data["data"].append(
                {
                    "label": label,
                    "bins": bins_set,
                    "rmin": rmin,
                    "rmax": rmax,
                    "npoints": len(datadict["data"]),
                    "color": datadict["color"],
                    "histtype": "step",
                    "log": plotobj.is_logscale(),
                }
            )
        data["axes"] = self._set_axes(ax, plotobj)
        data["legend"] = True if plotobj.legend else False
        data["aspect"] = self._aspect_fnc(ax, plotobj, None)
        return json.dumps({"plot_histogram": data}, cls=NpEncoder)

    def plot_collapse(self, plotobj, ax):
        """
        Render a representation of a CollapsePlot object.

        Parameters
        ----------
        plotobj : :any:`adari_core.plots.collapse.CollapsePlot`
            The CollapsePlot object to be rendered.
        ax : matplotlib.axes.Axis
            The matplotlib.axes.Axis to render the Plot with/onto.

        Returns
        -------
        JSON
        """
        data = {}
        self.check_plot_obj_class(plotobj, CollapsePlot)
        data["data"] = []
        for label, datadict in plotobj.data.items():
            plot_meta = self.check_datadict(datadict=datadict, x_pos=True)
            plot_meta["label"] = label
            plot_meta["where"] = "post"
            data["data"].append(plot_meta)
        data["axes"] = self._set_axes(ax, plotobj)
        data["legend"] = True if plotobj.legend else False
        data["aspect"] = self._aspect_fnc(ax, plotobj, None)
        return json.dumps({"plot_collapse": data}, cls=NpEncoder)

    def plot_combined(self, plotobj, ax):
        """
        Render a representation of a CombinedPlot object.

        Parameters
        ----------
        plotobj : :any:`adari_core.plots.combined.CombinedPlot`
            The CombinedPlot object to be rendered.
        ax : matplotlib.axes.Axis
            The matplotlib.axes.Axis to render the Plot with/onto.

        Returns
        -------
        JSON
        """
        self.check_plot_obj_class(plotobj, CombinedPlot)
        data = {}

        sorted_plot_objs = sorted(plotobj.data.items(), key=lambda x: x[1])
        data["data"] = []
        for plot, zorder in sorted_plot_objs:
            json_res = self._plot_render_functions[plot.__str__()](plot, ax)
            data["data"].append(json.loads(json_res))
        data["axes"] = self._set_axes(ax, plotobj)
        data["aspect"] = self._aspect_fnc(ax, plotobj, None)

        return json.dumps({"plot_combined": data}, cls=NpEncoder)

    def plot_cut(self, plotobj, ax):
        """
        Render a representation of a CutPlot object.

        Parameters
        ----------
        plotobj : :any:`adari_core.plots.cut.CutPlot`
            The CutPlot object to be rendered.
        ax : matplotlib.axes.Axis
            The matplotlib.axes.Axis to render the Plot with/onto.

        Returns
        -------
        JSON

        """
        data = {}
        self.check_plot_obj_class(plotobj, CutPlot)

        data["data"] = []
        for label, datadict in plotobj.data.items():
            try:
                cut_values = datadict["data"].get_axis_cut(
                    plotobj._cut_ax,
                    round(
                        datadict["data"].get_ind_coord(
                            datadict["cut_pos"], plotobj._cut_ax
                        )
                    ),
                )
            except IndexError as e:
                raise ValueError(
                    "Unable to cut image {} at pos {} ({})".format(
                        plotobj, datadict["cut_pos"], str(e)
                    )
                )
            origin = (
                datadict["data"].x_origin
                if (plotobj.cut_ax == "y")
                else datadict["data"].y_origin
            )
            step = (
                datadict["data"].x_step
                if (plotobj.cut_ax == "y")
                else datadict["data"].y_step
            )
            data["data"].append(
                {
                    "label": label,
                    "npoints": len(cut_values),
                    "xmin": origin,
                    "xmax": origin + len(cut_values),
                    "ymin": min(cut_values),
                    "ymax": max(cut_values),
                    # 'start' : origin,'stop' : origin+len(cut_values),
                    "step": step,
                    "cut_pos": datadict["cut_pos"],
                    "cut_ax": plotobj._cut_ax,
                    "color": datadict["color"],
                }
            )
        data["axes"] = self._set_axes(ax, plotobj)
        data["legend"] = True if plotobj.legend else False
        data["aspect"] = self._aspect_fnc(ax, plotobj, None)

        return json.dumps({"plot_cut": data}, cls=NpEncoder)

    def plot_text(self, plotobj, ax):
        """
        Render a representation of a TextPlot object.

        Parameters
        ----------
        plotobj : :any:`adari_core.plots.text.TextPlot`
            The TextPlot object to be rendered.
        ax : matplotlib.axes.Axis
            The matplotlib.axes.Axis to render the Plot with/onto.

        Returns
        -------
        JSON
        """
        data = {}
        self.check_plot_obj_class(plotobj, TextPlot)
        column_div = float(len(plotobj.data)) / float(plotobj.columns)

        data["data"] = []
        for i, s in enumerate(plotobj.data):
            col_this = int(float(i) / column_div)
            text_x = col_this * (1.0 / float(plotobj.columns))
            text_y = 1.0 - plotobj.v_space * int(i % (column_div))
            # column here is not meaningful if the TextPlot is a single column object
            # 'column' : col_this
            data["data"].append(
                {
                    "x": text_x,
                    "y": text_y,
                    "text": s,
                    "column": col_this,
                    "v_space": plotobj.v_space,
                }
            )

        return json.dumps({"plot_text": data}, cls=NpEncoder)

    def plot_colorbarMixin(self, plotobj, ax):
        """
        Render a represantation of a ColorbarMixin object.

        Parameters
        ----------
        plotobj : :any:`adari_core.plots.colormaps.ColorbarMixin`
            The ColorbarMixin object to be rendered.
        ax : matplotlib.axes.Axis
            The matplotlib.axes.Axis to render the Plot with/onto.

        Returns
        -------
        JSON
        """
        data = {}
        self.check_plot_obj_class(plotobj, ColorbarMixin)

        data["v_min"], data["v_max"] = plotobj.get_vlim()
        data["v_clip"] = plotobj.v_clip
        data["v_clip_kwargs"] = plotobj.v_clip_kwargs
        data["colormap"] = plotobj.get_colormap()
        data["colorbar_visible"] = plotobj.colorbar_visible

        return json.dumps({"plot_colorbarMixin": data}, cls=NpEncoder)

    def _get_colormap_name(self, colormap):
        """
        Use colormap dictionary to obtain the name of colormap.
        In case of user definied map the list of hex strings is returned.
        """
        try:
            return list(cmaps.keys())[list(cmaps.values()).index(tuple(colormap))]
        except Exception:
            return colormap
