import logging
import threading
from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor
from time import sleep
from typing import List, Dict, Set, Callable

from .task import Task
from .weight import get_children, GraphWeights


class JobsSubmitter:
    def order_and_submit(self, nodes: List[Task], executor: ThreadPoolExecutor, callback: str):
        raise NotImplementedError()


class StaticOrderSubmitter(JobsSubmitter):
    def __init__(self, ordering_function: Callable[[List[Task]], List[Task]]):
        self.logger = logging.getLogger("StaticOrderSubmitter")
        self.ordering_function = ordering_function

    def order_and_submit(self, nodes: List[Task], executor: ThreadPoolExecutor, callback: str):
        sorted_nodes = self.ordering_function(nodes)
        self.logger.debug("Submission of nodes for executing using static ordering: %s", sorted_nodes)
        for n in sorted_nodes:
            self.logger.debug("Submitting %s", n)
            executor.submit(n.execute, callback)


class DynamicSubmitter(JobsSubmitter):
    def __init__(self):
        self.logger = logging.getLogger("DynamicSubmitter")

    def order_and_submit(self, nodes: List[Task], executor: ThreadPoolExecutor, callback: str):
        self.logger.debug("Submission of nodes for executing using dynamic ordering: %s", nodes)
        threading.Thread(target=self.run, name="Dynamic scheduler thread", args=(nodes, executor, callback)).start()

    def run(self, nodes: List[Task], executor: ThreadPoolExecutor, callback: str):
        weights = GraphWeights(nodes)
        to_process = nodes
        while len(to_process) > 0:
            ready = [node for node in to_process if self.is_ready(node)]
            ready_ordered = sorted(ready, key=lambda task: weights.get_weight(task), reverse=True)
            for node in ready_ordered:
                executor.submit(node.execute, callback)
                to_process.remove(node)
            self.logger.debug("Dynamic ordering loop, tasks ready for execution: %s, still %s left", ready_ordered,
                              len(to_process))
            sleep(2)

    def is_ready(self, node: Task):
        return all([parent.ready.is_set() for parent in node.parents])


def topological_sort_dfs(nodes: List[Task]) -> List[Task]:
    """
    Orders nodes using DFS-like traversal. It will go as deep as possible (to leaf) before moving to parallel branch. This means if we have:
          A
        /   \
       B     C
      / \   / \
     D   E F   G
    The order will favour A->B->D->E before proceeding into C branch.

    The idea is to start at leaf node and work our way up recursively to find all parents and schedule them before the leaf.
    We want to process leaves in particular order: leaf from the same sub_graph as last one if possible.
    For that we need to split incoming graph into separate sub_graphs to order the leaves.
    :param nodes: list of nodes
    :return: list of nodes in topological order focusing on reaching leaf node as fast as possible
    """
    children = get_children(nodes)
    sub_graphs_leaves = [[node for node in sub_graph if not children[node]] for sub_graph in find_subgraphs(nodes)]
    leaves: List[Task] = sum(sub_graphs_leaves, [])
    result = []
    seen = set()

    to_process: List[Task] = list(leaves)
    while to_process:
        current = to_process[-1]
        seen.add(current)
        for m in current.parents:
            if m not in seen:
                to_process.append(m)
        if to_process[-1] == current:
            to_process.pop()
            if current not in result:
                result.append(current)
    return result


def topological_sort_bfs(nodes: List[Task]) -> List[Task]:
    """
    Orders nodes using BFS-like traversal. It will spread between all children branches evenly. This means if we have:
          A
        /   \
       B     C
      / \   / \
     D   E F   G
    The order will favour A->BC->DEFG

    The idea is to start from root nodes (with no parents), then proceed to nodes which depend only on those roots, then to nodes which depend on previous ones etc.
    Nodes on given level are ordered based on the size of the subtree starting at given node, placing "heavier" nodes to be scheduled first.
    Weight of the node comes from num_inputs/avg_num_inputs for given task type + weights of all it's children.
    :param nodes: list of nodes
    :return: list of nodes in topological order focusing on spreading over the tree
    """

    def is_ready(node: Task, seen: Set[Task]) -> bool:
        return all([parent in seen for parent in node.parents])

    result = []
    seen = set()
    weights = GraphWeights(nodes)

    while len(seen) < len(nodes):
        current_level = [node for node in nodes if node not in seen and is_ready(node, seen)]
        current_level = sorted(current_level, key=lambda task: weights.get_weight(task), reverse=True)
        for node in current_level:
            seen.add(node)
            result.append(node)
    return result


def topological_sort_bfs_action_type(nodes: List[Task]) -> List[Task]:
    """
    Orders nodes using BFS-like traversal but grouping nodes of the same "type". Eg if there are FLATs and DARKs on the same level,
    whole group will be processed together before moving to other group.
    This means if we have:
              A
          /       \
         B        C
      /    \    /    \
     DARK FLAT DARK FLAT
    The order will favour A->BC->DARKs->FLATs
    The idea is to start from root nodes (with no parents), then proceed to nodes which depend only on those roots, then to nodes which depend on previous ones etc, just like for BFS,
    but within given level to order the nodes by type.
    :param nodes: list of nodes
    :return: list of nodes in topological order focusing on processing same type together
    """

    def is_ready(node: Task, seen: Set[Task]) -> bool:
        return all([parent in seen for parent in node.parents])

    result = []
    seen = set()

    while len(seen) < len(nodes):
        current_level = sorted([node for node in nodes if node not in seen and is_ready(node, seen)],
                               key=lambda node: node.action.get_label())
        for node in current_level:
            seen.add(node)
            result.append(node)
    return result


def get_job_submitter(key: str) -> JobsSubmitter:
    return ordering[key]


ordering: Dict[str, JobsSubmitter] = {
    "dfs": StaticOrderSubmitter(topological_sort_dfs),
    "bfs": StaticOrderSubmitter(topological_sort_bfs),
    "type": StaticOrderSubmitter(topological_sort_bfs_action_type),
    "dynamic": DynamicSubmitter()
}


def find_subgraphs(nodes: List[Task]) -> List[List[Task]]:
    def bfs(node: Task, graph: Dict[Task, Set[Task]]):
        visited = {node}
        to_process = [node]
        while to_process:
            s = to_process.pop(0)
            for neighbour in graph[s]:
                if neighbour not in visited:
                    visited.add(neighbour)
                    to_process.append(neighbour)
        return visited

    graph = defaultdict(set)
    for node in nodes:
        for parent in node.parents:
            graph[parent].add(node)
            graph[node].add(parent)
    subgraphs = []
    visited = set()
    for node in nodes:
        if node not in visited:
            found_nodes = bfs(node, graph)
            visited.update(found_nodes)
            subgraphs.append(list(found_nodes))
    return subgraphs
