import logging
import threading
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Dict, Optional

from frozendict import frozendict

from edps.metrics.meter import Meter
from edps.metrics.publisher import Publisher


@dataclass(unsafe_hash=True)
class MeterKey:
    name: str
    tags: Dict[str, str]

    def __init__(self, name: str, tags: Dict[str, str]):
        self.name = name
        self.tags = frozendict(tags)


class MeterRegistry:
    instance: 'MeterRegistry' = None

    def __init__(self, publisher: Publisher, default_tags: Dict[str, str], step_sec: int = 60):
        self.logger = logging.getLogger("MeterRegistry")
        self.meters = {}
        self.executor = ThreadPoolExecutor(max_workers=1)
        self.tick = threading.Event()
        self.step_sec = step_sec
        self.publisher = publisher
        self.default_tags = default_tags
        self.publisher.publish(list(self.meters.values()))
        self.start()
        MeterRegistry.instance = self

    def get_meter(self, name: str, tags: Optional[Dict[str, str]] = None) -> Meter:
        all_tags = self.default_tags.copy()
        if tags:
            all_tags.update(tags)
        key = MeterKey(name=name, tags=all_tags)
        if key not in self.meters:
            meter = Meter(name, all_tags)
            self.register_meter(key, meter)
        return self.meters[key]

    def register_meter(self, key: MeterKey, meter: Meter):
        self.meters[key] = meter

    def start(self):
        self.executor.submit(self.run)

    def run(self):
        if not self.tick.wait(self.step_sec):
            try:
                self.publisher.publish(list(self.meters.values()))
            except Exception as e:
                self.logger.warning("Failed to publish metrics due to %s", e)
            self.start()

    def shutdown(self):
        self.executor.shutdown(wait=False)
        self.tick.set()
        self.publisher.publish(list(self.meters.values()))
