# SPDX-License-Identifier: BSD-3-Clause
import re
import numpy as np
from astropy.time import Time
from astropy import units


def is_iterable(v):
    """
    Check if the input object v is iterable
    :param v:
    :return:
    """
    try:
        _ = [i for i in v]
        return True
    except TypeError:
        return False


def assert_none_or_numeric(v):
    """
    Raise an AssertionError if the input is not None or a numeric value

    Parameters
    ----------
    v
        Input value to be tested

    Raises
    -------
    AssertionError
        If input value is not considering to be numeric or None
    """
    assert (
        np.issubdtype(
            np.asarray(
                [
                    v,
                ]
            ).dtype,
            np.number,
        )
        and not is_iterable(v)
    ) or v is None, "Input value must be None or numeric (input: {})".format(v)
    return


def value_none_or_numeric(v):
    """
    Raise a ValueError if the input is not None or a numeric value

    Parameters
    ----------
    v
        Input value to be tested

    Raises
    -------
    ValueError
        If input value is not considering to be numeric or None
    """
    try:
        assert_none_or_numeric(v)
    except AssertionError as e:
        raise ValueError(str(e))


def assert_numeric(v):
    """
    Raise an AssertionError if the input is not a numeric value.

    Parameters
    ----------
    v
        Input value to be tested

    Raises
    -------
    AssertionError
        If input value is not considering to be numeric
    """
    assert np.issubdtype(
        np.asarray(
            [
                v,
            ]
        ).dtype,
        np.number,
    ) and not is_iterable(
        v
    ), "Input value must be numeric (input: {})".format(v)
    return


def fetch_kw_or_default(ext, kw: str, default=None):
    """Get a keyword value from a header, or return a default

    This function interrogates the given :any:`astropy` Header for the presence
    of a keyword, and returns the value of that keyword if it exists. If it
    does not exist, it either returns a default value (if one is specified),
    or it returns the keyword name wrapped in question marks.

    Parameters
    ----------
    ext : :any:`astropy.io.fits.hdu`-derived HDU object
        The :any:`astropy` extension to check
    kw : str
        The keyword to search for.
    default : type matching the expected header keyword value, optional
        The default value to return if the keyword isn't found.

    Returns
    -------
    value : varying types
        The value of the requested keyword, or the default if not found. If
        `default=None`, the returned default will be `"???-<kw>-???"`.
    """
    value = ext.header.get(
        kw, default if default is not None else "???-{}-???".format(kw)
    )
    if value == "":
        value = default if default is not None else "???-{}-???".format(kw)

    return value


def fetch_kw_or_error(ext, kw: str):
    """Get a keyword value from a header, raise an error if not found

    This function interrogates the given :any:`astropy` Header for the presence
    of a keyword, and returns the value of that keyword if it exists. If it
    does not exist, it raises an error.

    Parameters
    ----------
    ext : :any:`astropy.io.fits.hdu`-derived HDU object
        The :any:`astropy` extension to check
    kw : str
        The keyword to search for.

    Returns
    -------
    value : varying types
        The value of the requested keyword.

    Raises
    ------
    ValueError
        Raised if the requested keyword is not found.
    """

    try:
        return ext.header[kw]
    except KeyError:
        raise ValueError(f"Header keyword {kw} not found in the " f"given extension")


def fetch_one_kw_of_range(
    ext,
    kw,
    dummy="x",
    start=1,
    stop=7,  # does not execute on stop value
    step=1,
    fmt="1d",
    default_keynum="_",
    default_value="None found",
):
    """Get a keyword value from a header when the keyword can take one of a range
    of numbers, or return a default

    This function interrogates the given :any:`astropy` Header for the presence
    of a keyword that includes a numerical value, and iterates through a range of
    such values until the associated keyword is found in the header. The function
    returns the successful numerical value as well as the value of that keyword.
    If no header keywords exist in the specified range, it returns a default value.

    Parameters
    ----------
    ext : :any:`astropy.io.fits.hdu`-derived HDU object
        The :any:`astropy` extension to check.
    kw : str
        The keyword to search for.
    dummy : str
        The string element in the input keyword kw that will be replaced by a
        number. Default: "x".
    start : int
        The starting value for searching the numerical range. Default: 1.
    stop : int
        The stopping value for searching the numerical range. Note that the
        stop value itself is not checked (as per normal 'range' behaviour).
        May not be equal to start. Default: 7.
    step : int
        The step value for the numerical values to be checked. May not be zero.
        Must be of correct sign to reach/exceed stop from start. Default: 1.
    fmt : str
        String format for the representation of the number. "02d" would allow
        zero-padding of a 2-digit number. Default: "1d".
    default_keynum : str
        The returned key if no keywords in the searched range are found
        in the header. Default: "_".
    default_value : str
        The returned value if no keywords in the searched range are found
        in the header. Default: "None found".

    Returns
    -------
    keyvalpair : 2-element list (int, varying types)
        The first element contains the first number that successfully locates a keyword
        in the header, or the default_keynum if none are found. The second element
        contains the value of the keyword created with keynum, or the default_value if
        none are found.
    """
    if start == stop:
        raise ValueError(
            "fetch_one_kw_of_range: start and stop values may not be equal"
        )
    if step == 0:
        raise ValueError("fetch_one_kw_of_range: step size may not be zero")
    if not np.sign((stop - start) / step) > 0:
        raise ValueError(
            "fetch_one_kw_of_range: step has wrong sign to reach stop from start"
        )
    if not re.search(dummy, kw):
        raise ValueError(
            f"fetch_one_kw_of_range: no replacement string '{dummy}' "
            f"found in input keyword '{kw}'"
        )
    else:
        keyvalpair = next(
            (
                (knum, kval)
                for knum in range(start, stop, step)
                if (kval := ext.header.get(re.sub(dummy, f"{knum:{fmt}}", kw), None))
            ),
            (default_keynum, default_value),
        )
    return keyvalpair


def flatten_list(ll):
    """
    Return a flattened version of the input list.

    Parameters
    ----------
    ll : list
        List to be flattened

    Returns
    -------
    list : The flattened list
    """
    flat_list = [item for sublist in ll for item in sublist]
    return flat_list


def get_factors(x: int):
    """
    Return the factors for the input integer.

    Parameters
    ----------
    x : int
        The input integer to check. Must be >= 1.

    Returns
    -------
    list of int
        A list of integer factors of `x`.
    """
    try:
        x = int(x)
        assert x > 0, "x must be > 0"
    except TypeError:
        raise ValueError(f"Unable to cast {x} to int")
    except AssertionError as e:
        raise ValueError(str(e))

    factors = []
    for f in range(1, int(np.sqrt(x)) + 1):
        if x % f == 0:
            factors.append(f)
            factors.append(x // f)
    factors.sort()
    return factors


def get_no_factors(x):
    """
    Determine how many factors integer x has, including itself and 1.

    Parameters
    ----------
    x : int
        The input integer to calculate the number of factors for. Must be >= 1.

    Returns
    -------
    int
        The number of factors `x` has, including 1 and `x` itself.
    """
    return len(get_factors(x))


def get_nth_prime(n: int):
    """
    Get the n-th prime number.

    Recall that 1 is *not* a prime number (nor is it a composite). Therefore,
    the first prime number is 2.

    Parameters
    ----------
    n : int
        The number of prime to return

    Returns
    -------
    int
        The n-th prime integer.
    """
    try:
        n = int(n)
    except TypeError:
        raise ValueError(f"Unable to cast {n} to an integer")
    if n < 1:
        raise ValueError(f"n must be > 0; you gave {n}")

    if n == 1:
        return 2

    test_number = get_nth_prime(n - 1) + 1
    while True:
        if get_no_factors(test_number) == 2:  # 1 and itself
            return test_number
        test_number += 1


def get_time_from_hdul(hdul, utc_offset_hrs=3):
    """Compute the civil date time of the observation from a givel HDUL file and the current time.

    Parameters
    ----------
    - hdul: (astropy.HDUList) File from which the metadata will be read (using the Primary extension).
    - utc_offset_hrs: Time offset for computing the local time (default=+3 hours == Garching time).
    """
    # Observation time
    dateobs = Time(hdul["PRIMARY"].header.get("DATE-OBS"))
    dateobs += utc_offset_hrs * units.hour
    year, month, day, _, _, _ = dateobs.ymdhms
    civil_date = "{}-{:02.0f}-{} UTC+{}".format(year, month, day, utc_offset_hrs)
    # Current time
    now = Time.now() + utc_offset_hrs * units.hour
    present_date = "{}-{:02.0f}-{}  {}:{}:{:.1f} UTC+{}".format(
        *now.ymdhms, utc_offset_hrs
    )
    return civil_date, present_date


def expand_ext_argument(ext, n, name):
    """
    Expand a single extension-type argument into a list of arguments.

    Parameters
    ----------
    ext : int, str, or list
        Either a single argument denoting a FITS file extension (either int
        or str), or a list of such arguments which must have length `n` to be
        valid
    n : int
        The length of file inputs list the ext argument must be matched against
    name : str
        The name of the variable being expanded, as a string (used for
        error reporting)

    Returns
    -------
    exts:
        A list of extension references (int or str) of length `n`.
    """

    if isinstance(ext, str) or isinstance(ext, int) :
        exts = [
            ext,
        ] * n
    else:  # Assume iterable
        exts = ext
    try:
        assert len(exts) >= n, f"Not enough {name} values to " f"match all file lists"
    except AssertionError as e:
        raise ValueError(str(e))

    return exts


def round_arbitrary(x, base=50):
    """
    Round an input value to the nearest user-defined base.

    Parameters
    ----------
    x : int or float
        Single argument of the number to be rounded
    base : int or float
        The base unit to which 'x' will be rounded

    Returns
    -------
    The 'base'-rounded value of x
    """

    return base * round(x / base)


def get_wavelength_from_header(h, axis=1):
    """
    Provide the wavelength vector from a HDU header.

    Parameters
    ----------
    h : :any:`astropy.io.fits.hdu`-derived HDU object
    axis : axis

    Returns
    -------
    The wavelength vector, wavelength unit
    """

    wl_c = h.header[f"CRVAL{axis}"]
    pix_c = h.header[f"CRPIX{axis}"]
    n_pix = h.header[f"NAXIS{axis}"]
    wl_del = h.header[f"CDELT{axis}"]
    pixels = np.arange(1, n_pix + 1)
    wavelength = wl_c + (pixels - pix_c) * wl_del
    wavelength_unit = h.header[f"CUNIT{axis}"]
    return [wavelength, wavelength_unit]


def read_idp_spectrum(h):
    wave_conversion = compute_wave_conversion_from_unit(h.header["TUNIT1"])
    wavelength = h.data["WAVE"][0] / wave_conversion  # Transform to nm
    flux = h.data[get_idp_flux_column_from_hdu(h)][0]
    flux_conversion = compute_flux_conversion_from_unit(h.header["TUNIT2"])
    flux *= flux_conversion
    flux_unit = pretty_print_units(h.header["TUNIT2"])

    return wavelength, flux, flux_unit


def get_idp_flux_column_from_hdu(h):
    """
    Guess which is the name of the column to get the flux

    Parameters
    ----------
    h: :any: `astropy.io.fits.hdu`-derived HDU object

    Returns
    -------
    A string with the name of the column
    """

    column = ""
    for key in h.header:
        if key[0:5] == "TUTYP":
            if h.header[key] == "spec:Data.FluxAxis.Value":
                column = h.header["TTYPE" + key[5:]]
                return column
            elif h.header[key] == "Spectrum.Data.FluxAxis.Value":
                column = h.header["TTYPE" + key[5:]]
                return column

    if column == "":
        if "FLUX" in h.data.columns.names:
            return "FLUX"
        elif "FLUX_REDUCED" in h.data.columns.names:
            return "FLUX_REDUCED"

    if column == "":
        raise ValueError("Cannot determine flux column")


WAVE_UNIT_CONVERSION_PREDEFINED = {
    "angstrom": 10.0,
    "Angstrom": 10.0,
    "nm": 1.0,
    "um": 1.0e-3,
}


def compute_wave_conversion_from_unit(wave_unit):
    """
    This utility function returns the conversion factor to
    the standard representation for IDP (nanometers)
    depending the string that represents the units

    Parameters
    ----------
    wave_unit : str
        The representation of the unit

    Returns
    -------
    A float with the conversion factor from wave_unit to nanomenters
    """

    if wave_unit in WAVE_UNIT_CONVERSION_PREDEFINED:
        return WAVE_UNIT_CONVERSION_PREDEFINED[wave_unit]
    else:
        raise ValueError("Cannot map the wave unit: " + wave_unit)


FLUX_UNIT_CONVERSION_PREDEFINED = {
    "": 1.0,
    "adu": 1.0,
    "erg.cm**(-2).s**(-1).angstrom**(-1)": 1.0e16,
}


def compute_flux_conversion_from_unit(flux_unit):
    """
    This utility function returns the conversion factor to
    the standard representation for IDP
    depending the string that represents the units

    Parameters
    ----------
    flux_unit : str
        The representation of the unit

    Returns
    -------
    A float with the conversion factor from flux_unit
    """

    if flux_unit in FLUX_UNIT_CONVERSION_PREDEFINED:
        return FLUX_UNIT_CONVERSION_PREDEFINED[flux_unit]
    else:
        raise ValueError("Cannot map the flux unit: " + flux_unit)


PRETTY_UNITS_PREDEFINED = {
    "erg.cm**(-2).s**(-1).angstrom**(-1)": "$F$$ / 10^{-16} \mathrm{erg/s/cm^2/\AA}$",
    "adu": "ADU",
    "": "$F$",
}


def pretty_print_units(text_label, scale=1):
    if text_label in PRETTY_UNITS_PREDEFINED:
        return PRETTY_UNITS_PREDEFINED[text_label]
    return text_label


def format_kw_or_default(ext, kw, frmt, default="N/A"): 
    """
    This utility function returns the formatted value of a header keyword,
    or default value in case the keyword is not present in the header.

    Parameters
    ----------
    ext : :any:`astropy.io.fits.hdu`-derived HDU object
        The :any:`astropy` extension to check.
    kw : str
        The header keyword. 
    frmt : str
        The formatting for found value.
    default : str, default="N/A"
        The default value if the keyword isn't found.

    Returns
    -------
    The formatted value of the requested keyword, or the default if not found.
    """

    val = fetch_kw_or_default(ext, kw, default=default)
    if val == default:
        return str(default)
    else:
        return frmt % val

