# SPDX-License-Identifier: BSD-3-Clause
"""This module renders an ADARI Core panel using matplotlib."""

# Note - docstrings only need to be included for abstract functions if there
# are special notes to add for this backend, or the input/output changes
# between backends (e.g., in render_panel). Otherwise, valid docstrings
# will be inherited from the parent BaseRenderer class.

from ..plots.combined import CombinedPlot
from .base import BaseRenderer
from adari_core.plots.panel import Panel
from adari_core.plots.images import ImagePlot, CentralImagePlot
from adari_core.plots.points import (
    LinePlot,
    ScatterPlot,
)
from adari_core.plots.collapse import CollapsePlot
from adari_core.plots.histogram import HistogramPlot, HISTOGRAM_DEFAULT_BINS
from adari_core.plots.cut import CutPlot
from adari_core.plots.text import TextPlot
from adari_core.utils.utils import is_iterable

import inspect
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.gridspec import GridSpec
import matplotlib.patches as patches
import matplotlib as mpl
from matplotlib.backends.backend_pdf import PdfPages
import os

from matplotlib.colors import LinearSegmentedColormap
from matplotlib.ticker import MultipleLocator, AutoMinorLocator, FixedLocator


class MatplotlibRenderer(BaseRenderer):
    """
    Render Panels using the matplotlib plotting package.
    """

    MARKER_RENDER_COMMANDS = {
        "circle": {"marker": "o"},  # DEFAULT
        "ring": {"marker": "o", "facecolors": "none"},
        "point": {"marker": ","},
        "plus": {"marker": "+"},
        "x": {"marker": "x"},
        "square": {"marker": "s", "facecolors": "none"},
        "block": {"marker": "s"},
        "none": {"marker": "none"},
    }

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._backend = "matplotlib"
        self._plot_render_functions = {
            ImagePlot.__name__: self.plot_image,
            CentralImagePlot.__name__: self.plot_image,
            LinePlot.__name__: self.plot_line,
            HistogramPlot.__name__: self.plot_histogram,
            ScatterPlot.__name__: self.plot_scatter,
            CollapsePlot.__name__: self.plot_collapse,
            CombinedPlot.__name__: self.plot_combined,
            CutPlot.__name__: self.plot_cut,
            TextPlot.__name__: self.plot_text,
        }
        self.interactive = True
        mpl_backend = kwargs.get("mpl_backend", None)
        if mpl_backend is not None:
            self.interactive = False
            mpl.use(mpl_backend)
        self.savefig_format = None
        renderer_config = args[0]
        if renderer_config is not None:
            if "mpl_savefig_format" in renderer_config:
                self.savefig_format = renderer_config["mpl_savefig_format"]
        self.artifacts = []

    def _aspect_fnc(self, ax, plotobj, im):
        """
        Function to sent the aspect ration in different types of plots

        Parameters
        ----------
        ax : int matplotlib.axes.Axis
            The Matplotlib axis to be manipulated.
        plotobj : child of adari_core.plots.plot.Plot
            The Plot object.
        im : The plotting object to be referenced for colorbar generation
            This will usually be the output of an ax.imshow call. Set to
            None to no-op.

        Returns
        -------
        None.

        """
        # Set axis limits (maybe make this a seperate func?)
        ax.set_xlim(plotobj.get_xlim())
        ax.set_ylim(plotobj.get_ylim())

        if plotobj.aspect is not None:
            ax.set_aspect(plotobj.aspect)

        if im is not None:
            if plotobj.colorbar_visible:
                # Do not show colorbar if its not supposed to be visible
                colorbar_kwargs = {}
                if plotobj.aspect is not None and plotobj.aspect == "auto":
                    colorbar_kwargs["pad"] = 0.02
                if plotobj.cbar_kwargs is not None:
                    colorbar_kwargs.update(plotobj.cbar_kwargs)
                plt.colorbar(im, **colorbar_kwargs)

        return

    def _set_axes(self, ax, plotobj):
        """
        Helper function to set axes parameters from a Plot object.

        Parameters
        ----------
        ax : matplotlib.axes.Axis
            The Matplotlib axis to be manipulated.
        plotobj : child of adari_core.plots.plot.Plot
            The Plot object.
        """
        ax.set_xlabel(plotobj.x_label)
        ax.set_ylabel(plotobj.y_label)
        ax.set_xscale(plotobj.x_scale)
        ax.set_yscale(plotobj.y_scale)
        ax.autoscale()  # Auto-include all data, unless limits modified below
        if plotobj.x_min is not None:
            ax.set_xlim(xmin=plotobj.x_min)
        if plotobj.x_max is not None:
            ax.set_xlim(xmax=plotobj.x_max)
        if plotobj.y_min is not None:
            ax.set_ylim(ymin=plotobj.y_min)
        if plotobj.y_max is not None:
            ax.set_ylim(ymax=plotobj.y_max)

    def _set_ticks(self, ax, plotobj):

        if plotobj.tick_visibility is not None:
            tv = plotobj.tick_visibility
    
            ax.tick_params(
                which="both",
                top=tv["top"],
                labeltop=tv["labeltop"],
                bottom=tv["bottom"],
                labelbottom=tv["labelbottom"],
                right=tv["right"],
                labelright=tv["labelright"],
                left=tv["left"],
                labelleft=tv["labelleft"],
            )

        if plotobj.x_major_ticks is not None:
            if is_iterable(plotobj.x_major_ticks):
                ax.set_xticks(plotobj.x_major_ticks)
            elif plotobj.x_major_ticks > 0:
                majorLocator = MultipleLocator(base=plotobj.x_major_ticks)
                ax.xaxis.set_major_locator(majorLocator)
            elif plotobj.x_major_ticks < 0:
                ax.set_xticks([])
        if plotobj.x_minor_ticks is not None:
            minorLocator = None
            if is_iterable(plotobj.x_minor_ticks):
                minorLocator = FixedLocator(base=plotobj.x_minor_ticks)
            elif plotobj.x_minor_ticks > 0:
                minorLocator = MultipleLocator(base=plotobj.x_minor_ticks)
            elif plotobj.x_minor_ticks == 0:
                minorLocator = AutoMinorLocator()
            if minorLocator is not None:
                ax.xaxis.set_minor_locator(minorLocator)
        if plotobj.y_major_ticks is not None:
            if is_iterable(plotobj.y_major_ticks):
                ax.set_yticks(plotobj.y_major_ticks)
            elif plotobj.y_major_ticks > 0:
                majorLocator = MultipleLocator(base=plotobj.y_major_ticks)
                ax.yaxis.set_major_locator(majorLocator)
            elif plotobj.y_major_ticks < 0:
                ax.set_yticks([])
        if plotobj.y_minor_ticks is not None:
            minorLocator = None
            if is_iterable(plotobj.y_minor_ticks):
                minorLocator = FixedLocator(base=plotobj.y_minor_ticks)
            elif plotobj.y_minor_ticks > 0:
                minorLocator = MultipleLocator(base=plotobj.y_minor_ticks)
            elif plotobj.y_minor_ticks == 0:
                minorLocator = AutoMinorLocator()
            if minorLocator is not None:
                ax.yaxis.set_minor_locator(minorLocator)

    def _get_rectangles_from_coords(self, coords):
        """
        Generate a list of rectangular areas from a list of coords.

        Used to select ranges for multi-cell plot.

        Paramters
        ---------
        other_coords : list of two-tuples
            list of coords to check [(x,y), ...]

        Returns
        --------
        list of two-tuples of two-tuples

        """
        # Make sure its all sorted to make extension a lot easier
        sorted_coords = sorted(coords)
        rects = []
        while len(sorted_coords) > 0:
            # MCW: cols are 0, rows are 1
            # Get the first index as the starting point of the rectangle
            # Expand horizontally first

            first_pos = sorted_coords.pop(0)
            last_col = first_pos[0]

            # Expand horizontally first along the top row until
            # the numbers are no longer consecutive

            # Use sorted list of the top row of the rectangle
            same_row_positions = [c for c in sorted_coords if c[1] == first_pos[1]]

            for pos in same_row_positions:
                if last_col + 1 == pos[0]:
                    # If the position is the next column across
                    sorted_coords.remove(pos)
                    last_col = pos[0]
                else:
                    break
                    # This is the furthest right we can go with this rectangle

            # Now find the last row expanding vertically
            last_row = first_pos[1]

            # Loop from the row of the start of the rect to the end of the panel
            # grid. Can be stopped early
            for row_pos in range(first_pos[1] + 1, coords[-1][1] + 1):
                # Rect must occupy the column range over the next row
                required_seq = set(
                    [
                        (col_pos, row_pos)
                        for col_pos in range(first_pos[0], last_col + 1)
                    ]
                )
                # if all coords of the range is present
                if all(item in sorted_coords for item in required_seq):
                    last_row = row_pos  # rect can be expanded down 1
                    for pos in required_seq:
                        sorted_coords.remove(pos)
                else:
                    break  # If range is not present then the rectangle cannot
                    # be expanded any further
            # Add rect range to the rect list
            rects.append(((first_pos[0], last_row), (first_pos[1], last_col)))
        return rects

    def _plot_vline(self, axes, datadict, linestyle, linewidth):
        trans = mpl.transforms.blended_transform_factory(axes.transData, axes.transAxes)
        ymin = 0
        ymax = 1
        if "ymin" in datadict.keys():
            ymin = datadict["ymin"]
        if "ymax" in datadict.keys():
            ymax = datadict["ymax"]
        axes.vlines(
            datadict["data"],
            ymin=ymin,
            ymax=ymax,
            transform=trans,
            colors=datadict["color"],
            linestyle=linestyle,
            linewidth=linewidth,
        )

    def _plot_annotate(self, axes, datadict, notes):
        trans = mpl.transforms.blended_transform_factory(axes.transData, axes.transAxes)
        for x,y,t in zip(datadict["data"][0],datadict["data"][1],notes):
            axes.annotate(
                t,
                (x,y),
                transform=trans,
                xytext=(5,0),
                textcoords='offset points',
            )


    def render_panels(self, panels_dict):
        """
        This renders all the panels. In the "normal" case it is just a loop
        on the render_panel function. In more complex cases like the pdf
        driver there is some bookkeeping prior and after the loop
        """
        results = []
        self.artifacts = []
        if self.savefig_format == "pdf":
            self.pdf_filename = next(iter(panels_dict.values()))["report_name"] + ".pdf"
            self.pdf_pages = PdfPages(self.pdf_filename)

        # Render panels
        for p, meta in panels_dict.items():
            prodcatg = meta.get("report_prodcatg", "")
            input_files = meta.get("input_files", None)
            res = self.render_panel(
                p,
                meta["report_name"],
                meta["report_description"],
                meta["report_tags"],
                prodcatg,
                input_files,
            )
            results.append(res)
        if self.savefig_format == "pdf":
            self.pdf_pages.close()
        return results

    def render_panel(
        self,
        pnl,
        pnl_name,
        pnl_description,
        pnl_tags,
        pnl_prodcatg,
        pnl_input_files=None,
    ):
        """Render the given Panel.

        Parameters
        ----------
        pnl : :any:`adari_core.plots.panel.Panel`
            The Panel to be rendered.
        pnl_name : str
            The Panel name
        pnl_description: str
            The Panel description that will be added to the metadata
        pnl_tags : iterable of str
            A list of tags that are added to the metadata
        pnl_prodcatg : str
            Product category as defined by ESO Phase III

        Returns
        -------
        matplotlib.figure.Figure
            A matplotlib Figure object containing the rendered Panel.
        """
        self.check_plot_obj_class(pnl, Panel)
        px = 1 / plt.rcParams["figure.dpi"]  # pixel in inches
        # FIXME Avoid magic number
        # FIXME How to cope with this maths and 'extended' plots?

        # PIPE-11154 - disable the warning about excess Figure objects
        # print(dict(mpl.rcParams))
        # Not required now that figures are being closed ASAP
        # mpl.rcParams["figure.max_open_warning"] = -1

        try:
            fig = plt.figure(
                figsize=(
                    pnl.x_stretch * pnl._x * 300 * px + 1,
                    pnl.y_stretch * pnl._y * 300 * px,
                ),
                # Each plot is 300x300 pixels
                tight_layout=True,
            )
        except UserWarning:
            fig = plt.figure(
                figsize=(
                    pnl.x_stretch * pnl._x * 300 * px + 1,
                    pnl.y_stretch * pnl._y * 300 * px,
                )
            )
            # Remove constrained_layout
            
        if pnl.right_subplot is not None:
            fig.subplots_adjust(right=pnl.right_subplot)
            
        if pnl.title is not None and pnl.title != "":
            fig.suptitle(pnl.title)
        # Create the necessary GridSpec
        gs = GridSpec(
            pnl._y,
            pnl._x,
            fig,
            width_ratios=pnl.width_ratios,
            height_ratios=pnl.height_ratios,
        )
        axarr = np.empty((pnl._y, pnl._x), dtype=object)

        # Do a pre-plotting loop where we make any connections required between
        # Plot objects, rather than axis-specific connections
        for loc in pnl.get_assigned_locations():
            toPlot = pnl._plots[loc]
            if isinstance(toPlot, CutPlot):
                # Need to add a vertical line to the ImagePlot at the cut
                # location
                for label, data_array in toPlot.data.items():
                    if toPlot.cut_ax == "x":
                        # Compute the correct cut pos
                        data_array["data"].add_vert_line(
                            data_array["cut_pos"], label=None
                        )
                    elif toPlot.cut_ax == "y":
                        data_array["data"].add_horiz_line(
                            data_array["cut_pos"], label=None
                        )
            elif isinstance(toPlot, CentralImagePlot):
                # Need to add a rect to the parent ImagePlot
                # object in the cutout position
                # We need to calculate the LL corner of the rectangle,
                # and the width and height, from the CentralImagePlot.
                # However, this comes back in *array* coordinates - we
                # need to convert to *data* coordinates.

                # Get the cutout array coords
                x_min, x_max, y_min, y_max = toPlot.get_ref_image_window_coords()

                # PIPE-11005 - restrict the bounds of the box to the
                # parent image extent
                (
                    im_left,
                    im_right,
                    im_bottom,
                    im_top,
                ) = toPlot.ref_image.get_image_plot_bounds()
                x_min = max(x_min, im_left)
                y_min = max(y_min, im_bottom)
                x_max = min(x_max, im_right)
                y_max = min(y_max, im_top)

                # Add the rect
                toPlot.ref_image.add_rect(
                    (x_min, y_min), x_max - x_min, y_max - y_min
                )  # Ref point, Width, Height
            elif (
                isinstance(toPlot, CombinedPlot)
                and CentralImagePlot in toPlot._set_of_existing_bases()
            ):
                # Much like the CentralImagePlot option above, but accessing the
                # CentralImagePlot element(s) of the CombinedPlot dictionary

                # Get the principal plot class
                thisDict = {
                    _: inspect.getmro(_.__class__)[0] for _ in toPlot.data.keys()
                }
                # Invert the dictionary and match the desired class
                thesePlots = [
                    list(thisDict.keys())[
                        list(thisDict.values()).index(CentralImagePlot().__class__)
                    ]
                ]
                # Forcing into a list is only helpful if there's ever
                # more than one CentralImagePlot within a CombinedPlot
                for thisPlot in thesePlots:
                    # Get the cutout array coords
                    x_min, x_max, y_min, y_max = thisPlot.get_ref_image_window_coords()

                    # PIPE-11005 - restrict the bounds of the box to the
                    # parent image extent
                    (
                        im_left,
                        im_right,
                        im_bottom,
                        im_top,
                    ) = thisPlot.ref_image.get_image_plot_bounds()
                    x_min = max(x_min, im_left)
                    y_min = max(y_min, im_bottom)
                    x_max = min(x_max, im_right)
                    y_max = min(y_max, im_top)

                    # Add the rect
                    thisPlot.ref_image.add_rect(
                        (x_min, y_min), x_max - x_min, y_max - y_min
                    )  # Ref point, Width, Height

        # Format {ImagePlot : {'x':[cut positions], 'y' : [cut positions]}}
        # Generate a map of figures and their
        for loc in pnl.get_assigned_locations():
            toPlot = pnl._plots[loc]
            ax = fig.add_subplot(
                gs[
                    loc[0] : loc[0] + pnl._extents[loc][0],
                    loc[1] : loc[1] + pnl._extents[loc][1],
                ],
                title=toPlot.title,
            )
            pltfig = self.get_render_function(toPlot)
            pltfig(toPlot, ax)
            axarr[loc] = ax

        # Do after everything is rendered to ensure it appears on top and all
        # the gridspec figures have been initialised

        # Handle the x-plot linkages
        for link in pnl._plots_linked_in_x:
            # Get the first figure of the coords set.
            # This will be the origin point for the shared axis
            prev_ax = axarr[link[0][1], link[0][0]]
            for x, y in link:
                current_ax = axarr[y, x]  # Get next axis in the coords set
                current_ax.sharex(prev_ax)
                prev_ax = current_ax  # Set to prev ax to link it to the next one

        # Handle the y-plot linkages
        for link in pnl._plots_linked_in_y:
            # Get the first figure of the coords set.
            # This will be the origin point for the shared axis
            prev_ax = axarr[link[0][1], link[0][0]]
            for x, y in link:
                current_ax = axarr[y, x]  # Get next axis in the coords set
                current_ax.sharey(prev_ax)
                prev_ax = current_ax  # Set to prev ax to link it to the next one

        # Finally show the figure if in interactive mode
        # or save to a figure if in the non-interactive mode.
        if self.interactive:
            fig.show()
        else:
            this_artifact = {
                "description": pnl_description,
                "tags": pnl_tags,
                "prodcatg": pnl_prodcatg,
            }
            if pnl_input_files is not None:
                this_artifact["input_files"] = pnl_input_files
            if self.savefig_format == "png":
                fig.savefig(pnl_name + ".png")
                # Add this panel to the list of artifacts
                filename = os.path.join(os.getcwd(), pnl_name + ".png")
                mime_type = "image/png"
                this_artifact["mime_type"] = mime_type
            elif self.savefig_format == "pdf":
                fig.savefig(self.pdf_pages)
                # Add this panel to the list of artifacts
                filename = os.path.join(os.getcwd(), self.pdf_filename)
                this_artifact["artifact_section"] = self.pdf_pages.get_pagecount()
                this_artifact["mime_type"] = "application/pdf"
            this_artifact["filename"] = filename
            self.artifacts.append(this_artifact)

        mpl.pyplot.close(fig)
        return None

    def plot_image(self, plotobj, ax, set_axes=True):
        """
        Render a representation of a ImagePlot object.

        Parameters
        ----------
        plotobj : :any:`adari_core.plots.images.ImagePlot`
            The ImagePlot object to be rendered.
        ax : matplotlib.axes.Axis
            The matplotlib.axes.Axis to render the Plot with/onto.
        """
        if ax is None:
            raise ValueError("Must supply a Matplotlib axes object")

        try:
            # Colormap is stored as a list of hex strings in the colormap mixin.
            # Do not give it a name.
            cmap_used = LinearSegmentedColormap.from_list("", plotobj.colormap)
        except ValueError:
            # If the colormap creator fails parsing the list
            raise ValueError("plot _colormap must be a list of hex strings")
        v_min, v_max = plotobj.get_vlim()

        plot_left, plot_right, plot_bottom, plot_top = plotobj.get_image_plot_bounds()

        im = ax.imshow(
            plotobj.data,
            cmap=cmap_used,
            vmin=v_min,
            vmax=v_max,
            origin="lower",
            aspect=plotobj.aspect,
            interpolation=plotobj.interp,
            extent=[
                plot_left,  # Left
                plot_right,  # Right
                plot_bottom,  # Bottom
                plot_top,
            ],
        )  # Top

        # Add any vertical lines
        for xpos, label in plotobj.vert_lines.items():
            ax.axvline(
                x=xpos, label=label, color="blue", zorder=20
            )  # Default zorder is higher than plot_rect
        for ypos, label in plotobj.horiz_lines.items():
            ax.axhline(
                y=ypos, label=label, color="blue", zorder=20
            )  # Default zorder is higher than plot_rect

        self._aspect_fnc(ax, plotobj, im)  # calling aspect ratio function

        # Add any rectangles
        for rect, label in plotobj.rects.items():
            ax.add_patch(
                patches.Rectangle(
                    rect[0],
                    rect[1],
                    rect[2],
                    ec="green",
                    fc="none",
                    ls="-",
                    lw=1.5,
                    zorder=10,  # Default to moderate value to allow space either side
                )
            )

        if set_axes:
            self._set_axes(ax, plotobj)

        self._set_ticks(ax, plotobj)

        return

    def plot_line(self, plotobj, ax, set_axes=True):
        """
        Render a representation of a LinePlot object.

        Parameters
        ----------
        plotobj : :any:`adari_core.plots.points.LinePlot`
            The LinePlot object to be rendered.
        ax : matplotlib.axes.Axis
            The matplotlib.axes.Axis to render the Plot with/onto.
        """
        self.check_plot_obj_class(plotobj, LinePlot)
        if ax is None:
            raise ValueError("Must supply a Matplotlib axes object")
        for label, datadict in plotobj.data.items():
            linestyle = self._compute_linestyle_properties(plotobj, datadict)
            linewidth = self._compute_linewidth_properties(plotobj, datadict)

            # plot a line plot
            if "data" in datadict.keys() and "vline" not in datadict.keys():
                ax.plot(
                    datadict["data"][0],
                    datadict["data"][1],
                    label=label,
                    color=datadict["color"],
                    linestyle=linestyle,
                    linewidth=linewidth,
                    zorder=20,  # Default to value higher than plot_rect
                )
            # plot a vertical line
            elif "vline" in datadict.keys() and datadict["vline"]:
                self._plot_vline(ax, datadict, linestyle, linewidth)

        # Plot text
        for textdict in plotobj._text:
            ax.text(
                textdict["x"],
                textdict["y"],
                textdict["text"],
                rotation=textdict["rotation"],
                horizontalalignment=textdict["horizontalalignment"],
                verticalalignment=textdict["verticalalignment"],
                fontsize=textdict["fontsize"],
            )

        if set_axes:
            self._set_axes(ax, plotobj)

        self._set_ticks(ax, plotobj)

        if plotobj.legend:
            ax.legend()
        self._aspect_fnc(ax, plotobj, None)
        return

    def plot_scatter(self, plotobj, ax, set_axes=True):
        """
        Render a representation of a ScatterPlot object.

        Parameters
        ----------
        plotobj : :any:`adari_core.plots.points.ScatterPlot`
            The ScatterPlot object to be rendered.
        ax : matplotlib.axes.Axis
            The matplotlib.axes.Axis to render the Plot with/onto.
        """

        self.check_plot_obj_class(plotobj, ScatterPlot)
        if ax is None:
            raise ValueError("Must supply a Matplotlib axes object")
        for label, datadict in plotobj.data.items():
            # Work out the marker and markersize for these data
            markersize, marker = self._compute_marker_properties(plotobj, datadict)
            if "facecolors" in self.MARKER_RENDER_COMMANDS[marker]:
                color_commands = {"edgecolors": datadict["color"]}
            else:
                color_commands = {"c": datadict["color"]}

            # If error array(s) defined, plot the errorbar
            if "xerror" in datadict.keys() or "yerror" in datadict.keys():
                xerr = None
                yerr = None
                if "xerror" in datadict.keys():
                    xerr = datadict["xerror"]
                if "yerror" in datadict.keys():
                    yerr = datadict["yerror"]
                ax.errorbar(
                    datadict["data"][0],
                    datadict["data"][1],
                    xerr=xerr,
                    yerr=yerr,
                    linestyle="",
                    ecolor=datadict["color"],
                    label=label,
                    markersize=markersize,
                    **color_commands,
                    **self.MARKER_RENDER_COMMANDS[marker],
                )
            else:
                # Regular scatter plot
                ax.scatter(
                    datadict["data"][0],
                    datadict["data"][1],
                    label=label,
                    s=markersize**2,
                    **color_commands,
                    **self.MARKER_RENDER_COMMANDS[marker],
                )
            # plot a vertical line
            if "vline" in datadict.keys() and datadict["vline"]:
                linestyle = self._compute_linestyle_properties(plotobj, datadict)
                linewidth = self._compute_linewidth_properties(plotobj, datadict)
                self._plot_vline(ax, datadict, linestyle, linewidth)
            
            # annotate points
            if "annotate" in datadict.keys() and datadict["annotate"]:
                if "notes" in datadict.keys():
                    notes = datadict["notes"]
                else:
                    raise ValueError("Must supply a list of notes to annotate")
                self._plot_annotate(ax, datadict, notes)

        if set_axes:
            self._set_axes(ax, plotobj)
        self._set_ticks(ax, plotobj)
        self._aspect_fnc(ax, plotobj, None)
        if plotobj.legend:
            ax.legend()
        return

    def plot_histogram(self, plotobj, ax, set_axes=True):
        """
        Render a representation of a HistogramPlot object.

        Parameters
        ----------
        plotobj : :any:`adari_core.plots.histogram.HistogramPlot`
            The HistogramPlot object to be rendered.
        ax : matplotlib.axes.Axis
            The matplotlib.axes.Axis to render the Plot with/onto.
        """
        self.check_plot_obj_class(plotobj, HistogramPlot)
        if ax is None:
            raise ValueError("Must supply a Matplotlib axes object")
        # Is there a preferred precedence (v_min/v_max vs scaling algorithm?)
        # Choosing v_min,v_max over automatic algorithms
        rmin = plotobj.v_min
        rmax = plotobj.v_max
        if rmin is None or rmax is None:
            rmin, rmax = plotobj.get_vlim()

        bins_set = plotobj.bins if plotobj.bins is not None else HISTOGRAM_DEFAULT_BINS

        for label, datadict in plotobj.data.items():
            ax.hist(
                datadict["data"],
                histtype="step",
                label=label,
                color=datadict["color"],
                bins=bins_set,
                range=(rmin, rmax),
                log=plotobj.is_logscale(),
            )
        if set_axes:
            self._set_axes(ax, plotobj)
        if plotobj.legend:
            ax.legend()
        self._aspect_fnc(ax, plotobj, None)
        return

    def plot_collapse(self, plotobj, ax, set_axes=True):
        """
        Render a representation of a CollapsePlot object.

        Parameters
        ----------
        plotobj : :any:`adari_core.plots.collapse.CollapsePlot`
            The CollapsePlot object to be rendered.
        ax : matplotlib.axes.Axis
            The matplotlib.axes.Axis to render the Plot with/onto.
        """
        self.check_plot_obj_class(plotobj, CollapsePlot)
        if ax is None:
            raise ValueError("Must supply a Matplotlib axes object")
        for label, datadict in plotobj.data.items():
            ax.step(
                range(len(datadict["data"])),
                datadict["data"],
                label=label,
                color=datadict["color"],
                where="post",
            )
        if set_axes:
            self._set_axes(ax, plotobj)
        self._aspect_fnc(ax, plotobj, None)

        # Have a legend as shown in sample images
        if plotobj.legend:
            ax.legend()

        return

    def plot_combined(self, plotobj, ax, **kwargs):
        """
        Render a representation of a CombinedPlot object.

        Parameters
        ----------
        plotobj : :any:`adari_core.plots.combined.CombinedPlot`
            The CombinedPlot object to be rendered.
        ax : matplotlib.axes.Axis
            The matplotlib.axes.Axis to render the Plot with/onto.
        """
        self.check_plot_obj_class(plotobj, CombinedPlot)

        if ax is None:
            raise ValueError("Must supply a Matplotlib axes object")
        sorted_plot_objs = sorted(plotobj.data.items(), key=lambda x: x[1])

        for plot, zorder in sorted_plot_objs:
            self._plot_render_functions[plot.__str__()](
                plot, ax, set_axes=False, **kwargs
            )
        self._set_axes(ax, plotobj)
        self._aspect_fnc(ax, plotobj, None)

        return

    def plot_cut(self, plotobj, ax, set_axes=True):
        """
        Render a representation of a CutPlot object.

        Parameters
        ----------
        plotobj : :any:`adari_core.plots.cut.CutPlot`
            The CutPlot object to be rendered.
        ax : matplotlib.axes.Axis
            The matplotlib.axes.Axis to render the Plot with/onto.
        """
        self.check_plot_obj_class(plotobj, CutPlot)
        if ax is None:
            raise ValueError("Must supply a Matplotlib axes object")
        for label, datadict in plotobj.data.items():
            try:
                cut_values = datadict["data"].get_axis_cut(
                    plotobj._cut_ax,
                    round(
                        datadict["data"].get_ind_coord(
                            datadict["cut_pos"], plotobj._cut_ax
                        )
                    ),
                )
            except IndexError as e:
                raise ValueError(
                    "Unable to cut image {} at pos {} ({})".format(
                        plotobj, datadict["cut_pos"], str(e)
                    )
                )
            origin = (
                datadict["data"].x_origin
                if (plotobj.cut_ax == "y")
                else datadict["data"].y_origin
            )
            step = (
                datadict["data"].x_step
                if (plotobj.cut_ax == "y")
                else datadict["data"].y_step
            )
            ax.step(
                np.arange(start=origin, stop=origin + len(cut_values), step=step),
                cut_values,
                label=label,
                color=datadict["color"] if datadict["color"] else "b",
            )
        if set_axes:
            self._set_axes(ax, plotobj)
        # Have a legend as shown in sample images
        if plotobj.legend:
            ax.legend()
        self._aspect_fnc(ax, plotobj, None)

    def plot_text(self, plotobj, ax):
        """
        Render a representation of a TextPlot object.

        Parameters
        ----------
        plotobj : :any:`adari_core.plots.text.TextPlot`
            The TextPlot object to be rendered.
        ax : matplotlib.axes.Axis
            The matplotlib.axes.Axis to render the Plot with/onto.
        """
        self.check_plot_obj_class(plotobj, TextPlot)
        if ax is None:
            raise ValueError("Must supply a Matplotlib axis object")
        column_div = float(len(plotobj.data)) / float(plotobj.columns)

        for i, s in enumerate(plotobj.data):
            col_this = int(float(i) / column_div)
            if hasattr(plotobj, "_xref") and plotobj._xref is not None:
                text_x = plotobj._xref + col_this * (1.0 / float(plotobj.columns))
            else:
                text_x = col_this * (1.0 / float(plotobj.columns))
            if hasattr(plotobj, "_yref") and plotobj._yref is not None:
                text_y = plotobj._yref - plotobj.v_space * int(i % (column_div))
            else:
                text_y = 1.0 - plotobj.v_space * int(i % (column_div))

            ax.text(
                text_x,
                text_y,
                s,
                fontsize=plotobj._fontsize,
                horizontalalignment=plotobj._halign,
                verticalalignment=plotobj._valign,
            )

        # Hide the x and y axis
        ax.set_frame_on(False)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        return ax
