from dataclasses import dataclass  # noqa

from edps import subworkflow, task, qc1calib, alternative_associated_inputs, Task, \
    AssociatedInputGroupBuilder

from .matisse_datasources import *
from .matisse_task_functions import is_very_high_resolution, is_standard_resolutions


@dataclass
class DistortionProducts:
    distortion_coefficients: Task
    calibrate_wave: Task
    distortion_coefficients_very_high_resolution: Task
    distortion_calibrations: AssociatedInputGroupBuilder
    distortion_calibrations_optional: AssociatedInputGroupBuilder


@subworkflow("distortion_coefficients", "")
def compute_distortion_coefficients(detector_calibration, lamp_flat):
    # For low, medium and high resolution modes
    distortion_coefficients = (task("distortion_coefficients")
                               .with_recipe('mat_est_shift')
                               .with_main_input(raw_distortion)
                               .with_associated_input(detector_calibration, [BADPIX, NONLINEARITY])
                               .with_associated_input(lamp_flat, [OBS_FLATFIELD])
                               .with_meta_targets([qc1calib])
                               .build())

    # For very high resolution mode
    calibrate_wave = (task("calibrate_wave_distortion")
                      .with_recipe("mat_cal_image")
                      .with_main_input(raw_distortion_very_high_resolution)
                      .with_associated_input(shift_map)
                      .with_associated_input(kappa_matrix_static)
                      .with_associated_input(detector_calibration, [BADPIX, NONLINEARITY])
                      .with_associated_input(lamp_flat, [OBS_FLATFIELD])
                      .with_meta_targets([qc1calib])
                      .build())

    distortion_coefficients_very_high_resolution = (task("distortion_coefficients_very_high_resolution")
                                                    .with_recipe("mat_wave_cal")
                                                    .with_main_input(calibrate_wave)
                                                    .with_associated_input(shift_map)
                                                    .with_meta_targets([qc1calib])
                                                    .build())

    distortion_calibrations = (alternative_associated_inputs()
                               .with_associated_input(distortion_coefficients, [SHIFT_MAP],
                                                      condition=is_standard_resolutions)
                               .with_associated_input(distortion_coefficients_very_high_resolution, [SHIFT_MAP],
                                                      condition=is_very_high_resolution))

    distortion_calibrations_optional = (alternative_associated_inputs()
                                        .with_associated_input(distortion_coefficients, [SHIFT_MAP],
                                                               condition=is_standard_resolutions, min_ret=0)
                                        .with_associated_input(distortion_coefficients_very_high_resolution,
                                                               [SHIFT_MAP], condition=is_very_high_resolution,
                                                               min_ret=0))

    return DistortionProducts(distortion_coefficients=distortion_coefficients,
                              calibrate_wave=calibrate_wave,
                              distortion_coefficients_very_high_resolution=distortion_coefficients_very_high_resolution,
                              distortion_calibrations=distortion_calibrations,
                              distortion_calibrations_optional=distortion_calibrations_optional)
