# SPDX-License-Identifier: BSD-3-Clause
"""
Specifications for the creating of image-based Plots.
"""

from .plot import Plot
from .axes import Plot2DMixin, XOriginMixin, YOriginMixin
from .colormaps import ColorbarMixin
import numpy as np
import math

CENTRAL_IMAGE_PLOT_EXTENT_DEFAULT = 200


class ImagePlot(Plot2DMixin, XOriginMixin, YOriginMixin, ColorbarMixin, Plot):
    """
    Hold the necessary information & data to plot an image within ADARI.
    """

    def __init__(self, *args, **kwargs):
        # for some reason this has to be *before* the super().__init__ to work?
        self.rotate = kwargs.get("rotate", None)
        self.flip = kwargs.get("flip", None)
        # default is None = antialiasing for png output
        self.interp = kwargs.get("interpolation", None)
        self.cbar_kwargs = kwargs.get("cbar_kwargs", None)
        super().__init__(*args, **kwargs)
        # Alter the default x_label and y_label
        self.x_label = kwargs.get("x_label", "x")
        self.y_label = kwargs.get("y_label", "y")

    def __str__(self):
        return "ImagePlot"

    def _is_data_valid(self, data, fail_on_none=True, error_on_fail=True):
        # Run the base level check
        s = super()._is_data_valid(
            data, fail_on_none=fail_on_none, error_on_fail=error_on_fail
        )
        if s is not None:
            return s

        try:
            # If any of the following fails, data is obviously bad
            # Convert to a numpy array
            tmp_data = np.asarray(data)
            assert (
                tmp_data.ndim == 2
            ), "Image array must be 2-dimensional (ndim={})".format(tmp_data.ndim)
            assert np.issubdtype(
                tmp_data.dtype, np.number
            ), "Image array must have a numeric data type"
            return True
        except AssertionError as e:
            # If we get here, we didn't return True, so return False
            if error_on_fail:
                raise ValueError(str(e))
            return False

    def add_data(self, d):
        """
        Attach image data to this ImagePlot.

        Note that a repeat call of this function will overwrite any existing
        data in the ImagePlot.

        Parameters
        ----------
        d : 2D array of numeric values
            The image data to be attached to this ImagePlot.
        """

        self._is_data_valid(d, fail_on_none=True, error_on_fail=True)
        # flipping occurs before the image rotation
        flipval = self.flip
        if flipval is not None:
            if flipval not in ["x", "y"]:
                raise ValueError("add_data: flip must be either x or y")
            flip_axis = 0 if flipval == "x" else 1
            d = np.flip(d, axis=flip_axis)

        rotval = self.rotate
        if rotval is not None:
            # only accept values of rotate that are multiples of 90 degrees
            d1, d2 = divmod(rotval, 90)
            if d2 == 0:
                d = np.rot90(d, k=d1, axes=(0, 1))
            else:
                raise ValueError("add_data: rotate must be a multiple of 90 degrees")

        self._data = d

    def get_axis_cut(self, axis, axis_pos):
        axis_pos = int(axis_pos)
        if axis == "x":
            return self._data[:, axis_pos]
        elif axis == "y":
            return self._data[axis_pos, :]
        else:
            raise ValueError("Must get axis in either x or y direction")

    def get_central_region(self, extent=200):
        """
        Get the central region (square) of this Plot's image data.

        Parameters
        ----------
        extent : int
            The extent (i.e., width/height) of the region to extract, in pixels.
            If the extent is greater than a particular axis size, the entire
            axis in that direction will be returned.

        Returns
        -------
        img_cent : numpy array
            The data array representing the central region of the ImagePlot
            data.
        """

        # Input checking
        try:
            extent = int(extent)
        except TypeError:
            raise ValueError("Unable to understand {} as an int".format(extent))

        # Find the center locations
        x_cent = self.data.shape[1] // 2
        y_cent = self.data.shape[0] // 2

        # Return the extracted data
        return self.data[
            max(0, y_cent - extent // 2) : min(
                self.data.shape[0] - 1, y_cent + extent // 2
            ),
            max(0, x_cent - extent // 2) : min(
                self.data.shape[1] - 1, x_cent + extent // 2
            ),
        ]

    def get_data_coord(self, ind: int, axis: str, pos="cent"):
        """
        Return a data coordinate for this ImagePlot.

        This function takes an array coordinate, and converts it to a data
        coordinate.

        By default, the returned data coordinate will be of the center of the
        pixel. However, by altering the ``pos`` argument, the user can ask
        for the lower or upper data coordinate bound of the pixel instead.

        The ADARI convention for index-to-data coordinate conversion is that the
        centre of the pixel at index ``[0,0]`` has data coordinates
        of ``(x_origin, y_origin)`` (which defaults to ``(1, 1)``).

        Parameters
        ----------
        ind : int
            The array coordinate (i.e., index) to get the data coordinate for.
        axis : str
            The axis we are requesting the coordinate for. Must be 'x' or 'y'.
        pos : str, optional
            The in-pixel position to request the data coordinate for. Must
            be one of 'low', 'high', or 'cent' (which is the default).

        Returns
        -------
        float
            The requested data coordinate.
        """
        # Input checking
        try:
            ind = math.floor(ind)
        except TypeError:
            raise ValueError("Unable to interpret {} as an integer".format(ind))

        try:
            assert axis == "x" or axis == "y", "Axis must be one of: 'x', 'y'"
        except AssertionError as e:
            raise ValueError(str(e))
        except TypeError:
            raise ValueError("axis must be a str")

        try:
            assert (
                pos == "low" or pos == "high" or pos == "cent"
            ), "pos must be one of: 'low', 'cent', 'high'"
        except AssertionError as e:
            raise ValueError(str(e))
        except TypeError:
            raise ValueError("pos must be a str")

        # Get the origin and step for this axis
        orig = None
        step = None
        if axis == "x":
            orig = self.x_origin
            step = self.x_step
        elif axis == "y":
            orig = self.y_origin
            step = self.y_step

        # Check that values are actually set, error if not
        try:
            assert (
                orig is not None
            ), "Plot {} origin has somehow " "become None!".format(axis)
            assert step is not None, "Plot {} step has somehow become " "None!".format(
                axis
            )
        except AssertionError as e:
            raise RuntimeError(str(e))

        # Do the calculation
        # Remember, the center of the pixel is a half step from the
        # 'index' position of the pixel
        data_pos = None
        if pos == "low":
            data_pos = orig + (ind - 0.5) * step
        elif pos == "cent":
            data_pos = orig + (ind) * step
        elif pos == "high":
            data_pos = orig + (ind + 0.5) * step

        return data_pos

    def get_ind_coord(self, datapos: float, axis: str):
        """
        Return the data array index for a given data coordinate.

        This function takes a data coordinate value, and returns the
        corresponding index value.

        Note that an exact (i.e., float) value
        will be returned; it is up the user to decide what to do with this
        value. Typically, rounding to the nearest integer will give the index
        of the pixel which contains ``datapos``. For data on the exact cusp of
        a pixel (i.e., a pixel coordinate of N.5), it is up to the user to
        decide whether that data point should belong to pixel N or pixel (N+1).

        The ADARI convention for index-to-data coordinate conversion is that the
        centre of the pixel at index ``[0,0]`` has data coordinates
        of ``(x_origin, y_origin)`` (which defaults to ``(1, 1)``).

        This function does *not* throw an Error if the returned index is going
        to be negative.

        Parameters
        ----------
        datapos : float
            The data value to get the index for.
        axis : str
            The axis we are requesting the coordinate for. Must be 'x' or 'y'.

        Returns
        -------
        ind : float
            The *exact* index of the requested data point.
        """

        # Input checking
        # Input checking
        try:
            datapos = float(datapos)
        except TypeError:
            raise ValueError("Unable to interpret {} as an integer" "".format(datapos))

        try:
            assert axis == "x" or axis == "y", "Axis must be one of: 'x', 'y'"
        except AssertionError as e:
            raise ValueError(str(e))
        except TypeError:
            raise ValueError("axis must be a str")

        # Get the origin and step for this axis
        orig = None
        step = None
        if axis == "x":
            orig = self.x_origin
            step = self.x_step
        elif axis == "y":
            orig = self.y_origin
            step = self.y_step

        # Check that values are actually set, error if not
        try:
            assert (
                orig is not None
            ), "Plot {} origin has somehow " "become None!".format(axis)
            assert step is not None, "Plot {} step has somehow become " "None!".format(
                axis
            )
        except AssertionError as e:
            raise RuntimeError(str(e))

        # Do the calculation
        ind = (datapos - orig) / step
        return ind

    def get_image_plot_bounds(self):
        """
        Get the plotting bounds for this image.

        Returns
        -------
        plot_left, plot_right, plot_bottom, plot_top : float
            The values necessary for the ImagePlot to be exactly bound by
            the plotting axes.
        """
        plot_left = self.x_origin - self.x_step // 2 - 0.5
        plot_right = plot_left + self.data.shape[1] * self.x_step
        plot_bottom = self.y_origin - self.y_step // 2 - 0.5
        plot_top = plot_bottom + self.data.shape[0] * self.y_step
        return plot_left, plot_right, plot_bottom, plot_top


class CentralImagePlot(ImagePlot):
    """
    Create a CentralImagePlot from an ImagePlot.

    This subclass of ImagePlot takes a single ImagePlot as its data
    argument, and automatically constructs a ImagePlot representing the central
    region of this data. It does so by computing the correct values for the
    following class attributes:
    - data
    - x_origin & x_step
    - y_origin & y_step
    The user can modify these values after the fact, but it is *not*
    recommended.

    The input ImagePlot should be fully populated with data, origins and steps
    (as required) *before* it is used in the CentralImagePlot.
    """

    def __str__(self):
        return "CentralImagePlot"

    def __init__(self, *args, extent=CENTRAL_IMAGE_PLOT_EXTENT_DEFAULT, **kwargs):
        # Skip over the auto data save so we can use the kwarg
        if len(args) > 0 and args[0] is not None:
            args = list(args)
            data = args.pop(0)
            args = tuple(args)
        else:
            data = None

        # Basic init
        super().__init__(*args, **kwargs)
        self.extent = extent
        self._ref_image = None
        self._x_ind_min = None
        self._x_ind_max = None
        self._y_ind_min = None
        self._y_ind_max = None

        # Re-instate the skipped data insert
        if data is not None:
            self.add_data(data, extent=self.extent)

    @staticmethod
    def _compute_central_array_bounds(im_shape, extent: int):
        """
        Compute the array bounds for a central image extraction.

        Parameters
        ----------
        im_shape : 2-tuple or similar
            The shape of the input image data array
        extent : int
            The extent of the central image to produce, in pixels

        Returns
        -------
        xmin, xmax, ymin, ymax : int
            The array coordinates to use to trim the central image, e.g.,
            for a standard numpy data array, the cutout would be:
            ``[ymin:ymax, xmin:xmax]`` (noting the normal numpy array axis
            convention).
        """

        # Find the center locations
        # The centre location is the array location where the 'high' side
        # of the split starts, i.e.:
        # 'low' split: a[?:C]
        # 'high' split: a[C:?]
        # where C is the center location, and a is the coords plane
        # Odd axis shape: the centre is an exact array coord
        # Even axis shape: the centre is between two array coords
        if im_shape[1] % 2 == 0:
            x_cent = im_shape[1] // 2 - 1
        else:
            x_cent = (im_shape[1] - 1) // 2
        if im_shape[0] % 2 == 0:
            y_cent = im_shape[0] // 2 - 1
        else:
            y_cent = (im_shape[0] - 1) // 2

        # Compute lower and upper bounds for each axis
        # By convention, if an one-pixel asymmetry needs to be applied to
        # get the requested extent back, we add that pixel to the
        if extent % 2 == 0:
            extent_half = extent // 2
            if im_shape[0] % 2 == 0:
                y_min = max(0, y_cent - extent_half + 1)
            else:
                y_min = max(0, y_cent - extent_half)
            if im_shape[1] % 2 == 0:
                x_min = max(0, x_cent - extent_half + 1)
            else:
                x_min = max(0, x_cent - extent_half)
        else:
            extent_half = (extent - 1) // 2
            y_min = max(0, y_cent - extent_half)
            x_min = max(0, x_cent - extent_half)

        # Remember, the way array slicing works, the max values are the array
        # coordinate on point *beyond* what we want
        if im_shape[0] % 2 != 0 and extent % 2 == 0:
            y_max = min(im_shape[0], y_cent + extent_half)
        else:
            y_max = min(im_shape[0], y_cent + extent_half + 1)
        if im_shape[1] % 2 != 0 and extent % 2 == 0:
            x_max = min(im_shape[1], x_cent + extent_half)
        else:
            x_max = min(im_shape[1], x_cent + extent_half + 1)

        return x_min, x_max, y_min, y_max

    def get_ref_image_window_coords(self):
        """
        Get the bounding box coordinates of the CentralImagePlot on the parent
        ImagePlot.

        Returns
        -------
        x_min, x_max, y_min, y_max : float
            The min and max points of the rectangle in both axes, in data
            coordinates. Note that these coordinates are made so to include
            entire pixels.
        """
        x_min_i, x_max_i, y_min_i, y_max_i = self._compute_central_array_bounds(
            self.ref_image.data.shape, self.extent
        )
        x_min = self.ref_image.x_origin + self.ref_image.x_step * (x_min_i - 0.5)
        x_max = self.ref_image.x_origin + self.ref_image.x_step * (x_max_i + 0.5)
        y_min = self.ref_image.y_origin + self.ref_image.y_step * (y_min_i - 0.5)
        y_max = self.ref_image.y_origin + self.ref_image.y_step * (y_max_i + 0.5)

        return x_min, x_max, y_min, y_max

    def _is_data_valid(self, data, fail_on_none=True, error_on_fail=True):
        # No need to run base check, the code below will do the same things
        try:
            assert data is not None, "Input data cannot be None"
            assert isinstance(data, ImagePlot), "Input data must be an " "ImagePlot"
            assert isinstance(data.data, np.ndarray), (
                "Input ImagePlot must " "have a populated data " "array"
            )
            return True
        except AssertionError as e:
            # If we get here, we didn't return True, so return False
            if error_on_fail:
                raise ValueError(str(e))
            return False

    def add_data(self, d, extent=CENTRAL_IMAGE_PLOT_EXTENT_DEFAULT):
        # Input checking & setting
        self._is_data_valid(d)
        self.extent = extent

        x_min, x_max, y_min, y_max = self._compute_central_array_bounds(
            d.data.shape, self.extent
        )

        # Set the data attribute
        self._data = d.data[y_min:y_max, x_min:x_max]
        nx = np.shape(self._data)[0]
        ny = np.shape(self._data)[1]
        if np.all(np.isnan(self._data)):
            self._data = np.empty((nx, ny))

        # Compute the origin and step values required
        if d.x_step is not None:
            self.x_step = d.x_step
            self.x_origin = d.x_origin + d.x_step * x_min
        if d.y_step is not None:
            self.y_origin = d.y_origin + d.y_step * y_min
            self.y_step = d.y_step

        self._ref_image = d

        # Set the axis limits based on the coordinates and origin
        # This ensures that, if used in a CombinedPlot, the CentralImagePlot
        # axes range governs the global plot range (unless overridden by
        # another Plot object with a higher z-order)
        self.x_min = self.get_data_coord(0, "x", "low")
        self.x_max = self.get_data_coord(self.data.shape[1] - 1, "x", "high")
        self.y_min = self.get_data_coord(0, "y", "high")
        self.y_max = self.get_data_coord(self.data.shape[0] - 1, "y", "low")

    @property
    def extent(self):
        """
        The size of the central image cutout (in pixels).
        """
        return self._extent

    @extent.setter
    def extent(self, e):
        try:
            e = int(e)
        except TypeError:
            raise ValueError("Unable to interpret {} as in int".format(e))

        try:
            assert e > 0, "Extent must be greater than 0"
            self._extent = e
        except AssertionError as err:
            raise ValueError(str(err))

    @property
    def ref_image(self):
        """
        The original ImagePlot this CentralImagePlot is derived from.
        """
        return self._ref_image

    @ref_image.setter
    def ref_image(self, x):
        raise RuntimeError("You may not set ref_image directly.")
