import logging
import os.path
import shutil
import time
from datetime import datetime, timedelta
from typing import List, Dict, Tuple, Set, Optional, Callable
from uuid import UUID

from edps.client.FitsFile import FitsFile
from edps.client.JobInfo import JobInfo
from edps.client.ProcessingJobStatus import ProcessingJobStatus
from edps.client.search import SearchFilter
from edps.config.configuration import Configuration
from edps.executor.filtering import ProductFilter
from edps.interfaces.JobsRepository import JobsRepository, ExecutionResult, JobDetails, RejectionResult, JobNotFoundError
from edps.jobs.CachingJobsDB import CachingJobsDB
from edps.jobs.InMemoryJobsDB import InMemoryJobsDB
from edps.jobs.TinyJobsDB import TinyJobsDB


class LocalJobsRepository(JobsRepository):
    def __init__(self, config: Configuration):
        db_type = config.db_type
        if db_type == 'tiny':
            self.db = TinyJobsDB(config)
        elif db_type == 'cqrs' or db_type == 'caching':
            self.db = CachingJobsDB(config)
        elif db_type == 'memory':
            self.db = InMemoryJobsDB()
        else:
            raise NotImplementedError("Only 'tiny', 'cqrs', 'caching' or 'memory' DB are supported")
        if config.truncate_repository:
            self.clear()
        self.logger = logging.getLogger("LocalJobsRepository")
        self.base_dir = config.base_dir

    def get_job_details(self, job_id: UUID) -> JobDetails:
        self.logger.debug("Searching for job '%s'", job_id)
        return self.get_job_details_simple(job_id, self.is_job_present)

    def get_job_details_simple(self, job_id: UUID,
                               is_job_present: Callable[[str, List[FitsFile]], bool] = lambda x, y: True) -> JobDetails:
        return self.db.get_job_details_simple(job_id, is_job_present)

    def get_jobs_list(self, pattern: str, offset: int, limit: int) -> List[JobDetails]:
        self.logger.debug("Searching for jobs for pattern '%s'", pattern)
        return self.db.get_jobs_list(pattern, offset, limit)

    def get_jobs_list_filter(self, search_filter: SearchFilter) -> List[JobDetails]:
        self.logger.debug("Searching for jobs for '%s'", search_filter)
        return self.get_jobs_list_filter_with_limits(search_filter, 0, None)

    def get_jobs_list_filter_with_limits(self, search_filter: SearchFilter,
                                         start: int, end: Optional[int]) -> List[JobDetails]:
        self.logger.debug("Searching for jobs for '%s'", search_filter)
        return self.db.get_jobs_list_filter_with_limits(search_filter, start, end)

    def set_job_status(self, job_id: UUID, result: ExecutionResult):
        self.logger.debug("Job %s status change to %s", job_id, result)
        entry = self.get_job_details_simple(job_id)
        new_value = JobDetails(entry.configuration, result, entry.submission_date, entry.rejected)
        return self.db.set_job_details(job_id, new_value)

    def set_jobs_submitted(self, jobs: List[JobInfo], submission_date: str):
        if jobs:
            self.logger.debug("Jobs %s submitted for execution at %s", jobs, submission_date)
            documents = [JobDetails(job, ExecutionResult(ProcessingJobStatus.CREATED, recipe_parameters=job.parameters.recipe_parameters), submission_date) for job in jobs]
            self.db.insert_multiple(documents)

    def insert_job(self, job: JobDetails) -> UUID:
        self.db.insert_multiple([job])
        return UUID(job.configuration.job_id)

    def find_complete_job(self, command: str, inputs: List[FitsFile], recipe_parameters: Dict, input_filter: ProductFilter, output_filter: ProductFilter) -> UUID:
        return self.db.find_complete_job(command, inputs, recipe_parameters, input_filter, output_filter, self.is_job_present)

    def reject_job(self, job_id: UUID) -> RejectionResult:
        rejected_jobs, incomplete_jobs = self.find_related_jobs(job_id)
        rejected_inputs = set([f.name for f in self.get_job_details_simple(job_id).configuration.input_files])
        for job_id in rejected_jobs + incomplete_jobs:
            self.set_job_rejected(job_id)
        non_rejected_inputs = set()
        for job_id in incomplete_jobs:
            non_rejected_inputs.update(self.find_non_rejected_inputs(job_id, set(rejected_jobs), rejected_inputs))
        return RejectionResult(rejected_jobs, incomplete_jobs, list(non_rejected_inputs))

    def delete_job_cascade(self, job_id: UUID) -> List[UUID]:
        rejected, incomplete = self.find_related_jobs(job_id)
        self.remove_jobs(rejected + incomplete)
        return rejected + incomplete

    def find_related_jobs(self, job_id: UUID) -> Tuple[List[UUID], List[UUID]]:
        rejected_jobs = {job_id}
        incomplete_jobs = set()
        to_check = {job_id}
        while True:
            already_processed = rejected_jobs.union(incomplete_jobs)
            main, associated = self.get_direct_children(list(to_check))
            immediate_children = main.union(associated)
            to_check = immediate_children.difference(already_processed)
            if to_check:
                associated_jobs = [self.get_job_details(job_id) for job_id in associated]
                already_rejected = {UUID(job.configuration.job_id) for job in associated_jobs if job.rejected}
                rejected_jobs.update(main.union(already_rejected))
                incomplete_jobs.update(associated.difference(already_rejected))
            else:
                return list(rejected_jobs), list(incomplete_jobs.difference(rejected_jobs))

    def get_direct_children(self, parent_jobs: List[UUID]) -> Tuple[Set[UUID], Set[UUID]]:
        associated_jobs = self.db.find_associated_jobs(parent_jobs)
        main_jobs_with_one_input, main_jobs_with_more_inputs = self.db.find_children_jobs(parent_jobs)
        return set(main_jobs_with_one_input), set(associated_jobs + main_jobs_with_more_inputs)

    def set_job_rejected(self, job_id: UUID):
        self.logger.debug("Job %s rejection", job_id)
        entry = self.get_job_details_simple(job_id)
        new_value = JobDetails(entry.configuration, entry.result, entry.submission_date, True)
        self.db.set_job_details(job_id, new_value)

    def remove_jobs(self, job_ids: List[UUID]) -> List[UUID]:
        for job_id in job_ids:
            try:
                job_details = self.get_job_details(job_id)
                job_directory = os.path.join(self.base_dir, job_details.configuration.instrument,
                                             job_details.configuration.task_name, str(job_id))
                shutil.rmtree(job_directory, ignore_errors=True)
            except JobNotFoundError as e:
                self.logger.warning("failed to remove job %s: %s", job_id, e, exc_info=e)
        self.db.remove_jobs(job_ids)
        return job_ids

    def find_non_rejected_inputs(self, job_id: UUID, rejected_jobs: Set[UUID], rejected_inputs: Set[str]) -> Set[str]:
        inputs = set()
        handled_jobs = set()
        parent_jobs_to_check = {str(job_id)}
        while parent_jobs_to_check:
            job_id = UUID(parent_jobs_to_check.pop())
            if job_id not in handled_jobs and job_id not in rejected_jobs:
                handled_jobs.add(job_id)
                details = self.get_job_details_simple(job_id)
                inputs.update([f.name for f in details.configuration.input_files if f.name not in rejected_inputs])
                parent_jobs_to_check.update(details.configuration.input_job_ids)
        return inputs

    def find_all_pending_jobs(self) -> List[JobDetails]:
        return [job for job in self.db.get_all() if job.is_pending() or job.is_interrupted()]

    def find_all_non_rejected_jobs(self) -> List[JobDetails]:
        return [job for job in self.db.get_all() if not job.rejected]

    def find_jobs_to_remove(self, cleanup_older_than: timedelta) -> List[UUID]:
        time_to = datetime.now() - cleanup_older_than
        old_jobs_ready_for_removal = set()
        old_jobs = {
            UUID(job.configuration.job_id)
            for job in self.get_jobs_list_filter(SearchFilter(submission_time_to=time_to))
            if not job.is_pending()
        }
        for job_id in old_jobs:
            main, associated = self.find_related_jobs(job_id)
            all_children = set(main + associated)
            if all_children.issubset(old_jobs):
                old_jobs_ready_for_removal.add(job_id)
        return list(old_jobs_ready_for_removal)

    def clear(self):
        self.db.clear()

    def close(self):
        self.db.close()

    @staticmethod
    def is_job_present(job_id: str, outputs: List[FitsFile]) -> bool:
        start_time = time.time()
        result = all([os.path.exists(file.name) for file in outputs])
        process_time = (time.time() - start_time) * 1000
        logging.debug(f"Checking if job {job_id} is not MISSING took {'{0:.2f}'.format(process_time)}ms")
        return result
