# SPDX-License-Identifier: BSD-3-Clause
"""
This module defines the ADARI Core Panel object.
"""

import numpy as np

from .plot import Plot
from adari_core.utils.utils import assert_numeric


def verify_plot_position(p, pos):
    """
    Get assurances that position pos does exist in the given panel p

    Parameters
    ----------
    p : :obj:`adari_core.plot.panel.Panel`
        The Panel to be tested.
    pos : 2-tuple
        The (x, y) position to be verified.

    Raises
    ------
    ValueError
        If the position `pos` is a valid position in Panel `p`.
    """

    try:
        assert isinstance(p, Panel), "Please provide a Panel object"
        assert isinstance(pos, tuple) or isinstance(pos, list), (
            "Please provide a two-tuple " "integer position"
        )
        assert 0 <= int(pos[0]) < p.x, "x-position {} invalid for this " "Panel".format(
            pos[0]
        )
        assert 0 <= int(pos[1]) < p.y, "y-position {} invalid for this " "Panel".format(
            pos[1]
        )
    except AssertionError as e:
        # Bad value
        raise ValueError(str(e))
    except (TypeError, ValueError):
        # Bad input types/structure
        raise ValueError("Please provide a two-tuple integer position")


def verify_plot_position_iter(p, pos_list):
    """
    Get assurances that a list of positions pos_list exist in Panel p

    Parameters
    ----------
    p : :obj:`adari_core.plot.panel.Panel`
        The Panel to be tested.
    pos_list : iterable of 2-tuples
        The iterable of (x, y) positions, each of which is verified in order.

    Raises
    ------
    ValueError
        If any of the positions in `pos_list` is not a valid position in Panel
        `p`.
    """

    try:
        assert isinstance(p, Panel), "Please provide a Panel object"
        assert isinstance(pos_list, list), (
            "Please provide a list of "
            "two-tuple integer positions, "
            "not {}".format(pos_list)
        )
    except AssertionError as e:
        raise ValueError(str(e))

    for i in range(len(pos_list)):
        try:
            verify_plot_position(p, pos_list[i])
        except ValueError as e:
            raise ValueError(
                "{} - input position {} of {}".format(str(e), i + 1, len(pos_list))
            )


class Panel(object):
    """The root Panel class.

    A :class:`Panel` is a display of one or more
    :class:`adari_core.plots.plot.Plot` objects.

    Positions on the Panel are 0-indexed, starting at the top-left corner
    of the plotting array, e.g., for a 2x2 grid::

        +-----------+
        | 0,0 | 1,0 |
        +-----+-----+
        | 0,1 | 1,1 |
        +-----------+

    Every assigned plot can also be given an x-extent and y-extent, allowing
    Plots to span multiple grid positions.

    The user may specify the relative width/height of columns/rows using the
    width_ratios and height_ratios arguments. By convention, the 'standard'
    relative width/height is denoted as 1. If the Panel is enlarged after
    creation by altering the x/y attribute, values of 1 will be appended
    to any existing width_ratios/height_ratios.

    Parameters
    ----------
    x : int, optional
        The number of panel positions in the x-direction. Defaults to 1.
    y : int, optional
        The nunber of panel positions in the y-direction. Defaults to 1.
    title : str, optional
        The overarching title to display for this Panel. Defaults to None.
    width_ratios : array-like of length x, optional
        The relative width of the Panel columns. Column ``i`` gets a relative
        width of ``width_ratios[i] / sum(width_ratios)``. Defaults to None,
        which means every column gets the same width (and is the equivalent
        of specifying ``width_ratios= x * [1]``.
    height_ratios : array-like of length y, optional
        The relative height of the Panel row. Row ``i`` gets a relative
        width of ``width_ratios[i] / sum(width_ratios)``. Defaults to None,
        which means every column gets the same width (and is the equivalent
        of specifying ``width_ratios= y * [1]``.
    x_stretch : float, optional
        Defines how much the overall Panel should be 'stretched' in the
        x-direction from the default (which makes square individual plots).
        Defaults to 1 (i.e., no stretch).
    y_stretch : float, optional
        Defines how much the overall Panel should be 'stretched' in the
        y-direction from the default (which makes square individual plots).
        Defaults to 1 (i.e., no stretch).
    right_subplot : float, optional
        Defines the position of the right edge of the subplots, 
        as a fraction of the figure width.
        Default value from figure.subplot.right in rcParams.
    """

    def __init__(
        self,
        x=1,
        y=1,
        title=None,
        width_ratios=None,
        height_ratios=None,
        x_stretch=1,
        y_stretch=1,
        right_subplot=None,
    ):
        self._x = 1
        self._y = 1
        self.title = title
        self._width_ratios = None
        self._height_ratios = None
        self._x_stretch = 1
        self._y_stretch = 1
        self._right_subplot = None
        self._plots = np.empty((1, 1), dtype=object)
        self._extents = np.empty((1, 1), dtype=tuple)
        self._plots_linked_in_x = []
        self._plots_linked_in_y = []

        # Input check
        try:
            assert isinstance(x, int), "x size must be an integer"
            assert isinstance(y, int), "y size must be an integer"
        except AssertionError as e:
            raise ValueError(str(e))

        if x != 1:
            try:
                self.x = x
            except ValueError:
                raise ValueError("x size {} invalid".format(x))
        if y != 1:
            try:
                self.y = y
            except ValueError:
                raise ValueError("y size {} invalid".format(y))

        # Height and width ratios set once x and y are fixed
        self.width_ratios = width_ratios
        self.height_ratios = height_ratios

        self.x_stretch = x_stretch
        self.y_stretch = y_stretch
        
        self.right_subplot = right_subplot

    def __str__(self):
        return "Panel"

    def alter_dimn(self, s):
        """Change the dimensions of the Plots grid for this Panel.

        If the requested dimension change shrinks the Panel, the top-left
        Plots are preserved; any Plots that are not within the new bounds of
        the Panel are discarded. If the requested dimension changes expands the
        Panel, additional empty positions are created in the bottom-right
        of the Panel.

        Parameters:
        -----------
        s : (int, int)
            The shape to change the array to. NOTE: This follows the numpy
            convention of (y, x) for coordinate tuples.
        """

        # Sanity check the inputs
        try:
            if s[0] <= 0 or s[1] <= 0:
                raise ValueError(
                    "A shape of {} is invalid - dimensions " "must be > 0".format(s)
                )
        except (IndexError, TypeError):
            raise ValueError(
                "Input shape must be a 2-tuple or similar, "
                "and input values must be ints"
            )
        x_new = s[1]
        y_new = s[0]

        # We now need to re-compute/reset the following internals:
        # ._plots
        # ._extents
        # ._plots_linked_in_x/_y
        # ._x/_y
        # width & height ratios

        # Re-size _plots and _extents iteratively
        _new_plots = np.empty((y_new, x_new), dtype=object)
        _new_extents = np.empty((y_new, x_new), dtype=object)
        for y in range(min(self._y, y_new)):
            for x in range(min(self._x, x_new)):
                _new_plots[y, x] = self._plots[y, x]
                # y-extent
                if self._extents[y, x] is not None:
                    _new_extents[y, x] = (
                        min(self._extents[y, x][0], y_new - y),
                        min(self._extents[y, x][1], x_new - x),
                    )
        self._plots = _new_plots
        self._extents = _new_extents

        # Remove any plot linkages that no longer exist
        for linkages in (self._plots_linked_in_x, self._plots_linked_in_y):
            for link_set in linkages:
                # Go *backwards* through the list to allow for popping
                for i in range(len(link_set) - 1, -1, -1):
                    if link_set[i][0] > (x_new - 1) or link_set[i][1] > (y_new - 1):
                        _ = link_set.pop(i)

        # Update the width_ratios and height ratios. We do this by:
        # - Trimming the list down if we are reducing Panel size;
        # - Appending 1s to the list if we are expanding the Panel;
        # - No-op if the width/height ratios are already None
        if self.width_ratios is not None and self.x != x_new:
            if x_new < self.x:
                self._width_ratios = self._width_ratios[:x_new]
            else:
                self._width_ratios += (x_new - self.x) * [1]
        if self.height_ratios is not None and self.y != y_new:
            if y_new < self.y:
                self._height_ratios = self._height_ratios[:y_new]
            else:
                self._height_ratios += (y_new - self.y) * [1]

        self._x = x_new
        self._y = y_new

    @property
    def title(self):
        """
        int : The size of the Panel grid in x.
        """
        return self._title

    @title.setter
    def title(self, t):
        self._title = t

    @property
    def x(self):
        """
        int : The size of the Panel grid in x.
        """
        return self._x

    @x.setter
    def x(self, xv):
        return self.alter_dimn((self._y, xv))

    @property
    def y(self):
        """
        int : The size of the Panel grid in y.
        """
        return self._y

    @y.setter
    def y(self, yv):
        return self.alter_dimn((yv, self._x))

    @property
    def width_ratios(self):
        return self._width_ratios

    @width_ratios.setter
    def width_ratios(self, rats):
        # Check input
        if rats is None:
            self._width_ratios = None
            return

        try:
            assert len(rats) == self.x, (
                f"Width ratios (len {len(rats)}) " f"must have length x ({self.x})"
            )
            for r in rats:
                assert_numeric(r)
                assert r > 0, f"Width ratio value ({r}) must be > 0"
        except AssertionError as e:
            raise ValueError(str(e))

        self._width_ratios = rats.copy()

    @property
    def height_ratios(self):
        return self._height_ratios

    @height_ratios.setter
    def height_ratios(self, rats):
        # Check input
        if rats is None:
            self._height_ratios = None
            return

        try:
            assert len(rats) == self.y, (
                f"Height ratios (len {len(rats)}) " f"must have length y ({self.y})"
            )
            for r in rats:
                assert_numeric(r)
                assert r > 0, f"Width ratio value ({r}) must be > 0"
        except AssertionError as e:
            raise ValueError(str(e))

        self._height_ratios = rats.copy()

    @property
    def x_stretch(self):
        return self._x_stretch

    @x_stretch.setter
    def x_stretch(self, d):
        # Check input
        try:
            assert float(d) > 0, "Stretch factor must be > 0"
        except (AssertionError, TypeError) as e:
            raise ValueError(str(e))

        self._x_stretch = float(d)

    @property
    def y_stretch(self):
        return self._y_stretch

    @y_stretch.setter
    def y_stretch(self, d):
        # Check input
        try:
            assert float(d) > 0, "Stretch factor must be > 0"
        except (AssertionError, TypeError) as e:
            raise ValueError(str(e))

        self._y_stretch = float(d)

    @property
    def right_subplot(self):
        return self._right_subplot

    @right_subplot.setter
    def right_subplot(self, d):
        if d is None:
            self._right_subplot = None
            return
        # Check input
        try:
            assert float(d) > 0, "Right value for subplot layout must be > 0"
        except (AssertionError, TypeError) as e:
            raise ValueError(str(e))

        self._right_subplot = float(d)

    # TODO Need setter and getter methods for the plot array

    def assign_plot(self, plt, xpos, ypos, xext=1, yext=1):
        """Assign Plot object to Panel at given position.

        Parameters
        ----------
        plt : Child object of :class:`adari_core.plots.plot.Plot`
            The Plot to be assigned to the Panel.
        xpos : int
            The x-position to place the Plot at.
        ypox : int
            The y-position to place the Plot at.
        xext, yext : optional, int
            The extent of the Plot in x and y, respectively. Defaults to 1
            (i.e., Plot is limited to only the assigned position).

        Raises
        ------
        ValueError
            If ``plot`` is not an instance of
            :class:`adari_core.plots.plot.Plot`, or if the position passed
            is not valid for this Panel instance.
        """
        # Input checking
        if not isinstance(plt, Plot):
            raise ValueError("plt must be an instance of Plot (or a child of)")
        if xpos < 0 or xpos >= self._x:
            raise ValueError("xpos must be between 0 and {}".format(self._x - 1))
        if ypos < 0 or ypos >= self._y:
            raise ValueError("ypos must be between 0 and {}".format(self._y - 1))
        if xext < 1 or (xpos + xext) > self._x + 1:
            raise ValueError(
                "xext values takes this Plot beyond the bounds " "of the Panel"
            )
        if yext < 1 or (ypos + yext) > self._y + 1:
            raise ValueError(
                "xext values takes this Plot beyond the bounds " "of the Panel"
            )

        self._plots[ypos, xpos] = plt
        self._extents[ypos, xpos] = (yext, xext)
        return

    def retrieve(self, x, y):
        """
        Retrieve a reference to a Plot on the Panel without removing it.

        Parameters
        ----------
        x, y : int
            The indices of the Plot position to return

        Returns
        -------
        child of Plot : The assigned Plot, or None if no plot was assigned
        """
        try:
            x = int(x)
            y = int(y)
        except TypeError as e:
            raise ValueError(str(e))

        try:
            pl = self._plots[y, x]
        except IndexError:
            raise ValueError(f"{x}, {y} is not a valid position on this Panel")

        return pl

    def pop(self, x, y):
        """
        'Pop' (i.e., remove & return) the Plot at the given position

        Parameters
        ----------
        x, y : int
            The indices of the Plot position to return

        Returns
        -------
        child of Plot : The assigned Plot, or None if no plot was assigned
        """
        pl = self.retrieve(x, y)
        # Remove the plot (and any linkages) from the Panel
        self._plots[y, x] = None
        for linkage in self._plots_linked_in_x:
            linkage.remove((x, y))
        for linkage in self._plots_linked_in_y:
            linkage.remove((x, y))
        return pl

    def get_plot_locations(self, plot):
        """
        Return a tuple of the locations of a particular Plot object in the
        panel.

        The returned locations are only the directly assigned locations, and
        do not take into account x- and y-extents.

        Parameters
        ----------
        plot : Child object of :class:`adari_core.plots.plot.Plot`
            Plot object to locate within the Panel

        Returns
        -------
        (xpos, ypos) : (int, int)
            The position of the passed ``plot`` on this Panel.
            Returns an empty list
            if the Plot is not assigned to the Panel.
        """

        if not isinstance(plot, Plot):
            raise ValueError(
                "You must pass a subclass of Plot to " "get_plot_locations"
            )
        posns = list(zip(*np.nonzero(self._plots == plot)))
        for i, posn in enumerate(posns):
            posns[i] = posn[::-1]
        return posns

    def get_assigned_locations(self):
        """
        Return all locations with a Plot object assigned.

        Note this returns only the locations that Plots are assigned to,
        and does not consider grid positions that a plot takes up thanks
        to Plot's x- and y-extent values.

        Returns
        -------
        positions : List of tuples
            Note this array follows the (y, x) numpy coordinate convention.
        """
        assigned_pos = list(
            np.transpose(
                np.nonzero(
                    ~(
                        np.isin(
                            self._plots,
                            [
                                None,
                            ],
                        )
                    )
                )
            )
        )
        for i, pos in enumerate(assigned_pos):
            assigned_pos[i] = tuple(pos)
        return assigned_pos

    def get_taken_locations(self):
        """
        Return an array of Booleans denoting which plot positions are currently
        assigned (including those assigned to a multi-cell plot).

        Returns
        -------
        positions: numpy arary of Booleans
            The positions that are currently assigned to a plot (True),
            and those that are not (False).
        """
        positions = np.zeros_like(self._plots, dtype=bool)
        assigned_pos = self.get_assigned_locations()
        for pos in assigned_pos:
            positions[
                pos[0] : pos[0] + self._extents[pos[0], pos[1]][0],
                pos[1] : pos[1] + self._extents[pos[0], pos[1]][1],
            ] = True

        return positions

    def link_plots_in_x(self, *coords):
        """
        Cause the x-axis of the Plots in the given positions to be linked.

        Parameters
        ----------
        coords : list of tuples
            A list of the (x,y) Panel coordinates for which the assigned Plot
            objects should have their x-axes linked.
        """
        # args is a list of plots,
        coords = list(coords)
        verify_plot_position_iter(self, coords)
        self._plots_linked_in_x.append(coords)

    def link_plots_in_y(self, *coords):
        """
        Cause the y-axis of the Plots in the given positions to be linked.

        Parameters
        ----------
        coords : list of tuples
            A list of the Panel coordinates for which the assigned Plot
            objects should have their y-axes linked.
        """
        coords = list(coords)
        verify_plot_position_iter(self, coords)
        self._plots_linked_in_y.append(coords)

    def link_plots_in_xy(self, *coords):
        """
        Link both the x- and y-axes of the Plots in the given positions.

        Parameters
        ----------
        coords : list of tuples
            A list of the Panel coordinates for which the assigned Plot
            objects should have their y-axes linked.
        """
        coords = list(coords)
        verify_plot_position_iter(self, coords)
        self._plots_linked_in_x.append(coords)
        self._plots_linked_in_y.append(coords)
