# SPDX-License-Identifier: BSD-3-Clause
"""Module for calculating 'ports' on ESO detectors.

This module interprets the assorted ESO DET header keywords from
ESO-supplied FITS files, in order to determine the various valid
readout windows.

ESO uses two makes of detector controller:
- FIERA
- NGC
Each controller has a distinct specification for ESO DET header
keywords, so an interpreter for each detector controller is provided
separately. A helper function is provided to determine which controller
is in use for a given FITS file.
"""
import astropy.io.fits as fits
import numpy as np
import re
import logging
from adari_core.utils.utils import fetch_kw_or_default

logger = logging.getLogger(__name__)

DET_OUT_X_STR = r"\s*(HIERARCH )? ESO DET OUT(?P<n>[0-9]+) X\s*$"
DET_OUT_X_REGEX = re.compile(DET_OUT_X_STR)
DET_OUT_X_FMT = "HIERARCH ESO DET OUT1 X"
DET_OUT_Y_STR = r"\s*(HIERARCH )? ESO DET OUT(?P<n>[0-9]+) Y\s*$"
DET_OUT_Y_REGEX = re.compile(DET_OUT_Y_STR)
DET_OUT_Y_FMT = "HIERARCH ESO DET OUT1 Y"
DET_OUT_NX_STR = r"\s*(HIERARCH )? ESO DET OUT(?P<n>[0-9]+) NX\s*$"
DET_OUT_NX_REGEX = re.compile(DET_OUT_NX_STR)
DET_OUT_NX_FMT = "HIERARCH ESO DET OUT1 NX"
DET_OUT_NY_STR = r"\s*(HIERARCH )? ESO DET OUT(?P<n>[0-9]+) NY\s*$"
DET_OUT_NY_REGEX = re.compile(DET_OUT_NY_STR)
DET_OUT_NY_FMT = "HIERARCH ESO DET OUT1 NY"
DET_OUT_OVSC_X_STR = r"\s*(HIERARCH )? ESO DET OUT(?P<n>[0-9]+) OVSCX\s*$"
DET_OUT_OVSC_X_REGEX = re.compile(DET_OUT_OVSC_X_STR)
DET_OUT_OVSC_X_FMT = "HIERARCH ESO DET OUT1 OVSCX"
DET_OUT_OVSC_Y_STR = r"\s*(HIERARCH )? ESO DET OUT(?P<n>[0-9]+) OVSCY\s*$"
DET_OUT_OVSC_Y_REGEX = re.compile(DET_OUT_OVSC_Y_STR)
DET_OUT_OVSC_Y_FMT = "HIERARCH ESO DET OUT1 OVSCY"
DET_OUT_PRSC_X_STR = r"\s*(HIERARCH )? ESO DET OUT(?P<n>[0-9]+) PRSCX\s*$"
DET_OUT_PRSC_X_REGEX = re.compile(DET_OUT_PRSC_X_STR)
DET_OUT_PRSC_X_FMT = "HIERARCH ESO DET OUT1 PRSCX"
DET_OUT_PRSC_Y_STR = r"\s*(HIERARCH )? ESO DET OUT(?P<n>[0-9]+) PRSCY\s*$"
DET_OUT_PRSC_Y_REGEX = re.compile(DET_OUT_PRSC_Y_STR)
DET_OUT_PRSC_Y_FMT = "HIERARCH ESO DET OUT1 PRSCY"
DET_OUT_INDEX_STR = r"\s*(HIERARCH )?ESO DET OUT(?P<n>[0-9]+) INDEX\s*$"
DET_OUT_INDEX_REGEX = re.compile(DET_OUT_INDEX_STR)
DET_OUT_INDEX_FMT = "HIERARCH ESO DET OUT1 INDEX"
DET_CHIP_NX = "HIERARCH ESO DET CHIP1 NX"
DET_CHIP_NY = "HIERARCH ESO DET CHIP1 NY"
DET_CHIP_NIR_NX = "HIERARCH ESO DET CHIP NX"
DET_CHIP_NIR_NY = "HIERARCH ESO DET CHIP NY"
DET_BINX = "HIERARCH ESO DET WIN1 BINX"
DET_BINY = "HIERARCH ESO DET WIN1 BINY"
NAXIS1 = "NAXIS1"
NAXIS2 = "NAXIS2"
SEQ_ARM = "HIERARCH ESO SEQ ARM"

logger = logging.getLogger(__name__)


def _identify_det_regions(hdu):
    """
    Identify the numbered detector regions in a given HDU.

    Parameters
    ----------
    hdu : astropy.io.fits.ImageHDU
        The ImageHDU to extract ports values from.

    Returns
    -------
    ports : list
        A list of ints with the port indexes, in numerical order. An empty list
        is returned if no ports are found.

    Raises
    ------
    RuntimeError
        If ESO DET<n> INDEX != n, a RuntimeError will be raised (this likely
        indicates something has gone wrong with the generation of the
        input ImageHDU).
    """
    # Look in the header for OUT DET INDEX keywords

    index_kw = [
        _ for _ in [re.match(DET_OUT_INDEX_REGEX, k) for k in hdu.header.keys()] if _
    ]

    regions = []

    for kw in index_kw:
        # try:
        #     assert int(hdu.header.get(kw[0])) == int(kw.group('n')), \
        #         f"DET OUT INDEX value mismatch " \
        #         f"{int(hdu.header.get(kw[0]))} /= {kw.group('n')}"
        # except AssertionError as e:
        #     raise RuntimeError(str(e))
        regions.append(int(kw.group("n")))

    regions.sort()
    return regions


def _get_port_information(hdu, port):
    try:
        # Get the port values
        # Deliberately use the key syntax to ensure we find any missing
        x = hdu.header[DET_OUT_X_FMT.format(port)]
        y = hdu.header[DET_OUT_Y_FMT.format(port)]
        nx = hdu.header[DET_OUT_NX_FMT.format(port)]
        ny = hdu.header[DET_OUT_NY_FMT.format(port)]
        ovscx = hdu.header[DET_OUT_OVSC_X_FMT.format(port)]
        ovscy = hdu.header[DET_OUT_OVSC_Y_FMT.format(port)]
        prscx = hdu.header[DET_OUT_PRSC_X_FMT.format(port)]
        prscy = hdu.header[DET_OUT_PRSC_Y_FMT.format(port)]
    except KeyError as e:
        raise RuntimeError(f"Incomplete info on port {port} : {str(e)}")
    return x, nx, prscx, ovscx, y, ny, prscy, ovscy


def get_arm(hdu):
    try:
        arm = hdu.header[SEQ_ARM]
    except KeyError as e:
        raise RuntimeError(f"Incomplete info on arm :  {str(e)}")
    return arm


def _get_chip_information(hdu):
    arm = get_arm(hdu)
    if "UVB" in arm or "VIS" in arm or "AGC" in arm:
        try:
            # Get the chip NX and NY
            nx = hdu.header[DET_CHIP_NX]
            ny = hdu.header[DET_CHIP_NY]
        except KeyError as e:
            raise RuntimeError(
                f"Incomplete chip information on this HDU " f"({str(e)})"
            )
    else:
        try:
            nx = hdu.header[DET_CHIP_NIR_NX]
            ny = hdu.header[DET_CHIP_NIR_NY]
        except KeyError as e:
            raise RuntimeError(
                f"Incomplete chip information on this HDU " f"({str(e)})"
            )
    return nx, ny


def determine_port_regions(hdu):
    """
    Determine the boundaries of each detector port on the detector image.

    Parameters
    ----------
    hdu : ImageHDU
        The HDU containing the detector image, and associated header keywords.

    Returns
    -------
    port_dict : dict
        A dictionary, describing the bounds of the region covered by each port.
        It is of the form `port_dict[n] = {'x': (x_min, x_max),
        'y': (y_min, y_max}` for port `n`.
    """
    pass


def strip_prescan_overscan(hdu, debug=True, **kwargs):
    # Create a new HDU to manipulate - don't want to harm real data
    im_hdu = fits.ImageHDU(data=hdu.data, header=hdu.header)

    # Figure out the binning
    binx = int(hdu.header.get(DET_BINX, 1))
    biny = int(hdu.header.get(DET_BINY, 1))
    axis_x = int(hdu.header.get(NAXIS1))
    axis_y = int(hdu.header.get(NAXIS2))

    chipx, chipy = _get_chip_information(hdu)

    if debug:
        print(
            f"{chipx} x {chipy} chip w/ xbin={binx}, ybin={biny}, "
            f"ports={_identify_det_regions(im_hdu)}"
        )
        logger.debug(
            f"{chipx} x {chipy} chip w/ xbin={binx}, ybin={biny}, "
            f"ports={_identify_det_regions(im_hdu)}"
        )

    rows_to_delete = []
    column_to_delete = []

    for port in _identify_det_regions(im_hdu):
        x, nx, prscx, ovscx, y, ny, prscy, ovscy = _get_port_information(im_hdu, port)
        if debug:
            print
            logger.debug(f"port {port}: {nx} x {ny}")

        if port == 1:
            y_lower_min = ny - 1
            y_lower_max = y_lower_min + prscy
            rows_to_delete += list(range(y_lower_min, y_lower_max))

            y_upper_max = int(axis_y) - 1
            y_upper_min = y_upper_max - ovscy
            rows_to_delete += list(range(y_upper_min, y_upper_max))

            x_left_min = nx - 1
            x_left_max = x_left_min + prscx
            column_to_delete += list(range(x_left_min, x_left_max))

            x_right_max = int(axis_x) - 1
            x_right_min = x_right_max - ovscx
            column_to_delete += list(range(x_right_min, x_right_max))

            if debug:
                logger.debug(f"Removing the following for port {port}")
                logger.debug(
                    f"Columns {y_lower_min} to {y_lower_max}, and {y_upper_min} to {y_upper_max} (inclusive)"
                )
                logger.debug(
                    f"Rows {x_right_min} to {x_right_max}, and {x_left_min} to {x_left_max} (inclusive)"
                )
                logger.debug(f" column to delete: {column_to_delete}")
                logger.debug(f" row to delete: {rows_to_delete}")

        elif port == 2:
            y_lower_min = y - 1
            y_lower_max = y_lower_min + prscy
            rows_to_delete += list(range(y_lower_min, y_lower_max))

            y_upper_max = int(axis_y / 2) - 1
            y_upper_min = y_upper_max - ovscy
            rows_to_delete += list(range(y_upper_min, y_upper_max))

            x_left_min = int(axis_x / 2)
            x_left_max = int(axis_x / 2) + prscx
            column_to_delete += list(range(x_left_min, x_left_max))

            x_right_max = int(axis_x) - 1
            x_right_min = x_right_max - ovscx
            column_to_delete += list(range(x_right_min, x_right_max))

            if debug:
                logger.debug(f"Removing the following for port {port}")
                logger.debug(
                    f"Columns {y_lower_min} to {y_lower_max}, and {y_upper_min} to {y_upper_max} (inclusive)"
                )
                logger.debug(
                    f"Rows {x_right_min} to {x_right_max}, and {x_left_min} to {x_left_max} (inclusive)"
                )
                logger.debug(f" column to delete: {column_to_delete}")
                logger.debug(f" row to delete: {rows_to_delete}")

        elif port == 3:
            y_lower_min = int(axis_y / 2)
            y_lower_max = y_lower_min + prscy
            rows_to_delete += list(range(y_lower_min, y_lower_max))

            y_upper_max = int(axis_y) - 1
            y_upper_min = y_upper_max - ovscy
            rows_to_delete += list(range(y_upper_min, y_upper_max))

            x_left_min = int(axis_x / 2)
            x_left_max = int(axis_x / 2) + prscx
            column_to_delete += list(range(x_left_min, x_left_max))

            x_right_max = int(axis_x) - 1
            x_right_min = x_right_max - ovscx
            column_to_delete += list(range(x_right_min, x_right_max))

            if debug:
                logger.debug(f"Removing the following for port {port}")
                logger.debug(
                    f"Columns {y_lower_min} to {y_lower_max}, and {y_upper_min} to {y_upper_max} (inclusive)"
                )
                logger.debug(
                    f"Rows {x_right_min} to {x_right_max}, and {x_left_min} to {x_left_max} (inclusive)"
                )
                logger.debug(f" column to delete: {column_to_delete}")
                logger.debug(f" row to delete: {rows_to_delete}")

        elif port == 4:

            y_lower_min = int(axis_y / 2)
            y_lower_max = y_lower_min + prscy
            rows_to_delete += list(range(y_lower_min, y_lower_max))

            y_upper_max = int(axis_y) - 1
            y_upper_min = y_upper_max - ovscy
            rows_to_delete += list(range(y_upper_min, y_upper_max))

            x_left_min = x - 1
            x_left_max = x_left_min + prscx
            column_to_delete += list(range(x_left_min, x_left_max))

            x_right_max = int(axis_x / 2) - 1
            x_right_min = x_right_max - ovscx
            column_to_delete += list(range(x_right_min, x_right_max))

            if debug:
                logger.debug(f"Removing the following for port {port}")
                logger.debug(
                    f"Columns {y_lower_min} to {y_lower_max}, and {y_upper_min} to {y_upper_max} (inclusive)"
                )
                logger.debug(
                    f"Rows {x_right_min} to {x_right_max}, and {x_left_min} to {x_left_max} (inclusive)"
                )
                logger.debug(f" column to delete: {column_to_delete}")
                logger.debug(f" row to delete: {rows_to_delete}")

    # Trim to unique values
    column_to_delete = list(set(column_to_delete))
    rows_to_delete = list(set(rows_to_delete))

    if debug:
        logger.debug(f"final column to delete: {column_to_delete}")
        logger.debug((f"final row to delete: {rows_to_delete}"))
        logger.debug(f"length_column = {len(column_to_delete)}")
        logger.debug(f"length_row = {len(rows_to_delete)}")

    print(len(im_hdu.data), list(column_to_delete))
    im_hdu.data = np.delete(im_hdu.data, list(column_to_delete), 1)
    im_hdu.data = np.delete(im_hdu.data, list(rows_to_delete), 0)

    # Strip out the relevant rows

    # im_hdu.data = np.delete(im_hdu.data, range(y, y_strip_bottom), 0)
    # im_hdu.data = np.delete(im_hdu.data, range(remove_x_min, remove_x_max), 1)
    # im_hdu.data = np.delete(im_hdu.data, range(remove_y_min, remove_y_max), 1)

    return im_hdu


def set1(hdul):
    arm = get_arm(hdul["PRIMARY"])
    if "NIR" in arm:
        metadata = [
            "DET.NDIT: "
            + str(
                fetch_kw_or_default(
                    hdul["PRIMARY"], "HIERARCH ESO DET NDIT", default="N/A"
                )
            ),
            "DET.DIT: "
            + str(
                fetch_kw_or_default(
                    hdul["PRIMARY"], "HIERARCH ESO DET DIT", default="N/A"
                )
            ),
        ]
    else:
        metadata = [
            "DET.READ.CLOCK: "
            + str(
                fetch_kw_or_default(
                    hdul["PRIMARY"], "HIERARCH ESO DET READ CLOCK", default="N/A"
                )
            ),
        ]
    return metadata


def set2(hdul):
    arm = get_arm(hdul["PRIMARY"])
    try:
        mode = hdul["PRIMARY"].header.get("HIERARCH ESO INS MODE")
    except KeyError as e:
        raise RuntimeError(f"Incomplete info on arm :  {str(e)}")

    metadata = set1(hdul)

    if "SLIT" in mode:
        if "UVB" in arm:
            opti = "INS.OPTI3.NAME: " + str(
                fetch_kw_or_default(
                    hdul["PRIMARY"], "HIERARCH ESO INS OPTI3 NAME", default="N/A"
                )
            )
        elif "VIS" in arm:
            opti = "INS.OPTI4.NAME: " + str(
                fetch_kw_or_default(
                    hdul["PRIMARY"], "HIERARCH ESO INS OPTI4 NAME", default="N/A"
                )
            )
        elif "NIR" in arm:
            opti = "INS.OPTI5.NAME: " + str(
                fetch_kw_or_default(
                    hdul["PRIMARY"], "HIERARCH ESO INS OPTI5 NAME", default="N/A"
                )
            )
        metadata.append(opti)
    else:
        opti = "INS.OPTI2.NAME: " + str(
            fetch_kw_or_default(
                hdul["PRIMARY"], "HIERARCH ESO INS OPTI2 NAME", default="N/A"
            )
        )
        metadata.append(opti)

    return metadata


def set3(hdul):
    metadata = set1(hdul)
    arm = "SEQ.ARM: " + str(
        fetch_kw_or_default(hdul["PRIMARY"], "HIERARCH ESO SEQ ARM", default="N/A")
    )
    metadata.append(arm)
    return metadata


def response(hdul):
    metadata = set2(hdul)
    target = "OBS.TARG.NAME: " + str(
        fetch_kw_or_default(
            hdul["PRIMARY"], "HIERARCH ESO OBS TARG NAME", default="N/A"
        )
    )
    metadata.append(target)
    return metadata

class XshooterReportMixin(object):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._version = "3.7.0"

class XshooterSetupInfo:

    @staticmethod
    def bias(hdul):
        metadata = [
            "DET.READ.CLOCK: "
            + str(
                fetch_kw_or_default(
                    hdul["PRIMARY"], "HIERARCH ESO DET READ CLOCK", default="N/A"
                )
            ),
            "SEQ.ARM: "
            + str(
                fetch_kw_or_default(
                    hdul["PRIMARY"], "HIERARCH ESO SEQ ARM", default="N/A"
                )
            ),
            "DATACOM: "
            + str(
                fetch_kw_or_default(
                    hdul["PRIMARY"], "HIERARCH ESO PRO DATANCOM", default="N/A"
                )
            ),
        ]
        return metadata

    @staticmethod
    def dark(hdul):
        arm = get_arm(hdul["PRIMARY"])
        metadata = set1(hdul)
        if "NIR" not in arm:
            exptime = "EXPTIME: " + str(
                fetch_kw_or_default(hdul["PRIMARY"], "EXPTIME", default="N/A")
            )
            metadata.append(exptime)
        return metadata

    @staticmethod
    def detmon(hdul):
        arm = get_arm(hdul["PRIMARY"])
        if "NIR" in arm:
            metadata = [
                "SEQ.ARM: "
                + str(
                    fetch_kw_or_default(
                        hdul["PRIMARY"], "HIERARCH ESO SEQ ARM", default="N/A"
                    )
                ),
            ]
        else:
            metadata = [
                "DET.READ.CLOCK: "
                + str(
                    fetch_kw_or_default(
                        hdul["PRIMARY"], "HIERARCH ESO DET READ CLOCK", default="N/A"
                    )
                ),
                "SEQ.ARM: "
                + str(
                    fetch_kw_or_default(
                        hdul["PRIMARY"], "HIERARCH ESO SEQ ARM", default="N/A"
                    )
                ),
            ]
        return metadata

    @staticmethod
    def order_prediction(hdul):
        return set3(hdul)

    @staticmethod
    def order_definition(hdul):
        return set3(hdul)

    @staticmethod
    def lamp_flat(hdul):
        return set2(hdul)

    @staticmethod
    def arc(hdul):
        return set2(hdul)

    @staticmethod
    def wavelength_calibration_2d(hdul):
        return set3(hdul)

    @staticmethod
    def flexures(hdul):
        metadata = [
            "OBS.NAME: "
            + str(
                fetch_kw_or_default(
                    hdul["PRIMARY"], "HIERARCH ESO OBS NAME", default="N/A"
                )
            ),
        ]
        return metadata

    @staticmethod
    def acquisition_camera_flats(hdul):
        metadata = [
            "SEQ.ARM: "
            + str(
                fetch_kw_or_default(
                    hdul["PRIMARY"], "HIERARCH ESO SEQ ARM", default="N/A"
                )
            ),
            "INS.FILT1.NAME: "
            + str(
                fetch_kw_or_default(
                    hdul["PRIMARY"], "HIERARCH ESO INS FILT1 NAME", default="N/A"
                )
            ),
        ]
        return metadata

    @staticmethod
    def response_nod(hdul):
        return response(hdul)

    @staticmethod
    def response_stare(hdul):
        return response(hdul)

    @staticmethod
    def response_offset(hdul):
        return response(hdul)

    @staticmethod
    def telluric_standard_slit_nod(hdul):
        return response(hdul)

    @staticmethod
    def telluric_standard_slit_stare(hdul):
        return response(hdul)

    @staticmethod
    def telluric_standard_slit_offset(hdul):
        return response(hdul)

    @staticmethod
    def flatfield(hdul):
        metadata = set2(hdul)
        return metadata

    @staticmethod
    def specphot_star(hdul):
        metadata = response(hdul)
        return metadata
