# SPDX-License-Identifier: BSD-3-Clause
"""
ADARI Core mixins for handling plot axes
"""

from adari_core.utils.utils import is_iterable, value_none_or_numeric
import numpy as np

SUPPORTED_AXIS_SCALES = [
    "linear",
    "log",
]


class XRangeMixin(object):
    """
    Plot object Mixin with x range-like properties.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._x_min = None
        self._x_max = None
        self._x_scale = "linear"
        self._x_label = ""
        self._x_major_ticks = None
        self._x_minor_ticks = None

        self.x_min = kwargs.get("x_min", None)
        self.x_max = kwargs.get("x_max", None)
        self.x_scale = kwargs.get("x_scale", "linear")
        self.x_label = kwargs.get("x_label", "")
        self.x_major_ticks = kwargs.get("x_major_ticks", None)
        self.x_minor_ticks = kwargs.get("x_minor_ticks", None)
        self._vert_lines = {}

    @property
    def x_label(self):
        """str : The x-axis label for this Plot object."""
        return self._x_label

    @x_label.setter
    def x_label(self, label):
        self._x_label = str(label)

    @property
    def x_min(self):
        """numeric : The minimum x-axis value for this Plot."""
        return self._x_min

    @x_min.setter
    def x_min(self, x):
        value_none_or_numeric(x)
        self._x_min = x

    @property
    def x_max(self):
        """numeric : The maximum x-axis value for this Plot."""
        return self._x_max

    @x_max.setter
    def x_max(self, x):
        value_none_or_numeric(x)
        self._x_max = x

    @property
    def x_scale(self):
        """str : The scale type of the x-axis. Allowed values: {}.""".format(
            ",".join(SUPPORTED_AXIS_SCALES)
        )
        return self._x_scale

    @x_scale.setter
    def x_scale(self, s):
        try:
            if s.lower() not in SUPPORTED_AXIS_SCALES:
                raise ValueError("{} is not a supported axis scale".format(s))
            self._x_scale = s.lower()
        except AttributeError:
            raise ValueError(
                "Please provide a valid axis scale: {}".format(
                    ", ".join(_ for _ in SUPPORTED_AXIS_SCALES)
                )
            )

    @property
    def x_major_ticks(self):
        """
        Iterable of numeric values, number, boolean or None.
        If it is an iterable: An array with the positions of the X major tick marks.
        If it is a number: Major ticks will be placed at regularly spaced with the number.
                           If number is zero then X major ticks are spaced automatically.
                           If negative, X major ticks are disabled
        If None, the default, then X major ticks are spaced automatically
        """
        return self._x_major_ticks

    @x_major_ticks.setter
    def x_major_ticks(self, v):
        if v is not None:
            if is_iterable(v):
                for item in v:
                    value_none_or_numeric(item)
            else:
                value_none_or_numeric(v)
        self._x_major_ticks = v

    @property
    def x_minor_ticks(self):
        """
        Iterable of numeric values, number, boolean or None.
        If it is an iterable: An array with the positions of the X minor tick marks.
        If it is a number: Minor ticks will be placed at regularly spaced with the number.
                           If number is zero then minor ticks are spaced automatically.
                           If negative, minor ticks are disabled
        If None, the default, no minor ticks are used
        """
        return self._x_minor_ticks

    @x_minor_ticks.setter
    def x_minor_ticks(self, v):
        if v is not None:
            if is_iterable(v):
                for item in v:
                    value_none_or_numeric(item)
            else:
                value_none_or_numeric(v)
        self._x_minor_ticks = v

    @property
    def vert_lines(self):
        return self._vert_lines

    @vert_lines.setter
    def vert_lines(self, v):
        raise ValueError(
            "Please use the add_vert_line function to add "
            "a vertical line to the Plot."
        )

    def set_xlim(self, xmin=None, xmax=None):
        """Define the x-limits of this Plot.

        Note that these values represent the displayed axis bounds only; they
        do not alter the data bounds/values in any way.

        Parameters
        ----------
        xmin, xmax : numeric values (int, float, etc.)
            The minimum and maximum bounds of the x-range of this
            Plot. Either can be set to None. A single positional argument may
            be supplied instead, of the form `(xmin, xmax)`.
        """
        if xmax is None and np.iterable(xmin):
            xmin, xmax = xmin
        value_none_or_numeric(xmin)
        value_none_or_numeric(xmax)
        self.x_min = xmin
        self.x_max = xmax

    def get_xlim(self):
        """Get the x-limits of this Plot.

        Returns
        -------
        x_min, x_max : numeric
            The minimum and maximum x values of this x-range.
        """
        return self.x_min, self.x_max

    def add_vert_line(self, xpos, label=None):
        """
        Add a vertical line to this Plot.

        Parameters
        ----------
        xpos : numeric
            The position to add the line in, in *data* coordinates.
        label : str, optional
            Line label. Defaults to None.
        """
        # Input checking
        try:
            xpos = float(xpos)
        except TypeError:
            raise ValueError("Unable to convert {} to a coordinate".format(xpos))

        # Do not allow overwrites
        try:
            existing_label = self._vert_lines[xpos]
            raise ValueError(
                "There is already a vertical line assigned to "
                "{} (labelled {})".format(xpos, existing_label)
            )
        except KeyError:
            self._vert_lines[xpos] = label


class YRangeMixin(object):
    """
    Plot object with y range-like properties.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._y_min = None
        self._y_max = None
        self._y_scale = "linear"
        self._y_label = ""
        self._y_major_ticks = None
        self._y_minor_ticks = None

        self.y_min = kwargs.get("y_min", None)
        self.y_max = kwargs.get("y_max", None)
        self.y_scale = kwargs.get("y_scale", "linear")
        self.y_label = kwargs.get("y_label", "")
        self.y_major_ticks = kwargs.get("y_major_ticks", None)
        self.y_minor_ticks = kwargs.get("y_minor_ticks", None)
        self._horiz_lines = {}

    @property
    def y_label(self):
        """str : Y-axis label"""
        return self._y_label

    @y_label.setter
    def y_label(self, label):
        self._y_label = str(label)

    @property
    def y_min(self):
        """numeric : The minimum y-axis value for this Plot."""
        return self._y_min

    @y_min.setter
    def y_min(self, y):
        value_none_or_numeric(y)
        self._y_min = y

    @property
    def y_max(self):
        """numeric : The maximum y-axis value for this Plot."""
        return self._y_max

    @y_max.setter
    def y_max(self, y):
        value_none_or_numeric(y)
        self._y_max = y

    @property
    def y_scale(self):
        """str : The scale type of the x-axis. Allowed values: {}.""".format(
            ",".join(SUPPORTED_AXIS_SCALES)
        )
        return self._y_scale

    @y_scale.setter
    def y_scale(self, s):
        try:
            if s.lower() not in SUPPORTED_AXIS_SCALES:
                raise ValueError("{} is not a supported axis scale".format(s))
            self._y_scale = s.lower()
        except AttributeError:
            raise ValueError(
                "Please provide a valid axis scale: {}".format(
                    ", ".join(_ for _ in SUPPORTED_AXIS_SCALES)
                )
            )

    @property
    def y_major_ticks(self):
        """
        Iterable of numeric values, number, boolean or None.
        If it is an iterable: An array with the positions of the Y major tick marks.
        If it is a number: Major ticks will be placed at regularly spaced with the number.
                           If number is zero then Y major ticks are spaced automatically.
                           If negative, Y major ticks are disabled
        If None, the default, then Y major ticks are spaced automatically
        """
        return self._y_major_ticks

    @y_major_ticks.setter
    def y_major_ticks(self, v):
        if v is not None:
            if is_iterable(v):
                for item in v:
                    value_none_or_numeric(item)
            else:
                value_none_or_numeric(v)
        self._y_major_ticks = v

    @property
    def y_minor_ticks(self):
        """
        Iterable of numeric values, number, boolean or None.
        If it is an iterable: An array with the positions of the Y minor tick marks.
        If it is a number: Y minor ticks will be placed at regularly spaced with the number.
                           If number is zero then Y minor ticks are spaced automatically.
                           If negative, Y minor ticks are disabled
        If None, the default, no minor ticks are used
        """
        return self._y_minor_ticks

    @y_minor_ticks.setter
    def y_minor_ticks(self, v):
        if v is not None:
            if is_iterable(v):
                for item in v:
                    value_none_or_numeric(item)
            else:
                value_none_or_numeric(v)
        self._y_minor_ticks = v

    @property
    def horiz_lines(self):
        return self._horiz_lines

    @horiz_lines.setter
    def horiz_line(self, v):
        raise ValueError(
            "Please use the add_vert_line function to add "
            "a horizontal line to the Plot."
        )

    def set_ylim(self, ymin=None, ymax=None):
        """Define the y limits of this Plot.

        Note that these values represent the displayed axis bounds only; they
        do not alter the data bounds/values in any way.

        Parameters
        ----------
        ymin, ymax : numeric values (int, float, etc.)
            The minimum and maximum bounds of the y-range of this
            Plot. Either can be set to None. A single positional argument may
            be supplied instead, of the form `(ymin, ymax)`.
        """
        if ymax is None and np.iterable(ymin):
            ymin, ymax = ymin
        value_none_or_numeric(ymin)
        value_none_or_numeric(ymax)
        self.y_min = ymin
        self.y_max = ymax

    def get_ylim(self):
        """Get the y limits of this plot.

        Returns
        -------
        y_min, y_max : numeric
            The minimum and maximum y values of this y range
        """
        return self.y_min, self.y_max

    def add_horiz_line(self, ypos, label=None):
        """
        Add a horizontal line to this Plot.
        Parameters
        ----------
        ypos : numeric
            The position to add the line in, in *data* coordinates.
        label : str, optional
            Line label. Defaults to None.
        """
        # Input checking
        try:
            ypos = float(ypos)
        except TypeError:
            raise ValueError("Unable to convert {} to a coordinate".format(ypos))

        # Do not allow overwrites
        try:
            existing_label = self._horiz_lines[ypos]
            raise ValueError(
                "There is already a horizontal line assigned to "
                "{} (labelled {})".format(ypos, existing_label)
            )
        except KeyError:
            self._horiz_lines[ypos] = label


class Plot2DMixin(XRangeMixin, YRangeMixin):
    """
    Super-mixin for defining a two-axis plot.

    Note this mixin does *not* include the :py:class:`XOriginMixin` or
    :py:class:`YOriginMixin`. These need to be added manually depending on
    the Plot type.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._aspect = kwargs.get("aspect", None)
        self._rects = {}
        self._tick_visibility = None
        self.tick_visibility = kwargs.get("tick_visibility", None)

    @property
    def aspect(self):
        """
        numeric : The aspect ratio of this plot (0 < aspect < :any:`numpy.inf`).
        """
        return self._aspect

    @aspect.setter
    def aspect(self, a):
        if a is None:
            self._aspect = None
            return

        try:
            if isinstance(a, str):
                assert a == "auto" or a == "equal", (
                    "Aspect ratio as a " "string must be either " "auto or equal"
                )
            else:
                assert 0 < a < np.inf, (
                    "Aspect ratio must be in range " "(0, inf) or equal or auto"
                )
        except AssertionError as e:
            raise ValueError(str(e))
        except TypeError:
            raise ValueError(f"Unable to interpret {a} as a numerical aspect ratio")

        # No need to try/except, block above did all necessary checking
        if isinstance(a, str):
            self._aspect = a
        else:
            self._aspect = float(a)

    @property
    def rects(self):
        return self._rects

    @rects.setter
    def rects(self, v):
        raise ValueError(
            "Please use the add_rect function to add a rectangle to the Plot."
        )

    @property
    def tick_visibility(self):
        """dictionary : Whether to draw the respective ticks and their labels."""
        return self._tick_visibility

    @tick_visibility.setter
    def tick_visibility(self, tv):
        self._tick_visibility = tv

    # TODO Update to allow color & line style to be passed
    def add_rect(self, anchor: (float, float), width: float, height: float, label=None):
        """
        Add a rectangle to this Plot object.

        Parameters
        ----------
        anchor : (float, float)
            The anchor point (x, y) of the rectangle, in data coordinates.
        width : float
            The width of the rectangle, in data coordinates.
        height : float
            The height of the rectangle, in data coordinates.
        label : str, optional
            The label for this rectangle (defaults to None).
        """
        # Input checking
        try:
            anchor = (float(anchor[0]), float(anchor[1]))
        except TypeError:
            raise ValueError(
                "Unable to interpret {} as an (x, y) position".format(anchor)
            )

        try:
            width = float(width)
            height = float(height)
        except TypeError:
            raise ValueError(
                "Unable to parse {} and {} as width " "& height".format(width, height)
            )

        # Form a box key from the inputs
        box_key = (anchor, width, height)

        # Do not allow overwrite
        try:
            existing_label = self.rects[box_key]
            raise ValueError(
                "There is already a box defined by {} "
                "named {}".format(box_key, existing_label)
            )
        except KeyError:
            self._rects[box_key] = label


class XOriginMixin(object):
    """
    This mixin is designed for Plots which take a grid of values to plot
    without explicit reference to the x coordinates of those values (e.g.,
    an image array).
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._x_origin = 1
        self._x_step = 1
        self.x_origin = kwargs.get("x_origin", 0)
        self.x_step = kwargs.get("x_step", 1)

    @property
    def x_origin(self):
        """numeric : origin point of the x-axis.

        For Plot types where data does not exist of discrete points (e.g., an
        ImagePlot), this origin value represents the centre of the data
        point (e.g., the centre of the first pixel in an ImagePlot)."""
        return self._x_origin

    @x_origin.setter
    def x_origin(self, x):
        value_none_or_numeric(x)
        self._x_origin = x

    @property
    def x_step(self):
        """numeric : Size of the data steps in the x-direction."""
        return self._x_step

    @x_step.setter
    def x_step(self, d):
        value_none_or_numeric(d)
        try:
            assert d > 0.0, "Step must be greater than 0"
        except AssertionError as e:
            raise ValueError(str(e))
        except TypeError as e:
            if d is None:
                pass
            else:
                raise e
        self._x_step = d


class YOriginMixin(object):
    """
    This mixin is designed for Plots which take a grid of values to plot
    without explicit reference to the y coordinates of those values (e.g.,
    an image array).
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._y_origin = 1
        self._y_step = 1
        self.y_origin = kwargs.get("y_origin", 0)
        self.y_step = kwargs.get("y_step", 1)

    @property
    def y_origin(self):
        """numeric : origin point of the y-axis.

        For Plot types where data does not exist of discrete points (e.g., an
        ImagePlot), this origin value represents the centre of the data
        point (e.g., the centre of the first pixel in an ImagePlot)."""
        return self._y_origin

    @y_origin.setter
    def y_origin(self, y):
        value_none_or_numeric(y)
        self._y_origin = y

    @property
    def y_step(self):
        """numeric : Size of the data steps in the y-direction."""
        return self._y_step

    @y_step.setter
    def y_step(self, d):
        value_none_or_numeric(d)
        try:
            assert d > 0.0, "Step must be greater than 0"
        except AssertionError as e:
            raise ValueError(str(e))
        except TypeError as e:
            if d is None:
                pass
            else:
                raise e
        self._y_step = d
