# SPDX-License-Identifier: BSD-3-Clause
"""This module specifies the mixins and plot classes for creating histograms.
"""

from .plot import Plot
from .images import ImagePlot
from .axes import Plot2DMixin, XOriginMixin, YOriginMixin
from adari_core.utils.clipping import ClippingMixin
import numpy as np
import logging
import math

logger = logging.getLogger(__name__)

COLORS = "rgbkycm"

HISTOGRAM_DEFAULT_BINS = 20


class HistogramPropsMixins(ClippingMixin, object):
    """Create the required Plot attributes for a histogram."""

    def __init__(self, *args, **kwargs):
        self._bins = None
        super().__init__(*args, **kwargs)
        self.bins = kwargs.get("bins", None)
        # set default to be sigma with nsigma=1
        self._v_clip = kwargs.get("v_clip", "sigma")
        self._v_clip_kwargs = kwargs.get("v_clip_kwargs", {})
        if self._v_clip == "sigma" and not self._v_clip_kwargs:
            self._v_clip_kwargs = {"nsigma": 1}

    @property
    def bins(self):
        """
        int : The number of bins.
        """
        return self._bins

    @bins.setter
    def bins(self, bin_count):
        # Input checking
        if bin_count is None:
            self._bins = None
            return

        try:
            bin_count = int(bin_count)
        except TypeError:
            raise ValueError(f"Unable to cast {bin_count} to an int")
        try:
            assert bin_count > 0, "Bin count must be > 0"
        except AssertionError as e:
            raise ValueError(str(e))
        self._bins = bin_count

    def set_bins(self, bin_size=3.0):
        """Set the number of bins given a bin size.

        Parameters
        ----------
        - bin_size: float
        The (initial) bin width to attempt.
        """
        # Input checking
        try:
            bin_size = int(bin_size)
        except TypeError:
            raise ValueError(f"Unable to cast {bin_size} to an integer value")
        vmin, vmax = self.get_vlim()
        vmin = int(vmin)
        vmax = int(vmax)
        logger.debug(f"Have vmin={vmin}, vmax={vmax}, delta={vmax - vmin}")
        if vmin is not None and vmax is not None:
            nbin = int((vmax - vmin) / bin_size)
            if self.bins is None:
                self.bins = nbin
            else:
                self.bins = min(nbin, self.bins)
        else:
            raise ValueError("Input vmin/vmax values cannot be None")

    def set_int_bins(self, bin_size: int = 1):
        """
        Set the number of bins given a bin size, assuming all data are integers.

        Parameters
        ----------
        - bin_size: int
        """
        vmin, vmax = self.get_vlim()
        if vmin is None or vmax is None:
            raise ValueError("Input vmin/vmax values cannot be None")

        logger.debug(
            f"Started with vmin={vmin}, vmax={vmax}, bin_size={bin_size}, nbin={self.bins}"
        )

        if self.bins is None:
            self.bins = HISTOGRAM_DEFAULT_BINS
        init_bin_size = (vmax - vmin + 1) / self.bins
        logger.debug(
            f"int_bin_size={init_bin_size}, delta={vmax-vmin}, nbin={self.bins}"
        )

        if init_bin_size < bin_size:
            final_nbin = math.ceil(
                (
                    vmax
                    - vmin
                    + (vmin - math.floor(vmin / bin_size) * bin_size) / bin_size
                    + 0.5
                )
            )
            final_bin_size = bin_size
        else:
            final_bin_size = math.floor(init_bin_size + 0.5)
            final_nbin = self.bins

        x_mid = (vmax - vmin) / 2.0 + vmin
        new_vmin = final_bin_size * (
            math.floor(x_mid / final_bin_size) - math.floor(final_nbin / 2.0) - 0.5
        )
        new_vmax = new_vmin + final_nbin * final_bin_size

        self.set_vlim(new_vmin, new_vmax)
        self.bins = final_nbin

        logger.debug(
            f"Ended with vmin={new_vmin}, vmax={new_vmax}, bin_size={final_bin_size}, final_nbin={final_nbin}"
        )


class HistogramPlot(
    Plot2DMixin, XOriginMixin, YOriginMixin, HistogramPropsMixins, Plot
):
    """
    A Plot for displaying data in a histogram format.

    The user is only required to supply an array of data points and a bin
    size. The plotting backend(s) will automatically split these data into
    bins as requested.

    Parameters
    ----------
    master_data : optional, iterable of numeric values, or None
        Data to be displayed with the special 'master' formatting. Defaults
        to None.
    raw_data : optional, iterable of numeric values, or None
        Data to be displayed with the special 'raw' formatting. Defaults to
        None.
    bins : optional, int, or none
        The number of data bins. Defaults to None, at which point the plotting
        backend will automatically bin the data.
    """

    # FIXME Need to convert to allow multiple data sets
    # How to deal with binning in this scenario?
    # A: Take a 'bins' parameter for each data set
    # Note, this may also include custom bins

    def __init__(self, *args, master_data=None, raw_data=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.x_label = kwargs.get("x_label", "counts")
        self.y_label = kwargs.get("y_label", "frequency")
        self.master_label = kwargs.get("master_label", "master")
        self.raw_label = kwargs.get("raw_label", "raw")
        # Requirements state that default y scale should be log for histograms
        self.y_scale = "log"
        self._data = {}
        if raw_data is not None:
            self.add_raw_data(raw_data)
        if master_data is not None:
            self.add_master_data(master_data)

    def __str__(self):
        return "HistogramPlot"

    def is_logscale(self):
        return self.y_scale == "log"

    def _is_data_valid(self, data, fail_on_none=True, error_on_fail=True):
        # Run the base level check
        s = super()._is_data_valid(
            data, fail_on_none=fail_on_none, error_on_fail=error_on_fail
        )
        if s is not None:
            return s

        try:
            if isinstance(data, ImagePlot):
                return True  # Support for Imageplot data types
            # If any of the following fails, data is obviously bad
            # Convert to a numpy array. Only requirement at the moment.
            tmp_data = np.asarray(data)
            assert np.issubdtype(
                tmp_data.dtype, np.number
            ), "Histogram data must be convertible to a numpy array of numbers"
            # Histogram can handle any type of data
            # with any dimensions (flattened)
            return True
        except AssertionError as e:
            # If we get here, we didn't return True, so return False
            if error_on_fail:
                raise ValueError(str(e))
            return False

    def add_data(self, data, *args, label=None, color=None, **kwargs):
        """
        Add data to this Plot.

        Parameters
        ----------
        d : iterable of numeric values
            The data to be added. Can be any number of dimensions, as it is
            flattened during storage.
        label : str
            The label to give this particular data set. This must be unique
            for this Plot object.
        color : str
            A valid color specification to plot these data with. Defaults to
            None, at which point a random color is assigned.

        Raises
        ------
        ValueError
            If the data is invalid, or if no label is given.
        """
        self._is_data_valid(data, fail_on_none=True, error_on_fail=True)
        if self._data is None:
            self._data = {}  # Compatibility with parent Plot functionality
        if label in self._data.keys():
            raise ValueError(
                "This {} already contains data "
                "labelled '{}'".format(self.__str__(), label)
            )
        if color is None:
            color = color = COLORS[len(self._data) % len(COLORS)]
        if isinstance(data, ImagePlot):
            # Extract data from ImagePlot obj and then flatten
            data_flattened = np.array(data.data).flatten()
        else:
            data_flattened = np.array(data).flatten()
        self._data[label] = {"data": data_flattened, "color": color}

    def add_master_data(self, data):
        """
        Helper for setting master data with preconfigured colors and label.

        This function is equivalent to calling:
        :any:`add_data` `(data, label="master", color="red")`.

        Parameters
        ----------
        d : iterable of numeric values
            The data to be added. Can be any number of dimensions
            as it is flattened during storage. Typically, master data is a
            2D image, but this function does not enforce that restriction.

        Raises
        ------
        ValueError
            If the data is invalid
        """
        self.add_data(data, label=self.master_label, color="red")

    def add_raw_data(self, data):
        """
        Helper for setting raw data with preconfigured colors and label.

        This function is equivalent to calling:
        :any:`add_data` `(data, label="raw", color="black")`.

        Parameters
        ----------
        d : iterable of numeric values
            The data to be added. Can be any number of dimensions as it is
            flattened  duing storage. Typically, raw data is a
            2D image, but this function does not enforce that restriction.

        Raises
        ------
        ValueError
            If the data is invalid
        """
        self.add_data(data, label=self.raw_label, color="black")

    def remove_data(self, key):
        """
        Remove the given data set from the Plot.

        Parameters
        ----------
        key : str
            The data key to access the data.

        Returns
        -------
        data : numpy array
            The data removed. `None` will be returned if the data was not
            found.
        """
        return self._data.pop(key, None)
