import json
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Dict

import requests

logger = logging.getLogger("AssociationBreakpoint")

ANY = 'ANY'


@dataclass
class AssociationBreakpoint:
    instrument: str
    dpInstrument: str
    mjdObs: float
    proCatg: str
    rawType: str
    doClass: str

    @classmethod
    def from_dict(cls, d: Dict) -> 'AssociationBreakpoint':
        return cls(instrument=d['instrument'],
                   dpInstrument=d['dpInstrument'],
                   mjdObs=d['mjdObs'],
                   proCatg=d['proCatg'],
                   rawType=d['rawType'],
                   doClass=d['doClass'])


class AssociationBreakpoints:

    def __init__(self, breakpoints_in: List[Dict]):
        logger.debug("Parsing %d association breakpoints", len(breakpoints_in))
        self.json = breakpoints_in
        breakpoints = [AssociationBreakpoint.from_dict(x) for x in breakpoints_in]
        self.breakpoints_by_ins_cat: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
        for bp in breakpoints:
            self.breakpoints_by_ins_cat[bp.instrument][bp.doClass].append(bp.mjdObs)
            if bp.proCatg != bp.doClass:
                self.breakpoints_by_ins_cat[bp.instrument][bp.proCatg].append(bp.mjdObs)
        for ins in self.breakpoints_by_ins_cat:
            for cat in self.breakpoints_by_ins_cat[ins]:
                self.breakpoints_by_ins_cat[ins][cat].sort()

    @classmethod
    def from_url(cls, url: str) -> 'AssociationBreakpoints':
        logger.debug("Loading association breakpoints from URL %s", url)
        resp = requests.get(url, timeout=15)
        return cls(resp.json())

    @classmethod
    def from_file(cls, filename: str) -> 'AssociationBreakpoints':
        logger.debug("Loading association breakpoints from FILE %s", filename)
        with open(filename) as fid:
            return cls(json.load(fid))

    def save(self, output_path: str):
        try:
            with open(output_path, "w") as breakpoints_file:
                breakpoints_file.write(json.dumps(self.json))
                logger.debug("Saved updated breakpoints list to '%s'", output_path)
        except Exception as e:
            logger.warning("Failed to save downloaded association breakpoints to '%s' due to '%s'", output_path, e)
