# SPDX-License-Identifier: BSD-3-Clause
"""
This module specifies an assortment of 'point'-like plots.
"""

from .plot import Plot
from .axes import Plot2DMixin, XOriginMixin, YOriginMixin
from adari_core.utils.clipping import ClippingMixin
from adari_core.utils.utils import assert_none_or_numeric
import numpy as np

COLORS = "rgbkycm"

MARKERS = [
    "circle",  # Filled circle (DEFAULT)
    "ring",  # Hollow circle
    "point",  # Small point
    "plus",  # Line cross, vert/horz lines
    "x",  # Line cross, diagonal lines
    "square",  # Hollow square
    "block",  # Filled square
    "none",  # No marker
]

DEFAULT_MARKER = MARKERS[0]
DEFAULT_MARKER_SIZE = 8

# These styles are the minimum common denominator between matplotlib and bokeh
LINESTYLES = [
    "solid",  # A continuous line
    "dotted",  # A dotted line
    "dashed",  # A dashed line
    "dashdot",  # A sequence of segments and dots
]

DEFAULT_LINESTYLE = LINESTYLES[0]

DEFAULT_LINEWIDTH = 1.5

DEFAULT_FONTSIZE = 10


class PointsPlotMixin(ClippingMixin, object):
    """
    This mixin provides some of the common functionality for the 'points' class
    of Plot objects.
    """

    def __init__(self, *args, **kwargs):
        if len(args) > 0:
            raise ValueError(
                "{} does not support instantiating with data; "
                "add data using add_data "
                "later".format(self.__str__())
            )
        if kwargs.get("data") is not None:
            raise ValueError(
                "{} does not support instantiating with data; "
                "add data using add_data "
                "later".format(self.__str__())
            )
        super().__init__(*args, **kwargs)
        self._data = {}

    def _is_data_valid(self, data, fail_on_none=True, error_on_fail=True):
        s = super()._is_data_valid(
            data, fail_on_none=fail_on_none, error_on_fail=error_on_fail
        )
        if s is not None:
            return s

        try:
            # If any of the following fails, data is obviously bad
            # Data should be:
            # a tuple of two lists, the first representing
            # the x positions, the second representing the y positions
            # OR a single list, at which point we'll add in a list of
            # x-positions
            tmp_data = np.asarray(data)
            # This may change with multi-data support
            assert 1 <= tmp_data.ndim <= 2, "Points data must be " "1- or 2-dimensional"
            if tmp_data.ndim == 2:
                assert tmp_data.shape[0] == 2, (
                    "Points data must be length "
                    "2, first iterable "
                    "representing x "
                    "points, the other "
                    "representing y points"
                )

            assert np.issubdtype(
                tmp_data.dtype, np.number
            ), "Data array must have a numeric data type"
            return True
        except AssertionError as e:
            # If we get here, we didn't return True, so return False
            if error_on_fail:
                raise ValueError(str(e))
            return False

    def add_data(self, d, *args, label=None, color=None, **kwargs):
        """
        Add data to this Plot.

        Parameters
        ----------
        d : 1D or 2D iterable of numeric values
            The data to be added. This may take one of two formats:

            - A 1D iterable (e.g., list) of points to be plotted. These points
              will be taken to be y-values, and plotted using their position
              in the iterable (e.g., list index) as the x-value; or,
            - A 2D iterable (e.g., list of lists, or numpy array) of points to
              be plotted. In this case, `d[0]` should be the list of x-points,
              and `d[1]` the list of corresponding y-points.

            Note that passing a list of (x,y) tuples or similar is *not*
            supported at present.
        label : str
            The label to give this particular data set. This must be unique
            for this Plot object.
        color : str
            A valid color specification to plot these data with. Defaults to
            None, at which point a random color is assigned.

        Raises
        ------
        ValueError
            If the data is invalid, or if no label is given.
        """

        self._is_data_valid(d, fail_on_none=True, error_on_fail=True)
        if label is None:
            raise ValueError("Please provide a label for these data")
        if not isinstance(label, str):
            raise ValueError("Data label must be a string")
        if label in self._data.keys():
            raise ValueError(
                "This {} already contains data "
                "labelled '{}'".format(self.__str__(), label)
            )
        if color is None:
            color = COLORS[len(self._data) % len(COLORS)]  # Assign the
            # next avail. colour

        # 1D case - add an array of x points
        if np.asarray(d).ndim == 1:
            d = np.asarray([range(len(d)), d], dtype=np.asarray(d).dtype)

        self._data[label] = {"data": d, "color": color, **kwargs}


class LinePlot(Plot2DMixin, XOriginMixin, YOriginMixin, PointsPlotMixin, Plot):
    """
    This Plot displays data as a line plot.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._linestyle = None
        self.linestyle = kwargs.get("linestyle")
        self._linewidth = None
        self.linewidth = kwargs.get("linewidth")
        self._text = []

    def add_data(
        self, d, *args, label=None, color=None, linestyle=None, linewidth=None, **kwargs
    ):
        """
        Add data to this Plot.

        Parameters
        ----------
        d : 1D or 2D iterable of numeric values
            The data to be added. This may take one of two formats:

            - A 1D iterable (e.g., list) of points to be plotted. These points
              will be taken to be y-values, and plotted using their position
              in the iterable (e.g., list index) as the x-value; or,
            - A 2D iterable (e.g., list of lists, or numpy array) of points to
              be plotted. In this case, `d[0]` should be the list of x-points,
              and `d[1]` the list of corresponding y-points.

            Note that passing a list of (x,y) tuples or similar is *not*
            supported at present.
        label : str
            The label to give this particular data set. This must be unique
            for this Plot object.
        color : str
            A valid color specification to plot these data with. Defaults to
            None, at which point a random color is assigned.
        linestyle : str
            A valid marker contained in :any:`LINESTYLES`. Defaults to None,
            at which point the Plot default value will take precedence; if that
            is not defined either, :any:`DEFAULT_LINESTYLE` will be used.
        linewidth: float
            The width of the line to plot. Defaults to None,
            at which point the Plot default value will take precedence; if that
            is not defined either, :any:`DEFAULT_LINEWIDTH` will be used.
        kwargs : dict
            Additional parameters. If parameter "vline" is present then rather than
            plotting a X vs Y plot vertical lines are drawn. The "d" parameter
            must be a 1D array with the values being the positions in X where the
            vertical lines will be drawn.
        Raises
        ------
        ValueError
            If any of these conditions apply:
               - the data is invalid
               - the line style is invalid
               - the line width is not positive
               - no label is given.
        """
        super().add_data(d, label=label, color=color, **kwargs)

        # Linestyle values
        if linestyle is None or linestyle in LINESTYLES:
            self._data[label]["linestyle"] = linestyle
        else:
            del self._data[label]  # Don't leave half-baked definition
            raise ValueError(
                f"{linestyle} is not a supported linestyle; "
                f"pass one of {','.join(LINESTYLES)}"
            )
        assert_none_or_numeric(linewidth)
        if linewidth is None or linewidth > 0:
            self._data[label]["linewidth"] = linewidth
        else:
            del self._data[label]  # Don't leave half-baked definition
            raise ValueError(f"line width {linewidth} is not a positive value or None")
        if "vline" in kwargs.keys() and kwargs["vline"]:
            if np.asarray(d).ndim != 1:
                raise ValueError("Data for vertical lines must be 1-D")
            self._data[label]["data"] = d

    def add_text(
        self,
        text,
        x,
        y,
        rotation=None,
        horizontalalignment=None,
        verticalalignment=None,
        fontsize=None,
    ):
        """
        text: The text to display

        x:    X position where to display the text. In axes coordinates

        y:    Y position where to display the text. In axes coordinates

        rotation: Rotation of the text in degrees

        horizontalalignment: how the text is aligned with respect to X position
                             Possible values: 'center', 'left', 'right'

        verticalalignment: how the text is aligned with respect to Y position
                           Possible values: 'center', 'top', 'botton'

        Can be overridden on
        a per-dataset basis by specifying a different value during
        :any:`add_data`.
        """
        if horizontalalignment is not None:
            if horizontalalignment not in ["center", "left", "right"]:
                raise ValueError(
                    f"{horizontalalignment} is not a valid horizontalalignment parameter; "
                    "pass one of 'left', 'center', 'right'"
                )
        else:
            horizontalalignment = "center"
        if verticalalignment is not None:
            if verticalalignment not in ["center", "top", "bottom"]:
                raise ValueError(
                    f"{verticalalignment} is not a valid verticalalignment parameter; "
                    "pass one of 'top', 'center', 'bottom'"
                )
        else:
            verticalalignment = "center"
        if rotation is None:
            rotation = 0
        if fontsize is not None:
            if fontsize < 0:
                raise ValueError(
                    f"{fontsize} is not a valid font size, must be positive"
                )
        else:
            fontsize = DEFAULT_FONTSIZE

        self._text.append(
            {
                "text": text,
                "x": x,
                "y": y,
                "rotation": rotation,
                "horizontalalignment": horizontalalignment,
                "verticalalignment": verticalalignment,
                "fontsize": fontsize,
            }
        )

    @property
    def linestyle(self):
        """
        str: The style of the line to plot

        Can be overridden on
        a per-dataset basis by specifying a different value during
        :any:`add_data`.
        """
        return self._linestyle

    @linestyle.setter
    def linestyle(self, m):
        if m is None or m in LINESTYLES:
            self._linestyle = m
        else:
            raise ValueError(
                f"{m} is not a supported linestyle; "
                f"pass one of {','.join(LINESTYLES)}, or None to "
                f"use Plot or system default"
            )

    @property
    def linewidth(self):
        """
        float: The width of the line to plot

        Can be overridden on
        a per-dataset basis by specifying a different value during
        :any:`add_data`.
        """
        return self._linewidth

    @linewidth.setter
    def linewidth(self, m):
        assert_none_or_numeric(m)
        if m is None or m > 0:
            self._linewidth = m
        else:
            raise ValueError(f"line width {m} is not a positive value or None")

    def __str__(self):
        return "LinePlot"


class ScatterPlot(Plot2DMixin, XOriginMixin, YOriginMixin, PointsPlotMixin, Plot):
    """
    This Plot displays data as a scatter plot.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._markersize = None
        self._marker = None
        self.markersize = kwargs.get("markersize")
        self.marker = kwargs.get("marker")
        self._linestyle = None
        self.linestyle = kwargs.get("linestyle")
        self._linewidth = None
        self.linewidth = kwargs.get("linewidth")

    def add_data(
        self, d, *args, label=None, color=None, marker=None, markersize=None, **kwargs
    ):
        """
        Add data to this Plot.

        Parameters
        ----------
        d : 1D or 2D iterable of numeric values
            The data to be added. This may take one of two formats:

            - A 1D iterable (e.g., list) of points to be plotted. These points
              will be taken to be y-values, and plotted using their position
              in the iterable (e.g., list index) as the x-value; or,
            - A 2D iterable (e.g., list of lists, or numpy array) of points to
              be plotted. In this case, `d[0]` should be the list of x-points,
              and `d[1]` the list of corresponding y-points.

            Note that passing a list of (x,y) tuples or similar is *not*
            supported at present.
        label : str
            The label to give this particular data set. This must be unique
            for this Plot object.
        color : str
            A valid color specification to plot these data with. Defaults to
            None, at which point a random color is assigned.
        marker : str
            A valid marker contained in :any:`MARKERS`. Defaults to None,
            at which point the Plot default value will take precedence; if that
            is not defined either, :any:`DEFAULT_MARKER` will be used.
        markersize : numeric
            A marker size for the plot. Defaults to None,
            at which point the Plot default value will take precedence; if that
            is not defined either, :any:`DEFAULT_MARKER_SIZE` will be used.

        Raises
        ------
        ValueError
            If the data is invalid, or if no label is given.
        """
        super().add_data(d, label=label, color=color, **kwargs)

        # Marker values
        if marker is None or marker in MARKERS:
            self._data[label]["marker"] = marker
        else:
            del self._data[label]  # Don't leave half-baked definition
            raise ValueError(
                f"{marker} is not a supported marker; "
                f"pass one of {','.join(MARKERS)}"
            )

        # Marker size
        if markersize is not None:
            try:
                markersize = float(markersize)
                assert markersize > 0.0, "Can't set a negative markersize"
                self._data[label]["markersize"] = float(markersize)
            except TypeError:
                del self._data[label]  # Don't leave half-baked definition
                raise ValueError(
                    f"Unable to convert {markersize} to a " f"numeric value"
                )
            except AssertionError as e:
                del self._data[label]  # Don't leave half-baked definition
                raise ValueError(str(e))
        else:
            self._data[label]["markersize"] = None  # markersize == None

    @property
    def markersize(self):
        """
        int: The markersize for the scatter plot points.

        Can be overridden on
        a per-dataset basis by specifying a different value during
        :any:`add_data`.
        """
        return self._markersize

    @markersize.setter
    def markersize(self, msize):
        # Input checking
        try:
            assert_none_or_numeric(msize)
        except AssertionError:
            raise ValueError(f"Unable to convert {msize} to a numeric value")
        if msize is None or msize > 0.0:
            self._markersize = msize
        else:
            raise ValueError(f"Markersize ({msize}) must be > 0")

    @property
    def marker(self):
        """
        str: The marker for the scatter plot points.

        Can be overridden on
        a per-dataset basis by specifying a different value during
        :any:`add_data`.
        """
        return self._marker

    @marker.setter
    def marker(self, m):
        if m is None or m in MARKERS:
            self._marker = m
        else:
            raise ValueError(
                f"{m} is not a supported marker; "
                f"pass one of {','.join(MARKERS)}, or None to "
                f"use Plot or system default"
            )

    def __str__(self):
        return "ScatterPlot"
