import dataclasses
import importlib
import inspect
from types import ModuleType
from typing import Dict, List, Optional, Set, Tuple
from uuid import UUID

from edps.config.configuration import Configuration
from .stateful_task import CleanupRequestArgs
from .stateful_workflow import StatefulWorkflow
from .task import Task, DataSource, BaseClassificationRule
from .task_interpreter import TaskInterpreter
from .workflow import Workflow
from ..client.search import SearchFilter


class WorkflowNotFoundError(Exception):
    def __init__(self, message: str):
        self.message = message


class WorkflowManager:
    def __init__(self, config: Configuration):
        self.workflow_cache: Dict[str, Workflow] = {}
        self.stateful_workflow_cache: Dict[str, StatefulWorkflow] = {}
        self.config = config

    def get_workflow(self, workflow_name: str) -> Workflow:
        if workflow_name not in self.workflow_cache:
            workflow_mod = self.import_workflow(workflow_name)
            self.workflow_cache[workflow_name] = self.create_workflow(workflow_mod)
        return self.workflow_cache.get(workflow_name)

    def get_stateful_workflow(self, workflow_name: str) -> StatefulWorkflow:
        if workflow_name not in self.stateful_workflow_cache:
            workflow = self.get_workflow(workflow_name)
            self.stateful_workflow_cache[workflow_name] = StatefulWorkflow(workflow, self.config)
        return self.stateful_workflow_cache.get(workflow_name)

    def update_workflow(self, name: str, workflow: Workflow):
        self.workflow_cache[name] = workflow
        self.stateful_workflow_cache[name] = StatefulWorkflow(workflow, self.config)

    def reset_workflow(self, workflow_name: str):
        if workflow_name == "all":
            self.stateful_workflow_cache = {}
        else:
            self.stateful_workflow_cache.pop(workflow_name, None)

    def remove_jobs(self, jobs: List[UUID]):
        for workflow in self.stateful_workflow_cache.values():
            workflow.remove_jobs(jobs)

    def get_interpreter_for_task(self, task_id: str) -> Optional[TaskInterpreter]:
        for workflow in self.stateful_workflow_cache.values():
            interpreter = workflow.get_interpreter_for_task(task_id)
            if interpreter is not None:
                return interpreter
        return None

    @staticmethod
    def import_workflow(workflow_name: str) -> ModuleType:
        try:
            module = importlib.import_module(workflow_name)
            return importlib.reload(module)
        except ModuleNotFoundError as e:
            raise WorkflowNotFoundError(f'workflow not found: {workflow_name}') from e

    def create_workflow(self, mod: ModuleType) -> Workflow:
        classification_rules: Set[BaseClassificationRule] = set()
        data_sources: Set[DataSource] = set()
        tasks: Set[Task] = set()
        for member in inspect.getmembers(mod):
            obj = member[1]
            classification_rules, data_sources, tasks = self.parse_object(classification_rules, data_sources, tasks, mod, obj)
        return self.create_workflow_from_objects(mod, tasks, data_sources, classification_rules)

    def parse_object(self, classification_rules, data_sources, tasks, mod, obj):
        if isinstance(obj, BaseClassificationRule):
            classification_rules.add(obj)
        elif isinstance(obj, DataSource):
            obj.with_workflow_name(mod.__name__)
            data_sources.add(obj)
        elif isinstance(obj, Task):
            obj.with_workflow_name(mod.__name__)
            tasks.add(obj)
        elif isinstance(obj, ModuleType) and 'wkf' in obj.__name__:
            w = self.create_workflow(obj)
            tasks = tasks.union(w.tasks)
            data_sources = data_sources.union(w.data_sources)
            classification_rules = classification_rules.union(w.classification_rules)
        # FIXME: this is internal API which can change but there doesn't seem to be any better way right now
        elif dataclasses._is_dataclass_instance(obj):
            for element in [getattr(obj, name) for name in obj.__dataclass_fields__.keys()]:
                classification_rules, data_sources, tasks = self.parse_object(classification_rules, data_sources, tasks, mod, element)
        return classification_rules, data_sources, tasks

    @staticmethod
    def create_workflow_from_objects(mod: ModuleType, tasks: Set[Task], data_sources: Set[DataSource],
                                     classification_rules: Set[BaseClassificationRule]) -> Workflow:
        data_sources, tasks = WorkflowManager.expand_tasks(data_sources, tasks)
        for data_source in list(data_sources):
            for rule in data_source.classification_rules:
                classification_rules.add(rule)
        for task in list(tasks):
            task.with_workflow_name(mod.__name__)
            for rule in task.classification_rules:
                classification_rules.add(rule)
        title = mod.__getattribute__('__title__') if '__title__' in dir(mod) else mod.__name__
        workflow = Workflow(list(classification_rules), list(data_sources), list(tasks), title=title)
        return workflow

    @staticmethod
    def expand_tasks(data_sources: Set[DataSource], tasks: Set[Task]) -> Tuple[Set[DataSource], Set[Task]]:
        tasks = set(tasks)
        data_sources = set(data_sources)
        seen = set()
        tasks_to_expand = list(tasks)
        while tasks_to_expand:
            task = tasks_to_expand.pop()
            seen.add(task)
            parents = [task.main_input] + [associated_input.input_task for group in task.associated_input_groups for
                                           associated_input in group.associated_inputs]
            for parent_task in parents:
                if parent_task not in seen:
                    seen.add(parent_task)
                    if isinstance(parent_task, DataSource):
                        data_sources.add(parent_task)
                    elif isinstance(parent_task, Task):
                        tasks_to_expand.append(parent_task)
                        tasks.add(parent_task)
        return data_sources, tasks

    def cleanup_request(self, workflow_name: str, targets: List[str], cleanup_args: CleanupRequestArgs):
        workflow = self.get_stateful_workflow(workflow_name)
        if workflow:
            workflow.cleanup(targets, cleanup_args)

    def expand_targets(self, search_filter: SearchFilter) -> SearchFilter:
        targets = []
        for workflow in self.workflow_cache.values():
            targets += workflow.get_targets(search_filter.targets, search_filter.meta_targets)
        return dataclasses.replace(search_filter, targets=targets)
