import logging
from typing import List, Optional
from uuid import UUID

from .assoc_util import AssociationArgs, RequestType
from .edge import MainInputEdge, AlternativeAssociatedInputEdges
from .job import Job
from .parameters import ParametersProvider
from .stateful_task import StatefulTask
from .task import AssociatedInputGroup

BreakpointedAssociatedInputGroups = List[AlternativeAssociatedInputEdges]


class TaskInterpreter:
    def __init__(self, task: StatefulTask, main_input: MainInputEdge,
                 associated_input_groups: List[AlternativeAssociatedInputEdges], associate_incomplete: bool):
        self.associate_incomplete = associate_incomplete
        self.task = task
        self.main_input = main_input
        self.associated_input_groups = associated_input_groups
        self.logger = logging.getLogger('TaskInterpreter')

    def interpret(self, request_id: UUID, assoc_args: AssociationArgs, parameters_provider: ParametersProvider,
                  meta_targets: Optional[List[str]]):
        parameters = parameters_provider.get_parameters(str(self), self.task.task.workflow_names, request_id)
        jobs = self.main_input.create_jobs(request_id, parameters, meta_targets)
        active_jobs = [job for job in jobs if job.is_active()]
        num_jobs = len(jobs)
        num_active_jobs = len(active_jobs)
        if num_jobs != num_active_jobs:
            self.logger.debug("[%s %s] created %d jobs, %d active jobs", self.task, request_id, num_jobs,
                              num_active_jobs)

        for job in active_jobs:
            breakpointed_associated_input_groups = self.apply_breakpoints(job, assoc_args)
            if assoc_args.request_type == RequestType.CALSELECTOR:
                self.add_associated_inputs_to_job(job, breakpointed_associated_input_groups)
            else:
                self.add_associations_and_update_job(job, breakpointed_associated_input_groups, request_id, assoc_args, parameters_provider)

        self.task.add_jobs(request_id, active_jobs)

    def add_associated_inputs_to_job(self, job: Job, associated_input_groups: BreakpointedAssociatedInputGroups):
        for group in associated_input_groups:
            associated_inputs = group.get_associated_inputs(job.parameters)
            if associated_inputs:
                input_group = AssociatedInputGroup(associated_inputs=associated_inputs, sort_keys=group.sort_keys)
                job.add_associated_input_group(input_group)

    def add_associations_and_update_job(self, job: Job, associated_input_groups: BreakpointedAssociatedInputGroups,
                                        request_id: UUID, assoc_args: AssociationArgs, parameters_provider: ParametersProvider):
        job.is_complete = (job.is_complete and
                           self.add_associations_to_job(request_id, job, assoc_args, associated_input_groups))
        if self.task.job_editing_function:
            self.edit_job(job)
        self.override_recipe_parameters(job, parameters_provider)
        if not job.is_complete:
            self.logger.warning("Incomplete job %s", job.as_dict())

    def add_associations_to_job(self, request_id: UUID, job: Job, assoc_args: AssociationArgs, associated_input_groups: BreakpointedAssociatedInputGroups) -> bool:
        is_complete = True
        for assoc_input in associated_input_groups:
            result = assoc_input.get_association_result(request_id, job.exemplar, assoc_args, job.parameters, True)
            if result.is_empty() and self.associate_incomplete and assoc_args.request_type == RequestType.ORGANIZATION:
                result = assoc_input.get_association_result(request_id, job.exemplar, assoc_args, job.parameters, False)
            job.add_association_result(result)
            if not result.is_complete and not result.is_optional:
                self.logger.warning("Missing associated inputs %s for job %s", str(assoc_input), str(job))
                is_complete = False
        return is_complete

    def get_match_functions_for_associated_input(self, task_name: str) -> List[str]:
        for group in self.associated_input_groups:
            for edge in group.edges:
                if edge.name == task_name:
                    return edge.get_match_functions_text()
        return []

    def edit_job(self, job: Job):
        self.logger.debug('Invoking job editing function for task %s and job %s', self.task.name, job.as_dict())
        try:
            self.task.job_editing_function(job)
        except Exception as e:
            self.logger.error('Caught exception in job editing function: %s', e, exc_info=e)
        else:
            self.logger.debug('Successfully executed job editing function: %s', job.as_dict())

    @staticmethod
    def override_recipe_parameters(job: Job, parameters_provider: ParametersProvider):
        for workflow in job.workflow_names:
            override_parameters = parameters_provider.get_request_recipe_parameters(workflow, job.task_name)
            if override_parameters:
                job.parameters.recipe_parameters.update(override_parameters)
                return

    def apply_breakpoints(self, job: Job, assoc_args: AssociationArgs) -> BreakpointedAssociatedInputGroups:
        breakpoints_by_category = assoc_args.breakpoints.breakpoints_by_ins_cat.get(job.get_instrument(), {})
        return [group.apply_breakpoints(breakpoints_by_category, job.get_mjdobs()) for group in self.associated_input_groups]
