from astropy.io import fits
from edps import List, ClassifiedFitsFile, get_parameter, JobParameters, Job, File
from edps import RecipeInvocationArguments, InvokerProvider, ProductRenamer, RecipeInvocationResult, RecipeInputs

from . import xshooter_keywords as kwd
from . import xshooter_rules as rules


# Functions to assign a value to a dynamic parameter based on input data

def which_arm(files: List[ClassifiedFitsFile]):
    arm = "NONE"
    if rules.is_uvb(files[0]):
        arm = "UVB"
    if rules.is_vis(files[0]):
        arm = "VIS"
    if rules.is_nir(files[0]):
        arm = "NIR"
    if rules.is_agc(files[0]):
        arm = "AGC"
    return arm


def which_ifu(files: List[ClassifiedFitsFile]):
    return "IFU" in files[0].get_keyword_value(kwd.ins_opti2_id, 'NONE')


def which_jh(files: List[ClassifiedFitsFile]):
    return "JH" in files[0].get_keyword_value(kwd.ins_opti5_name, 'NONE')


# Functions to read Job parameters (static or dynamic)

def is_UVB(params: JobParameters) -> bool:
    return params.get_workflow_param("arm") == "UVB"


def is_not_UVB(params: JobParameters) -> bool:
    return params.get_workflow_param("arm") != "UVB"


def is_VIS(params: JobParameters) -> bool:
    return params.get_workflow_param("arm") == "VIS"


def is_NIR(params: JobParameters) -> bool:
    return params.get_workflow_param("arm") == "NIR"


def is_AGC(params: JobParameters) -> bool:
    return params.get_workflow_param("arm") == "AGC"


def is_IFU(params: JobParameters) -> bool:
    return params.get_workflow_param("ifu") is True


def is_JH(params: JobParameters) -> bool:
    return params.get_workflow_param("jh") is True


def is_NIR_slit(params: JobParameters) -> bool:
    return is_NIR(params) and not is_IFU(params)


def is_NIR_ifu(params: JobParameters) -> bool:
    return is_NIR(params) and is_IFU(params)


def is_NIR_normal(params: JobParameters) -> bool:
    return is_NIR(params) and not is_JH(params)


def is_NIR_JH(params: JobParameters) -> bool:
    return is_NIR(params) and is_JH(params)


def is_UVB_slit(params: JobParameters) -> bool:
    return is_UVB(params) and not is_IFU(params)


def is_UVB_ifu(params: JobParameters) -> bool:
    return is_UVB(params) and is_IFU(params)


def is_VIS_slit(params: JobParameters) -> bool:
    return is_VIS(params) and not is_IFU(params)


def is_VIS_ifu(params: JobParameters) -> bool:
    return is_VIS(params) and is_IFU(params)


def use_night_response(params: JobParameters) -> bool:
    return get_parameter(params, "response") in ("night", "NIGHT")


def use_darks(params: JobParameters) -> bool:
    return get_parameter(params, "use_optical_dark") == "TRUE"


def use_dark_UVB(params: JobParameters) -> bool:
    return use_darks(params) and is_UVB(params)


def use_dark_VIS(params: JobParameters) -> bool:
    return use_darks(params) and is_VIS(params)


def physical_mode(params: JobParameters):
    return get_parameter(params, "reduction_mode") == "physical"


def polynomial_mode(params: JobParameters):
    return get_parameter(params, "reduction_mode") == "poly"


def polynomial_mode_UVB(params: JobParameters):
    return polynomial_mode(params) and is_UVB(params)


def polynomial_mode_VIS(params: JobParameters):
    return polynomial_mode(params) and is_VIS(params)


def polynomial_mode_NIR(params: JobParameters):
    return polynomial_mode(params) and is_NIR(params)


def physical_mode_UVB(params: JobParameters):
    return physical_mode(params) and is_UVB(params)


def physical_mode_VIS(params: JobParameters):
    return physical_mode(params) and is_VIS(params)


def physical_mode_NIR(params: JobParameters):
    return physical_mode(params) and is_NIR(params)


def use_flat_standard(params: JobParameters) -> bool:
    return params.get_workflow_param("use_flat") == "standard"


def use_flat_science(params: JobParameters) -> bool:
    return params.get_workflow_param("use_flat") == "science"


def use_flat_science_uvb_slit(params: JobParameters) -> bool:
    return use_flat_science(params) and is_UVB_slit(params)


def use_flat_science_vis_slit(params: JobParameters) -> bool:
    return use_flat_science(params) and is_VIS_slit(params)


def use_flat_science_nir_slit(params: JobParameters) -> bool:
    return use_flat_science(params) and is_NIR_slit(params)


def use_flat_standard_uvb_slit(params: JobParameters) -> bool:
    return use_flat_standard(params) and is_UVB_slit(params)


def use_flat_standard_vis_slit(params: JobParameters) -> bool:
    return use_flat_standard(params) and is_VIS_slit(params)


def use_flat_standard_nir_slit(params: JobParameters) -> bool:
    return use_flat_standard(params) and is_NIR_slit(params)


def use_flat_science_uvb_ifu(params: JobParameters) -> bool:
    return use_flat_science(params) and is_UVB_ifu(params)


def use_flat_science_vis_ifu(params: JobParameters) -> bool:
    return use_flat_science(params) and is_VIS_ifu(params)


def use_flat_science_nir_ifu(params: JobParameters) -> bool:
    return use_flat_science(params) and is_NIR_ifu(params)


def use_flat_standard_uvb_ifu(params: JobParameters) -> bool:
    return use_flat_standard(params) and is_UVB_ifu(params)


def use_flat_standard_vis_ifu(params: JobParameters) -> bool:
    return use_flat_standard(params) and is_VIS_ifu(params)


def use_flat_standard_nir_ifu(params: JobParameters) -> bool:
    return use_flat_standard(params) and is_NIR_ifu(params)


def telluric_on_standard(params: JobParameters) -> bool:
    return get_parameter(params, "telluric_correction_mode") == "standard" and is_not_UVB(params)


def telluric_on_science(params: JobParameters) -> bool:
    return get_parameter(params, "telluric_correction_mode") == "science" and is_not_UVB(params)


def attempt_telluric_correction(params: JobParameters) -> bool:
    return get_parameter(params, "telluric_correction_mode") != "none" and is_not_UVB(params)


def avoid_telluric_correction(params: JobParameters) -> bool:
    return not attempt_telluric_correction(params)


# --- Function to set parameters for xsh_mbias, depending on input files
def set_bias_processing(job: Job):
    params = job.parameters
    if is_AGC(params):
        job.parameters.recipe_parameters["xsh.xsh_mbias.pd_noise_compute"] = "TRUE"
        job.parameters.recipe_parameters["xsh.xsh_mbias.struct_refx"] = 100
        job.parameters.recipe_parameters["xsh.xsh_mbias.struct_refy"] = 100
    elif is_UVB(params) or is_VIS(params):
        job.parameters.recipe_parameters["xsh.xsh_mbias.pd_noise_compute"] = "TRUE"


# --- Function to set parameters for xsh_mdark, depending on input files
def set_dark_processing(job: Job):
    params = job.parameters
    if is_AGC(params):
        job.parameters.recipe_parameters["xsh.xsh_mdark.pre-overscan-corr"] = 0
    elif is_VIS(params):
        job.parameters.recipe_parameters["xsh.xsh_mdark.pre-overscan-corr"] = 3
    elif is_NIR(params):
        job.parameters.recipe_parameters["xsh.xsh_mdark.fpn_hsize"] = 4
        job.parameters.recipe_parameters["xsh.xsh_mdark.fpn_nsamples"] = 100


# --- Function to set parameters for xsh_predict, depending on input files
def set_predict_processing(job: Job):
    job.parameters.recipe_parameters["xsh.xsh_predict.detectarclines-min-sn"] = 5.0
    job.parameters.recipe_parameters["xsh.xsh_predict.model-scenario"] = 8
    job.parameters.recipe_parameters["xsh.xsh_predict.detectarclines-find-lines-center"] = "gaussian"

    if is_UVB(job.parameters):
        job.parameters.recipe_parameters["xsh.xsh_predict.model-maxit"] = 500
        job.parameters.recipe_parameters["xsh.xsh_predict.model-anneal-factor"] = 1.0
    elif is_VIS(job.parameters):
        job.parameters.recipe_parameters["xsh.xsh_predict.model-maxit"] = 500
        job.parameters.recipe_parameters["xsh.xsh_predict.model-anneal-factor"] = 1.0
        job.parameters.recipe_parameters["xsh.xsh_predict.pre-overscan-corr"] = 3
    elif is_NIR(job.parameters):
        job.parameters.recipe_parameters["xsh.xsh_predict.model-maxit"] = 1000
        job.parameters.recipe_parameters["xsh.xsh_predict.model-anneal-factor"] = 0.5


# --- Function to set parameters for xsh_orderpos, depending on input files
def set_orderpos_processing(job: Job):
    if is_UVB(job.parameters):
        job.parameters.recipe_parameters["xsh.xsh_orderpos.detectcontinuum-clip-res-max"] = 0.4
    elif is_VIS(job.parameters):
        job.parameters.recipe_parameters["xsh.xsh_orderpos.detectcontinuum-clip-res-max"] = 0.5
        job.parameters.recipe_parameters["xsh.xsh_orderpos.pre-overscan-corr"] = 3
        job.parameters.recipe_parameters["xsh.xsh_orderpos.detectcontinuum-ordertab-step-y"] = 1


# --- Functions to set detmon recipe parameters depending on the input file ------------------------

def set_detmon_ir_params(job: Job):
    # --llx
    # x coordinate of the lower-left point of the region of interest. If not
    # modified, default value will be 1.
    job.parameters.recipe_parameters["detmon.detmon_ir_lg.llx"] = 871
    # --lly
    # y coordinate of the lower-left point of the region of interest. If not
    # modified, default value will be 1.
    job.parameters.recipe_parameters["detmon.detmon_ir_lg.lly"] = 709
    # --urx
    # x coordinate of the upper-right point of the region of interest. If not
    # modified, default value will be X dimension of the input image.
    job.parameters.recipe_parameters["detmon.detmon_ir_lg.urx"] = 1065
    # --ury
    # y coordinate of the upper-right point of the region of interest. If not
    # modified, default value will be Y dimension of the input image.
    job.parameters.recipe_parameters["detmon.detmon_ir_lg.ury"] = 734
    # --autocorr
    # De-/Activate the autocorr option.
    job.parameters.recipe_parameters["detmon.detmon_ir_lg.autocorr"] = "TRUE"
    # --m
    # Maximum x-shift for the autocorr.
    job.parameters.recipe_parameters["detmon.detmon_ir_lg.m"] = 3
    # --n
    # Maximum y-shift for the autocorr.
    job.parameters.recipe_parameters["detmon.detmon_ir_lg.n"] = 3
    # --filter
    # Upper limit of Median flux to be filtered.
    # detmon.detmon_ir_lg.filter=42000
    job.parameters.recipe_parameters["detmon.detmon_ir_lg.filter"] = 40000
    # --tolerance
    # Tolerance for pair discrimination.
    # detmon.detmon_ir_lg.tolerance=0.001
    job.parameters.recipe_parameters["detmon.detmon_ir_lg.tolerance"] = 0.1


def set_detmon_opt_params(job: Job, *, kappa=None, niter=None, llx, lly, urx, ury, m, n, filter,
                          saturation_limit=65000):
    # --ref_level
    # User reference level.
    job.parameters.recipe_parameters["detmon.detmon_opt_lg.ref_level"] = 65535
    # --saturation_limit
    # all frames with mean counts above the saturation_limit and with EXPTIME >= that would not be used in
    # calculation.
    # detmon.detmon_opt_lg.saturation_limit=65535.0
    job.parameters.recipe_parameters["detmon.detmon_opt_lg.saturation_limit"] = saturation_limit
    # --pix2pix
    # De-/Activate the computation with pixel to pixel accuracy.
    job.parameters.recipe_parameters["detmon.detmon_opt_lg.pix2pix"] = "TRUE"
    # --bpmbin
    # De-/Activate the binary bpm option.
    job.parameters.recipe_parameters["detmon.detmon_opt_lg.bpmbin"] = "TRUE"
    # limit=46000 would be a better value for VIS , since for imaging flat with med>45600, first pixel
    # become saturated, (reference counts level (50000 ADU) should then be modified as well.
    job.parameters.recipe_parameters["detmon.detmon_opt_lg.gain_threshold"] = 40000

    # --kappa
    # Kappa value for the kappa-sigma clipping (Gain).
    if kappa:
        job.parameters.recipe_parameters["detmon.detmon_opt_lg.kappa"] = kappa
    # --niter
    # Number of iterations to compute rms (Gain).
    if niter:
        job.parameters.recipe_parameters["detmon.detmon_opt_lg.niter"] = niter

    # --llx
    # x coordinate of the lower-left point of the region of interest. If not
    # modified, default value will be 1.
    job.parameters.recipe_parameters["detmon.detmon_opt_lg.llx"] = llx
    # --lly
    # y coordinate of the lower-left point of the region of interest. If not
    # modified, default value will be 1.
    job.parameters.recipe_parameters["detmon.detmon_opt_lg.lly"] = lly
    # --urx
    # x coordinate of the upper-right point of the region of interest. If not
    # modified, default value will be X dimension of the input image.
    # detmon.detmon_opt_lg.urx=1992
    job.parameters.recipe_parameters["detmon.detmon_opt_lg.urx"] = urx
    # --ury
    # y coordinate of the upper-right point of the region of interest. If not
    # modified, default value will be Y dimension of the input image.
    job.parameters.recipe_parameters["detmon.detmon_opt_lg.ury"] = ury
    # --m
    # Maximum x-shift for the autocorr.
    job.parameters.recipe_parameters["detmon.detmon_opt_lg.m"] = m
    # --n
    # Maximum y-shift for the autocorr.
    job.parameters.recipe_parameters["detmon.detmon_opt_lg.n"] = n
    # --filter
    # Upper limit of Median flux to be filtered.
    # detmon.detmon_opt_lg.filter=51000
    job.parameters.recipe_parameters["detmon.detmon_opt_lg.filter"] = filter


def detmon_setting(job: Job):
    # ---- Change parameter setting depending on the input data type ----------
    # (if a parameter is not changed, default value is used)
    f = job.input_files
    arm = f[0].get_keyword_value(kwd.seq_arm, None)
    binx = f[0].get_keyword_value(kwd.det_binx, None)
    biny = f[0].get_keyword_value(kwd.det_biny, None)

    if arm == "NIR":
        set_detmon_ir_params(job)

    # OPTICAL CASES (UVB and VIS)
    # IF UVB binning 1x1
    elif arm == "UVB" and binx == 1 and biny == 1:
        set_detmon_opt_params(job, kappa=5.0, llx=1500, lly=1700, urx=1600, ury=1800, m=6, n=6, filter=65000)

    # IF UVB binning 1x2
    elif arm == "UVB" and binx == 1 and biny == 2:
        set_detmon_opt_params(job, kappa=5.0, niter=5, llx=1500, lly=850, urx=1600, ury=900, m=6, n=6, filter=63000)

    # IF UVB binning 2x2
    elif arm == "UVB" and binx == 2 and biny == 2:
        set_detmon_opt_params(job, kappa=5.0, llx=750, lly=850, urx=800, ury=900, m=6, n=6, filter=65000)

    # If VIS 1x1
    elif arm == "VIS" and binx == 1 and biny == 1:
        set_detmon_opt_params(job, llx=1500, lly=1700, urx=1600, ury=1800, m=10, n=10, filter=65000)

    # If VIS 1x2
    elif arm == "VIS" and binx == 1 and biny == 2:
        set_detmon_opt_params(job, llx=1500, lly=850, urx=1600, ury=900, m=10, n=10, filter=60000)

    # If VIS 2x2
    elif arm == "VIS" and binx == 2 and biny == 2:
        set_detmon_opt_params(job, llx=750, lly=850, urx=800, ury=900, m=10, n=10, filter=65000)

    elif arm == "AGC":
        set_detmon_opt_params(job, niter=5, llx=200, lly=240, urx=320, ury=360, m=6, n=6, filter=65000, kappa=4,
                              saturation_limit=60000)


# --- Function to set parameters for xsh_mflat, depending on input files
def set_flat_parameters(job: Job):
    # If VIS (either SLIT or IFU) some parameters have to be set differently from what
    # specified in the xshooter_parameter file.
    if is_VIS(job.parameters):
        job.parameters.recipe_parameters["xsh.xsh_mflat.pre-overscan-corr"] = 3
        job.parameters.recipe_parameters["xsh.xsh_mflat.detectorder-min-sn"] = 25
    if is_IFU(job.parameters):
        job.parameters.recipe_parameters["xsh.xsh_mflat.detectorder-slice-trace-method"] = "sobel"

    fix_setup_keywords(job)


# --- Function to set parameters for xsh_2dmap, depending on input files
def set_map2d_parameters(job: Job):
    job.parameters.recipe_parameters["xsh.xsh_2dmap.detectarclines-min-sn"] = 5.0
    job.parameters.recipe_parameters["xsh.xsh_2dmap.detectarclines-find-lines-center"] = "gaussian"
    job.parameters.recipe_parameters["xsh.xsh_2dmap.model-scenario"] = 8

    if is_UVB(job.parameters):
        job.parameters.recipe_parameters["xsh.xsh_2dmap.model-maxit"] = 500
        job.parameters.recipe_parameters["xsh.xsh_2dmap.model-anneal-factor"] = 1.0

    elif is_VIS(job.parameters):
        job.parameters.recipe_parameters["xsh.xsh_2dmap.model-maxit"] = 500
        job.parameters.recipe_parameters["xsh.xsh_2dmap.model-anneal-factor"] = 1.0

    elif is_NIR(job.parameters):
        job.parameters.recipe_parameters["xsh.xsh_2dmap.model-maxit"] = 1000
        job.parameters.recipe_parameters["xsh.xsh_2dmap.model-anneal-factor"] = 0.5

    fix_setup_keywords(job)


# --- Function to set parameters for xsh_wavecal, depending on input files
def set_wavecal_parameters(job: Job):
    job.parameters.recipe_parameters["xsh.xsh_wavecal.followarclines-search-window-half-size"] = 13
    if is_VIS(job.parameters):
        job.parameters.recipe_parameters["xsh.xsh_wavecal.pre-overscan-corr"] = 3

    fix_setup_keywords(job)


# --- Function to set parameters for xsh_flexcomp, depending on input files
def set_flexures_parameters(job: Job):
    if is_VIS(job.parameters):
        job.parameters.recipe_parameters["xsh.xsh_flexcomp.pre-overscan-corr"] = 3


# --- Function to set the appropriate setup keyword depending on the input data
def fix_setup_keywords(job: Job):
    arm = job.input_files[0].get_keyword_value(kwd.seq_arm, None)
    if arm == "UVB":
        job.setup.pop(kwd.ins_opti4_name, None)
        job.setup.pop(kwd.ins_opti5_name, None)
    elif arm == "VIS":
        job.setup.pop(kwd.ins_opti3_name, None)
        job.setup.pop(kwd.ins_opti5_name, None)
    elif arm == "NIR":
        job.setup.pop(kwd.ins_opti3_name, None)
        job.setup.pop(kwd.ins_opti4_name, None)


def select_reports(job: Job):
    fix_setup_keywords(job)
    is_idp = job.parameters.get_workflow_param("is_idp", "FALSE")
    if is_idp == "TRUE":
        job.reports = [jr for jr in job.reports if jr.name == "xshooter_science"]


# If the inputs spectra to combine have different OBSTECH, then combination
# recipe is set not to check the inputs for PhaseIII consistency. In this case, the final combined spectrum
# might not be PhaseIII compliant.
# --- Generic function to run a recipe
def run_recipe(input_file, associated_files, parameters, recipe_name, args, invoker, renamer) -> RecipeInvocationResult:
    # input_file: main input and category. Format: List[files], where files have the format
    #             File(string_with_full_path, string_with_category, "")
    # associated_files: calibrations. Format List[files], where files have the format
    #             File(string_with_full_path, string_with_category, "")
    # parameters: non default recipe parameters. Format {'parameter_name1': value1, 'parameter_name2': value2}
    # recipe_name: recipe name  Format: string
    # args, invoker: extra stuff provided by the task that calls the function calling run_recipe()

    inputs = RecipeInputs(main_upstream_inputs=input_file, associated_upstream_inputs=associated_files)
    arguments = RecipeInvocationArguments(inputs=inputs, parameters=parameters,
                                          job_dir=args.job_dir, input_map={},
                                          logging_prefix=args.logging_prefix)

    return invoker.invoke(recipe_name, arguments, renamer, create_subdir=True)


def combine_spectra(args: RecipeInvocationArguments, invoker_provider: InvokerProvider,
                    renamer: ProductRenamer) -> RecipeInvocationResult:
    # This function runs the recipe esotk_spectrum1d_combine. It checks if all the inputs have the same OBSTECH
    # keyword value. If not, it sets the recipe parameter --noIDP == TRUE.

    science_files = [File(f.name, "SPECTRUM_1D", "") for f in args.inputs.combined]
    parameters = args.parameters.copy()
    obstech_unique_values = set()
    for f in science_files:
        with fits.open(f.file_path) as hdul:
            obstech_unique_values.add(hdul[0].header.get("OBSTECH", None))

    if len(obstech_unique_values) > 1 or None in obstech_unique_values:
        parameters['esotk_spectrum1d_combine.noIDP'] = "TRUE"

    return run_recipe(science_files, [], parameters, 'esotk_spectrum1d_combine', args,
                      invoker_provider.recipe_invoker, renamer)
