import logging
import os.path
import re
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import List, Callable, Dict, Optional

import yaml

from edps.client.FitsFile import FitsFile
from edps.config.configuration import Configuration
from edps.executor.renamer import PatternProductRenamer
from edps.generator.constants import FileCategory, TaskName, WorkflowName, WorkflowPath, ParameterSetName, ALL_TASKS
from edps.interfaces.JobsRepository import ExecutionResult


@dataclass
class PackageConfigEntry:
    pattern: str
    categories: List[FileCategory]


@dataclass
class TasksPackageConfig:
    task_configurations: Dict[TaskName, List[PackageConfigEntry]]

    def get(self, task_name: TaskName) -> List[PackageConfigEntry]:
        return self.task_configurations.get(task_name, self.task_configurations.get(ALL_TASKS, []))

    @classmethod
    def empty(cls) -> 'TasksPackageConfig':
        return cls(task_configurations={})


@dataclass
class ParameterSetsPackageConfig:
    parameter_set_configurations: Dict[ParameterSetName, TasksPackageConfig]
    default_set: Optional[ParameterSetName]

    def get(self, task_name: TaskName, parameter_set: Optional[ParameterSetName]) -> List[PackageConfigEntry]:
        parameter_set = parameter_set or self.default_set
        return self.parameter_set_configurations.get(parameter_set, TasksPackageConfig.empty()).get(task_name)

    @classmethod
    def empty(cls) -> 'ParameterSetsPackageConfig':
        return cls(parameter_set_configurations={}, default_set=None)


@dataclass
class WorkflowsPackageConfig:
    workflow_configurations: Dict[WorkflowName, ParameterSetsPackageConfig]

    def get(self, workflow_names: List[WorkflowName], task_name: TaskName, parameter_set: Optional[ParameterSetName]) -> List[PackageConfigEntry]:
        for workflow in workflow_names:
            workflow_configuration = self.workflow_configurations.get(workflow, ParameterSetsPackageConfig.empty())
            config = workflow_configuration.get(task_name, parameter_set)
            if config:
                return config
        return []


class CategoryMatcher:
    def __init__(self, patterns: List[str]):
        self.patterns = patterns

    def filter_matching(self, files: List[FitsFile]) -> List[FitsFile]:
        return [file for file in files if self.matches(file.category)]

    def matches(self, category: FileCategory) -> bool:
        return any([re.match(pattern, category) for pattern in self.patterns])


class DatasetPackager:
    def package(self, result: ExecutionResult):
        raise NotImplementedError()


class NopDatasetPackager(DatasetPackager):
    def package(self, result: ExecutionResult):
        pass


class ConfigurableDatasetPackager(DatasetPackager):
    def __init__(self, package_base_dir: str, category_configurations: List[PackageConfigEntry], task: TaskName, dataset_names: List[str],
                 file_action: Callable[[str, str], None], submission_date: str, skip_conflicting: bool):
        self.category_configurations = {CategoryMatcher(configuration.categories): configuration.pattern for configuration in category_configurations}
        self.base_dir = os.path.abspath(package_base_dir)
        self.task_name = task
        self.dataset_names = dataset_names
        self.logger = logging.getLogger('ConfigurableDatasetPackager')
        self.file_action = file_action
        self.submission_date = submission_date
        self.skip_conflicting = skip_conflicting

    def package(self, result: ExecutionResult):
        try:
            # FIXME: reports? logs?
            for configuration, pattern in self.category_configurations.items():
                files_to_package = configuration.filter_matching(result.output_files)
                self.package_files(pattern, files_to_package, self.submission_date)
        except Exception as e:
            self.logger.error("Building dataset package failed due to " + str(e), exc_info=e)

    def package_files(self, pattern: str, output_files: List[FitsFile], submission_date: str):
        for output_file in output_files:
            if not self.dataset_names:
                self.logger.warning("File %s not associated with any dataset name.", output_file)
            else:
                # job can belong to multiple datasets
                # if DATASET placeholder is not used in the pattern then there is no point in packaging the same file many times
                dataset_names = self.dataset_names
                if '$DATASET' not in pattern:
                    dataset_names = ['']
                for dataset in dataset_names:
                    renamer = PatternProductRenamer(pattern, dataset, submission_date, self.task_name)
                    new_name = renamer.rename_file(output_file).name
                    self.add_file_to_package(output_file, new_name)

    def add_file_to_package(self, output_file: FitsFile, new_name: str):
        if new_name.startswith(os.path.sep):
            new_name = new_name[1:]
        target_path = os.path.join(self.base_dir, new_name)
        os.makedirs(os.path.dirname(target_path), exist_ok=True)
        counter = 0
        non_conflicting_path = target_path
        name, ext = os.path.splitext(target_path)
        # FIXME: this is obviously TOC-TOU and can easily fail
        while True:
            if os.path.exists(non_conflicting_path):
                if self.skip_conflicting:
                    self.logger.debug("Skipping packaging '%s' to '%s' because target filename already exists", output_file.name, target_path)
                    break
                else:
                    non_conflicting_path = name + "_" + str(counter) + ext
                    counter += 1
            else:
                self.logger.debug("Packaging '%s' to '%s'", output_file.name, non_conflicting_path)
                self.file_action(output_file.name, non_conflicting_path)
                break


class DatasetPackagerFactory:
    def __init__(self, workflows: Dict[WorkflowName, WorkflowPath], config: Configuration):
        self.package_mode = config.package_mode
        self.package_default_categories = config.package_default_categories
        self.default_configuration = [PackageConfigEntry(config.package_default_pattern, config.package_default_categories)]
        self.all_configuration = [PackageConfigEntry(config.package_default_pattern, [".*"])]
        self.file_action = {'copy': self.__copy, 'symlink': self.__symlink, 'link': self.__hardlink}.get(self.package_mode) or self.__symlink
        self.configurations = self.load_configurations(config.param_config_file, workflows)
        self.skip_conflicting = config.package_skip_conflicting

    def get_packager(self, workflow_names: List[WorkflowName], task: TaskName, parameter_set: Optional[ParameterSetName],
                     dataset_names: List[str], submission_date: str, package_base_dir: str, is_target: bool) -> DatasetPackager:
        if self.has_packaging_enabled(package_base_dir):
            task_specific_rules = self.configurations.get(workflow_names, task, parameter_set)
            if task_specific_rules:
                return ConfigurableDatasetPackager(package_base_dir, task_specific_rules, task,
                                                   dataset_names, self.file_action, submission_date, self.skip_conflicting)
            elif self.should_package_all_results(is_target):
                return ConfigurableDatasetPackager(package_base_dir, self.all_configuration, task, dataset_names,
                                                   self.file_action, submission_date, self.skip_conflicting)
            else:
                return ConfigurableDatasetPackager(package_base_dir, self.default_configuration, task, dataset_names,
                                                   self.file_action, submission_date, self.skip_conflicting)
        else:
            return NopDatasetPackager()

    def should_package_all_results(self, is_target: bool) -> bool:
        return is_target and not self.package_default_categories

    @staticmethod
    def has_packaging_enabled(package_base_dir: str) -> bool:
        return bool(package_base_dir)

    @staticmethod
    def __symlink(src: str, dst: str):
        os.symlink(src, dst)

    @staticmethod
    def __hardlink(src: str, dst: str):
        os.link(src, dst)

    @staticmethod
    def __copy(src: str, dst: str):
        shutil.copy(src, dst)

    @staticmethod
    def load_configurations(config_file: Optional[str], workflows: Dict[WorkflowName, WorkflowPath]) -> WorkflowsPackageConfig:
        if config_file and os.path.isfile(config_file):
            with open(config_file) as fid:
                config = yaml.safe_load(fid.read())
        else:
            config = {wkf: [f"{Path(path) / Path(path).name}_parameters.yaml"] for wkf, path in workflows.items()}
        loaded_configurations = {wkf: DatasetPackagerFactory.load_workflow_parameter_sets(paths) for wkf, paths in config.items()}
        return WorkflowsPackageConfig(workflow_configurations=loaded_configurations)

    @staticmethod
    def load_workflow_parameter_sets(paths: List[str]) -> ParameterSetsPackageConfig:
        result = {}
        default_set_name = None
        for parameter_file_path in paths:
            if not os.path.exists(parameter_file_path):
                logging.warning(f"Skipping loading of packaging configuration from {parameter_file_path} because it's missing")
            else:
                with open(parameter_file_path) as file:
                    config = yaml.safe_load(file.read())
                    for set_name, set_data in config.items():
                        if set_name in result:
                            raise RuntimeError(
                                f"Parameter set name '{set_name}' is duplicated, second occurrence in '{parameter_file_path}'")
                        else:
                            result[set_name] = DatasetPackagerFactory.parse_set(set_data)
                            if set_data.get('is_default', False):
                                default_set_name = set_name
        return ParameterSetsPackageConfig(parameter_set_configurations=result, default_set=default_set_name)

    @staticmethod
    def parse_set(set_data: dict) -> TasksPackageConfig:
        tasks_configurations = {}
        for task_name, patterns in set_data.get('packaging_config', {}).items():
            category_patterns = [PackageConfigEntry(pattern=pattern, categories=categories) for entry in patterns for pattern, categories in entry.items()]
            tasks_configurations[task_name] = category_patterns
        return TasksPackageConfig(task_configurations=tasks_configurations)
