# SPDX-License-Identifier: BSD-3-Clause
from .plot import Plot
from .axes import (
    Plot2DMixin,
    XRangeMixin,
    YRangeMixin,
)
from itertools import chain
import numpy as np
import inspect
import logging

logger = logging.getLogger(__name__)

_COMPATIBILITY_MIXINS = {
    Plot2DMixin,
    XRangeMixin,
    YRangeMixin,
}
"""
To be compatible plot types, they must share the same set of the above mixins
as each other
"""


class CombinedPlot(
    Plot2DMixin, Plot  # 'highest-level' Mixin the Combined class needs to support
):
    """Display mutliple ADARI Plot types on a single axis.

    There are some instances where multiple plot types should be shown
    together (e.g., a ScatterPlot overlaid over a LinePlot, or over an
    ImagePlot). This Plot type provides a method for generating such plots.

    The data attribute of CombinedPlot (a dict) holds multiple Plot objects
    (as keys), along with a z-order value (as values). Plots with a higher
    z-order value are plotted over those with a lower z-order value.
    """

    def __str__(self):
        return "CombinedPlot"

    def __init__(self, *args, **kwargs):
        if len(args) > 0:
            raise ValueError(
                "{} does not support instantiating with data; "
                "add data using add_data "
                "later".format(self.__str__())
            )
        if kwargs.get("data") is not None:
            raise ValueError(
                "{} does not support instantiating with data; "
                "add data using add_data "
                "later".format(self.__str__())
            )
        super().__init__(*args, **kwargs)
        self._data = {}

    # Overwrite the x and y axis commands to account for having
    # multiple plots in the system
    @XRangeMixin.x_min.getter
    def x_min(self):
        if not (
            XRangeMixin in self._set_of_existing_bases()
            or Plot2DMixin in self._set_of_existing_bases()
        ):
            raise RuntimeError(
                "The plots currently assigned to this "
                "CombinedPlot do not have x-axes."
            )
        if self._x_min is not None:
            return self._x_min
        x_min_to_return = None
        for plot in self._data.keys():
            x_min_to_check = plot.x_min
            if x_min_to_check is not None:
                try:
                    if x_min_to_check < x_min_to_return:
                        x_min_to_return = x_min_to_check
                except TypeError:  # Current is None
                    x_min_to_return = x_min_to_check
        return x_min_to_return

    @XRangeMixin.x_max.getter
    def x_max(self):
        if not (
            XRangeMixin in self._set_of_existing_bases()
            or Plot2DMixin in self._set_of_existing_bases()
        ):
            raise RuntimeError(
                "The plots currently assigned to this "
                "CombinedPlot do not have x-axes."
            )
        if self._x_max is not None:
            return self._x_max
        x_max_to_return = None
        for plot in self._data.keys():
            x_max_to_check = plot.x_max
            if x_max_to_check is not None:
                try:
                    if x_max_to_check > x_max_to_return:
                        x_max_to_return = x_max_to_check
                except TypeError:  # Current is None
                    x_max_to_return = x_max_to_check
        return x_max_to_return

    @YRangeMixin.y_min.getter
    def y_min(self):
        if not (
            YRangeMixin in self._set_of_existing_bases()
            or Plot2DMixin in self._set_of_existing_bases()
        ):
            logger.debug(self._set_of_existing_bases())
            raise RuntimeError(
                "The plots currently assigned to this "
                "CombinedPlot do not have y-axes."
            )
        if self._y_min is not None:
            return self._y_min
        y_min_to_return = None
        for plot in self._data.keys():
            y_min_to_check = plot.y_min
            if y_min_to_check is not None:
                try:
                    if y_min_to_check < y_min_to_return:
                        y_min_to_return = y_min_to_check
                except TypeError:  # Current is None
                    y_min_to_return = y_min_to_check
        return y_min_to_return

    @YRangeMixin.y_max.getter
    def y_max(self):
        if not (
            YRangeMixin in self._set_of_existing_bases()
            or Plot2DMixin in self._set_of_existing_bases()
        ):
            logger.debug(self._set_of_existing_bases())
            raise RuntimeError(
                "The plots currently assigned to this "
                "CombinedPlot do not have y-axes."
            )
        if self._y_max is not None:
            return self._y_max
        y_max_to_return = None
        for plot in self._data.keys():
            y_max_to_check = plot.y_max
            if y_max_to_check is not None:
                try:
                    if y_max_to_check > y_max_to_return:
                        y_max_to_return = y_max_to_check
                except TypeError:  # Current is None
                    y_max_to_return = y_max_to_check
        return y_max_to_return

    @Plot.title.getter
    def title(self):
        own_label = self._title
        if own_label != "":
            return own_label

        sorted_plot_objs = sorted(
            self.data.items(), key=lambda x: -x[1]
        )  # Highest no first
        for plot, zorder in sorted_plot_objs:
            if plot.title != "":
                return plot.title

        return own_label

    @XRangeMixin.x_label.getter
    def x_label(self):
        own_label = self._x_label
        if own_label != "":
            return own_label

        sorted_plot_objs = sorted(
            self.data.items(), key=lambda x: -x[1]
        )  # Highest no first
        for plot, zorder in sorted_plot_objs:
            if plot.x_label != "":
                return plot.x_label

        return own_label

    @YRangeMixin.y_label.getter
    def y_label(self):
        own_label = self._y_label
        if own_label != "":
            return own_label

        sorted_plot_objs = sorted(
            self.data.items(), key=lambda x: -x[1]
        )  # Highest no first
        for plot, zorder in sorted_plot_objs:
            if plot.y_label != "":
                return plot.y_label

        return own_label

    # Super get_xlim and get_ylim so we can change the docstring
    def get_xlim(self):
        """
        Get the x-axis bounds for this CombinedPlot.

        The behaviour of get_xlim is different for a CombinedPlot compared
        to other Plot types:

        - If x_min (or x_max) is set on the CombinedPlot object itself, this
          value will be returned for the lower (or upper) x bound.
        - If x_min (or x_max) is not set on the CombinedPlot object, the
          minimum x_min (or maximum x_max) value of the assigned Plot objects
          will be returned instead.

        Returns
        -------
        x_min, x_max : numeric
            The x-bounds of the CombinedPlot object.
        """
        return super().get_xlim()

    def get_ylim(self):
        """
        Get the y-axis bounds for this CombinedPlot.

        The behaviour of get_ylim is different for a CombinedPlot compared
        to other Plot types:

        - If y_min (or y_max) is set on the CombinedPlot object itself, this
          value will be returned for the lower (or upper) y bound.
        - If y_min (or y_max) is not set on the CombinedPlot object, the
          minimum y_min (or maximum y_max) value of the assigned Plot objects
          will be returned instead.

        Returns
        -------
        y_min, y_max : numeric
            The y-bounds of the CombinedPlot object.
        """
        return super().get_ylim()

    def _set_of_existing_bases(self):
        existing_bases = list(inspect.getmro(_.__class__) for _ in self.data.keys())
        existing_bases = set(chain.from_iterable(existing_bases))
        return existing_bases

    def _is_data_valid(self, data, fail_on_none=True, error_on_fail=True):
        s = super()._is_data_valid(
            data, fail_on_none=fail_on_none, error_on_fail=error_on_fail
        )
        if s is not None:
            return s

        try:
            assert isinstance(data, Plot), (
                "Input data for CombinedPlot is " "ADARI Plot objects"
            )
            assert not isinstance(data, CombinedPlot), (
                "Layering multiple " "CombinedPlots is not " "supported"
            )
            # Check that the new Plot is compatible with the existing plots
            existing_bases = self._set_of_existing_bases()
            if len(existing_bases) > 0:
                input_bases = set(inspect.getmro(data.__class__))
                assert input_bases.intersection(
                    _COMPATIBILITY_MIXINS
                ) == existing_bases.intersection(_COMPATIBILITY_MIXINS), (
                    "{} is not plot-compatible with the Plot types already "
                    "assigned to this CombinedPlot".format(data.__str__())
                )
            return True
        except AssertionError as e:
            if error_on_fail:
                raise ValueError(str(e))
            return False

    def add_data(self, data, z_order=None, **kwargs):
        """
        Attach a Plot object to this CombinedPlot.

        Parameters
        ----------
        data : child of :any:`adari_core.plots.plot.Plot`
            The Plot object to attach to the CombinedPlot
        z_order: int, optional
            The z_order of the attached Plot. Plots with a higher z-order value
            are plotted above those with a lower z-order value. Default is None,
            at which point the plot will be assigned a z-order one higher than
            the highest z-order currently assigned.

        Returns
        -------
        zorder : int
            The z-order assigned for the attached Plot

        Raises
        ------
        ValueError
            When the Plot attempting to be attached is already assigned to
            this CombinedPlot
        """
        self._is_data_valid(data, fail_on_none=True, error_on_fail=True)
        if data in self._data.keys():
            raise ValueError(
                "This {} is already assigned to this {}.".format(
                    data.__str__(), self.__str__()
                )
            )
        # Get the max assigned z-order so far
        if z_order is None:
            try:
                z_order_max = np.max([k for k in self._data.values()])
            except ValueError:
                z_order_max = 0
            self._data[data] = z_order_max + 1
        else:
            self._data[data] = int(z_order)
        return self._data[data]

    def remove_data(self, data):
        """
        Attempt to remove a Plot from the CombinedPlot.

        Parameters
        ----------
        data : Child of :any:`adari_core.plots.plot.Plot` object
            The Plot object to attempt to strip from the CombinedPlot
        """
        try:
            assert isinstance(data, Plot)
        except AssertionError:
            raise ValueError("Input data must be a Plot")
        try:
            del self._data[data]
        except KeyError:
            raise ValueError(
                "The specified {} is not attached "
                "to this {}".format(data.__str__, self.__str__)
            )
