import functools
import itertools
import logging
import os.path
from collections import defaultdict
from dataclasses import dataclass
from typing import Set, Dict, List, Optional
from uuid import UUID, uuid4

from edps.client.FitsFile import FitsFile
from edps.client.WorkflowStateDTO import DataSourceStateDTO, TaskStateDTO
from . import job_factory
from .assoc_util import ReferenceFile
from .constants import MJD_OBS, INF, ONLY_PRODUCTS_AS_MAIN_INPUT
from .fits import ClassifiedFitsFile
from .grouping import Groups
from .job import Job
from .parameters import Parameters, ParametersProvider, JobParameters
from .task import DataSource, Task, TaskBase, AssociatedInput


@dataclass
class CleanupRequestArgs:
    request_id: UUID
    keep_files: bool = True
    keep_jobs: bool = True


class StatefulTaskBase:
    def __init__(self, task_base: TaskBase):
        self.task_base = task_base
        self.name = self.task_base.name

    @property
    def identifier(self) -> UUID:
        return self.task_base.identifier

    def is_associated_input(self):
        return self.task_base.is_assoc_input

    def cleanup(self, args: CleanupRequestArgs):
        raise NotImplementedError

    def create_jobs_for_task(self, request_id: UUID, task: 'StatefulTask', parameters: Parameters,
                             meta_targets: Optional[List[str]]) -> List[Job]:
        raise NotImplementedError

    def as_associated_input_edge(self, task: 'StatefulTask', details: AssociatedInput):
        raise NotImplementedError

    def should_apply_breakpoints(self) -> bool:
        return self.task_base.should_apply_breakpoints

    @staticmethod
    def sort_jobs_by_mjdobs(jobs: List[Job]) -> List[Job]:
        return sorted(jobs, key=lambda job: job.exemplar.get_keyword_value(MJD_OBS, INF))

    @staticmethod
    def sort_files_by_mjdobs(files: List[ClassifiedFitsFile]) -> List[ClassifiedFitsFile]:
        return sorted(files, key=lambda f: f.get_keyword_value(MJD_OBS, INF))

    def group_files(self, files: List[ClassifiedFitsFile], parameters: Parameters) -> Dict[str, List[ClassifiedFitsFile]]:
        groups = {"": files}
        for grouper in self.task_base.groupers:
            groups = grouper.split(groups, parameters)
        groups = {key: self.sort_files_by_mjdobs(files) for key, files in groups.items() if len(files) > 0}
        return groups


@dataclass
class DataSourceRequestState:
    files: Set[ClassifiedFitsFile]
    groups: Groups
    incomplete_groups: Groups

    def __str__(self):
        return f"{len(self.files)} files, {len(self.groups)} groups, {len(self.incomplete_groups)} incomplete groups"

    @classmethod
    def empty(cls) -> 'DataSourceRequestState':
        return cls(files=set(), groups={}, incomplete_groups={})


class StatefulDataSource(StatefulTaskBase):
    def __init__(self, data_source: DataSource):
        super().__init__(data_source)
        self.data_source = data_source
        self.logger = logging.getLogger('StatefulDataSource')
        self.request_state: Dict[UUID, DataSourceRequestState] = dict()
        self.historical_files: Set[ClassifiedFitsFile] = set()

    def reload_files(self, files: List[ClassifiedFitsFile], parameters_provider: ParametersProvider):
        request_id = uuid4()
        self.interpret(request_id, files, parameters_provider)
        new_files = self.request_state.pop(request_id, DataSourceRequestState.empty()).files
        existing_historical_files = {f for f in self.historical_files if os.path.exists(f.get_path())}
        self.historical_files = existing_historical_files.union(new_files)

    def interpret(self, request_id: UUID, files: List[ClassifiedFitsFile], parameters_provider: ParametersProvider):
        parameters = parameters_provider.get_parameters(str(self), self.data_source.workflow_names, request_id)
        self.logger.debug("[%s %s] interpret", self, request_id)
        new_files = {file for file in files if file.classification_rule_id in self.data_source.classification_rules_ids}
        self.request_state[request_id] = self.create_request_state(new_files, parameters)
        self.logger.debug("[%s %s] %s", self, request_id, self.request_state[request_id])

    def create_request_state(self, new_files: Set[ClassifiedFitsFile],
                             parameters: Parameters) -> DataSourceRequestState:
        if not new_files:
            return DataSourceRequestState(new_files, {}, {})
        groups = self.group_files(list(new_files), parameters)
        complete_groups = {key: files for key, files in groups.items() if len(files) >= self.task_base.min_group_size}
        incomplete_groups = {key: files for key, files in groups.items() if len(files) < self.task_base.min_group_size}
        return DataSourceRequestState(new_files, complete_groups, incomplete_groups)

    def cleanup(self, args: CleanupRequestArgs):
        self.logger.debug("[%s %s] cleanup", self, args.request_id)
        if self.is_associated_input() and args.keep_files:
            self.historical_files.update(self.request_state.get(args.request_id, DataSourceRequestState.empty()).files)
        self.request_state.pop(args.request_id, None)

    def get_incomplete_groups(self, request_id: UUID) -> List[List[ClassifiedFitsFile]]:
        return [files for files in self.request_state[request_id].incomplete_groups.values()]

    def create_jobs_for_task(self, request_id: UUID, task: 'StatefulTask', parameters: Parameters,
                             meta_targets: Optional[List[str]]) -> List[Job]:
        return [task.create_job_with_input_files(files, parameters, meta_targets) for files in
                self.request_state[request_id].groups.values()]

    def get_associated_files(self, request_id: UUID, ref_file: ReferenceFile, parameters: JobParameters) -> List[ClassifiedFitsFile]:
        return [file for file in self.all_files(request_id) if ref_file.is_associated(file, parameters)]

    def all_files(self, request_id: UUID) -> Set[ClassifiedFitsFile]:
        return set(list(self.historical_files) + list(self.request_state[request_id].files))

    def as_associated_input_edge(self, task: 'StatefulTask', details: AssociatedInput):
        from .edge import AssociatedInputDataSource
        return AssociatedInputDataSource(self, task, details)

    def as_dto(self) -> DataSourceStateDTO:
        return DataSourceStateDTO(
            name=self.name,
            files=[FitsFile(name=x.get_path(), category=x.classification) for x in self.historical_files]
        )

    def __str__(self):
        return f'DataSource {self.name} {self.identifier}'

    def __repr__(self):
        return str(self)


CalibDB = Dict[str, Set[ClassifiedFitsFile]]


class StatefulTask(StatefulTaskBase):
    def __init__(self, task: Task, associate_existing_jobs):
        super().__init__(task)
        self.task = task
        self.task_id = task.task_id
        self.main_input = task.main_input
        self.associated_input_groups = task.associated_input_groups
        self.job_editing_function = task.job_editing_function
        self.associate_existing_jobs = associate_existing_jobs
        self.logger = logging.getLogger('StatefulTask')
        self.request_files: Dict[UUID, CalibDB] = dict()
        self.request_jobs: Dict[UUID, List[Job]] = defaultdict(list)
        self.historical_jobs: Dict[Job, Job] = {}
        self.calib_db: CalibDB = {}

    def log_status(self, request_id: UUID):
        self.logger.debug("[%s %s] %d complete jobs, %d incomplete jobs, %d historical jobs, %s files",
                          self, request_id, len(self.current_request_complete_jobs(request_id)),
                          len(self.current_request_incomplete_jobs(request_id)), len(self.historical_jobs),
                          self.calib_db_for_log())

    def calib_db_for_log(self) -> str:
        return str([(k, len(v)) for k, v in self.calib_db.items()])

    def interpret(self, request_id: UUID, files: List[ClassifiedFitsFile]):
        self.request_files[request_id] = self.create_calib_db(files, keep_history=False)

    def update_calib_db(self, files: List[ClassifiedFitsFile]):
        self.calib_db = self.create_calib_db(files, keep_history=True)

    def recreate_calib_db(self, files: List[ClassifiedFitsFile]):
        self.calib_db = self.create_calib_db(files, keep_history=False)

    def create_calib_db(self, files: List[ClassifiedFitsFile], keep_history: bool) -> CalibDB:
        calib_db = defaultdict(set, self.calib_db) if keep_history else defaultdict(set)
        accepted = 0
        for file in files:
            if self.task.is_task_product(file):
                calib_db[file.classification].add(file)
                accepted += 1
        self.logger.debug("Task %s loaded %d MASTER file(s) from supplied %d files, result state: %s", self.task.name,
                          accepted, len(files), str([(k, len(v)) for k, v in calib_db.items()]))
        return dict(calib_db)

    def replace_with_equivalent_historical_jobs(self, jobs: List[Job]) -> List[Job]:
        return [self.historical_jobs.get(job, job) for job in jobs]

    def cleanup(self, args: CleanupRequestArgs):
        self.logger.debug("[%s %s] cleanup", self, args.request_id)
        if args.keep_jobs:
            jobs_to_keep = {job: job for job in self.request_jobs.get(args.request_id, []) if job.is_complete and job.future}
            # FIXME: race condition if we get two requests, can be an issue if identical jobs gets created by both requests
            self.historical_jobs.update(jobs_to_keep)
        if args.keep_files:
            current_request_files = self.request_files.get(args.request_id, {})
            all_categories = set(list(self.calib_db.keys()) + list(current_request_files.keys()))
            for category in all_categories:
                self.calib_db[category] = self.calib_db.get(category, set()).union(current_request_files.get(category, set()))
        self.request_files.pop(args.request_id, None)
        self.request_jobs.pop(args.request_id, None)

    def get_associated_jobs(self, request_id: UUID, ref_file: ReferenceFile, parameters: JobParameters,
                            only_complete: bool) -> List[Job]:
        return [job for job in self.associatable_jobs(request_id, only_complete) if ref_file.is_associated(job.product, parameters)]

    def get_associated_files_by_category(self, request_id: UUID, category: str, ref_file: ReferenceFile,
                                         parameters: JobParameters) -> List[ClassifiedFitsFile]:
        files = self.calib_db.get(category, set()).union(self.request_files[request_id].get(category, set()))
        return [file for file in files if ref_file.is_associated(file, parameters)]

    def get_associated_files(self, request_id: UUID, categories: Set[str], ref_file: ReferenceFile,
                             parameters: JobParameters) -> Dict[str, List[ClassifiedFitsFile]]:
        return {category: self.get_associated_files_by_category(request_id, category, ref_file, parameters) for category
                in categories}

    def current_request_complete_jobs(self, request_id: UUID) -> List[Job]:
        return [job for job in self.request_jobs[request_id] if job.is_complete]

    def current_request_complete_groups(self, request_id: UUID, parameters: Parameters) -> List[List[Job]]:
        return self.group_jobs(self.current_request_complete_jobs(request_id), parameters)

    def current_request_incomplete_groups(self, request_id: UUID, parameters: Parameters) -> List[List[Job]]:
        return self.group_jobs(self.current_request_incomplete_jobs(request_id), parameters)

    def group_jobs(self, jobs: List[Job], parameters: Parameters) -> List[List[Job]]:
        if not self.task_base.groupers:
            return [[job] for job in jobs]
        jobs_for_exemplar = {job.exemplar: job for job in jobs}
        groups = self.group_files(list(jobs_for_exemplar.keys()), parameters)
        groups = {key: [jobs_for_exemplar[exemplar] for exemplar in files] for key, files in groups.items()}
        return [self.sort_jobs_by_mjdobs(jobs) for jobs in groups.values() if len(jobs) >= self.task_base.min_group_size]

    def current_request_incomplete_jobs(self, request_id: UUID) -> List[Job]:
        return [job for job in self.request_jobs[request_id] if not job.is_complete]

    def associatable_jobs(self, request_id: UUID, only_complete: bool) -> Set[Job]:
        request_jobs = self.request_jobs[request_id]
        association_pool = request_jobs + list(self.historical_jobs) if self.associate_existing_jobs else request_jobs
        return set([job for job in association_pool if job.is_complete or not only_complete])

    def add_jobs(self, request_id: UUID, jobs: List[Job]):
        # FIXME: if we ever got two identical jobs from generator then this would cause issues, because we would not "merge" them into equivalent one here
        self.request_jobs[request_id].extend(self.replace_with_equivalent_historical_jobs(jobs))

    def create_jobs_for_task(self, request_id: UUID, task: 'StatefulTask', parameters: Parameters,
                             meta_targets: Optional[List[str]]) -> List[Job]:
        jobs_from_raw = self.create_jobs_from_raw_data(meta_targets, parameters, request_id, task)
        jobs_from_products = self.create_jobs_from_products(request_id, task, parameters, meta_targets)
        return jobs_from_raw + jobs_from_products

    def create_jobs_from_raw_data(self, meta_targets, parameters, request_id, task):
        if self.use_raw_data(parameters, task):
            complete_jobs = [task.create_job_with_input_jobs(jobs, parameters, meta_targets) for jobs in
                             self.current_request_complete_groups(request_id, parameters)]
            incomplete_jobs = [task.create_job_with_input_jobs(jobs, parameters, meta_targets) for jobs in
                               self.current_request_incomplete_groups(request_id, parameters)]
            for job in incomplete_jobs:
                job.is_complete = False
            return complete_jobs + incomplete_jobs
        return []

    def use_raw_data(self, parameters, task):
        return not (task.can_accept_products() and parameters.get_workflow_param(ONLY_PRODUCTS_AS_MAIN_INPUT) is True)

    def create_job_with_input_files(self, files: List[ClassifiedFitsFile], parameters: Parameters,
                                    meta_targets: Optional[List[str]]) -> Job:
        return job_factory.create_job(task_details=self.task.as_task_details(), input_files=files,
                                      parameters=parameters, meta_targets=meta_targets)

    def create_job_with_input_jobs(self, jobs: List[Job], parameters: Parameters,
                                   meta_targets: Optional[List[str]]) -> Job:
        return job_factory.create_job(task_details=self.task.as_task_details(), input_jobs=jobs, parameters=parameters,
                                      meta_targets=meta_targets)

    def create_jobs_from_products(self, request_id: UUID, task: 'StatefulTask',
                                  parameters: Parameters, meta_targets: Optional[List[str]]) -> List[Job]:
        relevant_classifications = set([rule.classification for rule in task.task.accepted_classification_rules])
        request_files_by_classification = self.request_files[request_id]
        relevant_sets = [request_files_by_classification.get(c, set()) for c in relevant_classifications]
        relevant_files = functools.reduce(lambda a, b: a.union(b), relevant_sets, set())
        groups = self.group_files(list(relevant_files), parameters).values()
        return [task.create_job_with_input_files(files, parameters, meta_targets) for files in groups]

    def as_associated_input_edge(self, task: 'StatefulTask', details: AssociatedInput):
        from .edge import AssociatedInputTask
        return AssociatedInputTask(self, task, details)

    def remove_jobs(self, jobs_to_remove: List[UUID]):
        jobs_to_remove = set(jobs_to_remove)
        self.historical_jobs = {job: job for job in self.historical_jobs if UUID(job.id) not in jobs_to_remove}

    def get_historical_job(self, job_id: str) -> Optional[Job]:
        for job in self.historical_jobs.values():
            if job.id == job_id:
                return job

    def can_accept_products(self) -> bool:
        return len(self.task.accepted_classification_rules) > 0

    def as_dto(self) -> TaskStateDTO:
        files = list(itertools.chain.from_iterable(self.calib_db.values()))
        return TaskStateDTO(name=self.name,
                            files=[FitsFile(name=x.get_path(), category=x.classification) for x in files],
                            jobs=[x.id for x in self.historical_jobs])

    def __str__(self):
        return f'Task {self.name} {self.identifier}'

    def __repr__(self):
        return str(self)
