import logging
import os
import time
from itertools import chain
from pathlib import Path
from typing import List, Set, Dict, Optional, Tuple
from uuid import UUID

from edps.client.FitsFile import FitsFile as ClientFitsFile
from edps.client.JobInfo import JobInfo
from edps.client.WorkflowStateDTO import WorkflowStateDTO
from edps.config.configuration import Configuration
from edps.interfaces.generation_result import GenerationResult
from . import job_factory
from .assoc_util import AssociationArgs
from .classif_rule import BaseClassificationRule
from .classifier import Classifier
from .edge import AlternativeAssociatedInputEdges
from .edge import MainInputEdge
from .fits import ClassifiedFitsFile, FitsFileFactory, FitsFile
from .job import Job, AssociationResult, AssociatedFitsFile
from .job_graph import JobInfoGraph
from .parameters import ParametersProvider
from .stateful_task import StatefulDataSource, StatefulTask, StatefulTaskBase, CleanupRequestArgs
from .task import TaskBase
from .task_interpreter import TaskInterpreter
from .workflow import Workflow
from ..interfaces.JobsRepository import JobDetails


class StatefulWorkflow:
    def __init__(self, workflow: Workflow, config: Configuration):
        self.logger = logging.getLogger('StatefulWorkflow')
        self.workflow = workflow
        self.node_cache: Dict[UUID, StatefulTaskBase] = {}
        self.stateful_nodes = self.create_stateful_nodes(config.associate_existing_jobs)
        self.stateful_tasks = [node for node in self.stateful_nodes if isinstance(node, StatefulTask)]
        self.stateful_data_sources = [node for node in self.stateful_nodes if isinstance(node, StatefulDataSource)]
        self.task_interpreters = self.create_task_interpreters(config.associate_incomplete_jobs)

    def get_interpreter_for_task(self, task_id: str) -> Optional[TaskInterpreter]:
        return self.task_interpreters.get(task_id, None)

    def get_keywords(self) -> Set[str]:
        return self.workflow.get_keywords()

    def get_grouping_keywords(self) -> Set[str]:
        return self.workflow.get_grouping_keywords()

    def get_classification_rules(self) -> List[BaseClassificationRule]:
        return self.workflow.classification_rules

    def create_stateful_nodes(self, associate_existing_jobs: bool) -> List[StatefulTaskBase]:
        stateful_nodes = []
        for task in self.workflow.topological_sort():
            stateful_nodes.append(self.create_stateful_node(task, associate_existing_jobs))
        return stateful_nodes

    def create_task_interpreters(self, associate_incomplete: bool) -> Dict[str, TaskInterpreter]:
        task_interpreters = {}
        for task in self.stateful_tasks:
            main_input_task = self.node_cache[task.main_input.identifier]
            main_input_edge = MainInputEdge(main_input_task, task)
            associated_input_edges: List[AlternativeAssociatedInputEdges] = []
            for group in task.associated_input_groups:
                edges = []
                for inp in group.associated_inputs:
                    assoc_input_task = self.node_cache[inp.input_task.identifier]
                    edges.append(assoc_input_task.as_associated_input_edge(task, inp))
                associated_input_edges.append(AlternativeAssociatedInputEdges(edges, group.sort_keys))
            task_interpreters[task.task_id] = TaskInterpreter(task, main_input_edge, associated_input_edges,
                                                              associate_incomplete)
        return task_interpreters

    def get_target_nodes(self, targets: List[str]) -> List[StatefulTaskBase]:
        self.logger.debug("[Workflow] getting target nodes for targets %s", targets)
        if len(targets) > 0:
            ancestors = self.workflow.get_ancestors(targets)
            nodes_to_handle = [n for n in self.stateful_nodes if (n.name in targets) or (n.task_base in ancestors)]
            self.logger.debug("[Workflow] nodes to process %s", nodes_to_handle)
            return nodes_to_handle
        self.logger.debug("[Workflow] nodes to process (all) %s", self.stateful_nodes)
        return self.stateful_nodes

    def create_stateful_node(self, task_base: TaskBase, associate_existing_jobs: bool) -> StatefulTaskBase:
        stateful_task = task_base.as_stateful_task(associate_existing_jobs)
        self.node_cache[task_base.identifier] = stateful_task
        return stateful_task

    def load_calib_db(self, classified_files: List[ClassifiedFitsFile], parameters_provider: ParametersProvider):
        for data_source in self.stateful_data_sources:
            data_source.reload_files(classified_files, parameters_provider)
        for task in self.stateful_tasks:
            task.recreate_calib_db(classified_files)

    def create_jobs(self, request_id: UUID, classified_files: List[ClassifiedFitsFile],
                    targets: List[str], assoc_args: AssociationArgs, parameters_provider: ParametersProvider,
                    meta_targets: Optional[List[str]]) -> Tuple[GenerationResult, List[str], UUID]:
        try:
            self.interpret(request_id, classified_files, targets, assoc_args, parameters_provider, meta_targets)
            jobs = self.get_unique_jobs(request_id, targets)
            result = self.create_generation_result(request_id, classified_files, targets, jobs)
            self.mark_jobs_as_future([job for job in jobs if job.is_complete])
        except Exception as e:
            # FIXME: maybe this should crash?
            self.logger.error("Failed to create jobs", exc_info=e)
            result = GenerationResult.empty()
        return result, targets, request_id

    def interpret(self, request_id: UUID, classified_files: List[ClassifiedFitsFile], targets: List[str],
                  assoc_args: AssociationArgs, parameters_provider: ParametersProvider,
                  meta_targets: Optional[List[str]]):
        target_nodes = self.get_target_nodes(targets)
        target_data_sources = [node for node in target_nodes if isinstance(node, StatefulDataSource)]
        target_tasks = [node for node in target_nodes if isinstance(node, StatefulTask)]
        for data_source in target_data_sources:
            data_source.interpret(request_id, classified_files, parameters_provider)
        for task in target_tasks:
            task.interpret(request_id, classified_files)
            self.task_interpreters[task.task_id].interpret(request_id, assoc_args, parameters_provider, meta_targets)
            task.log_status(request_id)

    def create_generation_result(self, request_id: UUID, classified_files: List[ClassifiedFitsFile],
                                 targets: List[str], jobs: List[Job]) -> GenerationResult:
        job_ids = [job.id for job in jobs]
        jobs_as_info = [job.as_job_info() for job in jobs if job.is_complete]
        calselector_jobs = [job.as_calselector_job() for job in jobs]
        files_with_no_jobs = self.get_files_with_no_jobs(jobs, classified_files)
        incomplete_groups = self.get_incomplete_groups_as_paths(request_id, targets)
        incomplete_jobs = self.get_incomplete_jobs(request_id, targets)
        self.logger.debug("create_jobs result %s job_ids %s", request_id, job_ids)
        self.logger.debug("create_jobs result %s files_with_no_jobs %s", request_id, files_with_no_jobs)
        self.logger.debug("create_jobs result %s incomplete_groups %s", request_id, incomplete_groups)
        self.logger.debug("create_jobs result %s incomplete_jobs %s", request_id, incomplete_jobs)
        fmt = "GenerationResult: %d complete jobs, %d incomplete jobs, %d incomplete groups, %d files with no jobs"
        self.logger.info(fmt, len(job_ids), len(incomplete_jobs), len(incomplete_groups), len(files_with_no_jobs))
        return GenerationResult(job_ids=job_ids, jobs_as_info=jobs_as_info,
                                files_with_no_jobs=files_with_no_jobs, incomplete_groups=incomplete_groups,
                                incomplete_jobs=incomplete_jobs, calselector_jobs=calselector_jobs)

    @staticmethod
    def mark_jobs_as_future(jobs):
        for job in jobs:
            job.set_future()

    @staticmethod
    def get_files_with_no_jobs(jobs: List[Job], all_files: List[ClassifiedFitsFile]) -> List[str]:
        files_with_job = [job.input_files for job in jobs]
        files_with_job = set(chain.from_iterable(files_with_job))
        all_files = set(all_files)
        return [file.get_path() for file in all_files.difference(files_with_job)]

    def cleanup(self, targets: List[str], cleanup_args: CleanupRequestArgs):
        for node in self.get_target_nodes(targets):
            node.cleanup(cleanup_args)

    def filter_tasks(self, targets: List[str] = None) -> List[StatefulTask]:
        if not targets:
            return self.stateful_tasks
        return [task for task in self.stateful_tasks if task.name in targets]

    def to_dot(self, request_id: UUID, targets: List[str] = None) -> str:
        result = []
        for task in self.filter_tasks(targets):
            for job in task.request_jobs[request_id]:
                result.append(job.to_dot())
        return '\n'.join(result)

    def get_unique_jobs(self, request_id: UUID, targets: List[str] = None) -> List[Job]:
        unique_jobs = set()
        for task in self.filter_tasks(targets):
            for job in task.current_request_complete_jobs(request_id) + task.current_request_incomplete_jobs(
                    request_id):
                unique_jobs.update(job.get_all_parents())
        return list(unique_jobs)

    def get_targets(self, requested_targets: Optional[List[str]], meta_targets: Optional[List[str]]) -> List[str]:
        return self.workflow.get_targets(requested_targets, meta_targets)

    def get_incomplete_groups(self, request_id: UUID, targets: List[str]) -> List[List[ClassifiedFitsFile]]:
        incomplete_groups = []
        data_sources = [n for n in self.get_target_nodes(targets) if isinstance(n, StatefulDataSource)]
        for data_source in data_sources:
            incomplete_groups.extend(data_source.get_incomplete_groups(request_id))
        return incomplete_groups

    def get_incomplete_groups_as_paths(self, request_id: UUID, targets: List[str]) -> List[List[str]]:
        result = []
        incomplete_groups = self.get_incomplete_groups(request_id, targets)
        for group in incomplete_groups:
            result.append([f.get_path() for f in group])
        return result

    def get_incomplete_jobs(self, request_id: UUID, targets: List[str]) -> List[JobInfo]:
        result = set()
        for task in self.filter_tasks(targets):
            for job in task.current_request_incomplete_jobs(request_id):
                result.update([parent.as_job_info() for parent in job.get_all_parents() if not parent.is_complete])
        return list(result)

    def remove_jobs(self, jobs: List[UUID]):
        for task in self.stateful_tasks:
            task.remove_jobs(jobs)

    def load_past_jobs(self, jobs: List[JobDetails], fits_factory: FitsFileFactory):
        jobs_to_load = self.remove_invalid_jobs(jobs)
        created_jobs = self.recreate_jobs(jobs_to_load, fits_factory)
        for job in created_jobs:
            stateful_task = self.get_interpreter_for_task(job.task_id).task
            stateful_task.historical_jobs[job] = job

    def remove_invalid_jobs(self, jobs: List[JobDetails]) -> List[JobDetails]:
        job_configs = [job.configuration for job in jobs]
        available_task_ids = [t.task_id for t in self.stateful_tasks]
        jobs_to_keep = {job for job in job_configs if
                        (all([os.path.exists(f.name) for f in job.input_files]) and job.task_id in available_task_ids)}
        jobs_to_remove = set(job_configs).difference(jobs_to_keep)
        job_ids_to_remove = {job.job_id for job in jobs_to_remove}
        while True:
            new_jobs_to_remove = {job for job in jobs_to_keep if
                                  set(job.input_job_ids + job.associated_job_ids).intersection(job_ids_to_remove)}
            if new_jobs_to_remove:
                job_ids_to_remove.update({job.job_id for job in new_jobs_to_remove})
                jobs_to_keep = jobs_to_keep.difference(new_jobs_to_remove)
            else:
                break
        if job_ids_to_remove:
            self.logger.warning("The following jobs have missing inputs or invalid task and will not be loaded: %s",
                                job_ids_to_remove)
        job_ids_to_keep = {job.job_id for job in jobs_to_keep}
        return [job for job in jobs if job.configuration.job_id in job_ids_to_keep]

    def recreate_jobs(self, jobs: List[JobDetails], fits_factory: FitsFileFactory) -> List[Job]:
        created_jobs = {}
        created_jobs_as_info = []
        start = time.time()
        for job in jobs:
            job_config = job.configuration
            task_interpreter = self.get_interpreter_for_task(job_config.task_id)
            if not task_interpreter:
                self.logger.error("Failed to recreate job %s. Unknown task %s", job_config.job_id, job_config.task_id)
                continue
            task_details = task_interpreter.task.task.as_task_details()
            input_files = self.recreate_input_files(job_config.input_files, fits_factory)
            created_job = job_factory.recreate_job(task_details, job, input_files)
            created_job.set_future()
            created_jobs[job_config.job_id] = created_job
            created_jobs_as_info.append(job_config)
        self.logger.info("Recreated %d jobs from DB in %.3f s", len(created_jobs), time.time() - start)
        self.create_links(created_jobs, created_jobs_as_info)
        self.fix_exemplars(created_jobs, created_jobs_as_info)
        return list(created_jobs.values())

    def recreate_input_files(self, input_files: List[ClientFitsFile],
                             fits_factory: FitsFileFactory) -> List[ClassifiedFitsFile]:
        result: List[ClassifiedFitsFile] = []
        if len(input_files) > 0:
            result.extend(Classifier.parse_and_classify(fits_factory, [Path(input_files[0].name)], self.get_keywords(),
                                                        self.get_classification_rules()))
            result.extend([ClassifiedFitsFile(FitsFile(f.name, {}), f.category, None) for f in input_files[1:]])
        return result

    def create_links(self, created_jobs: Dict[str, Job], jobs: List[JobInfo]):
        self.logger.info("Building links between jobs")
        for job in jobs:
            created_job = created_jobs[job.job_id]
            for input_job_id in job.input_job_ids:
                created_job.input_jobs.append(created_jobs[input_job_id])
            for association_details in job.association_details:
                associated_jobs = [created_jobs[j.name] for j in association_details.jobs]
                associated_files = [AssociatedFitsFile.from_name_cat_mjd(f.name, f.category, f.mjdobs) for f in
                                    association_details.files]
                result = AssociationResult(associated_files, associated_jobs, association_details.complete,
                                           association_details.optional, association_details.level,
                                           association_details.task_name)
                created_job.add_association_result(result)

    def fix_exemplars(self, created_jobs: Dict[str, Job], jobs: List[JobInfo]):
        self.logger.info("Fixing exemplars")
        for job in JobInfoGraph(jobs).topological_sort():
            created_job = created_jobs[job.job_id]
            if len(created_job.input_files) > 0:
                created_job.exemplar = job_factory.create_exemplar(created_job.input_files[0], job.mjdobs)
                created_job.product = job_factory.create_product(created_job.input_files[0])
            elif len(created_job.input_jobs) > 0 and created_job.input_jobs[0].product:
                created_job.exemplar = created_job.input_jobs[0].product
                created_job.product = created_job.input_jobs[0].product
            else:
                self.logger.error("THIS SHOULD NEVER HAPPEN - Creating dummy exemplar for job %s", created_job.id)
                dummy_exemplar = ClassifiedFitsFile(FitsFile("DUMMY_EXEMPLAR", {}, virtual=True), None, None)
                created_job.exemplar = dummy_exemplar
                created_job.product = dummy_exemplar

    def get_historical_job(self, job: JobInfo) -> Optional[Job]:
        for task in self.stateful_tasks:
            if task.task_id == job.task_id:
                return task.get_historical_job(job.job_id)

    def as_dto(self) -> WorkflowStateDTO:
        return WorkflowStateDTO(data_sources=[x.as_dto() for x in self.stateful_data_sources],
                                tasks=[x.as_dto() for x in self.stateful_tasks])

    def get_historical_jobs(self) -> List[Job]:
        return [job for task in self.stateful_tasks for job in task.historical_jobs.values()]
