import logging
import os
import uuid
from collections import defaultdict
from pathlib import Path
from typing import List, Optional, Dict, Tuple, Set
from uuid import UUID

import yaml

from edps.client.CalselectorRequest import CalselectorRequest
from edps.client.FitsFile import FitsFile as ClassifiedFile
from edps.client.GraphType import GraphType
from edps.client.JobInfo import JobInfo
from edps.client.ParameterSetDTO import ParameterSetDTO
from edps.client.ProcessingRequest import ProcessingRequest
from edps.client.RequestParameters import RequestParameters
from edps.client.WorkflowDTO import WorkflowDTO
from edps.client.WorkflowStateDTO import WorkflowStateDTO
from edps.client.search import SearchFilter
from edps.config.configuration import Configuration, AppConfig
from edps.generator.constants import META_WORKFLOW_NAME, WorkflowName, WorkflowPath
from edps.interfaces.JobGenerator import JobGenerator, GenerationResult
from edps.interfaces.JobsRepository import JobNotFoundError, JobDetails
from edps.utils import merge_dicts
from .assoc_util import AssociationArgs, AssociationPreference, AssociationBreakpoints, RequestType
from .classifier import Classifier
from .constants import ARCFILE
from .fits import FitsFileFactory, FitsFile
from .input_files import InputFilesResolver
from .job import Job
from .parameters import ParameterSets, ParametersProvider
from .stateful_task import CleanupRequestArgs
from .workflow_manager import WorkflowManager


class LocalJobGenerator(JobGenerator):
    def __init__(self, config: Configuration, fits_file_factory: FitsFileFactory,
                 workflows: Dict[WorkflowName, WorkflowPath], meta_workflow: List[str]):
        self.logger = logging.getLogger('LocalJobGenerator')
        self.workflow_manager = WorkflowManager(config)
        self.fits_factory = fits_file_factory
        self.workflows = workflows
        self.meta_workflow = meta_workflow
        self.association_preference = AssociationPreference.from_str(config.assoc_preference)
        self.logger.info("association_preference=%s", self.association_preference)
        self.association_breakpoints = self.load_association_breakpoints(config)
        self.parameter_sets: ParameterSets = self.load_parameters(config.param_config_file, self.workflows)
        self.load_static_calibrations()
        self.calibrations_config = self.load_calibrations_config(config)
        self.load_calibrations("all")
        self.esorex_path = config.esorex_path

    @staticmethod
    def has_invalid_targets(targets: List[str], meta_targets: List[str], resolved_targets: List[str]) -> bool:
        return (targets or meta_targets) and not resolved_targets

    def classify_files(self, paths: List[str], workflow_name: str) -> List[ClassifiedFile]:
        workflow = self.workflow_manager.get_stateful_workflow(workflow_name)
        paths = self.resolve_inputs(paths)
        classified_files = Classifier.parse_and_classify(self.fits_factory, paths, workflow.get_keywords(),
                                                         workflow.get_classification_rules())

        return [file.as_fits_file_dto() for file in classified_files]

    def create_jobs(self, request: ProcessingRequest, request_type: RequestType) -> Tuple[GenerationResult, List[str], UUID]:
        nonexistent_files = [p for p in request.inputs if not os.path.exists(p)]
        if nonexistent_files:
            raise FileNotFoundError(f"the following files do not exist: {nonexistent_files}")

        request_id = uuid.uuid4()
        self.logger.debug("create_jobs request %s %s %s %s %s", request_id, request.inputs, request.workflow,
                          request.targets, request.meta_targets)
        workflow = self.workflow_manager.get_stateful_workflow(request.workflow)
        expanded_targets = workflow.get_targets(request.targets, request.meta_targets)
        if self.has_invalid_targets(request.targets, request.meta_targets, expanded_targets):
            msg = f"could not resolve targets {request.targets} meta-targets {request.meta_targets} for workflow '{request.workflow}'"
            raise FileNotFoundError(msg)

        self.logger.debug("expanded targets %s %s %s", request.targets, request.meta_targets, expanded_targets)
        input_files = self.resolve_inputs(request.inputs, request.excluded_inputs)
        classified_files = Classifier.parse_and_classify(self.fits_factory, input_files, workflow.get_keywords(),
                                                         workflow.get_classification_rules())
        for file in classified_files:
            self.logger.debug("%s %s", request_id, file)
        assoc_args = AssociationArgs(breakpoints=self.association_breakpoints,
                                     preference=self.resolve_assoc_preference(request),
                                     request_type=request_type)
        return workflow.create_jobs(request_id, classified_files, expanded_targets, assoc_args,
                                    ParametersProvider(self.parameter_sets, request.parameters), request.meta_targets)

    def cleanup_request(self, workflow_name: str, targets: List[str], cleanup_args: CleanupRequestArgs):
        self.workflow_manager.cleanup_request(workflow_name, targets, cleanup_args)

    def get_workflow(self, workflow_name: str) -> WorkflowDTO:
        return self.workflow_manager.get_workflow(workflow_name).as_dto()

    def get_workflow_state(self, workflow_name: str) -> WorkflowStateDTO:
        return self.workflow_manager.get_stateful_workflow(workflow_name).as_dto()

    def get_workflow_jobs(self, workflow_name: str) -> List[JobInfo]:
        jobs = self.workflow_manager.get_stateful_workflow(workflow_name).get_historical_jobs()
        return [job.as_job_info() for job in jobs]

    def get_workflow_job(self, workflow_name: str, job_id: str) -> JobInfo:
        for job in self.get_workflow_jobs(workflow_name):
            if job.job_id == job_id:
                return job
        raise JobNotFoundError(f"job {job_id} not found")

    def get_graph(self, workflow_name: str, graph_type: GraphType) -> str:
        return self.workflow_manager.get_workflow(workflow_name).generate_graph(graph_type)

    def get_assoc_map(self, workflow_name: str) -> str:
        return self.workflow_manager.get_workflow(workflow_name).generate_assoc_map()

    def reset_workflow(self, workflow_name: str):
        self.workflow_manager.reset_workflow(workflow_name)

    def get_targets(self, workflow_name: str, targets: List[str], meta_targets: List[str]) -> str:
        workflow = self.workflow_manager.get_workflow(workflow_name)
        targets = workflow.get_targets(targets, meta_targets)
        self.logger.debug("targets", targets, "ancestors", workflow.get_ancestors(targets))
        return ','.join(targets)

    def update_calib_db(self, workflow_name: str, paths: List[str]):
        self.logger.info('updating calibdb for workflow %s using %s', workflow_name, paths)
        workflow = self.workflow_manager.get_stateful_workflow(workflow_name)
        paths = self.resolve_inputs(paths)
        classified_files = Classifier.parse_and_classify(self.fits_factory, paths, workflow.get_keywords(),
                                                         workflow.get_classification_rules())
        workflow.load_calib_db(classified_files, ParametersProvider(self.parameter_sets, RequestParameters()))

    def get_parameter_sets(self, workflow_name: str) -> List[ParameterSetDTO]:
        return self.parameter_sets.get_parameter_sets(workflow_name)

    def get_default_params(self, workflow_name: str, task_name: str) -> Dict[str, object]:
        workflow = self.workflow_manager.get_workflow(workflow_name)
        return workflow.get_default_params(self.esorex_path, task_name)

    def get_recipe_params(self, workflow_name: str, task_name: str, parameter_set: str) -> Dict[str, str]:
        workflow = self.workflow_manager.get_workflow(workflow_name)
        default_parameters = workflow.get_default_params(self.esorex_path, task_name)
        recipe_parameters = self.parameter_sets.get_recipe_parameters(workflow_name, parameter_set).get(task_name, {})
        return merge_dicts(default_parameters, recipe_parameters)

    def remove_jobs(self, jobs: List[UUID]):
        self.workflow_manager.remove_jobs(jobs)

    def get_association_report(self, job: JobInfo) -> str:
        separator = '-' * 50
        result = [f'Association report for job: {job.job_id}',
                  f'Reference file (job): {job.exemplar_as_str()}']
        task_interpreter = self.workflow_manager.get_interpreter_for_task(job.task_id)
        for details in job.association_details:
            result.append(separator)
            result.append(str(details))
            match_functions = task_interpreter.get_match_functions_for_associated_input(details.task_name)
            result.append('\n'.join(match_functions))
        return '\n'.join(result)

    def find_historical_job(self, job: JobInfo) -> Optional[Job]:
        for workflow_name in job.workflow_names:
            workflow = self.workflow_manager.get_stateful_workflow(workflow_name)
            found_job = workflow.get_historical_job(job)
            if found_job:
                return found_job

    def load_static_calibrations(self):
        for workflow_name, workflow_path in self.workflows.items():
            pipeline_id = Path(workflow_path).parent.name
            static_calib_path = Path(workflow_path).parent.parent.parent / 'datastatic' / pipeline_id
            self.update_calib_db(workflow_name, [str(static_calib_path)])

    def load_calibrations_config(self, config: Configuration) -> Dict:
        config_file = config.calib_config_file
        if not config_file:
            self.logger.info("Skipping loading of calibrations.")
            return {}
        with open(config_file) as file:
            return yaml.safe_load(file.read())

    def load_calibrations(self, workflow_name: str):
        if workflow_name == "all":
            meta_paths = []
            for wkf, paths in self.calibrations_config.items():
                self.update_calib_db(wkf, paths)
                if wkf in self.meta_workflow:
                    meta_paths.extend(paths)
            self.update_calib_db(META_WORKFLOW_NAME, meta_paths)
        elif workflow_name == META_WORKFLOW_NAME:
            meta_paths = []
            for wkf in self.meta_workflow:
                if wkf in self.calibrations_config:
                    meta_paths.extend(self.calibrations_config.get(wkf))
                else:
                    self.logger.error("CalibDB is not configured for workflow %s", wkf)
            self.update_calib_db(META_WORKFLOW_NAME, meta_paths)
        elif workflow_name in self.calibrations_config:
            self.update_calib_db(workflow_name, self.calibrations_config.get(workflow_name))
        else:
            self.logger.error("CalibDB is not configured for workflow %s", workflow_name)

    def get_parameter_files(self, workflow: str, path: str) -> List[str]:
        files = [os.path.join(path, file) for file in os.listdir(path) if file.endswith('parameters.yaml')]
        if not files:
            self.logger.warning("No parameter files found for workflow %s", workflow)
        return files

    def load_parameters(self, config_file: str, workflows: Dict[WorkflowName, WorkflowPath]) -> ParameterSets:
        if config_file and os.path.isfile(config_file):
            with open(config_file) as fid:
                config = yaml.safe_load(fid.read())
        else:
            config = {wkf: self.get_parameter_files(wkf, path) for wkf, path in workflows.items()}
        return ParameterSets(config)

    def load_association_breakpoints(self, config: Configuration) -> AssociationBreakpoints:
        breakpoints_url = config.assoc_breakpoints_url
        if not breakpoints_url:
            self.logger.warning("Breakpoints URL in the application.properties is set to empty, no association breakpoints will be applied.")
            return AssociationBreakpoints([])
        static_breakpoints_file = os.path.join(os.path.expanduser("~"), ".edps", AppConfig.BREAKPOINTS_CONFIG)
        try:
            if breakpoints_url.startswith('http'):
                loaded_breakpoints = AssociationBreakpoints.from_url(breakpoints_url)
                loaded_breakpoints.save(static_breakpoints_file)
                return loaded_breakpoints
            else:
                return AssociationBreakpoints.from_file(breakpoints_url)
        except Exception as e:
            self.logger.info("Failed to load association breakpoints from configured location: '%s' due to '%s'", breakpoints_url, e)
            if os.path.exists(static_breakpoints_file):
                return AssociationBreakpoints.from_file(static_breakpoints_file)
            else:
                self.logger.warning(
                    "Failed to load breakpoints from static file '%s'. Association breakpoints will not be applied.",
                    static_breakpoints_file)
                return AssociationBreakpoints([])

    def load_past_jobs(self, complete: List[JobDetails]):
        jobs_by_workflow: Dict[str, List[JobDetails]] = defaultdict(list)
        for job in complete:
            for workflow_name in job.configuration.workflow_names:
                jobs_by_workflow[workflow_name].append(job)
        for workflow_name, jobs in jobs_by_workflow.items():
            workflow = self.workflow_manager.get_stateful_workflow(workflow_name)
            self.logger.info("Loading %d past jobs into workflow %s", len(jobs), workflow_name)
            workflow.load_past_jobs(jobs, self.fits_factory)

    def resolve_assoc_preference(self, request: ProcessingRequest) -> AssociationPreference:
        return request.association_preference or self.association_preference

    def expand_targets(self, search_filter: SearchFilter) -> SearchFilter:
        return self.workflow_manager.expand_targets(search_filter)

    @staticmethod
    def check_header_keywords(header: Dict, keywords: Set[str]):
        missing_keywords = [key for key in keywords if key not in header]
        if missing_keywords:
            raise KeyError(f"Missing keywords {missing_keywords}")

    def create_calselector_jobs(self, request: CalselectorRequest) -> Tuple[GenerationResult, List[str], UUID]:
        request_id = uuid.uuid4()
        self.logger.debug("create_calselector_jobs %s %s", request_id, request)
        workflow = self.workflow_manager.get_stateful_workflow(request.workflow)
        for hdr in request.inputs:
            self.check_header_keywords(hdr, workflow.get_grouping_keywords())

        expanded_targets = workflow.get_targets(request.targets, request.meta_targets)
        if self.has_invalid_targets(request.targets, request.meta_targets, expanded_targets):
            msg = f"could not resolve targets {request.targets} meta-targets {request.meta_targets} for workflow '{request.workflow}'"
            raise FileNotFoundError(msg)
        self.logger.debug("expanded targets %s %s %s", request.targets, request.meta_targets, expanded_targets)

        files = [FitsFile(hdr[ARCFILE], hdr, virtual=True) for hdr in request.inputs]
        classified_files = Classifier.classify_files(files, workflow.get_classification_rules())
        for file in classified_files:
            self.logger.debug("%s %s", request_id, file)
        assoc_args = AssociationArgs(breakpoints=self.association_breakpoints,
                                     preference=AssociationPreference.MASTER,
                                     request_type=RequestType.CALSELECTOR)
        return workflow.create_jobs(request_id, classified_files, expanded_targets, assoc_args,
                                    ParametersProvider(self.parameter_sets, request.parameters),
                                    request.meta_targets)

    @staticmethod
    def resolve_inputs(input_files: List[str], excluded_inputs: Optional[List[str]] = None) -> List[Path]:
        resolver = InputFilesResolver(input_files, excluded_inputs or [])
        return resolver.resolve_inputs()
