import dataclasses
import json
import os
import random
import secrets
import string
import unittest
import uuid
from collections import defaultdict
from contextlib import contextmanager
from typing import Dict, List, Optional

from astropy.io import fits

from edps import utils
from edps.client.Error import RestError
from edps.client.FitsFile import FitsFile
from edps.client.JobInfo import JobInfo
from edps.client.ProcessingJob import ProcessingJob
from edps.client.ProcessingRequest import ProcessingRequest
from edps.client.ProcessingResponse import ProcessingResponse
from edps.client.Rejection import Rejection
from edps.client.ReportsConfiguration import ReportsConfiguration, ALL_REPORTS
from edps.client.RequestParameters import RequestParameters
from edps.client.WorkflowDTO import CalselectorJob
from edps.client.monad import Either
from edps.generator.association_breakpoints import AssociationBreakpoint
from edps.generator.association_preference import AssociationPreference
from edps.generator.fits import FitsUtils
from edps.test.dsl.assertions import OrganizationAssertions, ExecutionAssertions, JobDetailsAssertions, \
    RejectionAssertions, AssociationReportAssertions, DatasetPackageAssertions, ClassificationAssertions, \
    CalselectorAssertions, Phase3PackageAssertions


class TestTemplate:
    def __init__(self, files: List['TestFitsFile']):
        self.files: List['TestFitsFile'] = files

    @staticmethod
    def builder() -> 'TestTemplateBuilder':
        return TestTemplateBuilder()


class TestTemplateBuilder:
    def __init__(self):
        self.files: List[TestFitsFile] = []

    def with_file(self, file: 'TestFitsFile') -> 'TestTemplateBuilder':
        self.files.append(file)
        return self

    def with_files(self, files: List['TestFitsFile']) -> 'TestTemplateBuilder':
        self.files.extend(files)
        return self

    def with_template_files(self, prefix: str, keywords: Dict[str, object], mjd_obs: float,
                            count: int) -> 'TestTemplateBuilder':
        tpl_start = keywords.get("tpl.start", str(random.randint(0, 10000)))
        for i in range(count):
            arcfile = f'{prefix}_{i + 1}_{self.random_alphanumeric(10)}.fits'
            self.with_file(TestFitsFile.builder()
                           .with_name(arcfile)
                           .with_keyword('arcfile', arcfile)
                           .with_keywords(keywords)
                           .with_keyword("tpl.start", tpl_start)
                           .with_keyword("tpl.nexp", count)
                           .with_keyword_if_missing("tpl.expno", i + 1)
                           .with_keyword_if_missing('mjd-obs', mjd_obs + i * 0.02)
                           .build())
        return self

    def random_alphanumeric(self, length: int) -> str:
        return ''.join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))

    def build(self) -> TestTemplate:
        return TestTemplate(self.files)


class TestFitsFile:
    def __init__(self, path: str, keywords: Dict[str, object]):
        self.path = path
        self.keywords = keywords

    @property
    def keywords_by_hdu(self) -> Dict:
        result = defaultdict(dict)
        for keyword, value in self.keywords.items():
            hdu_key = FitsUtils.split_keyword(keyword)
            result[hdu_key.hdu].update({hdu_key.keyword: value})
        return result

    @staticmethod
    def builder() -> 'TestFitsFileBuilder':
        return TestFitsFileBuilder()


class TestFitsFileBuilder(object):
    def __init__(self):
        self.name = ''.join(random.choices(population=string.ascii_letters, k=10)) + '.fits'
        self.keywords = {}

    def with_name(self, name: str) -> 'TestFitsFileBuilder':
        self.name = name
        return self

    def with_random_name(self, prefix: str = '') -> 'TestFitsFileBuilder':
        self.name = f'{prefix}_{uuid.uuid4()}.fits'
        return self

    def with_keywords(self, keywords: Dict[str, object]) -> 'TestFitsFileBuilder':
        for k, v in keywords.items():
            self.with_keyword(k, v)
        return self

    def with_keyword(self, key: str, val: object) -> 'TestFitsFileBuilder':
        self.keywords[FitsUtils.long_keyword(key)] = val
        return self

    def with_keyword_if_missing(self, key: str, val: object) -> 'TestFitsFileBuilder':
        if FitsUtils.long_keyword(key) not in self.keywords:
            return self.with_keyword(key, val)
        return self

    def build(self) -> TestFitsFile:
        return TestFitsFile(self.name, self.keywords)


class TestConfiguration:
    def __init__(self, helper: 'TestHelper', templates: List[TestTemplate], inputs: List[TestFitsFile],
                 calib_files: List[TestFitsFile], assoc_breakpoints: List[AssociationBreakpoint]):
        self.helper = helper
        self.inputs = self.helper.create_inputs(templates, inputs, calib_files)
        self.helper.create_assoc_breakpoints(assoc_breakpoints)

    @staticmethod
    def new_configuration(helper: 'TestHelper') -> 'TestConfigurationBuilder':
        return TestConfigurationBuilder(helper)

    def cleanup(self):
        self.helper.remove_inputs(self.inputs)
        if os.path.exists(TestHelper.BREAKPOINTS_FILENAME):
            os.remove(TestHelper.BREAKPOINTS_FILENAME)


class TestConfigurationBuilder:
    def __init__(self, helper):
        self.helper = helper
        self.input_templates: List[TestTemplate] = []
        self.input_files: List[TestFitsFile] = []
        self.calib_files: List[TestFitsFile] = []
        self.assoc_breakpoints: List[AssociationBreakpoint] = []

    def with_assoc_breakpoint(self, mjd_obs: float, instrument: str, raw_type: str,
                              pro_catg: str) -> 'TestConfigurationBuilder':
        self.assoc_breakpoints.append(AssociationBreakpoint(instrument=instrument, dpInstrument=instrument,
                                                            mjdObs=mjd_obs, rawType=raw_type, proCatg=pro_catg,
                                                            doClass=raw_type))
        return self

    def with_templates(self, inputs: List[TestTemplate]) -> 'TestConfigurationBuilder':
        self.input_templates.extend(inputs)
        return self

    def with_template(self, input_template: TestTemplate) -> 'TestConfigurationBuilder':
        self.input_templates.append(input_template)
        return self

    def with_input(self, input_file: TestFitsFile) -> 'TestConfigurationBuilder':
        self.input_files.append(input_file)
        return self

    def with_calib_file(self, calib_file: TestFitsFile) -> 'TestConfigurationBuilder':
        self.calib_files.append(calib_file)
        return self

    def with_inputs(self, inputs: List[TestFitsFile]) -> 'TestConfigurationBuilder':
        self.input_files.extend(inputs)
        return self

    @contextmanager
    def build(self) -> TestConfiguration:
        configuration = TestConfiguration(self.helper, self.input_templates, self.input_files, self.calib_files,
                                          self.assoc_breakpoints)
        try:
            yield configuration
        finally:
            configuration.cleanup()


class TestHelper:
    BREAKPOINTS_FILENAME = "breakpoints.json"

    def __init__(self, data_dir: str, workflow_dir: str, calib_dir: str, base_dir: str, dataset_base_dir: str, edps):
        self.workflow_dir = workflow_dir
        self.data_dir = data_dir
        self.calib_dir = calib_dir
        self.base_dir = base_dir
        self.dataset_base_dir = dataset_base_dir
        self.edps = edps

    def create_configuration(self) -> TestConfigurationBuilder:
        return TestConfiguration.new_configuration(self)

    def create_inputs(self, templates: List[TestTemplate], inputs: List[TestFitsFile],
                      calib_files: List[TestFitsFile]) -> List[TestFitsFile]:
        real_path_inputs = []
        os.makedirs(self.data_dir, exist_ok=True)
        os.makedirs(self.calib_dir, exist_ok=True)
        for template in templates:
            for file in template.files:
                real_path_inputs.append(self.create_file(file, self.data_dir))
        for file in inputs:
            real_path_inputs.append(self.create_file(file, self.data_dir))
        for file in calib_files:
            real_path_inputs.append(self.create_file(file, self.calib_dir))
        return real_path_inputs

    @staticmethod
    def create_assoc_breakpoints(assoc_breakpoints: List[AssociationBreakpoint]):
        if assoc_breakpoints:
            data = [dataclasses.asdict(bp) for bp in assoc_breakpoints]
            with open(TestHelper.BREAKPOINTS_FILENAME, "w") as outfile:
                json.dump(data, outfile)

    @staticmethod
    def create_file(file: TestFitsFile, directory: str) -> TestFitsFile:
        hdus = []
        num_hdus = max(file.keywords_by_hdu.keys()) if file.keywords_by_hdu else 0
        for hdu_num in range(num_hdus + 1):
            cards = [fits.Card(k, v) for k, v in file.keywords_by_hdu.get(hdu_num, {}).items()]
            if hdu_num == 0:
                hdus.append(fits.PrimaryHDU(header=fits.Header(cards=cards)))
            else:
                hdus.append(fits.ImageHDU(header=fits.Header(cards=cards)))
        hdul = fits.HDUList(hdus=hdus)
        destination = os.path.abspath(os.path.join(directory, file.path))
        hdul.writeto(destination, overwrite=True)
        hdul.close()
        return (TestFitsFile.builder()
                .with_name(destination)
                .with_keywords(file.keywords)
                .build())

    def remove_inputs(self, inputs: List[TestFitsFile]):
        for file in inputs:
            path = os.path.join(self.data_dir, file.path)
            if os.path.exists(path):
                os.remove(path)

    @staticmethod
    def assert_on_classification(test: unittest.TestCase,
                                 result: Either[List[FitsFile], RestError]) -> ClassificationAssertions:
        return ClassificationAssertions(test, result)

    @staticmethod
    def assert_on_calselector(test: unittest.TestCase,
                              result: Either[List[CalselectorJob], RestError]) -> CalselectorAssertions:
        return CalselectorAssertions(test, result)

    @staticmethod
    def assert_on_organization(test: unittest.TestCase,
                               result: Either[List[JobInfo], RestError]) -> OrganizationAssertions:
        return OrganizationAssertions(test, result)

    @staticmethod
    def assert_on_execution(test, result: Either[ProcessingResponse, RestError]) -> ExecutionAssertions:
        return ExecutionAssertions(test, result)

    @staticmethod
    def assert_on_job_details(test: unittest.TestCase,
                              result: Either[ProcessingJob, RestError]) -> JobDetailsAssertions:
        return JobDetailsAssertions(test, result)

    @staticmethod
    def assert_on_rejection(test: unittest.TestCase, result: Either[Rejection, RestError]) -> RejectionAssertions:
        return RejectionAssertions(test, result)

    @staticmethod
    def assert_on_association_report(test: unittest.TestCase,
                                     result: Either[bytes, RestError]) -> AssociationReportAssertions:
        return AssociationReportAssertions(test, result)

    @staticmethod
    def assert_on_phase3_package(test: unittest.TestCase,
                                 result: Either[str, RestError], output_path: str) -> Phase3PackageAssertions:
        return Phase3PackageAssertions(test, result, output_path)

    @staticmethod
    def create_request(configuration: TestConfiguration, targets: List[str], meta_targets: List[str],
                       workflow: str, workflow_parameters: Optional[Dict] = None, workflow_parameter_set: str = "",
                       recipe_parameters: Optional[Dict] = None, recipe_parameter_set: str = "",
                       breakpoint_tasks: Optional[List[str]] = None,
                       association_preference: Optional[AssociationPreference] = None,
                       package_output_dir: Optional[str] = None, renaming_prefix: Optional[str] = None,
                       reports_configuration: ReportsConfiguration = ALL_REPORTS) -> ProcessingRequest:
        workflow_parameters = workflow_parameters or {}
        recipe_parameters = recipe_parameters or {}
        breakpoint_tasks = breakpoint_tasks or []
        package_base_dir = package_output_dir
        return ProcessingRequest(inputs=[f.path for f in configuration.inputs], targets=targets,
                                 meta_targets=meta_targets, workflow=workflow,
                                 parameters=RequestParameters(workflow_parameter_set=workflow_parameter_set,
                                                              workflow_parameters=workflow_parameters,
                                                              recipe_parameter_set=recipe_parameter_set,
                                                              recipe_parameters=recipe_parameters),
                                 breakpoint_tasks=breakpoint_tasks,
                                 association_preference=association_preference,
                                 package_base_dir=package_base_dir,
                                 renaming_prefix=renaming_prefix,
                                 reports_configuration=reports_configuration
                                 )

    @staticmethod
    def merge_dicts(*args: Dict) -> Dict:
        return utils.merge_dicts(*args)

    def assert_on_dataset_package(self, test: unittest.TestCase) -> DatasetPackageAssertions:
        return DatasetPackageAssertions(test, self.dataset_base_dir)

    @staticmethod
    def assert_on_custom_dataset_package(test: unittest.TestCase,
                                         package_output_dir: str) -> DatasetPackageAssertions:
        return DatasetPackageAssertions(test, package_output_dir)
