import logging
import os
import re
import shutil
import threading
from typing import List, Dict, Optional, Callable, Tuple, Any
from uuid import UUID

from tinydb import TinyDB, JSONStorage, Query
from tinydb.middlewares import CachingMiddleware, Middleware
from tinydb.queries import QueryInstance

from edps.client.FitsFile import FitsFile
from edps.client.ProcessingJob import ProcessingJob
from edps.client.search import SearchFilter
from edps.config.configuration import Configuration
from edps.executor.filtering import ProductFilter
from edps.interfaces.JobsRepository import JobDetails, JobNotFoundError
from edps.jobs.JobsDB import JobsDB, JobExistenceFunction
from edps.utils import timer


class ThreadSafeMiddleware(Middleware):
    def __init__(self, storage_cls, db_dir: str, min_disk_space_mb: int):
        self.db_dir = db_dir
        self.min_disk_space_mb = min_disk_space_mb
        self.lock = threading.RLock()
        super().__init__(storage_cls)
        self.logger = logging.getLogger('ThreadSafeMiddleware')

    def __enter__(self):
        self.lock.acquire()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.lock.release()

    def read(self) -> Optional[Dict[str, Dict[str, Any]]]:
        with self:
            return self.storage.read()

    def write(self, data: Dict[str, Dict[str, Any]]) -> None:
        with self:
            total, used, free = shutil.disk_usage(self.db_dir)
            free_mb = free // 10 ** 6
            if free_mb < self.min_disk_space_mb:
                self.logger.error("Can't save EDPS DB to disk: available disk space %d MB, min %d MB",
                                  free_mb, self.min_disk_space_mb)
            else:
                self.logger.debug("Writing data to disk")
                return self.storage.write(data)


class ThreadSafeDbWrapper:
    def __init__(self, db: TinyDB):
        self.lock = threading.RLock()
        self.db = db
        db.default_table_name = '_default'

    def __enter__(self):
        self.lock.acquire()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.lock.release()

    def all(self):
        with self:
            return self.db.table(self.db.default_table_name).all()

    def drop_tables(self):
        with self:
            if hasattr(self.db, 'purge_tables'):
                self.db.purge_tables()
            else:
                self.db.drop_tables()

    def search(self, query):
        with self:
            return self.db.search(query)

    def update(self, value, query):
        with self:
            return self.db.update(value, query)

    def insert_multiple(self, documents):
        with self:
            return self.db.insert_multiple(documents)

    def insert(self, document):
        with self:
            return self.db.insert(document)

    def remove(self, query: Query) -> List[int]:
        with self:
            return self.db.remove(cond=query)

    def close(self):
        self.db.close()


class CacheFlusher:
    """
    Class responsible for periodically flushing tinydb caching middleware.
    """

    def __init__(self, flush: Callable, delay_seconds: int):
        self.flush = flush
        self.stopped = threading.Event()
        self.delay_seconds = delay_seconds
        self.logger = logging.getLogger('ThreadSafeMiddleware')

    def run(self):
        while not self.stopped.wait(self.delay_seconds):
            self.invoke_flush()

    @timer("edps_periodic_flush")
    def invoke_flush(self):
        try:
            self.flush()
        except Exception as e:
            self.logger.error("Failed to flush changes to disk", exc_info=e)

    def stop(self):
        self.stopped.set()


class TinyJobsDB(JobsDB):
    def __init__(self, config: Configuration):
        self.flusher = None
        flush_size = config.db_flush_size
        db_dir = os.path.dirname(os.path.realpath(config.db_path))
        os.makedirs(db_dir, exist_ok=True)
        middleware = ThreadSafeMiddleware(JSONStorage, db_dir, config.min_disk_space_mb)
        if flush_size > 0:
            storage = CachingMiddleware(middleware)
            CachingMiddleware.WRITE_CACHE_SIZE = flush_size
            self.flusher = CacheFlusher(storage.flush, config.db_flush_timeout)
        else:
            storage = middleware
        self.db = ThreadSafeDbWrapper(TinyDB(config.db_path, storage=storage))
        if self.flusher:
            threading.Thread(target=self.flusher.run).start()
        self.logger = logging.getLogger("LocalJobsRepository")

    def get_all(self) -> List[JobDetails]:
        return [JobDetails.from_dict(details) for details in self.db.all()]

    def get_job_details_simple(self, job_id: UUID,
                               is_job_present: JobExistenceFunction = lambda x, y: True) -> JobDetails:
        job: ProcessingJob = Query()
        details = self.db.search(job.configuration.job_id == str(job_id))
        if not details:
            raise JobNotFoundError(f'Job not found {str(job_id)}')
        result = JobDetails.from_dict(details[0], is_job_present)
        self.logger.debug("Returning job %s details %s", job_id, result)
        return result

    def get_jobs_list(self, pattern: str, offset: int, limit: int) -> List[JobDetails]:
        job: ProcessingJob = Query()
        matching_jobs = self.db.search(
            job.configuration.command.matches(pattern, flags=re.IGNORECASE) |
            job.configuration.job_id.matches(pattern, flags=re.IGNORECASE) |
            job.configuration.task_name.matches(pattern, flags=re.IGNORECASE) |
            job.configuration.instrument.matches(pattern, flags=re.IGNORECASE) |
            job.status.matches(pattern, flags=re.IGNORECASE)
        )
        self.logger.debug("Searching for jobs for pattern '%s', found: %s", pattern, len(matching_jobs))
        return [JobDetails.from_dict(details) for details in matching_jobs[offset:offset + limit]]

    def get_jobs_list_filter_with_limits(self, search_filter: SearchFilter, start: int,
                                         end: Optional[int]) -> List[JobDetails]:
        job: ProcessingJob = Query()
        matching_jobs = self.db.search(
            (job.configuration.instrument.matches(search_filter.instrument))
            &
            (self.match_targets(job, search_filter))
            &
            (search_filter.completion_time_from.isoformat() <= job.completion_date)
            &
            (job.completion_date <= search_filter.completion_time_to.isoformat())
            &
            (search_filter.mjdobs_from <= job.configuration.mjdobs)
            &
            (job.configuration.mjdobs <= search_filter.mjdobs_to)
            &
            (search_filter.submission_time_from.isoformat() <= job.submission_date)
            &
            (job.submission_date <= search_filter.submission_time_to.isoformat())
        )
        self.logger.debug("Found: %s", len(matching_jobs))
        return [JobDetails.from_dict(details) for details in matching_jobs[start:end]]

    def find_complete_job(self, command: str, inputs: List[FitsFile], recipe_parameters: Dict,
                          input_filter: ProductFilter, output_filter: ProductFilter,
                          is_job_present: JobExistenceFunction = lambda x, y: True) -> UUID:
        job: ProcessingJob = Query()
        self.logger.debug("Searching for jobs matching command '%s', inputs '%s' and parameters '%s'",
                          command, inputs, recipe_parameters)
        sorted_inputs = [x.model_dump() for x in sorted(inputs, key=lambda y: y.name)]
        matching_jobs = self.db.search((job.rejected == False) &
                                       (job.configuration.command == command) &
                                       (job.recipe_parameters == recipe_parameters) &
                                       (job.input_files == sorted_inputs) &
                                       (job.configuration.input_filter_mode == input_filter.get_mode()) &
                                       (job.configuration.input_filter.all(input_filter.get_categories())) &
                                       (job.configuration.output_filter_mode == output_filter.get_mode()) &
                                       (job.configuration.output_filter.all(output_filter.get_categories()))
                                       )
        self.logger.debug("Found: %s", len(matching_jobs))
        for details in matching_jobs:
            job: JobDetails = JobDetails.from_dict(details, is_job_present)
            if job.is_complete():
                return UUID(job.configuration.job_id)

    def find_associated_jobs(self, parent_jobs: List[UUID]) -> List[UUID]:
        parent_jobs = list(map(str, parent_jobs))
        job: ProcessingJob = Query()
        associated_jobs = self.db.search(
            job.configuration.associated_job_ids.any(parent_jobs)
        )
        return [self.get_id(details) for details in associated_jobs]

    def find_children_jobs(self, parent_jobs: List[UUID]) -> Tuple[List[UUID], List[UUID]]:
        parent_jobs = list(map(str, parent_jobs))
        job: ProcessingJob = Query()
        main_jobs = self.db.search(
            job.configuration.input_job_ids.any(parent_jobs)
        )
        main_jobs_with_one_input = [main_job for main_job in main_jobs if
                                    len(main_job['configuration']['input_job_ids']) == 1]
        main_jobs_with_more_inputs = [main_job for main_job in main_jobs if
                                      len(main_job['configuration']['input_job_ids']) > 1]

        return ([self.get_id(details) for details in main_jobs_with_one_input],
                [self.get_id(details) for details in main_jobs_with_more_inputs])

    def insert_multiple(self, data: List[JobDetails]):
        self.db.insert_multiple([job.to_dict() for job in data])

    def set_job_details(self, job_id: UUID, new_value: JobDetails):
        job: ProcessingJob = Query()
        self.db.update(new_value.to_dict(), job.configuration.job_id == str(job_id))

    def remove_jobs(self, job_ids: List[UUID]):
        self.logger.debug("Jobs %s deletion", job_ids)
        self.db.remove(Query().configuration.job_id.one_of([str(job_id) for job_id in job_ids]))

    def clear(self):
        self.db.drop_tables()

    def close(self):
        if self.flusher:
            self.flusher.stop()
        self.db.close()

    @staticmethod
    def get_id(details: Dict) -> UUID:
        return UUID(details['configuration']['job_id'])

    def match_targets(self, job: Query, search_filter: SearchFilter) -> QueryInstance:
        if not search_filter.targets and not search_filter.meta_targets:
            return QueryInstance(lambda x: True, None)
        elif search_filter.targets and search_filter.meta_targets:
            return (job.configuration.task_name.one_of(search_filter.targets or [])
                    &
                    job.configuration.meta_targets.any(search_filter.meta_targets or []))
        else:
            return job.configuration.task_name.one_of(search_filter.targets or [])
