import os.path
from collections import defaultdict
from dataclasses import dataclass
from typing import Union, Set, List, Tuple, Dict

import networkx as nx
from jinja2 import Environment, FileSystemLoader

from .constants import TRUE_CONDITION
from .meta_targets import QC1_CALIB, SCIENCE
from .task import Task, DataSource, AssociatedInputGroup
from .task_details import CommandType

SIMPLE_GRAPH_TEMPLATE = "/templates/simple_graph.dot"
DETAILED_GRAPH_TEMPLATE = "/templates/detailed_graph.dot"
BLUE = 'deepskyblue3'
GREEN = 'green'
WHITE = 'white'


@dataclass(frozen=True)
class AlternativeAssociatedInputs:
    input_names: List[str]
    num_inputs: int
    is_conditional: bool
    is_optional: bool

    @classmethod
    def from_associated_input_group(cls, group: AssociatedInputGroup):
        return cls(input_names=[inp.input_task.name for inp in group.associated_inputs],
                   num_inputs=len(group.associated_inputs),
                   is_conditional=any([inp.condition != TRUE_CONDITION for inp in group.associated_inputs]),
                   is_optional=all([inp.min_ret == 0 for inp in group.associated_inputs]))


class Graph:
    def __init__(self, graph: nx.DiGraph, title: str):
        self.graph = graph
        self.tasks = [node for node in nx.topological_sort(self.graph) if isinstance(node, Task)]
        self.data_sources = [node for node in nx.topological_sort(self.graph) if isinstance(node, DataSource)]
        self.workflow_name = title
        self.subworkflow_tasks = defaultdict(list)
        self.subworkflow_metatargets = defaultdict(set)
        for task in self.tasks:
            subworkflow = task.subworkflow_name[0] if task.subworkflow_name else "main"
            self.subworkflow_tasks[subworkflow].append(task)
            self.subworkflow_metatargets[subworkflow].update(task.meta_targets)
        self.environment = Environment(loader=FileSystemLoader(os.path.dirname(__file__)),
                                       trim_blocks=True, lstrip_blocks=True)

    @staticmethod
    def get_name_or_subworkflow(node: Union[Task, DataSource]) -> str:
        if isinstance(node, Task) and node.subworkflow_name:
            return node.subworkflow_name[0]
        else:
            return node.name

    @property
    def main_workflow_edges(self) -> List[Tuple[str, str]]:
        edges = []
        for source, target in self.graph.edges:
            source_name = self.get_name_or_subworkflow(source)
            target_name = self.get_name_or_subworkflow(target)
            if isinstance(source, Task) and source_name != target_name:
                edges.append((source_name, target_name))
        return edges

    @property
    def subworkflow_edges(self) -> Dict[str, List[Tuple[Task, Task]]]:
        edges = defaultdict(list)
        for subw, tasks in self.subworkflow_tasks.items():
            for source, target in self.graph.edges:
                if source in tasks and target in tasks:
                    edges[subw].append((source, target))
        return edges

    @staticmethod
    def get_color_for_meta_targets(meta_targets: Set[str]) -> str:
        return BLUE if SCIENCE in meta_targets else GREEN if QC1_CALIB in meta_targets else WHITE

    def is_subworkflow(self, task: str) -> bool:
        return task in self.subworkflow_tasks

    def detailed_graph(self) -> str:
        template = self.environment.get_template(DETAILED_GRAPH_TEMPLATE)
        task_inputs = {}
        task_colors = {}
        task_commands = {}
        for task in self.tasks:
            task_inputs[task.name] = [AlternativeAssociatedInputs.from_associated_input_group(group)
                                      for group in task.associated_input_groups]
            task_colors[task.name] = self.get_color_for_meta_targets(task.meta_targets)
            if task.command_type == CommandType.FUNCTION:
                task_commands[task.name] = 'function.' + task.command.split('.')[-1]
            else:
                task_commands[task.name] = task.command

        subworkflow_colors = {subw: self.get_color_for_meta_targets(meta_targets) for subw, meta_targets in
                              self.subworkflow_metatargets.items()}
        colors = [
            "#0000ff",
            "#ff0000",
            "#00ff00",
            "#ff8d6a",
            "#5e2903",
            "#00ffff",
            "#9904ff",
            "#007358",
            "#e4de01",
            "#ea09ff",
            "#ac4552",
            "#2fabc1",
            "#cd9763",
            "#04ffb0",
            "#bcadff",
        ]

        main_workflow_parent_nodes = list(dict.fromkeys([edge[0] for edge in self.main_workflow_edges]))
        main_workflow_colors = {node: colors[i % len(colors)] for i, node in enumerate(main_workflow_parent_nodes)}
        subworkflow_task_colors = {}
        for name, color in main_workflow_colors.items():
            if self.is_subworkflow(name):
                subworkflow_task_colors.update({task.name: color for task in self.subworkflow_tasks[name]})
        main_workflow_colors.update(subworkflow_task_colors)

        subworkflow_edge_colors = {}
        for subw, edges in self.subworkflow_edges.items():
            subworkflow_edge_colors[subw] = {edge[0].name: colors[i % len(colors)] for i, edge in enumerate(edges)}

        context = {
            "workflow": self.workflow_name,
            "tasks": self.subworkflow_tasks["main"],
            "task_colors": task_colors,
            "task_commands": task_commands,
            "subworkflow_colors": subworkflow_colors,
            "subworkflow_tasks": {k: v for k, v in self.subworkflow_tasks.items() if k != "main"},
            "task_inputs": task_inputs,
            "main_workflow_edges": self.main_workflow_edges,
            "subworkflow_edges": self.subworkflow_edges,
            "main_workflow_colors": main_workflow_colors,
            "subworkflow_edge_colors": subworkflow_edge_colors
        }
        return template.render(context)

    def simple_graph(self) -> str:
        template = self.environment.get_template(SIMPLE_GRAPH_TEMPLATE)
        raw_types = {ds.name for ds in self.data_sources if ds.is_main_input}
        tasks = {t.name for t in self.tasks if not t.subworkflow_name}
        subworkflows = {t.subworkflow_name[0] for t in self.tasks if t.subworkflow_name}
        edges = [(self.get_name_or_subworkflow(n0), self.get_name_or_subworkflow(n1))
                 for n0, n1 in self.graph.edges if n0 in self.tasks or n0 == n1.main_input]
        edges = [(e0, e1) for e0, e1 in edges if e0 != e1]
        colors = [1, 11, 2, 10, 3, 9, 4, 8, 5, 7]
        edge_colors = {}
        for i, (n0, n1) in enumerate(edges):
            if n0 in raw_types:
                edge_colors[n0] = colors[i % len(colors)]
        context = {
            "workflow": self.workflow_name,
            "raw_types": raw_types,
            "tasks": tasks,
            "subworkflows": subworkflows,
            "edges": edges,
            "edge_colors": edge_colors
        }
        return template.render(context)
