import dataclasses
import json
import logging
import os
import threading
from dataclasses import dataclass, field
from datetime import datetime
from typing import List, Dict, Optional, Set
from uuid import uuid4

import time


def get_timestamp() -> str:
    return datetime.now().isoformat(timespec='milliseconds')


@dataclass
class ReductionConfig:
    comment: str = ''
    report_type: str = 'none'
    parameter_set: str = 'science_parameters'
    workflow_parameters: dict = field(default_factory=dict)
    recipe_parameters: dict = field(default_factory=dict)
    timestamp: str = field(default_factory=get_timestamp)

    def __eq__(self, other: 'ReductionConfig'):
        return (self.parameter_set == other.parameter_set and
                self.workflow_parameters == other.workflow_parameters and
                self.recipe_parameters == other.recipe_parameters)

    @classmethod
    def from_dict(cls, data):
        return cls(comment=data.get('comment'),
                   report_type=data.get('report_type', 'none'),
                   parameter_set=data.get('parameter_set'),
                   workflow_parameters=data.get('workflow_parameters'),
                   recipe_parameters=data.get('recipe_parameters'),
                   timestamp=data.get('timestamp', get_timestamp()))


@dataclass
class ClassifiedFile:
    name: str
    category: str

    def __eq__(self, other: 'ClassifiedFile') -> bool:
        return other.name == self.name and other.category == self.category

    def __hash__(self):
        return hash((self.name, self.category))

    @classmethod
    def from_dict(cls, d):
        return cls(name=d['name'], category=d['category'])


@dataclass
class Reduction:
    timestamp: str
    dataset: str
    workflow: str
    tasks: List[str]
    target: Optional[str]
    obs_target: Optional[str]
    input_files: List[ClassifiedFile]
    job_ids: List[str] = field(default_factory=list)
    config: ReductionConfig = field(default_factory=ReductionConfig)
    id: str = field(default_factory=lambda: str(uuid4()))
    archived: bool = False
    completed: bool = False

    def __hash__(self):
        return hash(self.id)

    def __eq__(self, other: 'Reduction'):
        return (self.dataset == other.dataset and
                self.workflow == other.workflow and
                self.target == other.target and
                set(self.input_files) == set(other.input_files) and
                self.config == other.config)

    @classmethod
    def from_dict(cls, data):
        return cls(timestamp=data.get('timestamp'),
                   dataset=data.get('dataset'),
                   workflow=data.get('workflow'),
                   tasks=data.get('tasks', []),
                   target=data.get('target'),
                   obs_target=data.get('obs_target'),
                   input_files=[ClassifiedFile.from_dict(f) for f in data['input_files']],
                   job_ids=data.get('job_ids', []),
                   config=ReductionConfig.from_dict(data.get('config')),
                   id=data.get('id'),
                   archived=data.get('archived', True),
                   completed=data.get('completed', False))


class ReductionRepository:
    def __init__(self, filename: str):
        self.logger = logging.getLogger(self.__class__.__name__)
        self.filename = filename
        self.reductions: Dict[str, Reduction] = self.load()
        self.lock = threading.RLock()
        self.need_save = False
        self.keep_running = True
        threading.Thread(target=self.periodic_save, daemon=True).start()

    def periodic_save(self):
        while self.keep_running:
            if self.need_save:
                self.logger.info('Saving reductions to %s', self.filename)
                self.need_save = False
                with self.lock:
                    self._save()
            threading.Event().wait(30)

    def _save(self):
        start = time.perf_counter()
        with open(self.filename, 'w') as outfile:
            data = {key: dataclasses.asdict(obj) for key, obj in self.reductions.items()}
            json.dump(data, outfile)
        self.logger.debug('Saved %d reductions in %.2f s', len(self.reductions), time.perf_counter() - start)

    def load(self) -> Dict[str, Reduction]:
        if os.path.exists(self.filename):
            with open(self.filename) as infile:
                return {key: Reduction.from_dict(obj) for key, obj in json.load(infile).items()}
        else:
            return {}

    def get_all_as_set(self) -> Set[Reduction]:
        return set(self.reductions.values())

    def get_all_as_list(self) -> List[Reduction]:
        return list(self.reductions.values())

    def get_by_id(self, reduction_id: str) -> Reduction:
        return self.reductions.get(reduction_id)

    def get_by_dataset_and_target(self, dataset: str, target: str) -> List[Reduction]:
        return [r for r in self.reductions.values() if r.dataset == dataset and r.target == target]

    def add(self, reduction: Reduction) -> bool:
        with self.lock:
            if reduction not in self.reductions.values():
                self.reductions[reduction.id] = reduction
                self.need_save = True
                return True
            else:
                return False

    def update_comment(self, reduction_id: str, comment: str, mode: str = 'replace'):
        with self.lock:
            if reduction_id in self.reductions:
                if mode == 'append' and self.reductions[reduction_id].config.comment:
                    comment = self.reductions[reduction_id].config.comment + '\n' + comment
                self.reductions[reduction_id].config.comment = comment
                self.need_save = True

    def update_report_type(self, reduction_id: str, report_type: str):
        with self.lock:
            if reduction_id in self.reductions:
                self.reductions[reduction_id].config.report_type = report_type
                self.need_save = True

    def update_config(self, reduction_id: str, config: ReductionConfig, update_comment: bool = False):
        with self.lock:
            if reduction_id in self.reductions:
                if not update_comment:
                    config.comment = self.reductions[reduction_id].config.comment
                self.reductions[reduction_id].config = config
                self.need_save = True

    def clone(self, reduction: Reduction, new_config: ReductionConfig) -> Optional[Reduction]:
        new_reduction = dataclasses.replace(reduction,
                                            timestamp=new_config.timestamp,
                                            job_ids=[],
                                            config=new_config,
                                            id=str(uuid4()),
                                            archived=False,
                                            completed=False)
        with self.lock:
            return new_reduction if self.add(new_reduction) else None

    def delete(self, reduction_id: str) -> Reduction:
        with self.lock:
            self.need_save = True
            return self.reductions.pop(reduction_id)

    def set_archived(self, reduction_id: str, archived: bool):
        with self.lock:
            self.reductions[reduction_id].archived = archived
            self.need_save = True

    def set_completed(self, reduction_id: str):
        with self.lock:
            self.reductions[reduction_id].completed = True
            self.need_save = True

    def set_jobs(self, reduction_id: str, job_ids: List[str]):
        with self.lock:
            self.reductions[reduction_id].job_ids = job_ids
            self.need_save = True


def test():
    from uuid import uuid4
    from pprint import pprint
    repo = ReductionRepository('test.json')
    for _ in range(10):
        repo.add(Reduction(
            timestamp=get_timestamp(),
            dataset=str(uuid4()),
            workflow='uves.uves_wkf',
            tasks=['flat', 'bias'],
            target='flat',
            obs_target='eta carinae',
            input_files=[ClassifiedFile(name='f1', category='c1'),
                         ClassifiedFile(name='f2', category='c2'),
                         ClassifiedFile(name='f3', category='c3')],
            job_ids=[str(uuid4()) for _ in range(2)]))
    pprint(repo.reductions)
    repo._save()
    repo.load()


if __name__ == '__main__':
    test()
