import logging
import os.path
from collections import deque
from threading import Event, RLock
from typing import List

from edps.executor.constants import ESOREX_STDOUT, ESOREX_STDERR, OMP_NUM_THREADS
from edps.interfaces.JobScheduler import JobScheduler
from edps.utils import run_command


class Allocation:
    def __init__(self, k: int):
        self.k = k
        self.event = Event()

    def wait(self):
        return self.event.wait()

    def set(self):
        return self.event.set()


class MultiAcquisitionOrderedSemaphore:
    def __init__(self, count: int):
        self.count = count
        self.mutex = RLock()
        self.waiting = deque()
        self.logger = logging.getLogger('MultiAcquisitionOrderedSemaphore')

    def acquire(self, k: int) -> 'MultiAcquisitionOrderedSemaphore.SemaphoreResource':
        return self.SemaphoreResource(self, k)

    def _acquire(self, k: int):
        allocation = Allocation(k)
        with self.mutex:
            self.logger.debug(f"Attempting to acquire {k} shares, put into waiting queue.")
            self.waiting.append(allocation)
            self.notify_waiting()
        allocation.wait()

    def _release(self, k: int):
        with self.mutex:
            self.count += k
            self.logger.debug(f"Releasing {k} shares")
            self.notify_waiting()

    def notify_waiting(self):
        self.logger.debug("Waking up threads.")
        while self.waiting and self.waiting[0].k <= self.count:
            allocation = self.waiting.popleft()
            self.logger.debug(f"Shares available {self.count} >= {allocation.k}, allocation successful")
            self.count -= allocation.k
            allocation.set()

    class SemaphoreResource:
        def __init__(self, semaphore: 'MultiAcquisitionOrderedSemaphore', k: int):
            self.semaphore = semaphore
            self.k = k

        def __enter__(self):
            self.semaphore._acquire(self.k)

        def __exit__(self, exc_type, exc_val, exc_tb):
            self.semaphore._release(self.k)


class LocalJobScheduler(JobScheduler):
    def __init__(self, cores: int):
        self.logger = logging.getLogger('LocalJobScheduler')
        self.total_cores = cores
        self.semaphore = MultiAcquisitionOrderedSemaphore(cores)

    def execute(self, command: List[str], directory: str, recipe_requested_cores: int) -> int:
        with open(os.path.join(directory, ESOREX_STDOUT), 'wb') as stdout_file:
            with open(os.path.join(directory, ESOREX_STDERR), 'wb') as stderr_file:
                cores_to_allocate = min(self.total_cores, recipe_requested_cores)
                env = os.environ.copy()
                env[OMP_NUM_THREADS] = str(cores_to_allocate)
                try:
                    with self.semaphore.acquire(cores_to_allocate):
                        return run_command(command, directory, stdout_file, stderr_file, env)
                except Exception as e:
                    self.logger.warning("Executing command '%s' in directory '%s' failed with %s", command, directory,
                                        e, exc_info=e)
                    return -1
