import re
import threading
from typing import List, Dict, Callable, Optional, Tuple, Iterable
from uuid import UUID

from edps.client.FitsFile import FitsFile
from edps.client.search import SearchFilter
from edps.executor.filtering import ProductFilter
from edps.interfaces.JobsRepository import JobDetails, JobNotFoundError
from edps.jobs.JobsDB import JobsDB, JobExistenceFunction


class InMemoryJobsDB(JobsDB):
    def __init__(self, get_all: Callable[[], List[JobDetails]] = lambda: []):
        self.mutable_cache_state: Dict[str, JobDetails] = self.__load(get_all)
        self.reader_cache_snapshot: Dict[str, JobDetails] = self.mutable_cache_state.copy()
        self.lock = threading.RLock()

    def get_all(self) -> List[JobDetails]:
        return list(self.threadsafe_cache_values())

    def get_job_details_simple(self, job_id: UUID,
                               is_job_present: JobExistenceFunction = lambda x, y: True) -> JobDetails:
        details = self.reader_cache_snapshot.get(str(job_id), None)
        if not details:
            raise JobNotFoundError(f'Job not found {str(job_id)}')
        details.update_status(is_job_present)
        return details

    def get_jobs_list(self, pattern: str, offset: int, limit: int) -> List[JobDetails]:
        return [job for job in self.threadsafe_cache_values() if (
                re.match(pattern, job.configuration.command, flags=re.IGNORECASE) or
                re.match(pattern, job.configuration.job_id, flags=re.IGNORECASE) or
                re.match(pattern, job.configuration.task_name, flags=re.IGNORECASE) or
                re.match(pattern, job.configuration.instrument, flags=re.IGNORECASE) or
                re.match(pattern, job.result.status.name, flags=re.IGNORECASE)
        )]

    def get_jobs_list_filter_with_limits(self, search_filter: SearchFilter,
                                         start: int, end: Optional[int]) -> List[JobDetails]:
        results = []
        found = 0
        from_ts = search_filter.completion_time_from.isoformat()
        to_ts = search_filter.completion_time_to.isoformat()
        from_submission_ts = search_filter.submission_time_from.isoformat()
        to_submission_ts = search_filter.submission_time_to.isoformat()
        for job in self.threadsafe_cache_values():
            if end is not None and found > end:
                break
            if (re.match(search_filter.instrument, str(job.configuration.instrument), flags=re.IGNORECASE) and
                    (str(job.configuration.task_name) in search_filter.targets if search_filter.targets else not bool(search_filter.meta_targets)) and
                    (set(job.configuration.meta_targets).intersection(set(search_filter.meta_targets)) if search_filter.meta_targets else True) and
                    from_ts <= job.result.completion_date <= to_ts and
                    search_filter.mjdobs_from <= job.configuration.mjdobs <= search_filter.mjdobs_to and
                    from_submission_ts <= job.submission_date <= to_submission_ts):
                results.append(job)
                found += 1
        return results[start:end]

    def find_complete_job(self, command: str, inputs: List[FitsFile], recipe_parameters: Dict,
                          input_filter: ProductFilter, output_filter: ProductFilter,
                          is_job_present: JobExistenceFunction = lambda x, y: True) -> UUID:
        sorted_inputs = sorted(inputs, key=lambda y: y.name)
        for job in self.threadsafe_cache_values():
            if (job.rejected is False and
                    job.configuration.command == command and
                    job.result.recipe_parameters == recipe_parameters and
                    job.result.input_files == sorted_inputs and job.is_complete() and
                    job.configuration.input_filter_mode == input_filter.get_mode() and
                    set(job.configuration.input_filter) == input_filter.get_categories() and
                    job.configuration.output_filter_mode == output_filter.get_mode() and
                    set(job.configuration.output_filter) == output_filter.get_categories()
            ):
                return UUID(job.configuration.job_id)

    def find_associated_jobs(self, parent_jobs: List[UUID]) -> List[UUID]:
        parent_set = set([str(job_id) for job_id in parent_jobs])
        return [UUID(job.configuration.job_id) for job in self.threadsafe_cache_values() if
                job.configuration.associated_job_ids is not None and
                len(set(job.configuration.associated_job_ids).intersection(parent_set)) > 0]

    def find_children_jobs(self, parent_jobs: List[UUID]) -> Tuple[List[UUID], List[UUID]]:
        parent_set = set([str(job_id) for job_id in parent_jobs])
        main_jobs = [job for job in self.threadsafe_cache_values() if job.configuration.input_job_ids is not None and
                     len(set(job.configuration.input_job_ids).intersection(parent_set)) > 0
                     ]
        main_jobs_with_one_input = [main_job for main_job in main_jobs if
                                    len(main_job.configuration.input_job_ids) == 1]
        main_jobs_with_more_inputs = [main_job for main_job in main_jobs if
                                      len(main_job.configuration.input_job_ids) > 1]
        return ([UUID(job.configuration.job_id) for job in main_jobs_with_one_input],
                [UUID(job.configuration.job_id) for job in main_jobs_with_more_inputs])

    def insert_multiple(self, data: List[JobDetails]):
        self.execute_threadsafe(lambda: self.__unsafe_insert_multiple(data))

    def __unsafe_insert_multiple(self, data: List[JobDetails]):
        self.mutable_cache_state.update({job.configuration.job_id: job for job in data})

    def set_job_details(self, job_id: UUID, new_value: JobDetails):
        self.execute_threadsafe(lambda: self.__unsafe_set_job_details(job_id, new_value))

    def __unsafe_set_job_details(self, job_id: UUID, new_value: JobDetails):
        self.mutable_cache_state[str(job_id)] = new_value

    def remove_jobs(self, job_ids: List[UUID]):
        self.execute_threadsafe(lambda: self.__unsafe_remove_jobs(job_ids))

    def __unsafe_remove_jobs(self, job_ids: List[UUID]):
        for job_id in job_ids:
            self.mutable_cache_state.pop(str(job_id), None)

    def threadsafe_cache_values(self) -> Iterable[JobDetails]:
        return self.reader_cache_snapshot.values()

    def execute_threadsafe(self, cache_operation):
        with self.lock:
            cache_operation()
            self.reader_cache_snapshot = self.mutable_cache_state.copy()

    def clear(self):
        self.mutable_cache_state = {}
        self.reader_cache_snapshot = {}

    def close(self):
        self.mutable_cache_state = {}
        self.reader_cache_snapshot = {}

    @staticmethod
    def __load(get_all: Callable[[], List[JobDetails]]) -> Dict[str, JobDetails]:
        return {job.configuration.job_id: job for job in get_all()}
