from dataclasses import dataclass
from typing import Optional

from edps import subworkflow, task, Task, ReportInput, science

from .eris_datasources import *
from .eris_task_function import is_pupil_tracking, not_pupil_tracking_and_not_astrometry, is_astrometry
from .eris_task_function import is_short, is_short_general, is_short_broad, is_long, is_long_bra, \
    is_standard_star, is_nix_lss
from .eris_task_function import which_band, which_target, which_nix_mode, set_cube_collapse


@dataclass
class NixProcessedObservations:
    science_reduction: Task
    sky_subtraction: Optional[Task] = None
    astrometry: Optional[Task] = None
    photometry: Optional[Task] = None
    stacking: Optional[Task] = None


def associate_flat_calibrations(task_building, nix_flat_lamp, nix_flat_twilight, nix_flat_sky, observ_mode):
    # --- Flat field associations ------------------------------------------------------------------
    # Different observing_modes ("cube", "coronagraphy", "sam", "images", and "long_slit")
    # and observing_types ("science" and "'"on_sky_calib") require different flat field calibrations.

    # Association of SKY flats:
    task_building = (task_building
                     .with_associated_input(nix_flat_sky, [MASTER_FLAT_SKY_HIFREQ, MASTER_SKY_LOFREQ, MASTER_BPM_SKY],
                                            match_rules=match_nix_skyflat, condition=is_long)
                     .with_associated_input(nix_flat_sky, [MASTER_FLAT_SKY_HIFREQ, MASTER_SKY_LOFREQ, MASTER_BPM_SKY],
                                            match_rules=match_skyflat_long_bra, condition=is_long_bra)
                     .with_associated_input(nix_flat_sky, [MASTER_FLAT_SKY_HIFREQ, MASTER_SKY_LOFREQ, MASTER_BPM_SKY],
                                            match_rules=match_skyflat_lss, condition=is_nix_lss))

    # Association of lamp flats:
    if observ_mode in ["coronagraphy", "sam"]:  # (either science or on-sky calibrations)
        task_building = (task_building
                         .with_associated_input(nix_flat_lamp,
                                                [MASTER_FLAT_LAMP_HIFREQ, MASTER_FLAT_LAMP_LOFREQ, MASTER_BPM_LAMP],
                                                min_ret=0, condition=is_short_broad)
                         .with_associated_input(nix_flat_lamp,
                                                [MASTER_FLAT_LAMP_HIFREQ, MASTER_FLAT_LAMP_LOFREQ, MASTER_BPM_LAMP],
                                                min_ret=1, condition=is_short))

    if observ_mode in ["cube", "images"]:  # (either science or on-sky calibrations)
        task_building = task_building.with_associated_input(nix_flat_lamp,
                                                            [MASTER_FLAT_LAMP_HIFREQ, MASTER_FLAT_LAMP_LOFREQ,
                                                             MASTER_BPM_LAMP],
                                                            min_ret=1, condition=is_short_general)

    # Association of twilight flats:
    # only for non long_slit observations taken with short and short_broad filters
    # (note: conditions below are valid only for non long-slit)
    task_building = (task_building
                     .with_associated_input(nix_flat_twilight, [MASTER_FLAT_TWILIGHT_LOFREQ], min_ret=1,
                                            condition=is_short_broad, match_rules=match_twflat_short_broad)
                     .with_associated_input(nix_flat_twilight, [MASTER_FLAT_TWILIGHT_LOFREQ], min_ret=1,
                                            condition=is_short, match_rules=match_twflat_short))

    # --- End of flat field associations -----------------------------------------------------------
    return task_building


def nix_science(raw_inputs, linearity_nix, dark_nix, nix_flat_lamp, nix_flat_twilight, nix_flat_sky, observ_mode,
                observ_type, metatargets):
    if observ_mode in ["cube", "coronagraphy", "sam"] and observ_type == "science":
        metatargets = metatargets + [science]

    task_building = (task('process_nix_' + observ_type + '_' + observ_mode)
                     .with_recipe('eris_nix_cal_det')
                     .with_main_input(raw_inputs)
                     .with_dynamic_parameter("band_used", which_band)
                     .with_dynamic_parameter("std_or_science", which_target)
                     .with_dynamic_parameter("nix_mode", which_nix_mode)
                     .with_associated_input(linearity_nix, [GAIN_INFO, COEFFS_CUBE, BP_MAP_NL])
                     .with_associated_input(dark_nix, [MASTER_DARK_IMG])
                     .with_associated_input(nix_wcs_refine)
                     .with_associated_input(nix_phot_data, condition=is_standard_star)
                     .with_meta_targets(metatargets))

    # If the "coronagraphic" exposure is taken with DET.FRAME.FORMAT==cube, then the recipe parameter
    # eris.eris_nix_cal_det.collapse_cube is set to 1
    if observ_mode == "coronagraphy":
        task_building = task_building.with_job_processing(set_cube_collapse)

    task_building = associate_flat_calibrations(task_building, nix_flat_lamp, nix_flat_twilight, nix_flat_sky,
                                                observ_mode)

    if observ_mode == 'long_slit':
        task_building = (task_building.with_associated_input(nix_master_startrace))
    else:
        task_building = (task_building.with_associated_input(nix_wcs_matched_catalogue))

    if observ_mode == "images":
        task_building = task_building.with_associated_input(raw_sky_jitter_image, min_ret=0, max_ret=100)
        if observ_type == "on_sky_calib":
            task_building = task_building.with_report("eris_rawdisp", ReportInput.RECIPE_INPUTS)
    elif observ_mode == 'long_slit':
        task_building = task_building.with_associated_input(raw_sky_jitter_lss, min_ret=0, max_ret=100)
    elif observ_mode == "cube":
        task_building = task_building.with_associated_input(raw_sky_jitter_cube, min_ret=0, max_ret=100)

    return task_building.build()


def nix_reduction(raw_inputs, linearity_nix, dark_nix, nix_flat_lamp, nix_flat_twilight, nix_flat_sky, observ_mode,
                  observ_type, metatargets):
    # This subworkflow processes on-sky NIX exposures (science, standard, and astrometric).
    # The input variable observ_mode specifies whether the raw_inputs are coronagraphy, images, ifu or
    # long-slit spectra. It is used to define the reduction cascade and the return arguments.
    # The input variable obs_type specifies whether the input data are science, standard, or astrometry.
    # Task names are set accordingly to observ_type and observ_mode.

    # Coronography (APP and FPC), SAM, and CUBE-format exposures run only the first reduction step
    # (recipe: 'eris_nix_cal_det', which removes the instrument signature are removed from the data).

    # observ_mode could be "cube", "coronagraphy", "sam", "images", and "long_slit"
    # observ_type could be "science" and "'"on_sky_calib"'"

    # Reduction steps:

    # 1)
    # This tasks removes the instrument signature from the data. It is executed for all types of inputs (images, ifu,
    # or long-slit spectra)
    science_task = nix_science(raw_inputs, linearity_nix, dark_nix, nix_flat_lamp, nix_flat_twilight, nix_flat_sky,
                               observ_mode, observ_type, metatargets)
    if observ_mode in ["cube", "coronagraphy", "sam"]:  # Only the first reduction step.
        return science_task

    # IMAGES and LONG-SLIT observations are processed further

    # 2)
    # The next steps (sky subtraction, rectification/astrometry, flux calibration, and combination)
    # require different treatment for different observing modes (long_slit or images)
    if observ_mode == 'long_slit':
        # 2a) - sky subtraction
        sky_subtraction = (task('sky_subtraction_nix_' + observ_type + '_' + observ_mode)
                           .with_recipe('eris_nix_lss_skysub')
                           .with_main_input(science_task)
                           .with_meta_targets(metatargets)
                           .build())
        # 2b) astrometric/straightneing
        astrometry = (task('astrometry_nix_' + observ_type + '_' + observ_mode)
                      .with_recipe('eris_nix_lss_straighten')
                      .with_main_input(sky_subtraction)
                      .with_associated_input(nix_master_startrace)
                      .with_meta_targets(metatargets)
                      .build())
        if observ_type == "science":
            metatargets = metatargets + [science]

        # 2c) photometry: not done for long_slit observing mode.
        # 2d) stacking
        stacking = (task('combine_nix_' + observ_type + '_' + observ_mode)
                    .with_condition(not_pupil_tracking_and_not_astrometry)
                    .with_dynamic_parameter("is_pupil_tracking", is_pupil_tracking)
                    .with_dynamic_parameter("is_astrometry", is_astrometry)
                    .with_recipe('eris_nix_lss_stack')
                    .with_main_input(astrometry)
                    .with_meta_targets(metatargets)
                    .build())

        return NixProcessedObservations(science_reduction=science_task,
                                        sky_subtraction=sky_subtraction,
                                        astrometry=astrometry,
                                        stacking=stacking)
    if observ_mode == "images":
        # 2a) - sky subtraction
        sky_subtraction = (task('sky_subtraction_nix_' + observ_type + '_' + observ_mode)
                           .with_recipe('eris_nix_img_skysub')
                           .with_main_input(science_task)
                           .with_meta_targets(metatargets)
                           .build())

        # 2b) astrometric/straightneing
        astrometry = (task('astrometry_nix_' + observ_type + '_' + observ_mode)
                      .with_recipe('eris_nix_img_cal_wcs')
                      .with_main_input(sky_subtraction)
                      .with_meta_targets(metatargets)
                      .with_associated_input(nix_wcs_matched_catalogue)
                      .build())

        if observ_type == "science":
            metatargets = metatargets + [science]

        # 2c) photometry
        photometry = (task('photometry_nix_' + observ_type + '_' + observ_mode)
                      .with_recipe('eris_nix_img_cal_phot')
                      .with_main_input(astrometry)
                      .with_associated_input(nix_phot_data)
                      .with_min_group_size(2)
                      .with_meta_targets(metatargets)
                      .build())
        # 2d) stacking
        task_building = (task('combine_nix_' + observ_type + '_' + observ_mode)
                         .with_condition(not_pupil_tracking_and_not_astrometry)
                         .with_dynamic_parameter("is_pupil_tracking", is_pupil_tracking)
                         .with_dynamic_parameter("is_astrometry", is_astrometry)
                         .with_recipe('eris_nix_img_hdrl_stack')
                         .with_main_input(photometry)
                         .with_meta_targets(metatargets))

        if observ_type == "on_sky_calib" and observ_mode == "images":
            task_building = task_building.with_report("eris_img_std_star", ReportInput.RECIPE_INPUTS_OUTPUTS)

        stacking = task_building.build()

        return NixProcessedObservations(science_reduction=science_task,
                                        sky_subtraction=sky_subtraction,
                                        astrometry=astrometry,
                                        photometry=photometry,
                                        stacking=stacking)


@subworkflow("science_reduction_nix", "")
def nix_science_reduction(raw_inputs, linearity_nix, dark_nix, nix_flat_lamp, nix_flat_twilight, nix_flat_sky,
                          observ_mode, observ_type, metatargets):
    return nix_reduction(raw_inputs, linearity_nix, dark_nix, nix_flat_lamp, nix_flat_twilight, nix_flat_sky,
                         observ_mode, observ_type, metatargets)


@subworkflow("on_sky_calibrations_reduction_nix", "")
def nix_on_sky_calibrations_reduction(raw_inputs, linearity_nix, dark_nix, nix_flat_lamp, nix_flat_twilight,
                                      nix_flat_sky, observ_mode, observ_type, metatargets):
    return nix_reduction(raw_inputs, linearity_nix, dark_nix, nix_flat_lamp, nix_flat_twilight, nix_flat_sky,
                         observ_mode, observ_type, metatargets)
