import logging
from typing import List, Optional, Dict
from uuid import UUID

from .assoc_util import AssociationArgs
from .assoc_util import limit, ReferenceFile, order_jobs_by, order_files_by, order_assoc_results_by
from .association_strategy import get_association_strategy, get_assoc_threshold
from .fits import ClassifiedFitsFile
from .job import Job, AssociationResult, AssociatedFitsFile
from .parameters import Parameters, JobParameters
from .stateful_task import StatefulTaskBase, StatefulTask, StatefulDataSource
from .task import AssociatedInput, AssociationConfiguration
from .. import MJD_OBS


class Edge:
    def __init__(self, origin: StatefulTaskBase, destination: StatefulTask):
        self.origin = origin
        self.destination = destination


class MainInputEdge(Edge):
    def __init__(self, origin: StatefulTaskBase, destination: StatefulTask):
        super().__init__(origin, destination)

    def create_jobs(self, request_id: UUID, parameters: Parameters, meta_targets: Optional[List[str]]) -> List[Job]:
        return self.origin.create_jobs_for_task(request_id, self.destination, parameters, meta_targets)


class AssociatedInputEdge(Edge):
    def __init__(self, origin: StatefulTaskBase, destination: StatefulTask, details: AssociatedInput):
        super().__init__(origin, destination)
        self.associated_input = details
        self.min_ret = details.min_ret
        self.max_ret = details.max_ret
        self.condition = details.condition
        self.assoc_configs: List[AssociationConfiguration] = details.assoc_configs
        self.sort_keys = details.sort_keys
        self.name = self.origin.name

    @property
    def is_optional(self) -> bool:
        return self.min_ret == 0

    def is_active(self, parameters: JobParameters) -> bool:
        return self.condition(parameters)

    def get_match_functions_text(self) -> List[str]:
        return [assoc_config.get_text() for assoc_config in self.assoc_configs]

    def create_association_result(self, partial_result: AssociationResult, assoc_level: float) -> AssociationResult:
        return AssociationResult(associated_files=partial_result.associated_files,
                                 associated_jobs=partial_result.associated_jobs,
                                 is_complete=partial_result.compute_is_complete(),
                                 is_optional=self.is_optional,
                                 assoc_level=assoc_level,
                                 task_name=self.name)

    def get_association_result(self, request_id: UUID, exemplar: ClassifiedFitsFile, assoc_args: AssociationArgs,
                               parameters: JobParameters, only_complete: bool) -> Optional[AssociationResult]:
        raise NotImplementedError

    def clone(self) -> 'AssociatedInputEdge':
        raise NotImplementedError

    def apply_breakpoints(self, breakpoints_by_category: Dict[str, List[float]],
                          mjd_obs: float) -> 'AssociatedInputEdge':
        edge = self
        if self.origin.should_apply_breakpoints():
            edge = self.clone()
            edge.assoc_configs = [
                config.apply_breakpoints(breakpoints_by_category, self.associated_input.categories, mjd_obs)
                for config in edge.assoc_configs
            ]
        return edge


class AlternativeAssociatedInputEdges:
    def __init__(self, edges: List[AssociatedInputEdge], sort_keys: List[str]):
        self.edges = edges
        self.sort_keys = sort_keys
        edge_names = [edge.name for edge in edges]
        self.name = '.'.join(edge_names) if edge_names else "EMPTY_ALTERNATIVE_INPUT"

    def active_edges(self, parameters: JobParameters) -> List[AssociatedInputEdge]:
        return [edge for edge in self.edges if edge.is_active(parameters)]

    def is_optional(self) -> bool:
        return all([edge.is_optional for edge in self.edges])

    def get_association_result(self, request_id: UUID, exemplar: ClassifiedFitsFile, assoc_args: AssociationArgs,
                               parameters: JobParameters, only_complete: bool = True) -> AssociationResult:
        active_edges = self.active_edges(parameters)
        is_complete = active_edges == [] and len(self.edges) == 1
        task_name_for_empty_result = active_edges[0].name if len(active_edges) > 0 else self.name
        empty_result = AssociationResult(associated_files=[], associated_jobs=[],
                                         is_complete=is_complete, is_optional=self.is_optional(),
                                         assoc_level=0, task_name=task_name_for_empty_result)
        results = [edge.get_association_result(request_id, exemplar, assoc_args, parameters, only_complete) for edge in
                   self.active_edges(parameters)]
        results = [res for res in results if res is not None]
        results = order_assoc_results_by(results, exemplar, self.sort_keys, parameters)
        return results[0] if len(results) > 0 else empty_result

    def get_associated_inputs(self, parameters: JobParameters) -> List[AssociatedInput]:
        return [edge.associated_input for edge in self.active_edges(parameters)]

    def __repr__(self):
        return self.name

    def apply_breakpoints(self, breakpoints_by_category: Dict[str, List[float]],
                          mjd_obs: float) -> 'AlternativeAssociatedInputEdges':
        breakpointed_edges = [edge.apply_breakpoints(breakpoints_by_category, mjd_obs) for edge in self.edges]
        return AlternativeAssociatedInputEdges(breakpointed_edges, self.sort_keys)


class AssociatedInputDataSource(AssociatedInputEdge):
    def __init__(self, origin: StatefulDataSource, destination: StatefulTask, details: AssociatedInput):
        super().__init__(origin, destination, details)
        self.input_data_source = origin
        self.logger = logging.getLogger('AssociatedInputDataSource')

    def get_association_result(self, request_id: UUID, exemplar: ClassifiedFitsFile, assoc_args: AssociationArgs,
                               parameters: JobParameters, only_complete: bool) -> Optional[AssociationResult]:
        return self.get_closest_associated_files(request_id, exemplar, parameters)

    def get_closest_associated_files(self, request_id: UUID, exemplar: ClassifiedFitsFile,
                                     parameters: JobParameters) -> Optional[AssociationResult]:
        for index, assoc_config in enumerate(self.assoc_configs):
            ref_file = ReferenceFile(assoc_config, exemplar)
            if assoc_config.level <= get_assoc_threshold(parameters):
                result = self.get_closest_associated_files_for_level(request_id, ref_file, index, parameters)
                if result:
                    return self.create_association_result(result, assoc_config.level)
        return None

    def get_closest_associated_files_for_level(self, request_id: UUID, ref_file: ReferenceFile, index: int,
                                               parameters: JobParameters) -> Optional[AssociationResult]:

        associated_files = self.input_data_source.get_associated_files(request_id, ref_file, parameters)
        exemplar = ref_file.ref_file
        if len(associated_files) == self.min_ret == 0:
            return None
        elif len(associated_files) >= self.min_ret:
            self.logger.debug("[%s index=%d] Found associated files for %s", str(self.origin), index, str(ref_file))
            sorted_files = order_files_by(associated_files, exemplar, self.sort_keys, parameters)
            selected_files = [AssociatedFitsFile(f, f[MJD_OBS]) for f in limit(sorted_files, self.max_ret)]
            return AssociationResult(associated_files=selected_files, associated_jobs=[])
        else:
            self.logger.debug("[%s index=%d] Could not find associated files for %s. requested %d found %d",
                              str(self.origin), index, str(exemplar), self.min_ret, len(associated_files))
            return None

    def clone(self) -> 'AssociatedInputDataSource':
        return AssociatedInputDataSource(self.origin, self.destination, self.associated_input)


class AssociatedInputTask(AssociatedInputEdge):
    def __init__(self, origin: StatefulTask, destination: StatefulTask, details: AssociatedInput):
        super().__init__(origin, destination, details)
        self.input_task = origin
        self.categories = details.categories
        self.logger = logging.getLogger('AssociatedInputTask')

    def get_association_result(self, request_id: UUID, exemplar: ClassifiedFitsFile, assoc_args: AssociationArgs,
                               parameters: JobParameters, only_complete: bool) -> Optional[AssociationResult]:
        assoc_strategy = get_association_strategy(assoc_args, self)
        return assoc_strategy.get_association_result(request_id, exemplar, parameters, only_complete)

    def get_closest_associated_files_for_level(self, request_id: UUID, ref_file: ReferenceFile, index: int,
                                               parameters: JobParameters) -> Optional[AssociationResult]:
        associated_files_by_category = self.input_task.get_associated_files(request_id, self.categories, ref_file,
                                                                            parameters)
        associated_groups = associated_files_by_category.values()
        is_empty = len(associated_groups) == 0 or all([len(files) == 0 for files in associated_groups])
        is_incomplete = len(associated_groups) > 0 and any([len(files) < self.min_ret for files in associated_groups])
        if is_incomplete:
            num_files_per_catg = {category: len(files) for category, files in associated_files_by_category.items()}
            self.logger.debug("[%s index=%d] Could not find associated files for %s. requested %d found %s",
                              str(self.origin), index, str(ref_file), self.min_ret, str(num_files_per_catg))
        if is_empty or is_incomplete:
            return None
        else:
            self.logger.debug("[%s index=%d] Found associated files for %s", str(self.origin), index, str(ref_file))
            closest_associated_files = []
            for files in associated_groups:
                closest_associated_files.extend(
                    limit(order_files_by(files, ref_file.ref_file, self.sort_keys, parameters), self.max_ret))
            selected_files = [AssociatedFitsFile(f, f[MJD_OBS]) for f in closest_associated_files]
            return AssociationResult(associated_files=selected_files, associated_jobs=[])

    def get_closest_associated_jobs_for_level(self, request_id: UUID, ref_file: ReferenceFile, index: int,
                                              parameters: JobParameters,
                                              only_complete: bool) -> Optional[AssociationResult]:
        associated_jobs = self.input_task.get_associated_jobs(request_id, ref_file, parameters, only_complete)
        if len(associated_jobs) == 0 and self.min_ret == 0:
            return None
        elif len(associated_jobs) >= self.min_ret:
            self.logger.debug("[%s index=%d] Found associated jobs for %s", str(self.origin), index, str(ref_file))
            sorted_jobs = order_jobs_by(associated_jobs, ref_file.ref_file, self.sort_keys, parameters)
            return AssociationResult(associated_files=[], associated_jobs=limit(sorted_jobs, self.max_ret))
        else:
            self.logger.debug("[%s index=%d] Could not find associated jobs for %s. requested %d found %d",
                              str(self.origin), index, str(ref_file), self.min_ret, len(associated_jobs))
            return None

    def clone(self) -> 'AssociatedInputTask':
        return AssociatedInputTask(self.origin, self.destination, self.associated_input)
