import logging
import os.path
import threading
import weakref
from concurrent.futures.thread import ThreadPoolExecutor
from pathlib import PurePath
from typing import List, Dict, Set, Optional
from uuid import UUID

from edps.client.JobInfo import JobInfo
from edps.config.configuration import Configuration
from edps.domain.dataset import Datasets
from edps.interfaces.JobExecutor import JobExecutor
from edps.interfaces.JobsRepository import JobsRepository, JobNotFoundError
from . import ordering
from .action import File, FutureAction, Command, FileProcessingAction
from .breakpoint import BreakpointManager
from .cli_command import CLICommand
from .cloner import JobCloner
from .command import DummyCommand
from .constants import ESOREX_INPUT, ERROR_LOG, README_RE_RUN
from .dataset_packager import DatasetPackagerFactory, DatasetPackager
from .filtering import FilterFactory, ProductFilter
from .ordering import JobsSubmitter
from .recipe import InvokerProvider
from .renamer import PrefixProductRenamer, NopProductRenamer, ProductRenamer
from .reporting import ReportingScript, ReportGenerator
from .task import Task
from ..client.ReportsConfiguration import ReportsConfiguration, ALL_REPORTS
from ..client.RunReportsRequestDTO import RunReportsRequestDTO
from ..generator.constants import ParameterSetName


class LocalJobExecutor(JobExecutor):

    def __init__(self, config: Configuration,
                 repository: JobsRepository,
                 invoker_provider: InvokerProvider,
                 breakpoint_manager: BreakpointManager,
                 dataset_packager_factory: DatasetPackagerFactory):
        self.logger = logging.getLogger('LocalCascadeExecutor')
        self.genreport_path = config.genreport_path
        self.base_dir = config.base_dir
        self.package_base_dir = config.package_base_dir
        self.dummy = config.dummy
        self.continue_on_error = config.continue_on_error
        self.repository = repository
        self.invoker_provider = invoker_provider
        self.executor_pool = ThreadPoolExecutor(max_workers=config.processes)
        self.submitter: JobsSubmitter = ordering.get_job_submitter(config.ordering)
        self.output_prefix = config.output_prefix
        self.breakpoint_manager = breakpoint_manager
        self.job_cloner = JobCloner(self.repository, self.base_dir)
        self.dataset_packager_factory = dataset_packager_factory
        self.scheduled_jobs = weakref.WeakSet()
        self.reexecution_window = config.reexecution_window
        self.renamers: Dict[str, ProductRenamer] = {}
        self.save_report_inputs = config.save_report_inputs

    def execute(self, job_configurations: List[JobInfo], submission_date: str, package_base_dir: str, callback_url: str, targets: Optional[List[str]],
                renaming_prefix: Optional[str], reports_configuration: ReportsConfiguration, parameter_set_name: Optional[ParameterSetName]):
        try:
            execution_tasks = self.load_jobs_for_execution(job_configurations, submission_date, package_base_dir or self.package_base_dir, targets or [],
                                                           renaming_prefix, reports_configuration, parameter_set_name)
            self.scheduled_jobs.update(execution_tasks)
            self.submitter.order_and_submit(execution_tasks, self.executor_pool, callback_url)
        except Exception as e:
            # FIXME: If this happens, it will deadlock executions, so maybe we should fail all new jobs?
            self.logger.error("Failed to build tasks for execution", exc_info=e)

    def get_job_file(self, job_id: UUID, filename: str) -> bytes:
        file_path = self.get_abs_filepath(job_id, filename)
        with open(file_path, "rb") as data_file:
            return data_file.read()

    def any_scheduled(self, jobs_to_delete: Set[UUID]) -> bool:
        """
        WeakSet scheduled_jobs holds weak references to all tasks submitted to the executor.
        Tasks are referenced only by each other and by the thread pool executor queue.
        Once the task execution completes, all strong references are gone, and it will automatically disappear from the set.
        This way we can easily track all tasks which are still waiting to be executed.

        :param jobs_to_delete: IDs of jobs we're trying to delete
        :return: True if any of the listed jobs are still in the execution queue
        """
        self.logger.debug("Currently scheduled jobs: %s, trying to delete %s", self.scheduled_jobs, jobs_to_delete)
        return bool({task.job_id for task in self.scheduled_jobs.copy()} & jobs_to_delete)

    def get_scheduled_jobs(self) -> List[UUID]:
        return [task.job_id for task in self.scheduled_jobs.copy()]

    def get_job_dir(self, job_id: UUID) -> Optional[str]:
        job_details = self.repository.get_job_details(job_id)
        task_dir = os.path.join(self.base_dir, job_details.configuration.instrument, job_details.configuration.task_name)
        job_dir = os.path.join(task_dir, str(job_id))
        if os.path.exists(os.path.join(job_dir, ESOREX_INPUT)) or os.path.exists(os.path.join(job_dir, ERROR_LOG)):
            return os.path.abspath(job_dir)
        elif job_details.result.output_files:
            equivalent_job_id = UUID(PurePath(os.path.relpath(job_details.result.output_files[0].name, task_dir)).parts[0])
            return self.get_job_dir(equivalent_job_id)
        elif os.path.exists(os.path.join(job_dir, README_RE_RUN)):
            # FIXME: we need a better way to indicate the re-used job ID
            with open(os.path.join(job_dir, README_RE_RUN), 'r') as f:
                content = f.read()
                equivalent_job_id = UUID(content[content.index("'") + 1:-1])
                return self.get_job_dir(equivalent_job_id)
        else:
            return None

    def get_abs_filepath(self, job_id: UUID, filename: str) -> str:
        job_dir = self.get_job_dir(job_id)
        if job_dir:
            filepath = os.path.join(job_dir, filename)
            if os.path.exists(filepath):
                return os.path.abspath(filepath)
            else:
                raise FileNotFoundError(filepath)
        raise FileNotFoundError(job_id)

    def run_reports(self, request: RunReportsRequestDTO) -> List[str]:
        self.logger.info("Re-Running reports %s for jobs %s", request.report_types, request.job_ids)
        reported_jobs = []
        for job_id in request.job_ids:
            job_id = UUID(job_id)
            job_dir = self.get_job_dir(job_id)
            if job_dir:
                try:
                    job = self.repository.get_job_details(job_id)
                    scripts = self.create_reporting_scripts(job.configuration)
                    report_generator = ReportGenerator(job_dir, scripts, lambda: "re-running-reports", ALL_REPORTS, job.configuration.task_name)
                    report_entries = []
                    for input_type in request.report_types:
                        report_results = report_generator.generate_reports(input_type)
                        report_entries.extend(report_results)
                    report_names = set([report.report_name for report in report_entries])
                    job_result = self.repository.get_job_details(job_id).result
                    report_entries.extend([previous_entry for previous_entry in job_result.reports if previous_entry.report_name not in report_names])
                    job_result.reports = report_entries
                    self.repository.set_job_status(job_id, job_result)
                    reported_jobs.append(job.configuration.job_id)
                except JobNotFoundError as e:
                    self.logger.warning("Unable to find job {}", e)
            else:
                self.logger.warning("Unable to find job dir {}", job_id)
        return reported_jobs

    def load_jobs_for_execution(self, job_configurations: List[JobInfo], submission_date: str, package_base_dir: str, targets: List[str],
                                renaming_prefix: Optional[str], reports_configuration: ReportsConfiguration, parameter_set_name: Optional[ParameterSetName]) -> List[Task]:
        datasets = Datasets(job_configurations)
        targets = set(targets)
        tasks = {job.job_id: self.create_task(job, self.get_packager(parameter_set_name, datasets, job, submission_date, package_base_dir, self._is_target(job.task_name, targets)),
                                              renaming_prefix, reports_configuration)
                 for job in job_configurations}
        for job in job_configurations:
            job_id = job.job_id
            for parent_id in job.input_job_ids:
                tasks[job_id].main_parents.append(tasks[parent_id])
            for parent_id in job.associated_job_ids:
                tasks[job_id].associated_parents.append(tasks[parent_id])
        return list(tasks.values())

    def create_task(self, job: JobInfo, packager: DatasetPackager, renaming_prefix: Optional[str],
                    reports_configuration: ReportsConfiguration) -> Task:
        base_dir = os.path.join(self.base_dir, job.instrument, job.task_name)
        main_input_files = [File(f.name, f.category, 'main_input_file') for f in job.input_files]
        associated_files = [File(f.name, f.category, 'associated_file') for f in job.associated_files]
        recipe_parameters = job.parameters.recipe_parameters
        input_filter = FilterFactory.create_filter(job.input_filter, job.input_filter_mode)
        output_filter = FilterFactory.create_filter(job.output_filter, job.output_filter_mode)
        input_map = job.input_map
        product_renamer = self.get_renamer(job.instrument, renaming_prefix)
        reporting_scripts = self.create_reporting_scripts(job)
        job_id = UUID(job.job_id)
        if job.command_type in ('recipe', 'future', 'function', 'shell'):
            command = self.create_command(base_dir, input_map, job, job_id, reporting_scripts, product_renamer, output_filter, reports_configuration)
            action = FileProcessingAction(main_input_files, associated_files, self.repository, input_filter, command, self.reexecution_window)
            if job.is_future():
                action = FutureAction(action, main_input_files, associated_files)
        else:
            raise NotImplementedError(job.command_type)
        return Task(job_id, action, [], [], self.repository, packager, recipe_parameters, self.continue_on_error)

    def create_reporting_scripts(self, job: JobInfo) -> List[ReportingScript]:
        return [ReportingScript(self.genreport_path, report.name, report.input, report.driver, self.save_report_inputs)
                for report in job.reports]

    def get_packager(self, parameter_set_name: Optional[ParameterSetName], datasets: Datasets, job: JobInfo, submission_date: str, package_base_dir: str,
                     is_target: bool) -> DatasetPackager:
        return self.dataset_packager_factory.get_packager(job.workflow_names,
                                                          job.task_name,
                                                          parameter_set_name,
                                                          datasets.get_datasets_for_job(job.job_id),
                                                          submission_date,
                                                          package_base_dir,
                                                          is_target)

    def create_command(self, base_dir: str, input_map: Dict[str, str], job: JobInfo, job_id: UUID,
                       reporting_scripts: List[ReportingScript], product_renamer: ProductRenamer, output_filter: ProductFilter,
                       reports_configuration: ReportsConfiguration) -> Command:
        if job.command_type == 'function':
            return CLICommand(job_id, job.command, base_dir, input_map, reporting_scripts, product_renamer,
                              self.invoker_provider.function_invoker, self.breakpoint_manager, self.job_cloner,
                              output_filter, reports_configuration, job.task_name)
        elif self.dummy:
            return DummyCommand(job_id, job.command, base_dir, reporting_scripts, job.task_name, output_filter)
        elif job.command_type == 'shell':
            invoker = self.invoker_provider.cli_invoker
        elif job.command_type in ('recipe', 'future'):
            invoker = self.invoker_provider.recipe_invoker
        else:
            raise NotImplementedError(job.command_type)
        return CLICommand(job_id, job.command, base_dir, input_map, reporting_scripts, product_renamer, invoker,
                          self.breakpoint_manager, self.job_cloner, output_filter, reports_configuration, job.task_name)

    def get_renamer(self, instrument: str, renaming_prefix: Optional[str]) -> ProductRenamer:
        prefix = renaming_prefix or self.output_prefix
        complete_prefix = prefix + '.' + instrument if prefix else None
        renamer = PrefixProductRenamer(complete_prefix) if complete_prefix else NopProductRenamer()
        return self.renamers.setdefault(complete_prefix, renamer)

    @staticmethod
    def _is_target(task_name: str, target_tasks: Set[str]) -> bool:
        return task_name in target_tasks

    def shutdown(self):
        self.executor_pool.shutdown()

    def wait_until_nothing_scheduled(self):
        tick = threading.Event()
        while not tick.wait(1):
            if not self.scheduled_jobs:
                tick.set()
