from typing import List

import networkx as nx

from .task import TaskBase, Task, DataSource


class AssocMap:
    Matrix = List[List[str]]

    def __init__(self, graph: nx.DiGraph):
        self.graph = graph
        self.tasks = [node for node in nx.topological_sort(self.graph) if isinstance(node, Task)]
        self.assoc_inputs = [node for node in nx.topological_sort(self.graph) if
                             isinstance(node, DataSource) and node.is_assoc_input]
        self.assoc_map = self.create_assoc_map()

    def create_assoc_map(self) -> Matrix:
        header = [[self.format_label("Data source")] + [t.main_input.name for t in self.tasks],
                  [self.format_label("Classification")] + [self.classification_link(t) for t in self.tasks],
                  [self.format_label("Setup")] + [t.main_input.setup_keywords_str for t in self.tasks],
                  [self.format_label("Grouping")] + [t.main_input.grouping_keywords_str for t in self.tasks],
                  [self.format_label("Association")] + [self.association_link(t) for t in self.tasks],
                  [self.format_label("Task")] + [t.name for t in self.tasks],
                  [self.format_label("Recipe")] + [t.command for t in self.tasks],
                  ["<br>"] * (len(self.tasks) + 1),
                  [self.format_label("Static calibrations")] + [""] * len(self.tasks)]
        master_calib_header = [["", self.format_label("Master calibrations")] + [""] * (len(self.tasks) - 1)]
        task2task_matrix = self.create_task2task_matrix()
        calib2task_matrix = self.create_calib2task_matrix()
        return header + calib2task_matrix + master_calib_header + task2task_matrix

    def create_task2task_matrix(self) -> Matrix:
        matrix = self.init_task2task_matrix()
        return self.create_associations(self.tasks, self.tasks, matrix)

    def create_calib2task_matrix(self) -> Matrix:
        matrix = self.init_calib2task_matrix()
        return self.create_associations(self.assoc_inputs, self.tasks, matrix)

    def init_task2task_matrix(self) -> Matrix:
        line = [''] * (len(self.tasks) + 1)
        matrix = [line[:] for _ in range(len(self.tasks))]
        for num, task in enumerate(self.tasks):
            default_pro_catg = "UNDEFINED_PRODUCT"
            pro_catg = " ".join(task.categories)
            matrix[num][num + 1] = pro_catg if pro_catg != "" else default_pro_catg
        return matrix

    def init_calib2task_matrix(self) -> Matrix:
        return [[assoc_input.name] + [''] * len(self.tasks) for assoc_input in self.assoc_inputs]

    def create_associations(self, assoc_inputs: List[TaskBase], tasks: List[Task], matrix) -> Matrix:
        for row_num, assoc_input in enumerate(assoc_inputs):
            for col_num, task in enumerate(tasks):
                if task in self.graph.successors(assoc_input):
                    matrix[row_num][col_num + 1] = 'x'
        return matrix

    def as_csv(self) -> str:
        lines = [",".join(line) for line in self.assoc_map]
        return "\n".join(lines)

    def as_markdown(self) -> str:
        sep = "|"
        header = "## Association map"
        lines = [sep + sep.join(line) + sep for line in self.assoc_map]
        line1 = sep + ":---" + sep + sep.join([":---:"] * len(self.tasks)) + sep
        lines = [header, lines[0], line1] + lines[1:]
        lines.append("## Classification rules")
        for task in self.tasks:
            for classification, condition in task.main_input.classification_dict.items():
                lines.append(self.classification_rule(classification, condition))
        lines.append("## Match rules")
        for task in self.tasks:
            if task.assoc_configs:
                lines.append(self.match_rule(task))
        return "\n".join(lines)

    @staticmethod
    def format_label(text: str) -> str:
        return f'**{text}**'

    @staticmethod
    def classification_link(task: Task) -> str:
        return ' '.join([f'[{tag}](#{tag})' for tag in task.main_input.classification_dict])

    @staticmethod
    def association_link(task: Task) -> str:
        return f'[match_rule](#match_{task.main_input.name})' if task.assoc_configs else ''

    @staticmethod
    def classification_rule(classification: str, condition: str) -> str:
        return f'\n### {classification}\n```\n{condition}\n```\n[TOP](#association-map)\n'

    @staticmethod
    def match_rule(task: Task) -> str:
        return f'\n### match_{task.main_input.name}\n```\n{task.assoc_config_str}\n```\n[TOP](#association-map)\n'

    @staticmethod
    def get_function_name(text: str) -> str:
        items = text.split()
        if len(items) and items[0] == "def":
            return items[1].split("(")[0]
        return ""

    @staticmethod
    def format_link_to_function(text: str) -> str:
        function_name = AssocMap.get_function_name(text)
        if len(function_name) > 0:
            return f'[{function_name}](#{function_name})'
        return ""

    @staticmethod
    def format_function_source(text: str) -> str:
        function_name = AssocMap.get_function_name(text)
        if len(function_name) > 0:
            return f'\n### {function_name}\n```\n{text}\n```\n[TOP](#association-map)\n'
        return ""
