from edps import subworkflow, task, QC1_CALIB, ReportInput

from .fors_datasources import *
from .fors_task_functions import *


def create_calibration(main_input, bias, wave, ins_mode, task_recipe, task_report_flat, task_report_wave, maxret):
    task_building = (task('calibration_' + ins_mode)
                     .with_recipe(task_recipe)
                     .with_report('fors_rawdisp', ReportInput.RECIPE_INPUTS)
                     .with_report(task_report_flat, ReportInput.RECIPE_INPUTS_OUTPUTS)
                     .with_report(task_report_wave, ReportInput.RECIPE_INPUTS_OUTPUTS)
                     .with_main_input(main_input)
                     .with_associated_input(wave, max_ret=maxret)
                     .with_associated_input(bias, [MASTERBIAS])
                     .with_associated_input(master_linecat)
                     .with_associated_input(grism_table)
                     .with_associated_input(distortion_table, condition=is_pmos)
                     .with_dynamic_parameter('ins_mode', which_observation_type)
                     .with_meta_targets([QC1_CALIB]))

    if ins_mode == 'mxu':
        task_building = (task_building
                         .with_job_processing(set_fors_calib_mxu))
    elif ins_mode == 'hc_lss':
        task_building = (task_building
                         .with_job_processing(set_endwavelength))

    return task_building.build()


# This sub-workflow creates the calibration tasks for the various instrument mode.
@subworkflow("spectra_calibrations", "")
def spectra_calibrations(bias):
    calibration_lss = create_calibration(raw_screen_flat_lss, bias, raw_wave, 'lss', 'fors_calib', 'fors_flat_spec_lss',
                                         'fors_wave_cal', 1)
    calibration_mos = create_calibration(raw_screen_flat_mos, bias, raw_wave, 'mos', 'fors_calib', 'fors_flat_spec_mos',
                                         'fors_wave_cal', 1)
    calibration_mxu = create_calibration(raw_screen_flat_mxu, bias, raw_wave, 'mxu', 'fors_calib', 'fors_flat_spec_mos',
                                         'fors_wave_cal', 1)
    calibration_std = create_calibration(raw_screen_flat_std, bias, raw_wave, 'std', 'fors_calib', 'fors_flat_spec_mos',
                                         'fors_wave_cal', 1)
    calibration_pmos = create_calibration(raw_screen_flat_pmos, bias, raw_wave_pmos, 'pmos', 'fors_pmos_calib',
                                          'fors_flat_spec_mos', 'fors_wave_cal', 1000)
    calibration_hc_lss = create_calibration(raw_flat_hc_lss, bias, raw_wave_hc_lss, 'hc_lss', 'fors_calib',
                                            'fors_flat_spec_lss', 'fors_wave_cal', 1)
    calibration_hc_std = create_calibration(raw_flat_hc_std, bias, raw_wave_hc_std, 'hc_std', 'fors_calib',
                                            'fors_flat_spec_mos', 'fors_wave_cal', 1)

    return calibration_lss, calibration_mos, calibration_mxu, calibration_std, calibration_pmos, calibration_hc_lss, calibration_hc_std
