import logging
import sys
import uuid
from dataclasses import dataclass, field
from itertools import chain
from typing import Set, List, Dict, Optional, Callable

from edps.client.WorkflowDTO import AssociatedInputDTO, AssociatedInputGroupDTO, DataSourceDTO, TaskDTO, TaskType
from .assoc_config import AssociationConfiguration
from .classif_rule import BaseClassificationRule
from .constants import MJD_OBS, TRUE_CONDITION
from .fits import ClassifiedFitsFile
from .grouping import Grouper
from .job import Job, ActiveCondition
from .meta_targets import ALL
from .task_details import TaskDetails, DynamicParameterProvider, ReportConfig, FilterMode, CommandConfig

JobProcessingFunction = Callable[[Job], None]


class TaskBase:
    def __init__(self, name: str, grouping_keywords: List[str], groupers: List[Grouper], min_group_size: int,
                 classification_rules: List[BaseClassificationRule] = None, setup_keywords: List[str] = None,
                 assoc_configs: List[AssociationConfiguration] = None, condition: ActiveCondition = TRUE_CONDITION,
                 description: Optional[str] = None):
        self.name = name
        self.type: TaskType = TaskType.BASE
        self.classification_rules = classification_rules or []
        self.assoc_configs = assoc_configs or []
        self.is_main_input = False
        self.is_assoc_input = False
        self.condition = condition
        self.description = description
        self.identifier = uuid.uuid4()
        self.grouping_keywords = grouping_keywords or []
        self.groupers: List[Grouper] = groupers
        self.setup_keywords = setup_keywords or []
        self.match_keywords = set(chain.from_iterable([cfg.match_keywords for cfg in self.assoc_configs]))
        self.min_group_size = min_group_size
        self.workflow_names = set()
        self.should_apply_breakpoints = not any(rule.is_product() for rule in self.classification_rules)

    def add_classification_rules(self, classification_rules: List[BaseClassificationRule]):
        new_classification_rules = [x for x in classification_rules if x not in self.classification_rules]
        self.classification_rules.extend(new_classification_rules)

    @property
    def categories(self) -> Set[str]:
        return {x.classification for x in self.classification_rules}

    @property
    def classification_dict(self) -> Dict[str, str]:
        return {classif_rule.classification: classif_rule.as_str() for classif_rule in self.classification_rules}

    @property
    def grouping_keywords_str(self) -> str:
        return ", ".join(self.grouping_keywords).upper()

    @property
    def setup_keywords_str(self) -> str:
        return ", ".join(self.setup_keywords).upper()

    def with_workflow_name(self, name: str) -> 'TaskBase':
        self.workflow_names.add(name)
        return self

    def as_stateful_task(self, associate_existing_jobs: bool):
        raise NotImplementedError


class DataSource(TaskBase):
    def __init__(self, classification_rules: List[BaseClassificationRule], grouping_keywords: List[str],
                 setup_keywords: List[str], min_group_size: int, match_functions: List[AssociationConfiguration],
                 groupers: List[Grouper], subworkflow_name: List[str], name: str = None):
        super().__init__(name=name, classification_rules=classification_rules, assoc_configs=match_functions,
                         grouping_keywords=grouping_keywords, groupers=groupers, setup_keywords=setup_keywords,
                         min_group_size=min_group_size)
        self.type = TaskType.DATA_SOURCE
        self.logger = logging.getLogger('DataSource')
        self.classification_rules_ids = {rule.id for rule in classification_rules}
        self.subworkflow_name = subworkflow_name or []

    def as_stateful_task(self, associate_existing_jobs: bool):
        from .stateful_task import StatefulDataSource
        return StatefulDataSource(self)

    def as_dto(self) -> DataSourceDTO:
        return DataSourceDTO(name=self.name, setup=self.setup_keywords, categories=list(self.categories),
                             assoc_levels=[x.as_dto() for x in self.assoc_configs],
                             prod_categories=[], id=str(self.identifier),
                             group_by=self.grouping_keywords, min_group_size=self.min_group_size)


class AssociatedInput:
    def __init__(self, input_task: TaskBase, classification_rules: List[BaseClassificationRule] = None,
                 min_ret: int = 1, max_ret: int = 1, condition: ActiveCondition = TRUE_CONDITION,
                 assoc_configs: List[AssociationConfiguration] = None, sort_keys: List[str] = None):
        input_task.is_assoc_input = True
        self.input_task = input_task
        self.input_type = input_task.type
        self.classification_rules = classification_rules or []
        self.categories = {x.classification for x in self.classification_rules}
        self.input_task.add_classification_rules(self.classification_rules)
        self.input_categories = self.input_task.categories
        self.min_ret = min_ret
        self.max_ret = self._fix_max_ret(min_ret, max_ret)
        self.condition = condition
        self.assoc_configs = assoc_configs or input_task.assoc_configs
        self.sort_keys = sort_keys or []
        self.sort_keys.append(MJD_OBS)

    @staticmethod
    def _fix_max_ret(min_ret: int, max_ret: int) -> int:
        return sys.maxsize if max_ret < min_ret else max_ret

    def as_dto(self) -> AssociatedInputDTO:
        return AssociatedInputDTO(input_id=str(self.input_task.identifier), input_name=self.input_task.name,
                                  input_type=self.input_type, input_categories=list(self.input_categories),
                                  assoc_levels=[x.as_dto() for x in self.assoc_configs], order_by=self.sort_keys,
                                  prod_categories=list(self.categories), min_ret=self.min_ret, max_ret=self.max_ret,
                                  classification_rules=[x.as_dto() for x in self.input_task.classification_rules])


@dataclass
class AssociatedInputGroup:
    associated_inputs: List[AssociatedInput]
    sort_keys: List[str] = field(default_factory=list)

    def as_dto(self) -> AssociatedInputGroupDTO:
        return AssociatedInputGroupDTO(associated_inputs=[inp.as_dto() for inp in self.associated_inputs],
                                       order_by=self.sort_keys)


class Task(TaskBase):
    def __init__(self, name: str, command_config: CommandConfig, subworkflow_name: List[str],
                 main_input: TaskBase, associated_input_groups: List[AssociatedInputGroup], condition: ActiveCondition,
                 meta_targets: List[str], input_filter: List[str], input_filter_mode: FilterMode,
                 output_filter: List[str], output_filter_mode: FilterMode,
                 input_map: Dict[str, str], grouping_keywords: List[str], groupers: List[Grouper], min_group_size: int,
                 job_editing_function: Optional[JobProcessingFunction],
                 dynamic_parameters: Dict[str, DynamicParameterProvider], reports: List[ReportConfig],
                 description: Optional[str], accepted_classification_rules: List[BaseClassificationRule]):
        super().__init__(name=name, condition=condition, grouping_keywords=grouping_keywords, groupers=groupers,
                         min_group_size=min_group_size, description=description)
        self.type = TaskType.TASK
        self.logger = logging.getLogger('Task')
        self.command = command_config.command
        self.command_type = command_config.command_type
        self.command_recipes = command_config.recipes
        self.task_id = f'{self.name}.{self.command}'
        self.main_input = main_input
        main_input.is_main_input = True
        self.assoc_configs = main_input.assoc_configs
        self.associated_input_groups = associated_input_groups
        self.meta_targets = set(meta_targets)
        self.meta_targets.add(ALL)
        self.input_filter = input_filter
        self.input_filter_mode = input_filter_mode
        self.output_filter = output_filter
        self.output_filter_mode = output_filter_mode
        self.input_map = input_map
        self.job_editing_function = job_editing_function
        self.dynamic_parameters = dynamic_parameters
        self.setup_keywords = main_input.setup_keywords
        self.reports = reports
        self.subworkflow_name = subworkflow_name or []
        self.accepted_classification_rules = accepted_classification_rules

    def is_task_product(self, f: ClassifiedFitsFile) -> bool:
        return f.classification in self.categories

    def has_meta_targets(self, meta_targets: Set[str]) -> bool:
        return len(self.meta_targets.intersection(meta_targets)) > 0

    @property
    def assoc_config_str(self) -> str:
        texts = [assoc_config.as_str() for assoc_config in self.assoc_configs]
        return '\n\n'.join(texts)

    def as_task_details(self) -> TaskDetails:
        return TaskDetails(task_id=self.task_id, command=self.command, command_type=self.command_type,
                           task_name=self.name, input_filter=self.input_filter,
                           input_filter_mode=self.input_filter_mode, output_filter=self.output_filter,
                           output_filter_mode=self.output_filter_mode, input_map=self.input_map,
                           workflow_names=list(self.workflow_names), dynamic_parameters=self.dynamic_parameters,
                           active_condition=self.condition, setup_keywords=self.setup_keywords, reports=self.reports)

    def as_stateful_task(self, associate_existing_jobs: bool):
        from .stateful_task import StatefulTask
        return StatefulTask(self, associate_existing_jobs)

    def flatten_associated_inputs(self) -> List[AssociatedInput]:
        return [inp for group in self.associated_input_groups for inp in group.associated_inputs]

    def as_dto(self) -> TaskDTO:
        return TaskDTO(name=self.name, task_id=self.task_id, setup=self.setup_keywords,
                       main_input_id=str(self.main_input.identifier), main_input_name=self.main_input.name,
                       associated_inputs=[x.as_dto() for x in self.flatten_associated_inputs()],
                       associated_input_groups=[x.as_dto() for x in self.associated_input_groups],
                       categories=[], assoc_levels=[x.as_dto() for x in self.assoc_configs],
                       prod_categories=list(self.categories), id=str(self.identifier),
                       group_by=self.grouping_keywords, min_group_size=self.min_group_size,
                       command=self.command, command_type=self.command_type, meta_targets=list(self.meta_targets),
                       workflows=list(self.workflow_names), dynamic_parameters=list(self.dynamic_parameters))
