# SPDX-License-Identifier: BSD-3-Clause
"""This module specifies the CollapsePlot type for ADARI.

Collapse Plots inherit most fucntionality from the :any:`LinePlot` from
:any:`adari_core.plots.points`, but takes a 2D image as data instead of x- and
y-points.
"""
from .points import LinePlot, COLORS
from .images import ImagePlot

import numpy as np


class CollapsePlot(LinePlot):
    def __str__(self):
        return "CollapsePlot"

    def __init__(self, master_data=None, collapse_ax=None, *args, **kwargs):
        super().__init__(*args, **kwargs)

        if collapse_ax is None and master_data is None:
            raise ValueError(
                "If `collapse_ax` is not None, `master_data` must" " also not be None"
            )

        if master_data is not None and collapse_ax is not None:
            self.add_data(master_data, collapse_ax, color=kwargs.get("color", None))

        # allow the plot x_label to be overridden here besides the default
        self.x_label = kwargs.get("x_label", collapse_ax)

    def add_data(self, master_data, collapse_ax, *args, color=None, **kwargs):
        """
        Add data to this CollapsePlot object.

        Parameters
        ----------
        master_data : 2D iterable of numeric values, or an :any:`ImagePlot`
            The input data the CollapsePlot is built from. If an
            :any:`ImagePlot` is supplied, that Plot's data will be copied into
            this Plot. No link between the two Plot objects will be made or
            retained.
        collapse_ax : str, either 'x' or 'y'
            The axis along which to collapse the image data to make
            the CollapsePlot.
        color : str, optional
            A color description for the plotted line. Defaults to None, at
            which point a random color will be assigned.

        Raises
        ------
        ValueError
            If data is already assigned to this CollapsePlot.
        """
        # Current requirements interpreted as CollapsePlot only having one
        # line at a time
        #       if len(self._data) > 0:
        #           raise ValueError("CollapsePlot only supports plotting one line at a"
        #                            " time")

        # Provide support to pass in an ImagePlot instead
        if isinstance(master_data, ImagePlot):
            # Copy object of imageplot to avoid conflicts with the imageplot
            toCollapse = master_data.data.copy()
        else:
            toCollapse = master_data

        self._is_data_valid(toCollapse, fail_on_none=True, error_on_fail=True)

        # Collapse in a specific direction using numpy based on collapse_ax
        if collapse_ax == "x":
            collapsed_data = np.nanmean(toCollapse, axis=1)
        elif collapse_ax == "y":
            collapsed_data = np.nanmean(toCollapse, axis=0)
        else:
            raise ValueError("collapse_ax must be 'x' or 'y'")

        if color is None:
            color = COLORS[len(self._data) % len(COLORS)]

        data_label = "Avg of cols" if collapse_ax == "x" else "Avg of rows"
        self._data[data_label] = {"data": collapsed_data, "color": color}

    # Ensure plot is an 2D array, like an image
    def _is_data_valid(self, data, fail_on_none=True, error_on_fail=True):
        try:
            # If any of the following fails, data is obviously bad
            # Convert to a numpy array
            tmp_data = np.asarray(data)
            assert (
                tmp_data.ndim == 2
            ), "Image array must be 2-dimensional (ndim={})".format(tmp_data.ndim)
            assert np.issubdtype(
                tmp_data.dtype, np.number
            ), "Image array must have a numeric data type"
            return True
        except AssertionError as e:
            # If we get here, we didn't return True, so return False
            if error_on_fail:
                raise ValueError(str(e))
            return False
