from typing import List, Set

from edps import FilterMode
from edps.client.FitsFile import FitsFile


class ProductFilter(object):
    def __init__(self, mode: str, categories: Set[str]):
        self.mode = mode
        self.categories = categories

    def is_accepted(self, file: FitsFile) -> bool:
        raise NotImplementedError

    def get_mode(self) -> str:
        return self.mode

    def get_categories(self) -> Set[str]:
        return self.categories


class FilterFactory(object):
    @staticmethod
    def create_filter(categories: List[str], mode: str) -> ProductFilter:
        if len(categories) == 0:
            return AcceptAllFilter()
        elif mode.upper() == FilterMode.REJECT:
            return RejectFilter(set(categories))
        else:
            return SelectFilter(set(categories))


class AcceptAllFilter(ProductFilter):
    def __init__(self):
        super().__init__(FilterMode.SELECT,set())

    def is_accepted(self, file: FitsFile) -> bool:
        return True


class SelectFilter(ProductFilter):
    def __init__(self, categories: Set[str]):
        super().__init__(FilterMode.SELECT, categories)

    def is_accepted(self, file: FitsFile) -> bool:
        return file.category in self.get_categories()

class RejectFilter(ProductFilter):
    def __init__(self, categories: Set[str]):
        super().__init__(FilterMode.REJECT, categories)

    def is_accepted(self, file: FitsFile) -> bool:
        return file.category not in self.get_categories()
