import dataclasses
import inspect
import logging
from dataclasses import dataclass, field
from logging import Logger
from typing import List, Callable, Optional, Dict, Set

from edps.client.WorkflowDTO import AssociationConfigurationDTO
from edps.utils import is_lambda
from .association_breakpoints import ANY
from .fits import ClassifiedFitsFile
from .time_range import RelativeTimeRange

MatchFunction = Callable[[ClassifiedFitsFile, ClassifiedFitsFile], bool]
MatchKeywords = List[str]


def get_association_configuration_logger() -> Logger:
    return logging.getLogger('AssociationConfiguration')


@dataclass
class AssociationConfiguration:
    level: float
    match_function: MatchFunction
    match_keywords: MatchKeywords
    time_range: RelativeTimeRange
    logger: Logger = field(default_factory=get_association_configuration_logger)

    def as_str(self) -> str:
        return f"level: {self.level}, time range: {self.time_range}\n{self.get_text()}"

    def get_text(self) -> str:
        return f"match({self.match_keywords})" if self.match_keywords else inspect.getsource(self.match_function)

    def get_match_function_name(self) -> Optional[str]:
        return self.match_function.__name__ if not is_lambda(self.match_function) else None

    def are_associated(self, ref: ClassifiedFitsFile, f: ClassifiedFitsFile) -> bool:
        try:
            return self.time_range.is_within_range(ref, f) and self.match_function(ref, f)
        except Exception as e:
            self.logger.debug("association exception ref %s f %s : %s", ref.get_path(), f.get_path(), e, exc_info=e)
            return False

    def as_dto(self) -> AssociationConfigurationDTO:
        return AssociationConfigurationDTO(level=self.level, time_range=self.time_range.as_dto(),
                                           match_function=self.get_match_function_name(),
                                           match_keywords=self.match_keywords)

    def clone_with_time_range(self, time_range: RelativeTimeRange) -> 'AssociationConfiguration':
        return dataclasses.replace(self, time_range=time_range)

    def apply_breakpoints(self, breakpoints_by_category: Dict[str, List[float]], categories: Set[str],
                          mjd_obs: float) -> 'AssociationConfiguration':
        all_breakpoints = sum([breakpoints_by_category.get(category, []) for category in categories.union({ANY})], [])
        left = sorted([bp for bp in all_breakpoints if bp <= mjd_obs and abs(bp - mjd_obs) < self.time_range.left])
        right = sorted([bp for bp in all_breakpoints if bp >= mjd_obs and abs(bp - mjd_obs) < self.time_range.right])
        new_left = (left[-1] - mjd_obs) if left else self.time_range.left
        new_right = (right[0] - mjd_obs) if right else self.time_range.right
        new_time_range = RelativeTimeRange(new_left, new_right)
        return self.clone_with_time_range(new_time_range)
