from __future__ import annotations

import os
from collections import defaultdict
from dataclasses import dataclass
from functools import reduce
from typing import List, Set, Dict

from edps.client.JobInfo import JobInfo


@dataclass
class Dataset:
    name: str
    jobs: Set[str]


class Datasets:
    def __init__(self, jobs: List[JobInfo]):
        datasets = self.__build_datasets(jobs)
        self.datasets_for_job = defaultdict(list)
        for dataset in datasets:
            for job in dataset.jobs:
                self.datasets_for_job[job].append(dataset.name)

    def __build_datasets(self, jobs: List[JobInfo]) -> List[Dataset]:
        jobs_by_id = {job.job_id: job for job in jobs}
        dependent_jobs = set(sum([job.input_job_ids + job.associated_job_ids for job in jobs], []))
        return [Dataset(name=self.__extract_dataset_name(job, jobs_by_id), jobs=self.__find_dataset_jobs(job, jobs_by_id))
                for job in jobs if job.job_id not in dependent_jobs]

    def __extract_dataset_name(self, job: JobInfo, jobs_by_id: Dict[str, JobInfo]) -> str:
        if len(job.input_files) > 0:
            return self.extract_dataset_name_from_path(job.input_files[0].name)
        elif len(job.input_job_ids) > 0:
            return self.__extract_dataset_name(jobs_by_id[job.input_job_ids[0]], jobs_by_id)
        else:
            return "NO_INPUTS"

    def get_datasets_for_job(self, job_id: str) -> List[str]:
        return self.datasets_for_job[job_id]

    def __find_dataset_jobs(self, job: JobInfo, jobs_by_id) -> Set[str]:
        direct_parents = [jobs_by_id[parent_job] for parent_job in job.input_job_ids + job.associated_job_ids]
        recursive_parents = reduce(lambda s1, s2: s1.union(s2), [self.__find_dataset_jobs(parent_job, jobs_by_id) for parent_job in direct_parents], set())
        return {job.job_id}.union(recursive_parents)

    @staticmethod
    def extract_dataset_name_from_path(path: str) -> str:
        return os.path.splitext(os.path.basename(path))[0]
