import logging
import math
from collections import defaultdict
from typing import List, Dict, Tuple, Callable, TypeVar

from astropy import units
from astropy.coordinates import SkyCoord

from edps.generator.constants import INF, RA, DEC
from edps.generator.fits import ClassifiedFitsFile, FitsFile
from edps.generator.parameters import Parameters

Groups = Dict[str, List[ClassifiedFitsFile]]
GroupingFunction = Callable[[Groups, Parameters], Groups]
T = TypeVar('T')


class Grouper:
    def split(self, groups: Groups, parameters: Parameters) -> Groups:
        raise NotImplementedError


class FunctionGrouper(Grouper):
    def __init__(self, function: GroupingFunction):
        self.function = function

    def split(self, groups: Groups, parameters: Parameters) -> Groups:
        return self.function(groups, parameters)


class KeywordGrouping(Grouper):
    def __init__(self, keywords: List[str]):
        self.keywords = keywords

    def split(self, groups: Groups, parameters: Parameters) -> Groups:
        resolved_keywords = parameters.resolve_keywords(self.keywords)
        all_groups = defaultdict(list)
        for base_group_key, group_files in groups.items():
            for file in group_files:
                all_groups[base_group_key + file.get_group_id(resolved_keywords)].append(file)
        return all_groups


class FilesCluster:
    def __init__(self, files: List[ClassifiedFitsFile] = None):
        self.files = files if files else []

    def try_add(self, potential_file: ClassifiedFitsFile,
                distance: Callable[[ClassifiedFitsFile, ClassifiedFitsFile, str], float],
                keyword: str,
                all_threshold: float,
                any_threshold: float) -> 'FilesCluster':
        if not self.files:
            return FilesCluster([potential_file])

        distances = [distance(potential_file, file, keyword) for file in self.files]
        if any([d <= any_threshold for d in distances]) and all([d <= all_threshold for d in distances]):
            return FilesCluster(self.files + [potential_file])
        else:
            return self


class ClusterGrouping(Grouper):
    def __init__(self, keyword: str, max_diameter: str, max_separation: str):
        self.max_diameter = max_diameter
        self.max_separation = max_separation
        self.keyword = keyword
        self.logger = logging.getLogger('ClusterGrouping')

    def split(self, groups: Groups, parameters: Parameters) -> Groups:
        resolved_keyword = parameters.resolve_keyword(self.keyword)
        resolved_all_threshold = self.get_all_threshold(parameters)
        resolved_any_threshold = self.get_any_threshold(parameters)
        return {cluster_name: cluster for name, files in groups.items()
                for cluster_name, cluster in self.split_group_into_clusters(name, files, resolved_keyword, resolved_all_threshold, resolved_any_threshold).items()}

    def split_group_into_clusters(self, group_name: str, files: List[ClassifiedFitsFile], keyword: str, all_threshold: float, any_threshold: float) -> Groups:
        clusters = {}
        remaining_files = files
        while remaining_files:
            center, cluster, remaining_files = self.find_cluster(remaining_files, keyword, all_threshold, any_threshold)
            clusters[f'cluster_{center.fitsfile.dp_id}_{group_name}'] = cluster
        return clusters

    def find_cluster(self, files: List[ClassifiedFitsFile], keyword: str, all_threshold: float, any_threshold: float) -> Tuple[
        ClassifiedFitsFile, List[ClassifiedFitsFile], List[ClassifiedFitsFile]]:
        median = self.find_median(files, keyword)
        cluster = self.find_neighbours(median, files, keyword, all_threshold, any_threshold)
        self.logger.debug("Median %s  matches cluster %s by keyword %s and all_threshold %s, any_threshold %s", self.repr_file(median, keyword),
                          [self.repr_file(f, keyword) for f in cluster], keyword, all_threshold, any_threshold)
        remaining = [file for file in files if file not in cluster]
        return median, cluster, remaining

    def find_neighbours(self, median: ClassifiedFitsFile, files: List[ClassifiedFitsFile], keyword: str, all_threshold: float, any_threshold: float) -> List[ClassifiedFitsFile]:
        ordered_files = sorted(files, key=lambda file: self.distance(median, file, keyword))
        new_cluster = FilesCluster()
        for file in ordered_files:
            new_cluster = self.try_add_file_to_cluster(new_cluster, file, keyword, all_threshold, any_threshold)
        return new_cluster.files

    def try_add_file_to_cluster(self, current_cluster: FilesCluster, file: ClassifiedFitsFile, keyword: str, all_threshold: float, any_threshold: float) -> FilesCluster:
        return current_cluster.try_add(file, self.distance, keyword, all_threshold, any_threshold)

    def repr_file(self, file: ClassifiedFitsFile, keyword: str) -> str:
        if keyword == 'SKY.POSITION':
            return f'{file.dp_id}({keyword}={file.get_keyword_value(RA, INF), file.get_keyword_value(DEC, INF)})'
        else:
            return f'{file.dp_id}({keyword}={file.get_keyword_value(keyword, INF)})'

    def find_median(self, data: List[T], keyword: str) -> T:
        return data[len(data) // 2]

    def get_all_threshold(self, parameters: Parameters) -> float:
        return self.get_threshold(parameters, self.max_diameter)

    def get_any_threshold(self, parameters: Parameters) -> float:
        return self.get_threshold(parameters, self.max_separation)

    def get_threshold(self, parameters: Parameters, parameter_name: str) -> float:
        try:
            resolved_threshold = parameters.get_workflow_param(parameter_name)
            return float(resolved_threshold) if resolved_threshold is not None else INF
        except TypeError as e:
            logging.getLogger("ClusterGrouping").error(f"Value of the {parameter_name} parameter should be numerical.", exc_info=e)
            return INF

    def distance(self, file1: ClassifiedFitsFile, file2: ClassifiedFitsFile, keyword: str) -> float:
        raise NotImplementedError


class SimpleClusterGrouping(ClusterGrouping):
    def __init__(self, keyword: str, max_diameter: str, max_separation: str):
        super().__init__(keyword, max_diameter, max_separation)

    def split_group_into_clusters(self, group_name: str, files: List[ClassifiedFitsFile], keyword: str, all_threshold: float, any_threshold: float) -> Groups:
        sorted_by_keyword = sorted(files, key=lambda x: x.get_keyword_value(keyword, INF))
        return super().split_group_into_clusters(group_name, sorted_by_keyword, keyword, all_threshold, any_threshold)

    def distance(self, file1: ClassifiedFitsFile, file2: ClassifiedFitsFile, keyword: str) -> float:
        try:
            return math.fabs(file1.get_keyword_value(keyword, INF) - file2.get_keyword_value(keyword, INF))
        except Exception as e:
            self.logger.warning(f"Failed to compute clustering distance. Verify if value of {keyword} keyword is numerical.", exc_info=e)
            return INF


class SkyPositionClusterGrouping(ClusterGrouping):
    def __init__(self, keyword: str, max_diameter: str, max_separation: str):
        super().__init__(keyword, max_diameter, max_separation)
        self.logger = logging.getLogger('SkyPositionClusterGrouping')

    def find_median(self, files: List[ClassifiedFitsFile], keyword: str) -> ClassifiedFitsFile:
        ras = sorted([file.get_keyword_value(RA, 0) for file in files])
        decs = sorted([file.get_keyword_value(DEC, 0) for file in files])
        ra_median = super().find_median(ras, keyword)
        dec_median = super().find_median(decs, keyword)
        self.logger.debug("Median file has RA %s and DEC %s", ra_median, dec_median)
        median_file = ClassifiedFitsFile(FitsFile("", {RA: ra_median, DEC: dec_median}, virtual=True),
                                         "Synthetic median file", None)
        distances = [(file, self.distance(median_file, file, keyword)) for file in files]
        file, _ = min(distances, key=lambda x: x[1])
        self.logger.debug("Closest median match is %s", file)
        return file

    def distance(self, file1: ClassifiedFitsFile, file2: ClassifiedFitsFile, keyword: str) -> float:
        ra1 = file1.get_keyword_value(RA, 0)
        dec1 = file1.get_keyword_value(DEC, 0)
        ra2 = file2.get_keyword_value(RA, 0)
        dec2 = file2.get_keyword_value(DEC, 0)

        c1 = SkyCoord(unit=units.deg, ra=ra1, dec=dec1)
        c2 = SkyCoord(unit=units.deg, ra=ra2, dec=dec2)
        sep = c1.separation(c2)

        return sep.degree
