import os.path
import re
import tarfile
import unittest
from collections import Counter
from typing import List, Tuple

from astropy.io import fits

from edps.client.BreakpointsStateDTO import BreakpointsStateDTO
from edps.client.Error import RestError
from edps.client.FitsFile import FitsFile
from edps.client.JobInfo import JobInfo
from edps.client.JobSummary import Setup, Header
from edps.client.ProcessingJob import ProcessingJob
from edps.client.ProcessingJobStatus import ProcessingJobStatus
from edps.client.ProcessingResponse import ProcessingResponse
from edps.client.Rejection import Rejection
from edps.client.WorkflowDTO import CalselectorJob
from edps.client.monad import Either
from edps.generator.fits import FitsFileFactory


class OrganizationAssertions:
    def __init__(self, test: unittest.TestCase, result: Either[List[JobInfo], RestError]):
        self.test = test
        self.result = result

    def is_success(self) -> 'OrganizationAssertions':
        self.test.assertTrue(self.result.is_right(), "Request failed")
        return self

    def is_failure(self) -> 'OrganizationAssertions':
        self.test.assertTrue(self.result.is_left(), "Request succeeded")
        return self

    def has_jobs(self, count: int) -> 'OrganizationAssertions':
        self.test.assertEqual(count, len(self.result.get()), "Number of jobs doesn't match")
        return self

    def has_complete_jobs(self, count: int) -> 'OrganizationAssertions':
        jobs = self.result.get()
        num_complete_jobs = len([job for job in jobs if job.complete])
        job_info = [{'task': job.task_name, 'recipe': job.command, 'complete': job.complete} for job in jobs]
        self.test.assertEqual(count, num_complete_jobs, f"Number of complete jobs doesn't match. Actual: {job_info}")
        return self

    def has_incomplete_jobs(self, count: int) -> 'OrganizationAssertions':
        jobs = self.result.get()
        num_complete_jobs = len([job for job in jobs if not job.complete])
        job_info = [{'task': job.task_name, 'recipe': job.command, 'complete': job.complete} for job in jobs]
        self.test.assertEqual(count, num_complete_jobs, f"Number of incomplete jobs doesn't match. Actual: {job_info}")
        return self

    def has_job_matching_assoc_level(self, assoc_level: float, count: int) -> 'OrganizationAssertions':
        actual_count = len([job.assoc_level for job in self.result.get() if job.assoc_level == assoc_level])
        self.test.assertEqual(count, actual_count)
        return self

    def has_job_matching_pattern(self, pattern: str) -> 'OrganizationAssertions':
        self.test.assertTrue(any([self.job_matches(job, pattern) for job in self.result.get()]),
                             f"No job matching pattern '{pattern}' found")
        return self

    def has_job_matching_setup(self, setup: Setup) -> 'OrganizationAssertions':
        self.test.assertTrue(any([job.setup == setup for job in self.result.get()]),
                             f"No job matching setup '{setup}' found")
        return self

    def has_job_matching_header(self, header: Header) -> 'OrganizationAssertions':
        self.test.assertTrue(any([job.header == header for job in self.result.get()]),
                             f"No job matching header '{header}' found")
        return self

    def has_job_matching_report(self, report_name: str, report_input: str, driver: str) -> 'OrganizationAssertions':
        assertion = False
        for job in self.result.get():
            for report in job.reports:
                if report.name == report_name and report.input == report_input and report.driver == driver:
                    assertion = True
        self.test.assertTrue(assertion, f"No job matching report '{report_name},{report_input},{driver}' found")
        return self

    def has_job_matching(self, recipe: str, input_prefixes: List[str],
                         assoc_prefixes: List[str] = None) -> 'OrganizationAssertions':
        return self.has_jobs_matching([(recipe, input_prefixes, assoc_prefixes or [])])

    def has_jobs_matching(self, expected: List[Tuple[str, List[str], List[str]]]) -> 'OrganizationAssertions':
        actual_recipes = [job.command for job in self.result.get()]
        actual_input_prefixes = [os.path.basename(f.name) for job in self.result.get() for f in job.input_files]
        actual_assoc_prefixes = [os.path.basename(f.name) for job in self.result.get() for f in job.associated_files]
        for recipe, input_prefixes, assoc_prefixes in expected:
            message = f"No job found matching recipe '{recipe}' (actual: {actual_recipes})"
            message += f" input prefixes '{input_prefixes} (actual: {actual_input_prefixes})"
            message += f" assoc prefixes '{assoc_prefixes} (actual: {actual_assoc_prefixes})"
            self.test.assertTrue(self._has_matching_job(recipe, input_prefixes, assoc_prefixes), message)
        return self

    def has_job_belonging_to_workflow(self, workflow_name: str) -> 'OrganizationAssertions':
        self.test.assertTrue(any([workflow_name in job.workflow_names for job in self.result.get()]),
                             f"No job found matching workflow '{workflow_name}'")
        return self

    def _has_matching_job(self, recipe: str, input_prefixes: List[str], assoc_prefixes: List[str]) -> bool:
        for job in self.result.get():
            if OrganizationAssertions.is_job_matching(recipe, input_prefixes, assoc_prefixes, job):
                return True
        return False

    def has_parents(self, count: int) -> 'OrganizationAssertions':
        counts = [len(job.parents_ids) for job in self.result.get()]
        self.test.assertTrue(count in counts, f"No job found matching number of parents {count}")
        return self

    def has_any_parents(self) -> 'OrganizationAssertions':
        return self.has_parents(1)

    @staticmethod
    def is_job_matching(recipe: str, input_prefixes: List[str], assoc_prefixes: List[str], job: JobInfo) -> bool:
        if job.command == recipe and \
                OrganizationAssertions.files_match(job.input_files, input_prefixes) and \
                OrganizationAssertions.files_match(job.associated_files, assoc_prefixes):
            return True
        return False

    @staticmethod
    def files_match(input_files: List[FitsFile], prefixes: List[str]) -> bool:
        for file in input_files:
            if not OrganizationAssertions.has_matching_file(file, prefixes):
                return False
        for prefix in prefixes:
            if not OrganizationAssertions.has_matching_prefix(prefix, input_files):
                return False
        return True

    @staticmethod
    def has_matching_prefix(prefix: str, files: List[FitsFile]) -> bool:
        for file in files:
            if prefix in file.name:
                return True
        return False

    @staticmethod
    def has_matching_file(file: FitsFile, prefixes: List[str]) -> bool:
        for prefix in prefixes:
            if prefix in file.name:
                return True
        return False

    @staticmethod
    def job_matches(job: JobInfo, pattern) -> bool:
        return (any([re.match(pattern, value) for value in [job.job_id, job.command, job.command_type]]) or
                any([f.match(pattern) for f in job.input_files]) or
                any([f.match(pattern) for f in job.associated_files]) or
                any([re.match(pattern, uid) for uid in job.input_job_ids]) or
                any([re.match(pattern, uid) for uid in job.associated_job_ids]))


class ExecutionAssertions:
    def __init__(self, test: 'BaseIT', result: Either[ProcessingResponse, RestError]):
        self.test = test
        self.result = result

    def is_success(self) -> 'ExecutionAssertions':
        self.test.assertTrue(self.result.is_right(), "Request failed")
        return self

    def has_jobs(self, count: int) -> 'ExecutionAssertions':
        self.test.assertEqual(count, len(self.result.get().jobs), "Number of jobs doesn't match")
        return self

    def has_files_with_no_jobs(self, count: int) -> 'ExecutionAssertions':
        self.test.assertEqual(count, len(self.result.get().files_with_no_jobs),
                              "Number of files with no job doesn't match")
        return self

    def has_incomplete_jobs(self, count: int) -> 'ExecutionAssertions':
        self.test.assertEqual(count, len(self.result.get().incomplete_jobs), "Number of incomplete jobs doesn't match")
        return self

    def wait_until_complete(self, timeout: int = 30) -> 'ExecutionAssertions':
        for job in self.result.get().jobs:
            self.test.wait_until(lambda: self.test.job_complete(job), timeout)
        return self

    def wait_until_jobs_halted(self, count: int, timeout: int = 30) -> 'ExecutionAssertions':
        self.test.wait_until(lambda: self.test.job_halted(count), timeout)
        return self


class JobDetailsAssertions:
    def __init__(self, test: unittest.TestCase, result: Either[ProcessingJob, RestError]):
        self.test = test
        self.result = result

    def is_success(self) -> 'JobDetailsAssertions':
        self.test.assertTrue(self.result.is_right(), "Request failed")
        return self

    def is_complete(self) -> 'JobDetailsAssertions':
        self.test.assertEqual(ProcessingJobStatus.COMPLETED, self.result.get().status)
        return self

    def is_missing(self) -> 'JobDetailsAssertions':
        self.test.assertEqual(ProcessingJobStatus.MISSING, self.result.get().status)
        return self

    def is_rejected(self) -> 'JobDetailsAssertions':
        self.test.assertTrue(self.result.get().rejected)
        return self

    def is_failed(self) -> 'JobDetailsAssertions':
        self.test.assertEqual(ProcessingJobStatus.FAILED, self.result.get().status)
        return self

    def outputs_match(self, pattern: str) -> 'JobDetailsAssertions':
        self.test.assertTrue(all([re.match(pattern, os.path.basename(f.name)) for f in self.result.get().output_files]))
        return self


class RejectionAssertions:
    def __init__(self, test: unittest.TestCase, result: Either[Rejection, RestError]):
        self.test = test
        self.result = result

    def is_success(self) -> 'RejectionAssertions':
        self.test.assertTrue(self.result.is_right(), "Request failed")
        return self

    def has_rejected_ids(self, ids: List[str]) -> 'RejectionAssertions':
        self.test.assertTrue(all([job_id in self.result.get().rejected_job_ids for job_id in ids]))
        return self

    def has_incomplete_ids(self, ids: List[str]) -> 'RejectionAssertions':
        self.test.assertTrue(all([job_id in self.result.get().incomplete_job_ids for job_id in ids]))
        return self

    def has_files_to_resubmit(self, files: List['TestFitsFile']) -> 'RejectionAssertions':
        self.test.assertSetEqual(set([os.path.basename(f) for f in self.result.get().files_to_resubmit]),
                                 set([os.path.basename(f.path) for f in files]))
        return self


class AssociationReportAssertions:
    def __init__(self, test: unittest.TestCase, result: Either[bytes, RestError]):
        self.test = test
        self.result = result

    def is_success(self) -> 'AssociationReportAssertions':
        self.test.assertTrue(self.result.is_right(), "Request failed")
        return self

    def has_content(self, content: bytes) -> 'AssociationReportAssertions':
        self.test.assertTrue(content in self.result.get())
        return self


class BreakpointListAssertions:
    def __init__(self, test: unittest.TestCase, result: Either[BreakpointsStateDTO, RestError]):
        self.test = test
        self.result = result

    def is_success(self) -> 'BreakpointListAssertions':
        self.test.assertTrue(self.result.is_right(), "Request failed")
        return self


class DatasetPackageAssertions:
    def __init__(self, test: unittest.TestCase, dataset_base_dir: str):
        self.test = test
        self.dataset_base_dir = dataset_base_dir

    def has_no_dataset(self, directory: str):
        dataset_dir = os.path.join(self.dataset_base_dir, directory)
        self.test.assertFalse(os.path.exists(dataset_dir))

    def has_dataset_package_matching(self, directory: str, files_count: int) -> 'DatasetPackageAssertions':
        dataset_dir = os.path.join(self.dataset_base_dir, directory)
        self.test.assertTrue(os.path.exists(dataset_dir))
        self.test.assertEqual(files_count, len(os.listdir(dataset_dir)))
        return self

    def contains_file(self, file: str) -> 'DatasetPackageAssertions':
        real_path = os.path.join(self.dataset_base_dir, file)
        self.test.assertTrue(os.path.exists(real_path))
        return self


class ClassificationAssertions:
    def __init__(self, test: unittest.TestCase, result: Either[List[FitsFile], RestError]):
        self.test = test
        self.result = result

    def is_success(self) -> 'ClassificationAssertions':
        self.test.assertTrue(self.result.is_right(), "Request failed")
        return self

    def has_files(self, count: int) -> 'ClassificationAssertions':
        self.test.assertEqual(count, len(self.result.get()), "Number of files doesn't match")
        return self

    def has_matching_classification(self, classification: str) -> 'ClassificationAssertions':
        condition = any([file.category == classification for file in self.result.get()])
        self.test.assertTrue(condition, f"No file matching classification '{classification}' found")
        return self


class CalselectorAssertions:
    def __init__(self, test: unittest.TestCase, result: Either[List[CalselectorJob], RestError]):
        self.test = test
        self.result = result

    def is_success(self) -> 'CalselectorAssertions':
        self.test.assertTrue(self.result.is_right(), "Request failed")
        return self

    def has_jobs(self, count: int) -> 'CalselectorAssertions':
        self.test.assertEqual(count, len(self.result.get()), "Number of jobs doesn't match")
        return self

    def has_category(self, category: str) -> 'CalselectorAssertions':
        condition = any([job.exemplar.category == category for job in self.result.get()])
        self.test.assertTrue(condition, f"No job matching category '{category}' found")
        return self

    def has_associated_category(self, category: str) -> 'CalselectorAssertions':
        found = False
        for job in self.result.get():
            for group in job.associated_input_groups:
                for inp in group.associated_inputs:
                    if category in inp.input_categories:
                        found = True
                        break
        self.test.assertTrue(found, f"No associated input matching category '{category}' found")
        return self


class Phase3PackageAssertions:
    def __init__(self, test: unittest.TestCase, result: Either[str, RestError], output_path: str):
        self.test = test
        self.result = result
        self.output_path = output_path

    def is_success(self) -> 'Phase3PackageAssertions':
        self.test.assertTrue(self.result.is_right(), "Request failed")
        return self

    def has_fits_files_with_prefix(self, prefix: str, count: int) -> 'Phase3PackageAssertions':
        files = [f for f in os.listdir(self.output_path) if f.startswith(prefix) and f.endswith(".fits")]
        found = len(files)
        self.test.assertEqual(count, found, f"Expected {count} fits files with prefix '{prefix}' but found {found}")
        return self

    def has_fits_files_with_category(self, prodcatg: str, count: int) -> 'Phase3PackageAssertions':
        files = [FitsFileFactory.create_fits_file(f, {"PRODCATG"}) for f in self._files() if f.endswith(".fits")]
        counter = Counter()
        for f in files:
            with f:
                counter.update([f.get_keyword_value("PRODCATG", None)])
        found = counter.get(prodcatg, 0)
        self.test.assertEqual(count, found, f"Expected {count} fits files with category {prodcatg} but found {found}")
        return self

    def has_fits_files_with_key_value_comment(self, key: str, value: str, comment: str) -> 'Phase3PackageAssertions':
        fits_files = [f for f in self._files() if f.endswith(".fits")]
        found = False
        for f in fits_files:
            with fits.open(f) as hdul:
                hdr = hdul[0].header
                if key in hdr and hdr[key] == value and hdr.comments[key] == comment:
                    found = True
                    break
        self.test.assertTrue(found, f"No fits file found with {key}={value} and comment '{comment}'")
        return self

    def has_png_files(self, count: int) -> 'Phase3PackageAssertions':
        found = len([f for f in self._files() if f.endswith(".png")])
        self.test.assertEqual(count, found, f"Expected {count} PNG files but found {found}")
        return self

    def has_tar_files(self, count: int) -> 'Phase3PackageAssertions':
        found = len([f for f in self._files() if f.endswith(".tar")])
        self.test.assertEqual(count, found, f"Expected {count} tar files but found {found}")
        return self

    def all_tar_matching_science_files(self) -> 'Phase3PackageAssertions':
        tar_files = [f for f in self._files() if f.endswith(".tar")]
        fits_files = [FitsFileFactory.create_fits_file(f, {"PRODCATG"}) for f in self._files() if f.endswith(".fits")]
        science_files = []
        for f in fits_files:
            with f:
                if f.get_keyword_value("PRODCATG", "").startswith("SCIENCE."):
                    science_files.append(os.path.basename(f.file_path).replace(".fits", ""))
        self.test.assertTrue(all([os.path.basename(tar).replace(".tar", "") in science_files for tar in tar_files]))
        return self

    def has_matching_tar(self, ext: str, count: int) -> 'Phase3PackageAssertions':
        tar_files = [f for f in self._files() if f.endswith(".tar")]
        for tar_file_name in tar_files:
            with tarfile.open(tar_file_name, "r") as tar_file:
                contents = tar_file.getnames()
                if len(contents) == count and all([file.endswith(ext) for file in contents]):
                    break
        else:
            self.test.assertTrue(False, f"Tar matching specification {ext}:{count} not found")
        return self

    def _files(self) -> List[str]:
        return [os.path.join(self.output_path, f) for f in os.listdir(self.output_path)]
