import logging
import os
import socket
import uuid
from datetime import datetime
from typing import List, Tuple, Dict, Optional
from uuid import UUID

import psutil as psutil

from edps.cleanup.Cleaner import Cleaner
from edps.client.CalselectorRequest import CalselectorRequest
from edps.client.FitsFile import FitsFile
from edps.client.FlatOrganization import DatasetsDTO, LabelledDatasetDTO
from edps.client.GraphType import GraphType
from edps.client.JobInfo import JobInfo
from edps.client.ParameterSetDTO import ParameterSetDTO
from edps.client.ProcessingJob import ProcessingJob
from edps.client.ProcessingJobStatus import ProcessingJobStatus, JobStatus
from edps.client.ProcessingRequest import ProcessingRequest
from edps.client.ProcessingResponse import ProcessingResponse
from edps.client.Rejection import Rejection
from edps.client.ReportsConfiguration import ReportsConfiguration, ALL_REPORTS
from edps.client.RunReportsRequestDTO import RunReportsRequestDTO
from edps.client.WorkflowDTO import CalselectorJob
from edps.client.WorkflowDTO import WorkflowDTO
from edps.client.WorkflowStateDTO import WorkflowStateDTO
from edps.client.search import SearchFilter
from edps.config.configuration import Configuration
from edps.executor.LocalJobExecutor import LocalJobExecutor
from edps.executor.breakpoint import BreakpointManager
from edps.executor.dataset_packager import DatasetPackagerFactory
from edps.executor.recipe import InvokerProvider
from edps.generator.LocalJobGenerator import LocalJobGenerator
from edps.generator.assoc_util import RequestType
from edps.generator.constants import WorkflowName, WorkflowPath, ParameterSetName
from edps.generator.fits import FitsFileFactory
from edps.generator.stateful_task import CleanupRequestArgs
from edps.interfaces.JobExecutor import JobExecutor
from edps.interfaces.JobGenerator import JobGenerator
from edps.interfaces.JobScheduler import JobScheduler
from edps.interfaces.JobsRepository import JobsRepository, ExecutionResult, JobNotFoundError
from edps.jobs.LocalJobsRepository import LocalJobsRepository
from edps.metrics.meter_registry import MeterRegistry
from edps.metrics.publisher import ElasticPublisher, NopPublisher
from edps.phase3.phase3 import Phase3Configuration, build_phase3_dataset
from edps.scheduler.LocalJobScheduler import LocalJobScheduler
from edps.utils import timer


class EDPS:
    meta_workflow: List[str] = []
    logger = logging.getLogger("EDPS")

    def __init__(self, config: Configuration, workflows: Dict[WorkflowName, WorkflowPath]):
        EDPS.meta_workflow = config.meta_workflow
        publisher = NopPublisher()
        if config.metrics_enabled:
            publisher = ElasticPublisher(config.metrics_url, config.metrics_index, config.metrics_user,
                                         config.metrics_password)
        MeterRegistry(publisher, {"hostname": socket.gethostname()}, config.metrics_interval)
        self.configuration = config
        self.workflows = workflows
        self.repository = None
        self.breakpoint_manager = None
        self.executor = None
        self.cleaner = None
        try:
            self.repository = self.create_jobs_repository(config)
            self.scheduler = self.create_jobs_scheduler(config)
            self.breakpoint_manager = BreakpointManager()
            fits_file_factory = FitsFileFactory()
            dataset_packager_factory = DatasetPackagerFactory(workflows, config)
            self.invoker_provider = InvokerProvider(esorex_path=config.esorex_path,
                                                    job_scheduler=self.scheduler,
                                                    default_omp_threads=config.default_omp_threads)
            self.executor = self.create_executor(config, self.repository, self.invoker_provider,
                                                 self.breakpoint_manager, dataset_packager_factory)
            self.generator = self.create_job_generator(config, fits_file_factory, self.workflows)
            self.cleaner = Cleaner(config, self.delete_job_cascade, self.repository)
            self.clean_startup(config.resume_on_startup)
        except Exception as e:
            self.logger.error("EDPS startup failed", exc_info=e)
            self.shutdown()
            raise e

    def load_calibrations(self, workflow_name: str):
        self.generator.load_calibrations(workflow_name)

    @timer("edps_reduce_duration")
    def process_data(self, request: ProcessingRequest) -> ProcessingResponse:
        submission_date = datetime.now().isoformat()
        result, targets, request_id = self.generator.create_jobs(request, RequestType.REDUCTION)
        self.repository.set_jobs_submitted(
            [job_info for job_info in result.jobs_as_info if not job_info.is_future()], submission_date)
        jobs_to_break = [UUID(job.job_id) for job in result.jobs_as_info if job.task_name in request.breakpoint_tasks]
        self.breakpoint_manager.create_breakpoints(jobs_to_break)
        # FIXME: we're using workflow parameters set name to resolve packaging config
        self.execute_jobs(result.jobs_as_info, request.package_base_dir, request.callback, submission_date, targets,
                          request.renaming_prefix, request.reports_configuration,
                          request.parameters.workflow_parameter_set)
        self.generator.cleanup_request(request.workflow, targets, CleanupRequestArgs(request_id=request_id,
                                                                                     keep_files=self.configuration.keep_request_files,
                                                                                     keep_jobs=True))
        return ProcessingResponse(jobs=[job.to_summary_dto() for job in result.jobs_as_info],
                                  incomplete_jobs=[job.to_summary_dto() for job in result.incomplete_jobs],
                                  files_with_no_jobs=result.files_with_no_jobs,
                                  incomplete_groups=result.incomplete_groups)

    def classify_files(self, request: ProcessingRequest) -> List[FitsFile]:
        return self.generator.classify_files(request.inputs, request.workflow)

    @timer("edps_organise_duration")
    def organise_data(self, request: ProcessingRequest) -> List[JobInfo]:
        result, targets, request_id = self.generator.create_jobs(request, RequestType.ORGANIZATION)
        self.generator.cleanup_request(request.workflow, targets, CleanupRequestArgs(request_id=request_id,
                                                                                     keep_files=self.configuration.keep_request_files,
                                                                                     keep_jobs=False))
        return result.jobs_as_info + result.incomplete_jobs

    def organise_data_flatten(self, request: ProcessingRequest) -> DatasetsDTO:
        result, targets, request_id = self.generator.create_jobs(request, RequestType.ORGANIZATION)
        self.generator.cleanup_request(request.workflow, targets, CleanupRequestArgs(request_id=request_id,
                                                                                     keep_files=self.configuration.keep_request_files,
                                                                                     keep_jobs=False))
        organized_jobs = {job.job_id: job for job in result.jobs_as_info + result.incomplete_jobs}
        dependent_jobs = set(sum([job.input_job_ids + job.associated_job_ids for job in organized_jobs.values()], []))
        data_sets = [LabelledDatasetDTO.from_job_info(job, organized_jobs) for job in organized_jobs.values() if
                     job.job_id not in dependent_jobs]
        return DatasetsDTO(datasets=data_sets)

    @timer("edps_reports_duration", level=logging.INFO)
    def run_reports(self, request: RunReportsRequestDTO) -> List[str]:
        return self.executor.run_reports(request)

    @timer("edps_get_jobs_duration")
    def get_jobs(self, pattern: str, offset: int, limit: int) -> List[ProcessingJob]:
        return [x.to_processing_job() for x in self.repository.get_jobs_list(pattern, offset, limit)]

    @timer("edps_get_jobs_filter_duration")
    def get_jobs_filter(self, search_filter: SearchFilter) -> List[ProcessingJob]:
        search_filter = self.generator.expand_targets(search_filter)
        return [job.to_processing_job() for job in self.repository.get_jobs_list_filter(search_filter)]

    @timer("edps_get_job_details_duration")
    def get_job_details(self, job_id: UUID) -> ProcessingJob:
        return self.repository.get_job_details(job_id).to_processing_job()

    def get_job_status(self, job_id: UUID) -> JobStatus:
        job_details = self.repository.get_job_details_simple(job_id)
        return job_details.to_job_status()

    @timer("edps_phase3_package_duration")
    def package_phase3(self, configuration: Phase3Configuration) -> List[str]:
        return build_phase3_dataset(self.configuration.base_dir, configuration, self.repository)

    def get_job_report(self, job_id: UUID, panel_name: str) -> Tuple[str, bytes]:
        for report in self.repository.get_job_details_simple(job_id).result.reports:
            for panel in report.panels:
                if panel.file_name == panel_name:
                    report_path = os.path.join(report.report_name, panel.file_name)
                    return panel.media_type, self.executor.get_job_file(job_id, report_path)
        raise FileNotFoundError(panel_name)

    def get_job_log(self, job_id: UUID, log_name: str) -> bytes:
        for log in self.repository.get_job_details_simple(job_id).result.logs:
            if log.file_name == log_name:
                return self.executor.get_job_file(job_id, log_name)
        raise FileNotFoundError(log_name)

    def list_workflows(self) -> List[str]:
        return list(self.workflows)

    def get_workflow(self, workflow_name: str) -> WorkflowDTO:
        return self.generator.get_workflow(workflow_name)

    def get_workflow_state(self, workflow_name: str) -> WorkflowStateDTO:
        return self.generator.get_workflow_state(workflow_name)

    def get_workflow_jobs(self, workflow_name: str) -> List[JobInfo]:
        return self.generator.get_workflow_jobs(workflow_name)

    def get_workflow_job(self, workflow_name: str, job_id: str) -> JobInfo:
        return self.generator.get_workflow_job(workflow_name, job_id)

    def get_graph(self, workflow_name: str, graph_type: GraphType) -> str:
        return self.generator.get_graph(workflow_name, graph_type)

    def get_assoc_map(self, workflow_name: str) -> str:
        return self.generator.get_assoc_map(workflow_name)

    def reset_workflow(self, workflow_name: str):
        self.generator.reset_workflow(workflow_name)

    def get_targets(self, workflow_name: str, targets: List[str], meta_targets: List[str]) -> str:
        return self.generator.get_targets(workflow_name, targets, meta_targets)

    def get_parameter_sets(self, workflow_name: str) -> List[ParameterSetDTO]:
        return self.generator.get_parameter_sets(workflow_name)

    def get_default_params(self, workflow_name: str, task_name: str) -> Dict[str, object]:
        return self.generator.get_default_params(workflow_name, task_name)

    def get_recipe_params(self, workflow_name: str, task_name: str, parameter_set: str) -> Dict[str, str]:
        return self.generator.get_recipe_params(workflow_name, task_name, parameter_set)

    def reject_job(self, job_id: UUID) -> Rejection:
        self.logger.info("Rejecting job %s", job_id)
        rejection_result = self.repository.reject_job(job_id)
        self.logger.info("Rejecting job %s triggered rejecting %s and making %s incomplete, files %s need resubmission",
                         job_id, rejection_result.rejected_jobs, rejection_result.incomplete_jobs,
                         rejection_result.files_to_resubmit)
        self.generator.remove_jobs(rejection_result.rejected_jobs + rejection_result.incomplete_jobs)
        return Rejection.from_rejection_result(rejection_result)

    def delete_job_cascade(self, job_id: UUID) -> List[UUID]:
        rejected, incomplete = self.repository.find_related_jobs(job_id)
        jobs_to_delete = rejected + incomplete
        if self.executor.any_scheduled(set(jobs_to_delete)):
            raise ValueError(
                f"Deleting jobs {jobs_to_delete} triggered by deleting root {job_id} failed because one or more of the jobs are scheduled for run.")
        else:
            deleted_jobs = self.repository.remove_jobs(jobs_to_delete)
            self.generator.remove_jobs(deleted_jobs)
            return deleted_jobs

    def delete_independent_jobs_subset(self, job_ids: List[UUID]) -> List[UUID]:
        children_for_job = {job_id: set(sum(self.repository.find_related_jobs(job_id), [])) for job_id in set(job_ids)}
        while True:
            current = set(children_for_job.keys())
            children_for_job = {
                job_id: children for (job_id, children) in children_for_job.items() if not children.difference(current)
            }
            if children_for_job.keys() == current:
                break
        jobs_to_delete = list(children_for_job.keys())
        logging.debug(f"Requested deleting: {job_ids}, found independent set ready for deletion: {jobs_to_delete}")
        if self.executor.any_scheduled(set(job_ids)):
            raise ValueError(f"Deleting jobs {job_ids} failed because one or more of the jobs are scheduled for run.")
        else:
            deleted_jobs = self.repository.remove_jobs(jobs_to_delete)
            self.generator.remove_jobs(deleted_jobs)
            return deleted_jobs

    def get_association_report(self, job_id: UUID) -> str:
        job_info = self.repository.get_job_details_simple(job_id).configuration
        return self.generator.get_association_report(job_info)

    def resubmit_job_with_id(self, job_id: UUID) -> UUID:
        job = self.repository.get_job_details(job_id).configuration
        return self.resubmit_job(job)

    def resubmit_job(self, job: JobInfo) -> UUID:
        # FIXME: this logic uses only what we persist about the job, which means it does not consider "targets" or "renaming_prefix"
        job_to_execute = self.generator.find_historical_job(job)
        if job_to_execute:
            all_jobs = job_to_execute.get_all_parents()
            self.execute_jobs([job.as_job_info() for job in all_jobs])
            return uuid.UUID(job.job_id)
        else:
            raise JobNotFoundError("Failed to find job {}".format(job.job_id))

    def execute_jobs(self, jobs: List[JobInfo], package_base_dir: Optional[str] = None, callback: Optional[str] = None,
                     submission_time: str = datetime.now().isoformat(), targets: Optional[List[str]] = None,
                     renaming_prefix: Optional[str] = None, reports_configuration: ReportsConfiguration = ALL_REPORTS,
                     parameter_set_name: Optional[ParameterSetName] = None):
        self.executor.execute(jobs, submission_time, package_base_dir, callback, targets, renaming_prefix,
                              reports_configuration, parameter_set_name)

    @staticmethod
    def create_executor(config: Configuration, repository: JobsRepository, invoker_provider: InvokerProvider,
                        breakpoint_manager: BreakpointManager,
                        dataset_packager_factory: DatasetPackagerFactory) -> JobExecutor:
        return LocalJobExecutor(config, repository, invoker_provider, breakpoint_manager, dataset_packager_factory)

    @staticmethod
    def create_job_generator(config: Configuration, fits_file_factory: FitsFileFactory,
                             workflows: Dict[WorkflowName, WorkflowPath]) -> JobGenerator:
        return LocalJobGenerator(config, fits_file_factory, workflows, EDPS.meta_workflow)

    @staticmethod
    def create_jobs_repository(config: Configuration) -> JobsRepository:
        if config.local_repository:
            return LocalJobsRepository(config)
        else:
            raise NotImplementedError("Only local repository is supported right now")

    @staticmethod
    def create_jobs_scheduler(config: Configuration) -> JobScheduler:
        return LocalJobScheduler(config.cores)

    def shutdown(self):
        self.logger.info("Shutting down EDPS.")
        self.logger.info("Muting breakpoints...")
        if self.breakpoint_manager:
            self.breakpoint_manager.shutdown()
        self.logger.info("Waiting for periodic cleaner to stop...")
        if self.cleaner:
            self.cleaner.shutdown()
        self.logger.info("Waiting for on-going executions to complete...")
        if self.executor:
            self.executor.shutdown()
        self.logger.info("Waiting for all DB writes to flush...")
        if self.repository:
            self.repository.close()
        MeterRegistry.instance.shutdown()
        self.logger.info("Clean shutdown complete.")

    def clean_startup(self, resume: bool):
        self.load_past_jobs()
        self.fix_pending_jobs(resume)

    def fix_pending_jobs(self, resume):
        pending = self.repository.find_all_pending_jobs()
        self.logger.info("Cleaning up %d pending jobs", len(pending))
        for job in pending:
            self.repository.set_job_status(UUID(job.configuration.job_id),
                                           ExecutionResult(status=ProcessingJobStatus.FAILED, interrupted=resume))
        jobs_to_resume = [job for job in pending if not job.rejected]
        if resume:
            self.logger.info("Resuming pending jobs %s", [job.configuration.job_id for job in jobs_to_resume])
            for job in jobs_to_resume:
                try:
                    self.resubmit_job(job.configuration)
                except JobNotFoundError as e:
                    self.logger.warning(e.message)

    def load_past_jobs(self):
        complete = [job for job in self.repository.find_all_non_rejected_jobs()]
        self.logger.info("Loading %d past jobs for association", len(complete))
        self.generator.load_past_jobs(complete)

    def list_pipelines(self) -> List[str]:
        return [name for name in os.listdir(self.configuration.pipeline_path) if name != 'README']

    def create_calselector_jobs(self, request: CalselectorRequest) -> List[CalselectorJob]:
        result, targets, request_id = self.generator.create_calselector_jobs(request)
        self.generator.cleanup_request(request.workflow, targets, CleanupRequestArgs(request_id=request_id,
                                                                                     keep_files=False,
                                                                                     keep_jobs=False))
        return result.calselector_jobs

    def get_scheduled_jobs(self) -> List[UUID]:
        return self.executor.get_scheduled_jobs()

    def halt_executions(self, terminate_running=True):
        if self.breakpoint_manager:
            self.breakpoint_manager.shutdown()
        if terminate_running:
            for proc in psutil.Process(os.getpid()).children(recursive=True):
                try:
                    cmdline = proc.cmdline()
                except psutil.AccessDenied:
                    continue
                try:
                    if any([self.configuration.esorex_path in param or self.configuration.genreport_path in param for
                            param in cmdline]):
                        self.logger.warning(f"Kill {cmdline}")
                        proc.terminate()
                except psutil.AccessDenied:
                    self.logger.warning(f"Failed to terminate {proc}")
        self.executor.wait_until_nothing_scheduled()
        self.breakpoint_manager.stopped = False
