import logging
import sys
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path, PurePath
from threading import RLock
from typing import List, Optional, Dict, TypeVar, Callable, Set, Any, Tuple

from astropy.io import fits
from astropy.io.fits import HDUList

from edps.client.FitsFile import FitsFile as FitsFileDTO
from .constants import MJD_OBS, INSTRUME, NIGHT, HEADER_KEYWORDS

T = TypeVar('T')


@dataclass
class HDUKeyword:
    hdu: int
    keyword: str


class FitsUtils:
    @staticmethod
    def split_keyword(keyword: str) -> HDUKeyword:
        if ':' in keyword:
            hdu, key = keyword.split(':')
            return HDUKeyword(hdu=int(hdu), keyword=key)
        else:
            return HDUKeyword(hdu=0, keyword=keyword)

    @staticmethod
    def index_of_primary_hdu(filename: str, hdul: HDUList) -> int:
        # tile-compressed files have the primary header in the first extension
        return 1 if filename.endswith('.fz') and len(hdul) > 1 and INSTRUME in hdul[1].header else 0

    @staticmethod
    def long_keyword(short_key: str) -> str:
        long_key = short_key.upper()
        if '.' in long_key:
            return "HIERARCH ESO " + long_key.replace(".", " ")
        return long_key

    @staticmethod
    def get_keyword_value(hdul: HDUList, keyword: str) -> Optional[Any]:
        hdu_key = FitsUtils.split_keyword(keyword)
        long_key = FitsUtils.long_keyword(hdu_key.keyword)
        try:
            return hdul[hdu_key.hdu].header.get(long_key, None)
        except IndexError:
            return None


class FitsFile:
    def __init__(self, file_path: str, keyword_values: Dict, *, virtual: bool = False):
        self.file_path = file_path
        self.dp_id = PurePath(self.file_path).name
        self.keyword_values = keyword_values
        self.virtual = virtual
        self.fits_file: Optional[HDUList] = None
        self.primary_hdu_offset: Optional[int] = None
        self.ref_count = 0
        self.mutex = RLock()
        self.logger = logging.getLogger('FitsFile')

    def __enter__(self):
        with self.mutex:
            self.ref_count += 1

    def __exit__(self, *args):
        with self.mutex:
            self.ref_count -= 1
            if self.ref_count == 0 and self.fits_file:
                self.logger.debug('__exit__ file %s', self.file_path)
                self.fits_file.close()
                self.fits_file = None
                self.primary_hdu_offset = None

    def __getitem__(self, item):
        if not self.is_virtual():
            self.load_keyword(item)
        return self.keyword_values[item]

    def _exception_safe_fits_open(self) -> Tuple[Optional[HDUList], Optional[int]]:
        self.logger.debug('opening file %s', self.file_path)
        try:
            hdul = fits.open(self.file_path)
            offset = FitsUtils.index_of_primary_hdu(self.file_path, hdul)
            return hdul, offset
        except Exception as e:
            self.logger.error('failed to open file %s: %s', self.file_path, e, exc_info=e)
            return None, None

    def load_keyword(self, keyword: str):
        if keyword not in self.keyword_values:
            self.logger.debug('lazy loading keyword %s', keyword)
            with self.mutex:
                if self.ref_count > 0 and not self.fits_file:
                    self.fits_file, self.primary_hdu_offset = self._exception_safe_fits_open()
            if self.fits_file:
                self.keyword_values[keyword] = FitsUtils.get_keyword_value(self.fits_file[self.primary_hdu_offset:],
                                                                           keyword)
            else:
                self.keyword_values[keyword] = None

    def get_keyword_value(self, keyword: str, default_value):
        value = self[keyword]
        return value if value is not None else default_value

    def is_virtual(self) -> bool:
        return self.virtual


class ClassifiedFitsFile:
    def __init__(self, fitsfile: FitsFile, classification: Optional[str], classification_rule_id: Optional[str]):
        self.fitsfile = fitsfile
        self.classification = classification
        self.classification_rule_id = classification_rule_id
        self.dp_id = fitsfile.dp_id

    def __hash__(self):
        return hash(self.fitsfile.file_path)

    def __eq__(self, other: 'ClassifiedFitsFile'):
        return self.fitsfile.file_path == other.fitsfile.file_path

    @classmethod
    @lru_cache(maxsize=1000)
    def from_fitsfile(cls, fitsfile: FitsFile, classification: Optional[str], classification_rule_id: Optional[str]):
        return cls(fitsfile, classification, classification_rule_id)

    def get_path(self) -> str:
        return self.fitsfile.file_path

    def get_keyword_values(self) -> Dict:
        return self.fitsfile.keyword_values

    def get_keyword_value(self, keyword: str, default_value: T) -> T:
        with self:
            return self.fitsfile.get_keyword_value(keyword, default_value)

    def __enter__(self):
        self.fitsfile.__enter__()

    def __exit__(self, *args):
        self.fitsfile.__exit__(args)

    def __getattr__(self, attr):
        return getattr(self.fitsfile, attr)

    def __getitem__(self, item):
        return self.fitsfile[item]

    def __str__(self):
        return f'{self.fitsfile.file_path} {self.classification} {self.fitsfile.keyword_values}'

    def __repr__(self):
        return str(self)

    def is_virtual(self) -> bool:
        return self.fitsfile.is_virtual()

    def get_group_id(self, grouping_keywords: List[str]) -> str:
        with self:
            return str([self[kw] for kw in grouping_keywords])

    def get_mjdobs(self) -> float:
        with self:
            return self.get_keyword_value(MJD_OBS, sys.float_info.max)

    def as_fits_file_dto(self) -> FitsFileDTO:
        return FitsFileDTO(name=self.get_path(), category=self.classification)


class ParameterResolvingClassifiedFitsFile(ClassifiedFitsFile):
    def __init__(self, file: ClassifiedFitsFile, keyword_resolver: Callable[[str], str]):
        super().__init__(file.fitsfile, file.classification, file.classification_rule_id)
        self.keyword_resolver = keyword_resolver

    def get_keyword_value(self, keyword: str, default_value: T) -> T:
        return super().get_keyword_value(self.keyword_resolver(keyword), default_value)

    def __getitem__(self, keyword):
        return self.fitsfile[self.keyword_resolver(keyword)]


class FitsFileFactory:
    logger = logging.getLogger('FitsFileFactory')

    @staticmethod
    def check_mandatory_keywords(filename: str, keyword_values: Dict):
        for keyword in [MJD_OBS, INSTRUME]:
            if keyword_values[keyword] is None:
                FitsFileFactory.logger.warning("%s undefined %s", filename, keyword.upper())

    @staticmethod
    def extract_keywords(filename: str, additional_keywords: Set[str]) -> Dict:
        mandatory_keywords = {MJD_OBS, INSTRUME}
        expanded_keywords = additional_keywords.union(mandatory_keywords).union(set(HEADER_KEYWORDS))
        keyword_values = FitsFileFactory.extract_provided_keywords(filename, expanded_keywords)
        mjd_obs = keyword_values.get(MJD_OBS, None)
        keyword_values[NIGHT] = int(mjd_obs - 0.5) if mjd_obs is not None else None
        return keyword_values

    @staticmethod
    def extract_provided_keywords(filename: str, keywords: Set[str]) -> Dict:
        with fits.open(filename) as hdul:
            keyword_values = dict()
            offset = FitsUtils.index_of_primary_hdu(filename, hdul)
            for keyword in keywords:
                keyword_values[keyword] = FitsUtils.get_keyword_value(hdul[offset:], keyword)
            return keyword_values

    @staticmethod
    def _parse_fits_file(filename: str, keywords: Set[str]) -> FitsFile:
        FitsFileFactory.logger.debug('parsing %s', filename)
        file = FitsFileFactory._get_file(filename)
        if not file.keyword_values:
            keyword_values = FitsFileFactory.extract_keywords(filename, keywords)
            file.keyword_values = keyword_values
            FitsFileFactory.check_mandatory_keywords(filename, keyword_values)
        return file

    @staticmethod
    @lru_cache(maxsize=10000)
    def _get_file(filename: str) -> FitsFile:
        return FitsFile(filename, {})

    @staticmethod
    def create_fits_file(filename: str, keywords: Set[str]) -> Optional[FitsFile]:
        try:
            fitsfile = FitsFileFactory._parse_fits_file(filename, keywords)
            return fitsfile
        except Exception as e:
            FitsFileFactory.logger.warning('invalid fits file %s: %s', filename, e, exc_info=e)
            return None

    @staticmethod
    def create_fits_files(file_paths: List[Path], keywords: Set[str]) -> List[FitsFile]:
        fits_files = (FitsFileFactory.create_fits_file(str(p), keywords) for p in file_paths)
        return [file for file in fits_files if file is not None]
