# SPDX-License-Identifier: BSD-3-Clause
"""Module for calculating 'ports' on ESO detectors.

This module interprets the assorted ESO DET header keywords from
ESO-supplied FITS files, in order to determine the various valid
readout windows.

ESO uses two makes of detector controller:
- FIERA
- NGC
Each controller has a distinct specification for ESO DET header
keywords, so an interpreter for each detector controller is provided
separately. A helper function is provided to determine which controller
is in use for a given FITS file.
"""
import astropy.io.fits
import astropy.io.fits as fits
import numpy as np
import re
import logging

from adari_core.utils.utils import get_nth_prime

logger = logging.getLogger(__name__)

DET_OUT_X_STR = r"\s*(HIERARCH )? ESO DET OUT(?P<n>[0-9]+) X\s*$"
DET_OUT_X_REGEX = re.compile(DET_OUT_X_STR)
DET_OUT_X_FMT = "HIERARCH ESO DET OUT{} X"
DET_OUT_Y_STR = r"\s*(HIERARCH )? ESO DET OUT(?P<n>[0-9]+) Y\s*$"
DET_OUT_Y_REGEX = re.compile(DET_OUT_Y_STR)
DET_OUT_Y_FMT = "HIERARCH ESO DET OUT{} Y"
DET_OUT_NX_STR = r"\s*(HIERARCH )? ESO DET OUT(?P<n>[0-9]+) NX\s*$"
DET_OUT_NX_REGEX = re.compile(DET_OUT_NX_STR)
DET_OUT_NX_FMT = "HIERARCH ESO DET OUT{} NX"
DET_OUT_NY_STR = r"\s*(HIERARCH )? ESO DET OUT(?P<n>[0-9]+) NY\s*$"
DET_OUT_NY_REGEX = re.compile(DET_OUT_NY_STR)
DET_OUT_NY_FMT = "HIERARCH ESO DET OUT{} NY"
DET_OUT_OVSC_X_STR = r"\s*(HIERARCH )? ESO DET OUT(?P<n>[0-9]+) OVSCX\s*$"
DET_OUT_OVSC_X_REGEX = re.compile(DET_OUT_OVSC_X_STR)
DET_OUT_OVSC_X_FMT = "HIERARCH ESO DET OUT{} OVSCX"
DET_OUT_OVSC_Y_STR = r"\s*(HIERARCH )? ESO DET OUT(?P<n>[0-9]+) OVSCY\s*$"
DET_OUT_OVSC_Y_REGEX = re.compile(DET_OUT_OVSC_Y_STR)
DET_OUT_OVSC_Y_FMT = "HIERARCH ESO DET OUT{} OVSCY"
DET_OUT_PRSC_X_STR = r"\s*(HIERARCH )? ESO DET OUT(?P<n>[0-9]+) PRSCX\s*$"
DET_OUT_PRSC_X_REGEX = re.compile(DET_OUT_PRSC_X_STR)
DET_OUT_PRSC_X_FMT = "HIERARCH ESO DET OUT{} PRSCX"
DET_OUT_PRSC_Y_STR = r"\s*(HIERARCH )? ESO DET OUT(?P<n>[0-9]+) PRSCY\s*$"
DET_OUT_PRSC_Y_REGEX = re.compile(DET_OUT_PRSC_Y_STR)
DET_OUT_PRSC_Y_FMT = "HIERARCH ESO DET OUT{} PRSCY"
DET_OUT_INDEX_STR = r"\s*(HIERARCH )?ESO DET OUT(?P<n>[0-9]+) INDEX\s*$"
DET_OUT_INDEX_REGEX = re.compile(DET_OUT_INDEX_STR)
DET_OUT_INDEX_FMT = "HIERARCH ESO DET OUT{} INDEX"
DET_CHIP_NX = "HIERARCH ESO DET CHIP NX"
DET_CHIP_NY = "HIERARCH ESO DET CHIP NY"
DET_BINX = "HIERARCH ESO DET BINX"
DET_BINY = "HIERARCH ESO DET BINY"
NAXIS1 = "NAXIS1"
NAXIS2 = "NAXIS2"

logger = logging.getLogger(__name__)


def _identify_det_regions(hdu):
    """
    Identify the numbered detector regions in a given HDU.

    Parameters
    ----------
    hdu : astropy.io.fits.ImageHDU
        The ImageHDU to extract ports values from.

    Returns
    -------
    ports : list
        A list of ints with the port indexes, in numerical order. An empty list
        is returned if no ports are found.

    Raises
    ------
    RuntimeError
        If ESO DET<n> INDEX != n, a RuntimeError will be raised (this likely
        indicates something has gone wrong with the generation of the
        input ImageHDU).
    """
    # Look in the header for OUT DET INDEX keywords

    index_kw = [
        _ for _ in [re.match(DET_OUT_INDEX_REGEX, k) for k in hdu.header.keys()] if _
    ]

    regions = []

    for kw in index_kw:
        # try:
        #     assert int(hdu.header.get(kw[0])) == int(kw.group('n')), \
        #         f"DET OUT INDEX value mismatch " \
        #         f"{int(hdu.header.get(kw[0]))} /= {kw.group('n')}"
        # except AssertionError as e:
        #     raise RuntimeError(str(e))
        regions.append(int(kw.group("n")))

    regions.sort()
    return regions


def _detect_binned_headerkw(hdu):
    """
    Most ESO data express their PRSC and OVSC keywords in terms of data units
    (i.e., they already incorporate binning). However, some instruments
    (e.g., ESPRESSO) do not. Therefore, we need to intelligently figure out
    if those keywords incorporate the binning or not.

    Parameters
    ----------
    hdu : ImageHDU
        The HDU containing the detector image, and associated header.

    Returns
    -------
    binned : bool
        Are the header keywords binned (True) or not (False).
    """
    pass


def _get_port_information(hdu, port):
    try:
        # Get the port values
        # Deliberately use the key syntax to ensure we find any missing
        x = hdu.header[DET_OUT_X_FMT.format(port)]
        y = hdu.header[DET_OUT_Y_FMT.format(port)]
        nx = hdu.header[DET_OUT_NX_FMT.format(port)]
        ny = hdu.header[DET_OUT_NY_FMT.format(port)]
        ovscx = hdu.header[DET_OUT_OVSC_X_FMT.format(port)]
        ovscy = hdu.header[DET_OUT_OVSC_Y_FMT.format(port)]
        prscx = hdu.header[DET_OUT_PRSC_X_FMT.format(port)]
        prscy = hdu.header[DET_OUT_PRSC_Y_FMT.format(port)]
    except KeyError as e:
        raise RuntimeError(f"Incomplete info on port {port} : {str(e)}")
    return x, nx, prscx, ovscx, y, ny, prscy, ovscy


def _get_chip_information(hdu):
    try:
        # Get the chip NX and NY
        nx = hdu.header[DET_CHIP_NX]
        ny = hdu.header[DET_CHIP_NY]
    except KeyError as e:
        raise RuntimeError(f"Incomplete chip information on this HDU " f"({str(e)})")

    return nx, ny


def _compute_port_extents(n, prsc, ovsc, bin):
    return (n + prsc + ovsc) // bin


def _update_port_array_lists(
    port=None,
    binx=1,
    biny=1,
    x=None,
    nx=None,
    prscx=None,
    ovscx=None,
    y=None,
    ny=None,
    prscy=None,
    ovscy=None,
    port_array_lists=None,
):

    # Input checking
    try:
        assert not np.isnan(
            np.array([port, binx, biny, x, nx, prscx, ovscx, y, ny, prscy, ovscy])
        ).any(), ("Must pass a full argument set to " "_generate_port_array_lists")
        assert not np.any(np.array([binx, biny, nx, ny]) < 1), (
            "Must not have spacing variables that " "are 0 or negative"
        )
        assert not np.any(np.array([prscx, prscy, ovscx, ovscy]) < 0), (
            "Pre/overscan cannot " "be negative"
        )
    except AssertionError as e:
        raise ValueError(str(e))

    # Let's find where this port is *capable* of reading out
    # Note this is different to working out which pixels the port *does*
    # read
    # The general idea is if
    # X/bin +/- NX/bin (and similar for y)
    # exceeds the data array bounds in one direction, the port cannot
    # read the pixels in that direction
    # Don't use the pre/overscan information for this computation
    x_extent = _compute_port_extents(nx, 0, 0, binx)
    y_extent = _compute_port_extents(ny, 0, 0, biny)

    # Remember *index* coords are 0-index, *pixel* coords are 1-index
    y_up = (y - 1) // biny + y_extent <= port_array_lists.shape[0]
    y_down = y // biny - y_extent >= 0
    x_right = (x - 1) // binx + x_extent <= port_array_lists.shape[1]
    x_left = x // binx - x_extent >= 0
    if not (x_left or x_right) or not (y_up or y_down):
        raise RuntimeError("I've found an impossible set of port values!")

    # Use a numpy nditer for most efficient in-situ updates
    i_min = port_array_lists.shape[1]
    i_max = 0
    j_min = port_array_lists.shape[0]
    j_max = 0
    if y_up:
        j_min = min(j_min, (y - 1) // biny)
        j_max = max(j_max, (y - 1) // biny + y_extent)
    if y_down:
        j_min = min(j_min, (y - 1) // biny - y_extent + 1)
        j_max = max(j_max, (y - 1) // biny + 1)
    if x_left:
        i_min = min(i_min, (x - 1) // binx - x_extent + 1)
        i_max = max(i_max, (x - 1) // binx + 1)
    if x_right:
        i_min = min(i_min, (x - 1) // binx)
        i_max = max(i_max, (x - 1) // binx + x_extent)

    port_array_lists[j_min:j_max, i_min:i_max] *= get_nth_prime(port)

    # Sanity check
    assert np.count_nonzero(port_array_lists % get_nth_prime(port) == 0), (
        f"Failed to add " f"port {port}"
    )

    return  # port_array_lists updated in-situ


def _calculate_port_array(hdu, kw_binned=None, binx=1, biny=1, debug=False):
    """
    This private function computes a pixel grid that matches the shape of the
    input ImageHDU, but with the pixel's port number instead of the original
    pixel data.

    Parameters
    ----------
    hdu : ImageHDU
        The HDU containing the detector image, and associated header.

    Returns
    -------
    port_array : numpy array
        An array the same shape as the data array of the input `hdu`, where
        each pixel contains the port number that the pixel is read out
        through.
    """
    # Input checking
    try:
        assert isinstance(hdu, astropy.io.fits.ImageHDU) or isinstance(
            hdu, astropy.io.fits.hdu.image.ImageHDU
        ), f"hdu must be an astropy ImageHDU ({type(hdu)})"
    except AssertionError as e:
        raise ValueError(str(e))

    # FIXME sort this out later
    # if kw_binned is None:
    #     kw_binned = _detect_binned_headerkw(hdu)
    # if kw_binned:
    #     # Null away any binning effects
    #     binx = 1
    #     biny = 1

    logger.debug("Starting sequence")
    # Generate the port number array
    chip_nx, chip_ny = _get_chip_information(hdu)
    port_array_lists = np.full((chip_ny // biny, chip_nx // binx), 1, dtype=int)
    port_array = np.full((chip_ny // biny, chip_nx // binx), -1, dtype=int)
    logger.debug("Created initial port_array & port_array_lists")

    # Determine the port numbers for this HDU
    port_numbers = _identify_det_regions(hdu)

    for port in port_numbers:
        x, nx, prscx, ovscx, y, ny, prscy, ovscy = _get_port_information(hdu, port)

        _update_port_array_lists(
            port=port,
            binx=binx,
            biny=biny,
            x=x,
            nx=nx,
            prscx=prscx,
            ovscx=ovscx,
            y=y,
            ny=ny,
            prscy=prscy,
            ovscy=ovscy,
            port_array_lists=port_array_lists,
        )

    # Confirm there are no pixels without a possible port
    try:
        assert np.count_nonzero(port_array_lists == 1) == 0, (
            f"{np.count_nonzero(port_array_lists == 1)} "
            f"pixels "
            f"have no port possible"
        )
    except AssertionError as e:
        raise RuntimeError(str(e))

    logger.debug("Populated port_array_lists")

    # We have an array with the 'possible' ports. We now need to
    # reduce this down into an array of actual ports. We do this iteratively:
    # - For any list of possible ports of length 1, that's clearly the port
    #   number for those pixels. However, what we really need is for there
    #   to be a number of single-possibility-only pixels with size
    #   x_extent * y_extent.
    # - Then, we repeat the process, but our criterion for a 'found port'
    #   becomes when the number of pixels that can allow the port equals
    #   the size of the port (i.e. x times y, accounting for the binning and
    #   overscan). Again, repeat until we can't.
    # - Once we hit this point, we probably have a mixed set of possibilities
    #   that we can't easily disentagle. The trick here is to identify that
    #   there is normally a row or column on the edge of a region that
    #   has only one port possibility; we need to build the port region
    #   logically from that point.
    logger.debug("Preparing for port search...")

    def is_in(val, it):
        return it % val == 0

    is_in_all = np.vectorize(is_in)

    port_primes = {port: get_nth_prime(port) for port in port_numbers}

    logger.debug("Starting port search...")
    while len(port_numbers) > 0:
        ports_in = len(port_numbers)

        # To maximize our chance for early success, we will sort the
        # port_numbers so that the number with the most single-possibility-only
        # pixels gets processed first (which means last in the list)
        # port_numbers.sort(key=lambda x: np.count_nonzero(
        #     np.bitwise_and(lenall(port_array_lists) == 2,
        #                    is_in_all(port_primes[x], port_array_lists))
        # ))
        # logger.debug(f"   Sorted port list")

        logger.debug(f"    Iterating over {len(port_numbers)} remaining ports")
        for port in port_numbers[::-1]:
            x, nx, prscx, ovscx, y, ny, prscy, ovscy = _get_port_information(hdu, port)
            port_found = False
            port_x_extent = _compute_port_extents(nx, 0, 0, binx)
            port_y_extent = _compute_port_extents(ny, 0, 0, biny)
            port_size = port_x_extent * port_y_extent

            # Single-port-possible search
            if np.count_nonzero(port_array_lists == port_primes[port]) == port_size:
                # Port found!
                logger.debug(f"   Found port {port} via single-port-possible")
                # pdb.set_trace()
                port_found = True
                port_array[port_array_lists == port_primes[port]] = port
            # Extent limited to port bounds search
            elif (
                np.count_nonzero(is_in_all(port_primes[port], port_array_lists))
                == port_size
            ):
                # Port found!
                logger.debug(f"   Found port {port} via extent limited")
                # pdb.set_trace()
                port_found = True
                port_array[is_in_all(port_primes[port], port_array_lists)] = port
            # Edge row/column search
            else:
                # Find where only this port is possible
                single_locs = np.argwhere(port_array_lists == port_primes[port])
                if len(single_locs) > 0:
                    # Get the min and max indices for all the single-option
                    # locations, plus the root pixel location
                    j_min, j_max = min(np.min(single_locs[:, 0]), (y - 1) // biny), max(
                        np.max(single_locs[:, 0]), (y - 1) // biny
                    )
                    i_min, i_max = min(np.min(single_locs[:, 1]), (x - 1) // binx), max(
                        np.max(single_locs[:, 1]), (x - 1) // binx
                    )
                    # Can these bounds make the port? If so, that must be it
                    # Test each bound individually to make sure the
                    # shape is correct
                    if (j_max - j_min + 1 == port_y_extent) and (
                        i_max - i_min + 1 == port_x_extent
                    ):
                        # Port found!
                        logger.debug(
                            f"   Found port {port} via " f"single-opts-plus-root"
                        )
                        port_found = True

                        port_array[j_min : j_max + 1, i_min : i_max + 1] = port
                    # We're now at the point where we need to
                    # do a pure row/column search algorithm
                    else:
                        # Are there correctly shaped single-option rows?
                        row_inds = np.unique(single_locs[:, 0])
                        col_inds = np.unique(single_locs[:, 1])
                        if len(row_inds) * port_y_extent == len(single_locs):
                            # Rows required = port_y_extent
                            # Can we go up?
                            go_up = (
                                np.min(row_inds) + port_y_extent <= port_array.shape[0]
                            )
                            # Can we do down?
                            go_down = np.max(row_inds) - port_y_extent + 1 >= 0
                            if go_up and not go_down:
                                port_found = True
                                # Port found!
                                logger.debug(f"   Found port {port} via " f"rows up")
                                port_array[
                                    np.min(row_inds) : np.min(row_inds) + port_y_extent,
                                    np.min(col_inds) : np.max(col_inds) + 1,
                                ] = port
                            if go_down and not go_up:
                                port_found = True
                                logger.debug(f"   Found port {port} via " f"rows down")
                                port_array[
                                    np.max(row_inds)
                                    - port_y_extent
                                    + 1 : np.max(row_inds)
                                    + 1,
                                    np.min(col_inds) : np.max(col_inds) + 1,
                                ] = port
                        # Are there correctly shaped single-option columns?
                        elif len(col_inds) * port_x_extent == len(single_locs):
                            # Columns required = port_x_extent
                            # Can we go right?
                            go_right = (
                                np.min(col_inds) + port_x_extent <= port_array.shape[1]
                            )
                            # Can we go left?
                            go_left = np.max(col_inds) - port_x_extent + 1 <= 0
                            if go_right and not go_left:
                                # Port found!
                                logger.debug(f"   Found port {port} via " f"cols right")
                                port_found = True
                                port_array[
                                    np.min(row_inds) : np.max(row_inds) + 1,
                                    np.min(col_inds) : np.min(col_inds) + port_x_extent,
                                ] = port
                            if go_left and not go_right:
                                # Port found!
                                logger.debug(f"   Found port {port} via " f"cols left")
                                port_found = True
                                port_array[
                                    np.min(row_inds) : np.max(row_inds) + 1,
                                    np.max(col_inds)
                                    - port_x_extent
                                    + 1 : np.max(col_inds)
                                    + 1,
                                ] = port

            # If found, blank out the parts of port_array_lists that correspond
            # to this port
            if port_found:
                logger.debug(
                    f"   Removing port {port} from consideration "
                    f"({np.count_nonzero(port_array == port)} "
                    f"points)..."
                )
                # Strip this port out of the (remaining)
                # port_array_lists and port_numbers
                port_array_lists[
                    is_in_all(port_primes[port], port_array_lists)
                ] //= port_primes[port]
                port_array_lists[port_array == port] = 1
                port_numbers.remove(port)
                logger.debug("   ...done!")
                # Don't break - might hit multiples of these in one
                # iteration

        # Check this while loop iteration actually did something, otherwise
        # we'll hit a recursion limit
        if len(port_numbers) == ports_in:
            raise RuntimeError(
                f"Unable to differentiate between ports " f"{port_numbers}"
            )

    # Output & sanity checking
    if len(port_numbers) > 0:
        raise RuntimeError(
            f"port_numbers still has {port_numbers} - "
            f"should be empty at this point!"
        )
    if np.any(port_array_lists > 1):
        raise RuntimeError(
            "There's some port possibilities still "
            "lingering in port_array_lists, even though "
            "the assignment loop completed!"
        )
    if np.any(port_array < 0):
        raise RuntimeError(
            f"{np.count_nonzero(port_array == -1)} "
            f"pixels have no port found, despite all "
            f"known ports having been computed"
        )

    return port_array


def _add_prsc_ovsc_to_ports(hdu, port_array, binx=1, biny=1, debug=False):
    """
    Insert pre- and over-scan regions into a computed port array.

    The :any:`_calculate_port_array` function computes a port array which
    is for the real pixels in the input HDU only. This function expands
    the port array to match the shape of the input HDU data plane by inserting
    prescan/overscan rows and columns are required.

    Parameters
    ----------
    hdu : ImageHDU
        The HDU containing the detector image, and associated header.
    port_array : numpy array
        An array the same shape as the data array of the input `hdu` *less the
        pre and overscan rows/columns*, where
        each pixel contains the port number that the pixel is read out
        through. Should be the output of :any:`_calculate_port_array`.

    Returns
    -------
    port_array_exp : numpy array
        A version of port_array that matches the dimensions of the data array
        in the input HDU; that is, the port_array_exp includes the pre- and
        over-scan pixels that are present.
    """
    # Get the ports for this HDU
    ports = _identify_det_regions(hdu)

    # Create an empty port_array_exp
    port_array_exp = np.full(hdu.data.shape, 0, dtype=np.int64)
    port_info = {}

    # Loop through the ports in turn, and set the relevant pixels on the
    # port_array_exp to match that port
    for port in ports:
        # Get the info for this port
        x, nx, prscx, ovscx, y, ny, prscy, ovscy = _get_port_information(hdu, port)
        # Find the min and max positions of this port on the input port array
        port_array_where = np.argwhere(port_array == port)
        x_min = np.min(port_array_where[:, 1])
        x_max = np.max(port_array_where[:, 1])
        y_min = np.min(port_array_where[:, 0])
        y_max = np.max(port_array_where[:, 0])
        logger.debug(
            f"Bounds for port {port} are " f"[{x_min}:{x_max}], [{y_min}:{y_max}]"
        )

        # Determine the port extent, including pre/overscan
        x_extent = _compute_port_extents(nx, prscx, ovscx, binx)
        y_extent = _compute_port_extents(ny, prscy, ovscy, biny)

        # Append this info to the port in the form of:
        # (x_min, x_max, y_min, y_max, x_extent, y_extent)
        port_info[port] = (x_min, x_max, y_min, y_max, x_extent, y_extent)

    # Now we need to cycle through the ports in sequence, populating the
    # extended array
    # We don't care about if we're adding pre or overscan pixels; we're
    # more concerned with getting the ports in the correct distribution,
    # with enough pixels assigned to cover everything
    # This is why we're running off the array coordinates in the input
    # port_array, not the detector x and y
    logger.debug("Port info is:")
    logger.debug("port: (x_min, x_max, y_min, y_max, x_extent, y_extent)")
    for k, v in port_info.items():
        logger.debug(f"{k}: {v}")
    port_exec_list = [(p, v[4], v[5], v[2], v[0]) for p, v in port_info.items()]
    port_exec_list.sort(key=lambda i: (i[3], i[4]))
    logger.debug("Port execution list:")
    logger.debug(port_exec_list)
    for port, x_extent, y_extent, _, _ in port_exec_list:
        logger.debug(f"Inserting port {port}")
        # Find the first position where we could place this
        # N x (y_ind, x_ind) array
        avail_spots = np.argwhere(port_array_exp == 0)
        # Get the spot lowest in y, then lowest in x
        avail_spots_inds = np.lexsort((avail_spots[:, 1], avail_spots[:, 0]))
        # Cycle through the available spots until the port can be assigned
        for inds in avail_spots_inds:
            spot = avail_spots[inds]
            # Can we put this port here? If so, do it and break
            if np.all(
                port_array_exp[
                    spot[0] : spot[0] + y_extent, spot[1] : spot[1] + x_extent
                ]
                == 0
            ):
                port_array_exp[
                    spot[0] : spot[0] + y_extent, spot[1] : spot[1] + x_extent
                ] = port
                break

    # Sanity check the output
    try:
        assert np.count_nonzero(port_array_exp == 0) == 0, (
            "Some pixels " "still have no " "port assigned!"
        )
    except AssertionError as e:
        raise RuntimeError(str(e))

    return port_array_exp


def determine_port_regions(hdu):
    """
    Determine the boundaries of each detector port on the detector image.

    Parameters
    ----------
    hdu : ImageHDU
        The HDU containing the detector image, and associated header keywords.

    Returns
    -------
    port_dict : dict
        A dictionary, describing the bounds of the region covered by each port.
        It is of the form `port_dict[n] = {'x': (x_min, x_max),
        'y': (y_min, y_max}` for port `n`.
    """
    pass


def strip_prescan_regions(
    hdu, strip_x=True, strip_y=True, instrume=None, xbin=1, ybin=1
):
    # Create a new HDU to manipulate - don't want to harm real data
    im_hdu = fits.ImageHDU(data=hdu.data, header=hdu.header)

    return im_hdu


def strip_overscan_regions(
    hdu, strip_x=True, strip_y=True, instrume=None, xbin=1, ybin=1
):
    # Create a new HDU to manipulate - don't want to harm real data
    im_hdu = fits.ImageHDU(data=hdu.data, header=hdu.header)

    return im_hdu


def strip_prescan_overscan(
    hdu, strip_x=True, strip_y=True, instrume=None, xbin=1, ybin=1, **kwargs
):
    ret = strip_prescan_regions(hdu, strip_x=strip_x, strip_y=strip_y, xbin=1, ybin=1)
    ret = strip_overscan_regions(hdu, strip_x=strip_x, strip_y=strip_y, xbin=1, ybin=1)
    return ret


def strip_prescan_overscan_espresso(hdu, debug=False, **kwargs):
    # Create a new HDU to manipulate - don't want to harm real data
    im_hdu = fits.ImageHDU(data=hdu.data, header=hdu.header)

    # We know that the ESPRESSO data is formatted as follows:
    # Ports 1-8 are read from the top-left corner, and run along
    # the top of the CCD;
    # Ports 9-16 are read from the bottom-right corner, and run along
    # the bottom of the CCD

    y_trim_min = None
    y_trim_max = None

    # Figure out the binning
    binx = int(hdu.header.get("CD1_1", 1))
    biny = int(hdu.header.get("CD2_2", 1))

    chipx, chipy = _get_chip_information(hdu)

    logger.debug(f"{chipx} x {chipy} chip w/ xbin={binx}, ybin={biny}")

    for port in _identify_det_regions(im_hdu):
        x, nx, prscx, ovscx, y, ny, prscy, ovscy = _get_port_information(im_hdu, port)
        logger.debug(f"port {port}: {nx} x {ny}")

        if port <= 8:
            y_trim_min = y // biny + (prscy + ny) // biny - 1
        else:
            y_trim_max = y // biny + 2 * ovscy // biny - (prscy + ny) // biny

    logger.debug(
        f"y_trim_min={y_trim_min}, y_trim_max={y_trim_max}, "
        f"delta={y_trim_max - y_trim_min}"
    )

    # Strip out the relevant rows
    im_hdu.data = np.delete(im_hdu.data, range(y_trim_min, y_trim_max), 0)

    return im_hdu


def correct_port_headers_uves(hdulist, debug=False):
    # PIPE-10833 - detect if we have 4pts/625kHz/lg or 2pts/625kHz/lg data,
    # and correct port headers accordingly
    readspeed = hdulist[0].header.get("ESO DET READ SPEED")
    if debug:
        logger.debug(f"Dealing with readspeed {readspeed}")
    if readspeed is not None and (
        "4pts/625kHz/lg" in readspeed or "2pts/625kHz/lg" in readspeed
    ):
        for im_hdu in hdulist[1:]:
            if im_hdu.header.get("CRPIX1", 1) < 0:
                im_hdu.header["CRPIX1"] += 50
            # Correct CRPIX if special (i.e. <0)
            for port in _identify_det_regions(im_hdu):
                ovscx = im_hdu.header.get(f"ESO DET OUT{port} OVSCX")
                if debug:
                    logger.debug(
                        f"Updating extname {im_hdu.header.get('EXTNAME')}, "
                        f"port {port} with ovscx={ovscx}"
                    )
                if ovscx is not None:
                    im_hdu.header[f"ESO DET OUT{port} NX"] += ovscx
                    im_hdu.header[f"ESO DET OUT{port} OVSCX"] = 0
                    im_hdu.header[f"ESO DET OUT{port} X"] += 2 * (port - 1) * ovscx


def strip_prescan_overscan_uves(hdu, debug=False, **kwargs):
    # Create a new HDU to manipulate - don't want to harm real data
    im_hdu = fits.ImageHDU(data=hdu.data, header=hdu.header)

    # Figure out the binning
    binx = int(hdu.header.get("CDELT1", 1))
    x_origin_rebin = int(hdu.header.get("CRPIX1", 1) * binx)
    # The x-origin for blue arm data is 0 - this doesn't make much sense
    # given the other port descriptors, so if this happens, co-opt the value
    # to be 1
    if x_origin_rebin == 0:
        x_origin_rebin = 1

    trim_bounds = []
    for port in _identify_det_regions(im_hdu):
        x, nx, prscx, ovscx, _, ny, _, _ = _get_port_information(im_hdu, port)
        LtR = False
        if (
            x == 1 or x + x_origin_rebin + prscx + ovscx - 1 == 1
        ):  # Left-to-right detector read
            LtR = True
            if x == 1:
                if debug:
                    logger.debug("LtR, x==1 x_trim_min")
                x_trim_min = x + x_origin_rebin + prscx - 2
            else:
                if debug:
                    logger.debug("LtR, x!=1 x_trim_min")
                    logger.debug(
                        f"x = {x}, x_origin_rebin={x_origin_rebin}, ovscx={ovscx}, prscx={prscx}"
                    )
                x_trim_min = x + x_origin_rebin + 2 * prscx + ovscx - 2
            x_trim_max = x_trim_min + nx
        else:  # Right-to-left read (i.e., CCD-44/WIN2 in red)
            x_trim_max = x + x_origin_rebin + 1
            x_trim_max //= binx
            x_trim_max += (
                ovscx + 2 * prscx - 3 + binx
            )  # Correct for overscan, pre-scan & index diffs
            if x_origin_rebin > 0:
                x_trim_max -= (port - 1) * (prscx + ovscx)
            if (
                x_trim_max > im_hdu.data.shape[1]
            ):  # Special case for blue 625kHz_2pt - bring back to bounds
                x_trim_max = im_hdu.data.shape[1] - prscx
            x_trim_min = x_trim_max - nx

        trim_bounds.append((x_trim_min, x_trim_max))

        if debug:
            logger.debug("---")
            logger.debug(f"HDU {hdu.name}")
            logger.debug(f"left-to-right read: {LtR}")
            logger.debug(f"port {port} nx x ny: {nx} x {ny}")
            logger.debug(f"x_origin_rebin={x_origin_rebin}, x={x}, binx={binx}")
            logger.debug(f"prscx={prscx}, ovscx={ovscx}")
            logger.debug(
                f"Input data size: {im_hdu.data.shape[1]} x " f"{im_hdu.data.shape[0]}"
            )
            logger.debug(f"x_trim_min={x_trim_min}, x_trim_max={x_trim_max}")
            logger.debug(f"delta={x_trim_max - x_trim_min}")

    # Strip out the relevant rows
    initial_done = False
    initial_data = np.copy(im_hdu.data)
    for trim_bound in trim_bounds:
        if not initial_done:
            im_hdu.data = im_hdu.data[:, trim_bound[0] : trim_bound[1]]
            initial_done = True
        else:
            im_hdu.data = np.concatenate(
                (im_hdu.data, initial_data[:, trim_bound[0] : trim_bound[1]]), axis=1
            )

    return im_hdu


def strip_prescan_overscan_muse(hdu, debug=True, **kwargs):
    # Create a new HDU to manipulate - don't want to harm real data
    im_hdu = fits.ImageHDU(data=hdu.data, header=hdu.header)

    # Figure out the binning
    binx = int(hdu.header.get(DET_BINX, 1))
    biny = int(hdu.header.get(DET_BINY, 1))
    axis_x = int(hdu.header.get(NAXIS1))
    axis_y = int(hdu.header.get(NAXIS2))

    chipx, chipy = _get_chip_information(hdu)

    if debug:
        logger.debug(
            f"{chipx} x {chipy} chip w/ xbin={binx}, ybin={biny}, "
            f"ports={_identify_det_regions(im_hdu)}"
        )

    rows_to_delete = []
    column_to_delete = []

    for port in _identify_det_regions(im_hdu):
        x, nx, prscx, ovscx, y, ny, prscy, ovscy = _get_port_information(im_hdu, port)
        if debug:
            logger.debug(f"port {port}: {nx} x {ny}")

        if port == 1:
            y_lower_min = y - 1
            y_lower_max = y_lower_min + prscy
            rows_to_delete += list(range(y_lower_min, y_lower_max))

            y_upper_max = int(axis_y / 2) - 1
            y_upper_min = y_upper_max - ovscy
            rows_to_delete += list(range(y_upper_min, y_upper_max))

            x_left_min = x - 1
            x_left_max = x_left_min + prscx
            column_to_delete += list(range(x_left_min, x_left_max))

            x_right_max = int(axis_x / 2) - 1
            x_right_min = x_right_max - ovscx
            column_to_delete += list(range(x_right_min, x_right_max))

            if debug:
                logger.debug(f"Removing the following for port {port}")
                logger.debug(
                    f"Columns {y_lower_min} to {y_lower_max}, and {y_upper_min} to {y_upper_max} (inclusive)"
                )
                logger.debug(
                    f"Rows {x_right_min} to {x_right_max}, and {x_left_min} to {x_left_max} (inclusive)"
                )
                logger.debug(f" column to delete: {column_to_delete}")
                logger.debug(f" row to delete: {rows_to_delete}")

        elif port == 2:
            y_lower_min = y - 1
            y_lower_max = y_lower_min + prscy
            rows_to_delete += list(range(y_lower_min, y_lower_max))

            y_upper_max = int(axis_y / 2) - 1
            y_upper_min = y_upper_max - ovscy
            rows_to_delete += list(range(y_upper_min, y_upper_max))

            x_left_min = int(axis_x / 2)
            x_left_max = int(axis_x / 2) + prscx
            column_to_delete += list(range(x_left_min, x_left_max))

            x_right_max = int(axis_x) - 1
            x_right_min = x_right_max - ovscx
            column_to_delete += list(range(x_right_min, x_right_max))

            if debug:
                logger.debug(f"Removing the following for port {port}")
                logger.debug(
                    f"Columns {y_lower_min} to {y_lower_max}, and {y_upper_min} to {y_upper_max} (inclusive)"
                )
                logger.debug(
                    f"Rows {x_right_min} to {x_right_max}, and {x_left_min} to {x_left_max} (inclusive)"
                )
                logger.debug(f" column to delete: {column_to_delete}")
                logger.debug(f" row to delete: {rows_to_delete}")

        elif port == 3:
            y_lower_min = int(axis_y / 2)
            y_lower_max = y_lower_min + prscy
            rows_to_delete += list(range(y_lower_min, y_lower_max))

            y_upper_max = int(axis_y) - 1
            y_upper_min = y_upper_max - ovscy
            rows_to_delete += list(range(y_upper_min, y_upper_max))

            x_left_min = int(axis_x / 2)
            x_left_max = int(axis_x / 2) + prscx
            column_to_delete += list(range(x_left_min, x_left_max))

            x_right_max = int(axis_x) - 1
            x_right_min = x_right_max - ovscx
            column_to_delete += list(range(x_right_min, x_right_max))

            if debug:
                logger.debug(f"Removing the following for port {port}")
                logger.debug(
                    f"Columns {y_lower_min} to {y_lower_max}, and {y_upper_min} to {y_upper_max} (inclusive)"
                )
                logger.debug(
                    f"Rows {x_right_min} to {x_right_max}, and {x_left_min} to {x_left_max} (inclusive)"
                )
                logger.debug(f" column to delete: {column_to_delete}")
                logger.debug(f" row to delete: {rows_to_delete}")

        elif port == 4:

            y_lower_min = int(axis_y / 2)
            y_lower_max = y_lower_min + prscy
            rows_to_delete += list(range(y_lower_min, y_lower_max))

            y_upper_max = int(axis_y) - 1
            y_upper_min = y_upper_max - ovscy
            rows_to_delete += list(range(y_upper_min, y_upper_max))

            x_left_min = x - 1
            x_left_max = x_left_min + prscx
            column_to_delete += list(range(x_left_min, x_left_max))

            x_right_max = int(axis_x / 2) - 1
            x_right_min = x_right_max - ovscx
            column_to_delete += list(range(x_right_min, x_right_max))

            if debug:
                logger.debug(f"Removing the following for port {port}")
                logger.debug(
                    f"Columns {y_lower_min} to {y_lower_max}, and {y_upper_min} to {y_upper_max} (inclusive)"
                )
                logger.debug(
                    f"Rows {x_right_min} to {x_right_max}, and {x_left_min} to {x_left_max} (inclusive)"
                )
                logger.debug(f" column to delete: {column_to_delete}")
                logger.debug(f" row to delete: {rows_to_delete}")

    # Trim to unique values
    column_to_delete = list(set(column_to_delete))
    rows_to_delete = list(set(rows_to_delete))

    if debug:
        logger.debug(f"final column to delete: {column_to_delete}")
        logger.debug((f"final row to delete: {rows_to_delete}"))
        logger.debug(f"length_column = {len(column_to_delete)}")
        logger.debug(f"length_row = {len(rows_to_delete)}")

    im_hdu.data = np.delete(im_hdu.data, list(column_to_delete), 1)
    im_hdu.data = np.delete(im_hdu.data, list(rows_to_delete), 0)

    # Strip out the relevant rows

    # im_hdu.data = np.delete(im_hdu.data, range(y, y_strip_bottom), 0)
    # im_hdu.data = np.delete(im_hdu.data, range(remove_x_min, remove_x_max), 1)
    # im_hdu.data = np.delete(im_hdu.data, range(remove_y_min, remove_y_max), 1)

    return im_hdu
