import ast
import inspect
import logging
import sys
from typing import List, Set, Optional, Dict

import networkx as nx

from edps.client.GraphType import GraphType
from edps.client.WorkflowDTO import WorkflowDTO
from .assoc_map import AssocMap
from .classif_rule import BaseClassificationRule, DictionaryClassificationRule
from .constants import MJD_OBS, INSTRUME, PRO_CATG, RA, DEC, DPR_CATG, DPR_TECH, DPR_TYPE
from .graph import Graph
from .recipe_config import get_recipe_params, RecipeParam
from .task import DataSource, Task, TaskBase


class KeywordVisitor(ast.NodeVisitor):
    def __init__(self):
        self.keywords = set()

    # def visit_Attribute(self, node: ast.Attribute):
    #     if isinstance(node.value, ast.Name):
    #         self.keywords.add(node.attr)

    # Python 3.7
    def visit_Str(self, node: ast.Str):
        if isinstance(node.s, str) and node.s.islower():
            self.keywords.add(node.s)

    # Python 3.9
    # def visit_Constant(self, node: ast.Constant):
    #     if isinstance(node.value, str) and node.value.islower():
    #         self.keywords.add(node.value)


class Workflow:
    def __init__(self, classification_rules: List[BaseClassificationRule], data_sources: List[DataSource],
                 tasks: List[Task], title: str):
        self.logger = logging.getLogger('Workflow')
        self.classification_rules = classification_rules
        self.data_sources = data_sources
        self.tasks = tasks
        self.title = title
        self.keywords = self.get_keywords()
        self.graph = self.create_graph()
        self.logger.debug("keywords: %s", str(self.keywords))

    def validate(self):
        task_names = [t.name for t in self.tasks]
        unique_names = set()
        for name in task_names:
            if name in unique_names:
                raise RuntimeError(f"duplicate task name {name}")
            else:
                unique_names.add(name)

    def create_graph(self) -> nx.DiGraph:
        graph = nx.DiGraph()
        for task in self.tasks:
            graph.add_node(task, name=task.name)
            graph.add_edge(task.main_input, task, type='input')
            for group in task.associated_input_groups:
                for inp in group.associated_inputs:
                    graph.add_edge(inp.input_task, task, type='assoc')
        return graph

    def topological_sort(self):
        return nx.topological_sort(self.graph)

    @staticmethod
    def parse_keywords(source_file):
        with open(source_file) as f:
            source = f.read()
        tree = ast.parse(source)
        visitor = KeywordVisitor()
        visitor.visit(tree)
        return visitor.keywords

    def get_workflow_names(self) -> List[str]:
        workflow_names = set()
        for task in self.tasks:
            workflow_names.update(task.workflow_names)
        return list(workflow_names)

    def get_keywords_from_module(self) -> List[str]:
        workflow_names = self.get_workflow_names()
        if workflow_names:
            workflow_package = workflow_names[0].split('.')[0]
            keywords_module_name = f"{workflow_package}.{workflow_package}_keywords"
            keywords_module = sys.modules.get(keywords_module_name)
            if keywords_module:
                return [
                    value for name, value in inspect.getmembers(keywords_module)
                    if name not in ['__cached__', '__name__', '__file__', '__package__'] and isinstance(value, str)
                ]
        return []

    def get_keywords(self) -> Set[str]:
        keywords = set(self.get_keywords_from_module())
        for ds in self.data_sources:
            keywords.update(ds.grouping_keywords)
            keywords.update(ds.setup_keywords)
            keywords.update(ds.match_keywords)
        for cr in self.classification_rules:
            if isinstance(cr, DictionaryClassificationRule):
                keywords.update(cr.keyword_values.keys())
        keywords.update([MJD_OBS, PRO_CATG, INSTRUME, RA, DEC, DPR_CATG, DPR_TECH, DPR_TYPE])
        return keywords

    def get_grouping_keywords(self) -> Set[str]:
        return {kw for ds in self.data_sources for kw in ds.grouping_keywords}

    def expand_meta_targets(self, meta_targets: Optional[List[str]]) -> List[str]:
        meta_targets = set(meta_targets) if meta_targets is not None else set()
        return [task.name for task in self.tasks if task.has_meta_targets(meta_targets)]

    def get_valid_targets(self, targets: Optional[List[str]]) -> List[str]:
        targets = targets if targets is not None else []
        task_names = [t.name for t in self.tasks]
        return [t for t in targets if t in task_names]

    def get_targets(self, targets: Optional[List[str]], meta_targets: Optional[List[str]]) -> List[str]:
        expanded_targets = self.expand_meta_targets(meta_targets)
        valid_targets = self.get_valid_targets(targets)
        return list(set(expanded_targets + valid_targets))

    def get_ancestors(self, targets: List[str]) -> Set[TaskBase]:
        target_nodes = [n for n in self.graph.nodes if n.name in targets]
        ancestors = set()
        for node in target_nodes:
            ancestors.update(nx.ancestors(self.graph, node))
        ancestor_names = [n.name for n in ancestors]
        self.logger.debug("[Workflow] all ancestors for targets %s: %s", targets, ancestor_names)
        return ancestors

    def dump(self):
        for node in list(self.graph.nodes()):
            print(vars(node))
        for edge, attr in list(self.graph.edges.items()):
            print('{}->{} (type={})'.format(edge[0].name, edge[1].name, attr['type']))

    def generate_assoc_map(self) -> str:
        return AssocMap(self.graph).as_markdown()

    def generate_graph(self, graph_type: GraphType) -> str:
        graph = Graph(self.graph, self.title)
        return graph.simple_graph() if graph_type == GraphType.SIMPLE else graph.detailed_graph()

    def get_default_params(self, esorex_path: str, task_name: str) -> Dict[str, object]:
        selected_tasks = [task for task in self.tasks if task.name == task_name]
        if selected_tasks:
            task = selected_tasks[0]
            recipe_parameters: List[RecipeParam] = []
            for recipe in task.command_recipes:
                recipe_parameters.extend(get_recipe_params(esorex_path, recipe))
            return {parameter.name: parameter.value for parameter in recipe_parameters}
        else:
            self.logger.warning(
                "Task '%s' was not found, available tasks: '%s', unable to get default recipe parameters",
                task_name, [task.name for task in self.tasks])
            raise FileNotFoundError("Task {} not found".format(task_name))

    def as_dto(self) -> WorkflowDTO:
        return WorkflowDTO(keywords=list(self.get_keywords()),
                           classification_rules=[x.as_dto() for x in self.classification_rules],
                           data_sources=[x.as_dto() for x in self.data_sources],
                           tasks=[x.as_dto() for x in self.tasks])

    def __str__(self):
        return str(list(self.graph.edges.items()))
