import logging
from typing import Optional, Callable
from uuid import UUID

from .assoc_util import ReferenceFile, AssociationArgs, AssociationPreference
from .constants import ASSOCIATION_THRESHOLD, INF
from .fits import ClassifiedFitsFile
from .job import AssociationResult
from .. import JobParameters

AssociationFunction = Callable[[UUID, ReferenceFile, int, JobParameters, bool], Optional[AssociationResult]]

logger = logging.getLogger("AssociationStrategy")


class AssociationStrategy:
    def __init__(self, associated_input):
        from .edge import AssociatedInputTask
        self.associated_input: AssociatedInputTask = associated_input
        self.assoc_configs = self.associated_input.assoc_configs

    def get_association_result(self, request_id: UUID, ref_file: ClassifiedFitsFile,
                               parameters: JobParameters, only_complete: bool) -> Optional[AssociationResult]:
        raise NotImplementedError

    def _get_associated_files_for_quality_level(self, request_id: UUID, ref_file: ReferenceFile, index: int,
                                                parameters: JobParameters, only_complete: bool) -> Optional[AssociationResult]:
        return self.associated_input.get_closest_associated_files_for_level(request_id, ref_file, index, parameters)

    def _get_associated_jobs_for_quality_level(self, request_id: UUID, ref_file: ReferenceFile, index: int,
                                               parameters: JobParameters, only_complete: bool) -> Optional[AssociationResult]:
        return self.associated_input.get_closest_associated_jobs_for_level(request_id, ref_file, index, parameters, only_complete)

    def _create_association_result(self, partial_result: AssociationResult, assoc_level: float) -> AssociationResult:
        return self.associated_input.create_association_result(partial_result, assoc_level)

    def _find_best_quality_associations(self, request_id: UUID, exemplar: ClassifiedFitsFile,
                                        get_associations_for_level: AssociationFunction,
                                        parameters: JobParameters, only_complete: bool) -> Optional[AssociationResult]:
        for index, assoc_config in enumerate(self.assoc_configs):
            ref_file = ReferenceFile(assoc_config, exemplar)
            if assoc_config.level <= get_assoc_threshold(parameters):
                result = get_associations_for_level(request_id, ref_file, index, parameters, only_complete)
                if result:
                    return self._create_association_result(result, assoc_config.level)
        return None


class PreferRawStrategy(AssociationStrategy):
    def get_association_result(self, req_id: UUID, ref_file: ClassifiedFitsFile, parameters: JobParameters, only_complete: bool) -> Optional[AssociationResult]:
        return (self._find_best_quality_associations(req_id, ref_file, self._get_associated_jobs_for_quality_level, parameters, only_complete) or
                self._find_best_quality_associations(req_id, ref_file, self._get_associated_files_for_quality_level, parameters, only_complete))


class PreferMasterStrategy(AssociationStrategy):
    def get_association_result(self, req_id: UUID, ref_file: ClassifiedFitsFile, parameters: JobParameters, only_complete: bool) -> Optional[AssociationResult]:
        return (self._find_best_quality_associations(req_id, ref_file, self._get_associated_files_for_quality_level, parameters, only_complete) or
                self._find_best_quality_associations(req_id, ref_file, self._get_associated_jobs_for_quality_level, parameters, only_complete))


class PerLevelAssociationStrategy(AssociationStrategy):
    def get_association_result(self, request_id: UUID, ref_file: ClassifiedFitsFile, parameters: JobParameters, only_complete: bool) -> Optional[AssociationResult]:
        return self._find_best_quality_associations(request_id, ref_file, self._get_associations_for_quality_level, parameters, only_complete)

    def _get_associations_for_quality_level(self, request_id: UUID, ref_file: ReferenceFile, index: int,
                                            parameters: JobParameters, only_complete: bool) -> Optional[AssociationResult]:
        raise NotImplementedError


class PreferRawPerLevelStrategy(PerLevelAssociationStrategy):
    def _get_associations_for_quality_level(self, request_id: UUID, ref_file: ReferenceFile, index: int,
                                            parameters: JobParameters, only_complete: bool) -> Optional[AssociationResult]:
        return (self._get_associated_jobs_for_quality_level(request_id, ref_file, index, parameters, only_complete) or
                self._get_associated_files_for_quality_level(request_id, ref_file, index, parameters, only_complete))


class PreferMasterPerLevelStrategy(PerLevelAssociationStrategy):
    def _get_associations_for_quality_level(self, request_id: UUID, ref_file: ReferenceFile, index: int,
                                            parameters: JobParameters, only_complete: bool) -> Optional[AssociationResult]:
        return (self._get_associated_files_for_quality_level(request_id, ref_file, index, parameters, only_complete) or
                self._get_associated_jobs_for_quality_level(request_id, ref_file, index, parameters, only_complete))


def get_association_strategy(assoc_args: AssociationArgs, associated_input) -> AssociationStrategy:
    if assoc_args.preference == AssociationPreference.RAW:
        return PreferRawStrategy(associated_input)
    elif assoc_args.preference == AssociationPreference.MASTER:
        return PreferMasterStrategy(associated_input)
    elif assoc_args.preference == AssociationPreference.RAW_PER_QUALITY_LEVEL:
        return PreferRawPerLevelStrategy(associated_input)
    elif assoc_args.preference == AssociationPreference.MASTER_PER_QUALITY_LEVEL:
        return PreferMasterPerLevelStrategy(associated_input)
    else:
        raise NotImplementedError(f"Association Strategy for '{assoc_args.preference}' not implemented")


def get_assoc_threshold(parameters: JobParameters) -> float:
    try:
        return float(parameters.get_workflow_param(ASSOCIATION_THRESHOLD, INF))
    except TypeError as e:
        logger.error(f"Value of the {ASSOCIATION_THRESHOLD} parameter should be numerical.", exc_info=e)
        return INF
