# SPDX-License-Identifier: BSD-3-Clause
"""This module provides a renderer class for the Bokeh plotting system."""

# Note - docstrings only need to be included for abstract functions if there
# are special notes to add for this backend, or the input/output changes
# between backends (e.g., in render_panel). Otherwise, valid docstrings
# will be inherited from the parent BaseRenderer class.å

from .base import BaseRenderer
from adari_core.plots.panel import Panel
from adari_core.plots.combined import CombinedPlot
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.points import (
    LinePlot,
    ScatterPlot,
)
from adari_core.plots.collapse import CollapsePlot
from adari_core.plots.histogram import HistogramPlot
from adari_core.plots.cut import CutPlot
from adari_core.plots.text import TextPlot

import numpy as np
import logging

import bokeh.layouts
from bokeh.layouts import Spacer
from bokeh.models import LinearColorMapper, Label
import bokeh.plotting
from bokeh.plotting import figure
from bokeh.models import ColorBar, Range1d

logger = logging.getLogger(__name__)


class BokehRenderer(BaseRenderer):
    """
    Render Panels using the bokeh plotting package.
    """

    MARKER_RENDER_COMMANDS = {
        "circle": {"marker": "circle"},  # DEFAULT
        "ring": {"marker": "circle", "hollow": True},
        "point": {"marker": "dot"},
        "plus": {"marker": "cross"},
        "x": {"marker": "x"},
        "square": {"marker": "square", "hollow": True},
        "block": {"marker": "square"},
    }

    def _compute_bokeh_marker_args(self, marker: str, color: str):
        """
        Compute the necessary kwargs for bokeh scatter plotting

        Parameters
        ----------
        marker : str
            The marker to be used.
        color : str
            The color to be applied to the marker

        Returns
        -------
        kw : dict
            The dictionary of keyword arguments that forms this marker.
        """
        try:
            marker_scaffold = {
                k: v
                for k, v in self.MARKER_RENDER_COMMANDS[marker].items()  # Make a copy
            }
        except KeyError:
            raise ValueError(f"{marker} is not a supported marker")

        marker_scaffold["line_color"] = color
        if marker_scaffold.get("hollow", False):
            marker_scaffold["fill_color"] = None
            del marker_scaffold["hollow"]
        else:
            marker_scaffold["fill_color"] = color

        return marker_scaffold

    def __init__(self, *args, **kwargs):
        self._backend = "bokeh"
        self._plot_render_functions = {
            ImagePlot.__name__: self.plot_image,
            CentralImagePlot.__name__: self.plot_image,
            LinePlot.__name__: self.plot_line,
            HistogramPlot.__name__: self.plot_histogram,
            ScatterPlot.__name__: self.plot_scatter,
            CollapsePlot.__name__: self.plot_collapse,
            CombinedPlot.__name__: self.plot_combined,
            CutPlot.__name__: self.plot_cut,
            TextPlot.__name__: self.plot_text,
        }

    def render_panels(self, panels_dict):
        results = []

        # Render panels
        for p, meta in panels_dict.items():
            res = self.render_panel(
                p, meta["report_name"], meta["report_description"], meta["report_tags"]
            )
            results.append(res)

        return results

    def render_panel(self, pnl, pnl_name, pnl_description, pnl_tags):
        """
        Render the given Panel object.
        :param pnl: The Panel object to be rendered
        :return:
        """

        self.check_plot_obj_class(pnl, Panel)

        grid = []
        rows = []
        for row in pnl._plots:
            rowlist = []
            for j in row:
                if j is not None:
                    logger.debug(j._data.shape)
                    logger.debug(j)
                    plot_func = self.get_render_function(j)
                    # the results must not be None
                    rowlist.append(plot_func(j))
                else:
                    rowlist.append(Spacer())
            grid.append(rowlist)
        grid = np.array(grid)

        # Handle the axis linkages - need to be aware each linked plot may
        # occur multiple times within the plot grid
        # for linked_plots in pnl._plots_linked_in_x:
        #    # Find the location(s) of the first plot in the linked set
        #    i0 = pnl.get_plot_locations(linked_plots[0])
        #    for i in range(1, len(i0)):
        #        grid[i0[i]].x_range = grid[i0[0]].x_range
        #    # Now, loop over all the other linked plots (again noting they
        #    # may appear multiple times)
        #    for j in range(1, len(linked_plots)):
        #        j0 = pnl.get_plot_locations(linked_plots[j])
        #        for ij in j0:
        #            grid[ij].x_range = grid[i0[0]].x_range
        # x_array = np.array([10,20,30,40,50,60])
        # y_array = np.array([50,60,70,80,90,100])
        # Create a line plot
        # p1 = figure()
        # p1.line(x_array, y_array)
        # p2 = figure()
        # p2.line(y_array, x_array)
        # rows2.append(p1)
        # rows2.append(p2)
        # return bokeh.layouts.row(children=rows)
        grid = grid.tolist()
        # print(grid)
        # return bokeh.layouts.layout(children=grid)
        # return bokeh.layouts.layout(children=grid)
        return bokeh.layouts.gridplot(children=grid)  # ,sizing_mode='scale_height)

    def plot_image(self, plotobj):
        """
        Render a representation of a ImagePlot object.

        Parameters
        ----------
        plotobj : :any:`adari_core.plots.images.ImagePlot`
            The ImagePlot object to be rendered.

        Returns
        -------
        bokeh.plotting.figure
            A Bokeh figure object containing the rendered Plot.
        """

        self.check_plot_obj_class(plotobj, ImagePlot)
        # https://docs.bokeh.org/en/2.4.0/docs/reference/palettes.html#bokeh-palettes
        v_min, v_max = plotobj.get_vlim()
        cmapper = LinearColorMapper(low=v_min, high=v_max, palette=plotobj.colormap)
        plot = figure(
            title=plotobj.title,
            x_axis_label=plotobj.x_label,
            y_axis_label=plotobj.y_label,
        )

        plot.image(
            image=[plotobj.data],
            x=0,
            y=0,
            dw=plotobj.data.shape[0],
            dh=plotobj.data.shape[1],
            color_mapper=cmapper,
        )
        if plotobj.colorbar_visible:
            color_bar = ColorBar(color_mapper=cmapper, label_standoff=12)
            plot.add_layout(color_bar, "right")
        return plot

    def plot_line(self, plotobj):
        """
        Render a representation of a LinePlot object.

        Parameters
        ----------
        plotobj : :any:`adari_core.plots.points.LinePlot`
            The LinePlot object to be rendered.

        Returns
        -------
        bokeh.plotting.figure
            A Bokeh figure object containing the rendered Plot.
        """

        self.check_plot_obj_class(plotobj, LinePlot)

        plot = figure(
            title=plotobj.title,
            x_axis_label=plotobj.x_label,
            y_axis_label=plotobj.y_label,
        )
        for label, datadict in plotobj.data.items():
            # TODO: Set the linestyle
            plot.line(
                x=datadict["data"][0],
                y=datadict["data"][1],
                legend_label=label,
            )
        return plot

    def plot_scatter(self, plotobj):
        self.check_plot_obj_class(plotobj, ScatterPlot)
        plot = figure(
            title=plotobj.title,
            x_axis_label=plotobj.x_label,
            y_axis_label=plotobj.y_label,
        )
        for label, datadict in plotobj.data.items():
            # Work out the marker and markersize for these data
            markersize, marker = self._compute_marker_properties(plotobj, datadict)
            plot.scatter(
                x=datadict["data"][0],
                y=datadict["data"][1],
                legend_label=label,
                size=markersize
                * 4 ** self._compute_bokeh_marker_args(marker, plotobj["color"]),
            )
        return plot

    def plot_histogram(self, plotobj):
        return Spacer()

    def plot_collapse(self, plotobj):
        return Spacer()

    def plot_combined(self, plotobj):
        return Spacer()

    def plot_cut(self, plotobj):
        self.check_plot_obj_class(plotobj, CutPlot)
        plot = figure(
            title=plotobj.title,
            x_axis_label=plotobj.x_label,
            y_axis_label=plotobj.y_label,
        )

        for label, datadict in plotobj.data.items():
            try:
                cut_values = datadict["data"].get_axis_cut(
                    plotobj._cut_ax,
                    round(
                        datadict["data"].get_ind_coord(
                            datadict["cut_pos"], plotobj._cut_ax
                        )
                    ),
                )
            except IndexError as e:
                raise ValueError(
                    "Unable to cut image {} at pos {} ({})".format(
                        plotobj, datadict["cut_pos"], str(e)
                    )
                )
            origin = (
                datadict["data"].x_origin
                if (plotobj.cut_ax == "y")
                else datadict["data"].y_origin
            )
            step = (
                datadict["data"].x_step
                if (plotobj.cut_ax == "y")
                else datadict["data"].y_step
            )
            plot.step(
                x=np.arange(start=origin, stop=origin + len(cut_values), step=step),
                y=cut_values,
                name=label,
                mode="center",
                color="black",
            )  # datadict['color'])
            # color=datadict['color'])
        return plot

    def plot_text(self, plotobj):
        # It may make more sense to put the TextPlot into the html template???
        # This is not as nice as matplotlib approach
        # return Spacer()
        plotobj.v_space = 0.05
        self.check_plot_obj_class(plotobj, TextPlot)
        column_div = float(len(plotobj.data)) / float(plotobj.columns)
        plot = figure()
        plot.axis.visible = False
        plot.xgrid.visible = False
        plot.ygrid.visible = False
        yvals = []
        for i, s in enumerate(plotobj.data):
            col_this = int(float(i) / column_div)
            text_x = col_this * (1.0 / float(plotobj.columns))
            text_y = 1.0 - plotobj.v_space * int(i % (column_div))
            yvals.append(text_y)

            t = Label(
                x=text_x,
                y=text_y,
                text=s,  # border_line_color='white', border_line_alpha=0.0,
                background_fill_color="white",
                background_fill_alpha=1.0,
            )
            plot.add_layout(t)
        plot.y_range = Range1d(
            min(yvals), 1 + plotobj.v_space
        )  # min(yvals),max(yvals)*1.1)
        plot.x_range = Range1d(0, 10)  # min(yvals),max(yvals))

        return plot
