# SPDX-License-Identifier: BSD-3-Clause
"""Clipping for ADARI Core plotting.

This module allows for common data scaling variables
to be available for use in different classes (e.g. Colormaps, Histograms).
"""
import numpy as np
from astropy.visualization import PercentileInterval, ZScaleInterval

from adari_core.utils.utils import assert_none_or_numeric, value_none_or_numeric

# -- HELPER FUNCTIONS -- #
### Move into ClippingMixin? ###


def _percentile(image, percentiles):
    """
    Compute data percentile values from image data.

    Parameters
    ----------
    image : :any:`numpy.array`-like
        The input image to analyze.
    percentiles : iterable of numeric values
        The percentile values to compute for `image`.

    Returns
    -------
    percentile_values : list of numeric values
        The requested image percentile values. Note that there is a one-to-one
        correspondence between the requested percentiles and the output values,
        i.e., if the requested percentiles are not ordered, the output
        values will be shown in that order.

    """
    sorted_im = np.sort(image.ravel())
    try:
        return [sorted_im[int(v / 100.0 * image.size)] for v in percentiles]
    except IndexError:
        raise ValueError("Unable to determined percentiles for empty data set")


def _clean_data(data):
    """
    Helper function to clean data by stripping out bad values.

    Parameters
    ----------
    data : iterable of values
        Data to be cleaned

    Returns
    -------
    data_clean : `np.array`
        Cleaned data array, with invalid values removed/masked.

    Raises
    ------
    ValueError
        If input data is empty, or otherwise unbale to be filtered.
    """
    try:
        data_clean = np.array(data)[np.isfinite(data)]
        data_clean = data_clean[~np.isnan(data_clean)]
        assert not np.any(
            np.asarray(data_clean.shape) == 0
        ), "Unable to computer bounds for an empty array"
    except TypeError:
        raise ValueError("The data seems to be empty or None")
    except AssertionError as e:
        raise ValueError(str(e))
    return data_clean


# -- SCALING FUNCTIONS -- #
def clipping_auto(data, percentiles=[25, 50, 75], **kwargs):
    """
    Calculate clipping points based on percentile points.

    Parameters
    ----------
    data : :any:`np.array` or similar
        Data to compute the ranges from.
    percentiles : len-3 iterable of numeric values
        The input percentile values.

    Returns
    -------
    v_min, v_max : numeric
        The computed bounds.
    """
    # Input checking
    if kwargs:
        raise ValueError(
            "Invalid v_clip_kwargs for v_clip=auto. Expecting percentiles only."
        )

    try:
        assert len(percentiles) == 3, (
            "Auto-clipping requires exactly " "three percentile values"
        )
        assert np.all(
            np.logical_and(np.array(percentiles) >= 0.0, np.array(percentiles) <= 100.0)
        ), ("Percentile values " "must be in the range " "[0, 100]")
        assert np.all(
            np.array(
                [
                    (percentiles[i + 1] - percentiles[i]) > 0
                    for i in range(len(percentiles) - 1)
                ]
            )
        ), "Percentages must be monotonically increasing"
    except AssertionError as e:
        raise ValueError(str(e))

    data_clean = _clean_data(data)

    low, median, high = _percentile(data_clean, percentiles)
    v_min = median - 10 * (median - low)
    v_max = median + 10 * (high - median)

    return v_min, v_max


def clipping_sigma(data, nsigma=3, **kwargs):
    """
    Provide data bounds based off simple sigma clipping.

    Parameters
    ----------
    data : :any:`np.array` or similar
        The data to calculate bounds for
    nsigma : int
        The number of sigmas to clip above/below the median.

    Returns
    -------
    v_min, v_max : numeric
        The computed bounds.
    """

    if kwargs:
        raise ValueError(
            "Invalid v_clip_kwargs for v_clip=sigma. Expecting nsigma only."
        )

    try:
        nsigma = int(nsigma)
        assert nsigma > 0, "nsigma must be greater than zero"
    except TypeError:
        raise ValueError(
            "Unable to interpret nsigma={} as an " "integer".format(nsigma)
        )
    except AssertionError as e:
        raise ValueError(str(e))

    data_clean = _clean_data(data)

    sig = np.std(data_clean)
    med = np.median(data_clean)

    return med - nsigma * sig, med + nsigma * sig


def clipping_mad(data, nmad=3, **kwargs):
    """
    Provide data bounds based off simple median absolute deviation (MAD)
    clipping.

    Parameters
    ----------
    data : :any:`np.array` or similar
        The data to calculate bounds for
    nmad : int
        The number of MADs to clip above/below the median.

    Returns
    -------
    v_min, v_max : numeric
        The computed bounds.
    """

    if kwargs:
        raise ValueError("Invalid v_clip_kwargs for v_clip=mad. Expecting nmad only.")

    try:
        nmad = int(nmad)
        assert nmad > 0, "nsigma must be greater than zero"
    except TypeError:
        raise ValueError("Unable to interpret nsigma={} as an " "integer".format(nmad))
    except AssertionError as e:
        raise ValueError(str(e))

    data_clean = _clean_data(data)

    med = np.median(data_clean)
    mad_array = np.abs(data_clean - med)
    mad = np.median(mad_array)

    return med - nmad * mad, med + nmad * mad


def clipping_minmax(data, **kwargs):
    """
    Provide data bounds based off the min and max values of the data
    clipping.

    Parameters
    ----------
    data : :any:`np.array` or similar
        The data to calculate bounds for

    Returns
    -------
    v_min, v_max : numeric
        The computed bounds.
    """
    if kwargs:
        raise ValueError(
            "Invalid v_clip_kwargs for v_clip=minmax. Expecting empty kwargs only."
        )

    data_clean = _clean_data(data)

    return np.min(data_clean), np.max(data_clean)


def clipping_percentile(data, percentile=80.0, **kwargs):
    """
    Interval based on a keeping a specified fraction of pixels.

    Parameters
    ----------
    data : :any:`np.array` or similar
        The data to calculate bounds for
    percentile : float
        [0-100] The fraction of pixels that should be within the computed limits.
        Those limits enclose the center of the pixel distribution.

    n_samples : int, optional
        Maximum number of values to use. If this is specified, and there
        are more values in the dataset as this, then values are randomly
        sampled from the array (with replacement).

    Returns:
    ------------
    Percentile : [float, float]
        The computed lower and upper limits

    """

    if kwargs:
        raise ValueError(
            "Invalid v_clip_kwargs for v_clip=percentile. Expecting percentile only."
        )

    try:
        # percentile = int(percentile)
        assert (
            percentile >= 0 and percentile <= 100
        ), "percentile must be in the range [0-100]"

    except TypeError:
        raise ValueError(
            "Unable to interpret percentile={} as an " "integer".format(percentile)
        )
    except AssertionError as e:
        raise ValueError(str(e))

    data_clean = _clean_data(data)
    return PercentileInterval(percentile).get_limits(data_clean)


def clipping_val(data, low=-100.0, high=100.0, **kwargs):
    """
    Return input bounds

    Parameters
    ----------
    data : :any:`np.array` or similar
        The data to calculate bounds for
    low : float
        The low value to clip
    high : float
        The high value to clip

    Returns
    -------
    v_min, v_max : numeric
        The computed bounds.
    """
    if kwargs:
        raise ValueError(
            "Invalid v_clip_kwargs for v_clip=val. Expecting low and high values only only."
        )

    return low, high


def clipping_zscale(data, **kwargs):
    """
    Return input bounds

    Parameters
    ----------
    data : :any:`np.array` or similar
        The data to calculate bounds for

    Returns
    -------
    v_min, v_max : numeric
        The computed bounds.
    """
    if kwargs:
        raise ValueError(
            "Invalid v_clip_kwargs for v_clip=zscale. Expecting empty kwargs only."
        )

    data_clean = _clean_data(data)
    return ZScaleInterval().get_limits(data_clean)


# Specify the v-scaling functions available
"""
The association of clipping methods to string keys.
"""
_V_CLIPPING_METHODS = {
    "auto": clipping_auto,
    "sigma": clipping_sigma,
    "mad": clipping_mad,
    "minmax": clipping_minmax,
    "percentile": clipping_percentile,
    "val": clipping_val,
    "zscale": clipping_zscale,
}


class ClippingMixin(object):
    """
    Mixin for ??? to support clipping data of various types.
    Currently includes four different clipping methods
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # some variables with sensible defaults
        # whether v_min/v_max or v_clip/v_clip_kwargs take precedence should be
        # determined in the backend routines (e.g. plot_histogram in matplotlib.py)
        self._v_min = kwargs.get("v_min", None)
        self._v_max = kwargs.get("v_max", None)
        self._v_clip = kwargs.get("v_clip", "sigma")
        self._v_clip_kwargs = kwargs.get("v_clip_kwargs", {})

    # GETTER/SETTER FUNCTIONS
    @property
    def v_max(self):
        """
        The *assigned* v_max value for this scale.

        Note that this property sets/holds the *assigned* value for the upper
        end of the scale. If :any:`v_clip` and :any:`v_clip_kwargs`
        are set, this value of :any:`v_max` will override that calculation.

        Setting this value to None signals that it should be computed from
        the specified :any:`v_clip` clipping if set, otherwise the backend is
        free to choose its own upper bound when plotting.
        """
        return self._v_max

    @v_max.setter
    def v_max(self, v):
        try:
            assert_none_or_numeric(v)
        except AssertionError as e:
            raise ValueError(str(e))
        self._v_max = v

    @property
    def v_min(self):
        """
        The *assigned* v_min value for this scale.

        Note that this property sets/holds the *assigned* value for the lower
        end of the scale. If :any:`v_clip` and :any:`v_clip_kwargs`
        are set, this value of :any:`v_min` will override that calculation.

        Setting this value to None signals that it should be computed from
        the specified :any:`v_clip` clipping if set, otherwise the backend is
        free to choose its own lower bound when plotting.
        """
        return self._v_min

    @v_min.setter
    def v_min(self, v):
        try:
            assert_none_or_numeric(v)
        except AssertionError as e:
            raise ValueError(str(e))
        self._v_min = v

    def get_vlim(self):
        """
        Return the v-limits (min & max) of this Plot.

        The limits returned are computed as follows:

        - If :any:`v_min` and/or :any:`v_max` are set, those
          values take precedence;
        - If a value for :any:`v_min` and/or :any:`v_max` isn't set,
          a value will be
          computed based off the :any:`v_clip` property, if possible;
        - Otherwise, :obj:`None` will be returned for the missing limit(s).

        Returns
        -------
        (numeric, numeric)
            The bounds of the data range, i.e., ``(v_min, v_max)``.
        """
        v_min_r = None
        v_max_r = None

        # since the histogram stores self._data as a dict, we have to cater
        # for that case here...
        if self._v_clip is not None:
            if isinstance(self._data, dict):
                # I suppose we could average v_min_c and v_max_c here?
                minval = []
                maxval = []
                for key, val in self._data.items():
                    if "data" in val:
                        v_min_c, v_max_c = _V_CLIPPING_METHODS[self._v_clip](
                            val["data"], **self._v_clip_kwargs
                        )
                        minval.append(v_min_c)
                        maxval.append(v_max_c)
                v_min_c = np.average(minval)
                v_max_c = np.average(maxval)
            else:
                v_min_c, v_max_c = _V_CLIPPING_METHODS[self._v_clip](
                    self._data, **self._v_clip_kwargs
                )

        else:
            v_min_c = v_max_c = None

        if self.v_min is None:
            v_min_r = v_min_c
        else:
            v_min_r = self.v_min

        if self.v_max is None:
            v_max_r = v_max_c
        else:
            v_max_r = self.v_max

        return v_min_r, v_max_r

    def set_vlim(self, vmin=None, vmax=None):
        """Define the v limits of this data scale.

        Parameters
        ----------
        vmin, vmax : numeric values (int, float, etc.)
            The minimum and maximum bounds of the data range of this
            Plot. Either can be set to None. A single positional argument may
            be supplied instead, of the form `(vmin, vmax)`.
        """
        if vmax is None and np.iterable(vmin):
            vmin, vmax = vmin
        value_none_or_numeric(vmin)
        value_none_or_numeric(vmax)
        self.v_min = vmin
        self.v_max = vmax

    @property
    def v_clip(self):
        """
        The clipping method to use for this data scale.

        Should match one of the keys of :any:`_V_CLIPPING_METHODS`. This
        property should be set with the :any:`set_v_clip_method` helper
        function.
        """
        return self._v_clip

    @property
    def v_clip_kwargs(self):
        """
        The kwargs associated with the clipping method given by self.v_clip
        """
        return self._v_clip_kwargs

    @v_clip.setter
    def v_clip(self, m):
        if m in _V_CLIPPING_METHODS.keys():
            self._v_clip = m
        else:
            raise ValueError(
                "Clipping method must be one of: "
                "{}".format(",".join(_V_CLIPPING_METHODS.keys()))
            )

    def set_v_clip_method(self, m, **kwargs):
        """
        Set the v-clipping method for this Plot.

        The clipping method describes how the plot backends will scale
        the v-axis (i.e., the data) of the output Plot. This is analogous
        to setting the x- and y-limits of the axes.

        Parameters
        ----------
        m : str
            The clipping method to use. Must be a key to
            :obj:`_V_CLIPPING_METHODS`.
        kwargs : keyword arguments
            The keyword arguments from this function are passed directly on
            to the clipping function specified.
        """
        # Run the kwargs through the desired clipping function with dummy
        # data - this ensures we have all the necessary kwargs
        dummy_data = np.ones((10, 10))
        try:
            _, _ = _V_CLIPPING_METHODS[m](dummy_data, **kwargs)
        except TypeError:
            raise ValueError(
                "The kwargs you provided for {} are "
                "incomplete - please review the documentation "
                "for {}".format(m, _V_CLIPPING_METHODS[m])
            )
        except KeyError as e:
            raise ValueError("{} is not a valid clipping method".format(str(e)))

        self.v_clip = m
        self._v_clip_kwargs = kwargs
