from dataclasses import dataclass  # noqa

from edps import subworkflow, qc1calib, task, qc0, calchecker, Task, copy_upstream

from .matisse_datasources import *
from .matisse_task_functions import *


@dataclass
class ProcessedObservations:
    calibration: Task
    extraction: Task
    correction: Task
    optical_path_difference: Task
    coherent_processing: Task
    incoherent_processing: Task


@subworkflow("reduce_observations", "")
def process_observations(main_input, shutter_closed, detector_calibration, lamp_flat, distortion_calibrations,
                         kappa_matrix, data_type):
    # This subworkflow computes the uncalibrated visibilities, differential phase, closure phase, and correlated flux.
    # It processes test data, calibrator and science exposures. The input data type ("monitoring", "science",
    # "calibrator") sets the task name.

    # Calibration of the raw images
    if data_type == "science":
        metatarget_list = [qc0, calchecker]
    else:
        metatarget_list = [qc1calib, calchecker]

    calibration_tb = (task("calibration_" + data_type)
                      .with_recipe("mat_cal_image")
                      .with_main_input(main_input)
                      .with_associated_input(shutter_closed, min_ret=0, max_ret=100)
                      .with_dynamic_parameter("which_resolution", which_resolution)
                      .with_associated_input(detector_calibration, [BADPIX, NONLINEARITY])
                      .with_alternative_associated_inputs(distortion_calibrations)
                      .with_associated_input(raw_sky, min_ret=0, max_ret=100)
                      .with_meta_targets(metatarget_list))

    if data_type == "monitoring":
        calibration_tb = (calibration_tb
                          .with_associated_input(lamp_flat, [OBS_FLATFIELD], match_rules=assoc_flats_to_lamp))

    if data_type == "calibrator":
        calibration_tb = (calibration_tb
                          .with_associated_input(lamp_flat, [OBS_FLATFIELD], match_rules=assoc_flats_to_calibrator))

    if data_type == "science":
        calibration_tb = (calibration_tb
                          .with_associated_input(lamp_flat, [OBS_FLATFIELD], match_rules=assoc_flats_to_science))

    calibration = (calibration_tb.build())

    # Extraction of photometry
    # I need to collect the KAPPA_MATRIX associated to the same raw files that triggered the process
    get_kappa = (task("get_kappa_" + data_type)
                 .with_function(copy_upstream)
                 .with_main_input(main_input)
                 .with_dynamic_parameter("which_band", which_band)
                 .with_associated_input(kappa_matrix, [KAPPA_MATRIX], min_ret=1, condition=is_LM)
                 .with_associated_input(kappa_matrix, [KAPPA_MATRIX], min_ret=0, condition=is_N)
                 .build())

    extraction = (task("extraction_" + data_type)
                  .with_recipe("mat_ext_beams")
                  .with_main_input(calibration)
                  .with_associated_input(get_kappa, match_rules=same_raw_inputs)
                  .with_job_processing(set_ext_beams)
                  .with_meta_targets(metatarget_list)
                  .build())

    # Correct extracted flux
    correction = (task("correction_" + data_type)
                  .with_recipe("mat_est_corr")
                  .with_main_input(calibration)
                  .with_job_processing(set_est_corr)
                  .with_meta_targets(metatarget_list)
                  .build())

    # Computation of optical path difference
    optical_path_difference = (task("optical_path_difference_" + data_type)
                               .with_recipe("mat_est_opd")
                               .with_main_input(correction)
                               .with_meta_targets(metatarget_list)
                               .build())

    # Coherent processing:  computation of the raw differential phase and visibility.
    coherent_processing = (task("coherent_processing_" + data_type)
                           .with_recipe("mat_proc_coher")
                           .with_main_input(optical_path_difference)
                           .with_associated_input(correction, match_rules=same_raw_inputs)
                           .with_associated_input(extraction, match_rules=same_raw_inputs)
                           .with_meta_targets(metatarget_list)
                           .build())

    # Incoherent processing: calculation of the average power and bi spectrum from all selected frames.
    # Derivation of the raw squared visibility and raw closure phase.
    incoherent_processing = (task("incoherent_processing_" + data_type)
                             .with_recipe("mat_proc_incoher")
                             .with_main_input(correction)
                             .with_associated_input(extraction, match_rules=same_raw_inputs)
                             .with_meta_targets(metatarget_list)
                             .build())

    return ProcessedObservations(calibration=calibration,
                                 extraction=extraction,
                                 correction=correction,
                                 optical_path_difference=optical_path_difference,
                                 coherent_processing=coherent_processing,
                                 incoherent_processing=incoherent_processing)
