import itertools
import logging
from collections import defaultdict
from typing import List, Dict, Iterator

from .task import Task


def get_children(nodes: List[Task]) -> Dict[Task, List[Task]]:
    children = defaultdict(list)
    for node in nodes:
        for parent in node.parents:
            children[parent].append(node)
    return children


def compute_averages(nodes: List[Task]) -> Dict[str, float]:
    """
    Compute average number of inputs per node type
    :param nodes: list of graph nodes
    :return: mapping between node label and average number of inputs for such nodes
    """
    groups = itertools.groupby(sorted(nodes, key=lambda x: x.action.get_label()), key=lambda x: x.action.get_label())
    return {label: avg(nodes_group) for label, nodes_group in groups}


def avg(nodes: Iterator[Task]) -> float:
    nodes = list(nodes)
    inputs = sum([len(n.action.input_files) for n in nodes])
    if not nodes or not inputs:
        return 1
    return float(inputs) / (len(list(nodes)))


class GraphWeights:
    def __init__(self, nodes: List[Task]):
        self.logger = logging.getLogger('GraphWeights')
        self.composite_weights: Dict[Task, float] = {}
        self.children = get_children(nodes)
        self.averages = compute_averages(nodes)
        self.simple_weights = defaultdict(float)
        for node in nodes:
            current_weight = len(node.action.input_files) / self.averages[node.action.get_label()]
            self.simple_weights[node] = current_weight

    def get_weight(self, task: Task) -> float:
        if task not in self.composite_weights:
            self.composite_weights[task] = self.compute_weight(task)
        return self.composite_weights[task]

    def compute_weight(self, task):
        children_weights = [self.get_weight(child) for child in self.children[task]]
        total_weight = self.simple_weights[task] + sum(children_weights)
        self.logger.debug(
            "Computing weight for task %s. Node has %s inputs, average is %s, isolated node weight is: %s, children weights are: %s, total weight: %s",
            task, len(task.action.input_files), self.averages[task.action.get_label()], self.simple_weights[task], children_weights, total_weight)
        return total_weight
