from dataclasses import dataclass

from edps import Task, Optional, qc0, ReportInput
from edps import task, subworkflow, qc1calib, match_rules, FilterMode

from .visir_datasources import *
from .visir_task_functions import run_clip


@dataclass
class VisirImages:
    repack: Task
    detection: Task
    clip: Task
    qc: Task
    undistort: Task
    processed: Task
    photometry: Optional[Task] = None


def process_images(raw_input, data_type):
    # This subworkflow reduces images of science and standard stars. The data_type
    # (science or standard) defines the task name.

    # The data_type defines the metatarget list of the last task in the subworkflow.
    if data_type == "science":
        metatargets = [qc0]
    else:  # standard
        metatargets = [qc1calib]

    # Data decoding
    repack_building = (task("repack_image_" + data_type)
                       .with_recipe("visir_util_repack")
                       .with_main_input(raw_input)
                       .with_associated_input(static_mask, min_ret=0)
                       .with_meta_targets(metatargets))

    if data_type == "standard":
        repack = (repack_building
                  .with_report("visir_rawdisp", ReportInput.RECIPE_INPUTS)
                  .build())
    else:
        repack = repack_building.build()

    # source detection
    detection = (task("source_detection_" + data_type)
                 .with_recipe("visir_util_detect_shift")
                 .with_main_input(repack)
                 .with_meta_targets(metatargets)
                 .build())

    # clipping
    clip = (task("clip_" + data_type)
            .with_function(run_clip)
            .with_main_input(detection)
            .with_input_map({BEAM_DETECTED: BKG_CORRECTED})
            .with_meta_targets(metatargets)
            .build())

    match_calib = (match_rules()
                   .with_match_keywords([kwd.mjd], level=0))

    # computation of quality-control parameters
    qc = (task("qc_" + data_type)
          .with_recipe("visir_util_qc")
          .with_main_input(clip)
          .with_associated_input(detection, match_rules=match_calib)
          .with_input_map({BEAM_DETECTED: BKG_CORRECTED})
          .with_meta_targets(metatargets)
          .build())

    # distortion correction
    undistort = (task("undistort_" + data_type)
                 .with_recipe("visir_util_run_swarp")
                 .with_main_input(clip)
                 .with_associated_input(detection, match_rules=match_calib)
                 .with_input_map({BEAM_DETECTED: BKG_CORRECTED})
                 .with_meta_targets(metatargets)
                 .build())

    if data_type == "standard":
        # measure photometry of standard stars
        photometry = (task("photometry_" + data_type)
                      .with_recipe("visir_old_img_phot")
                      .with_report("visir_img_std_star", ReportInput.RECIPE_INPUTS_OUTPUTS)
                      .with_main_input(qc)
                      .with_associated_input(undistort, match_rules=match_calib)
                      .with_associated_input(photometric_catalog)
                      .with_input_filter(QC_HEADER,  COADDED_CONTRIBUTION_COMBINED, mode=FilterMode.REJECT)
                      .with_input_map({COADDED_IMAGE: IM_CAL_PHOT_PREPROCESSED,
                                       COADDED_IMAGE_COMBINED: IM_CAL_PHOT_PREPROCESSED,
                                       COADDED_WEIGHT: WEIGHT_MAP,
                                       COADDED_WEIGHT_COMBINED: WEIGHT_MAP,
                                       QC_HEADER_COMBINED: QC_HEADER})
                      .with_meta_targets(metatargets)
                      .build())

        # process standard star image
        processed_standard = (task("processed_image_" + data_type)
                           .with_recipe("visir_util_join")
                           .with_main_input(photometry)
                           .with_input_map({COADDED_IMAGE: COADDED_IMAGE_COMBINED,
                                            QC_HEADER_COMBINED: QC_HEADER,
                                            COADDED_WEIGHT: WEIGHT_MAP,
                                            COADDED_WEIGHT_COMBINED: WEIGHT_MAP})
                           .with_meta_targets(metatargets)
                           .build())

        return VisirImages(repack=repack,
                           detection=detection,
                           clip=clip,
                           qc=qc,
                           undistort=undistort,
                           photometry=photometry,
                           processed=processed_standard)

    else:
        # process science image
        processed_image = (task("processed_image_" + data_type)
                           .with_recipe("visir_util_join")
                           .with_main_input(qc)
                           .with_associated_input(undistort, match_rules=match_calib)
                           .with_input_filter(QC_HEADER, mode=FilterMode.REJECT)
                           .with_input_map({COADDED_IMAGE: COADDED_IMAGE_COMBINED,
                                            QC_HEADER_COMBINED: QC_HEADER,
                                            COADDED_WEIGHT: WEIGHT_MAP,
                                            COADDED_WEIGHT_COMBINED: WEIGHT_MAP})
                           .with_meta_targets(metatargets)
                           .build())

        return VisirImages(repack=repack,
                           detection=detection,
                           clip=clip,
                           qc=qc,
                           undistort=undistort,
                           processed=processed_image)


@subworkflow("process_standard_images", "")
def process_standard_images(raw_input):
    return process_images(raw_input, "standard")


@subworkflow("process_science_images", "")
def process_science_images(raw_input):
    return process_images(raw_input, "science")
