import hashlib
import logging
import os.path
import re
import shutil
import tarfile
from typing import List, Tuple, Dict, Any, Optional, Set
from uuid import UUID

from astropy.io import fits
from astropy.io.fits import Header
from pydantic import BaseModel

from edps import FitsFile
from edps.client.ProcessingJob import ReportEntry
from edps.executor.renamer import PatternProductRenamer, ProductRenamer
from edps.executor.reporting import ReportEntryPanel
from edps.generator.fits import FitsFileFactory
from edps.interfaces.JobsRepository import JobDetails, JobsRepository

AdditionalKeywords = Dict[str, Any]
Pipefile = str
Filename = str
Category = str

PRODCATG = 'PRODCATG'
PRO_CATG = 'PRO.CATG'
PIPEFILE = 'PIPEFILE'


class Phase3ScienceFileConfiguration(BaseModel):
    associated_fits_categories: Optional[List[str]] = []
    associated_report_categories: Optional[List[str]] = []
    tar_fits_categories: Optional[List[str]] = []
    tar_report_categories: Optional[List[str]] = []
    tar_category: Optional[str] = None


class Phase3StreamConfiguration(BaseModel):
    tasks: List[str]
    science_files_configuration: Dict[str, Phase3ScienceFileConfiguration]


class Phase3Configuration(BaseModel):
    output_location: str
    jobs_with_keywords: Dict[str, AdditionalKeywords]
    stream_configuration: Dict[str, Phase3StreamConfiguration]
    renaming_pattern: str
    category_renaming_pattern: Dict[str, str] = {}
    # When True, parent jobs of the provided jobs are recursively included in the dataset.
    # In this case, all provided jobs must be leaves in the job graph.
    # False for HAWKI, where we create a dataset per job.
    group_jobs: bool = True
    # When True, associated science files that were not used to produce the main science file are filtered out
    # True for most instruments, for example to filter products by fiber in GIRAFFE.
    # False for MUSE where associated science files are needed, since the individual cubes are not direct inputs to the main cube.
    filter_associated: bool = True


logger = logging.getLogger("Phase3")


def get_product_renamer(renaming_pattern: str, category_renaming_pattern: Dict[str, str],
                        categories: Set[str]) -> ProductRenamer:
    pattern = next(
        (category_renaming_pattern[cat] for cat in categories if cat in category_renaming_pattern),
        renaming_pattern
    )
    logger.debug(f"Using renaming pattern for categories '{categories}': '{pattern}'")
    return PatternProductRenamer(pattern)


def rename(path: str, renaming_pattern: str, category_renaming_pattern: Dict[str, str]) -> Filename:
    categories = get_categories(path)
    renamer = get_product_renamer(renaming_pattern, category_renaming_pattern, categories)
    return renamer.rename(path)


# DFS-21738
def build_phase3_dataset(base_dir: str, config: Phase3Configuration, repository: JobsRepository) -> List[str]:
    result = []
    for stream_name, stream_config in config.stream_configuration.items():
        output_dir = os.path.join(config.output_location, stream_name)
        output = build_phase3_stream(base_dir, output_dir, config.jobs_with_keywords, stream_config,
                                     config.renaming_pattern, config.category_renaming_pattern,
                                     config.filter_associated, config.group_jobs, repository)
        result.append(output)
    return result


def format_obgrade_card(value: str) -> str:
    return f"OB_GRADE= '{value:<8}' / A-fully within B-mostly within C|D-out of specs X-unknown"


def filter_output_files(output_files: Set[FitsFile], main_file: str) -> Set[FitsFile]:
    recipe_input_files = get_recipe_input_files(main_file)
    return {
        f for f in output_files
        if (not is_science_file(f.name) or os.path.basename(f.name) in recipe_input_files)
    }


# DFS-21368
def build_phase3_stream(base_dir: str, output_location: str,
                        jobs_with_keywords: Dict[str, AdditionalKeywords],
                        stream_configuration: Phase3StreamConfiguration,
                        renaming_pattern: str,
                        category_renaming_pattern: Dict[str, str],
                        filter_associated: bool,
                        group_jobs: bool,
                        repository: JobsRepository) -> str:
    job_ids = [UUID(job_id) for job_id in jobs_with_keywords.keys()]
    if group_jobs:
        verify_leaves(job_ids, repository)
    os.makedirs(output_location)
    old_to_new_prov_names: Dict[Pipefile, Filename] = {}
    jobs = [repository.get_job_details(job_id) for job_id in job_ids]
    for job in jobs:
        if job.configuration.task_name in stream_configuration.tasks:
            logger.info("Processing job: %s %s", job.configuration.job_id, job.configuration.task_name)
            additional_keywords = jobs_with_keywords[job.configuration.job_id]
            science_files = [file.name for file in job.result.output_files if is_science_file(file.name)]
            output_files, associated_files = get_associated_files(job, repository, False, group_jobs)
            for main_file in science_files:
                logger.info("Processing science file: %s", main_file)
                renamed_main_file = rename(main_file, renaming_pattern, category_renaming_pattern)
                renamed_main_prefix, _ = os.path.splitext(os.path.basename(renamed_main_file))
                configurations = [stream_configuration.science_files_configuration.get(cat) for cat in
                                  get_categories(main_file)]
                if any(configurations):
                    pipefile = get_pipefile(main_file)
                    old_to_new_prov_names[pipefile] = renamed_main_file
                    configuration = [c for c in configurations if c is not None][0]
                    with fits.open(main_file) as hdu_list:
                        job_dir = os.path.join(base_dir, job.configuration.instrument, job.configuration.task_name,
                                               job.configuration.job_id)
                        primary_header = hdu_list[0].header
                        next_index = 1

                        # remove all ASSON/M/C keywords
                        remove_asso_keywords(primary_header, next_index)

                        # add keywords to primary header
                        for key, value in additional_keywords.items():
                            lower_key = key.lower()
                            if lower_key == 'vm_sm':
                                primary_header[key] = (value, 'VM-Visitor Mode; SM-Service Mode')
                            elif lower_key == 'ob_grade':
                                # special handling for OB_GRADE to avoid comment truncation
                                card = fits.Card.fromstring(format_obgrade_card(value))
                                primary_header.append(card)
                            else:
                                primary_header[key] = value

                        # add assoc files, remember to change prodcatg if it was SCIENCE
                        adjusted_assoc_files = []
                        if filter_associated:
                            output_files = filter_output_files(output_files, main_file)
                        for assoc_file in output_files | associated_files:
                            logger.info("Processing assoc file: %s", assoc_file)
                            if matches_catg_keywords(assoc_file.name, configuration.associated_fits_categories):
                                logger.info("Assoc file: %s is configured to be added", assoc_file)
                                # NOTE that we rename the file before changing the prodcatg, this can lead to inconsistencies
                                target_file_name = rename(assoc_file.name, renaming_pattern, category_renaming_pattern)
                                old_to_new_prov_names[pipefile] = target_file_name
                                set_keyword(main_file, primary_header, "ASSON", next_index, target_file_name)
                                assoc_file_destination = os.path.join(output_location, target_file_name)
                                shutil.copy(assoc_file.name, assoc_file_destination)
                                logger.info("Assoc file: %s copied to %s", assoc_file, assoc_file_destination)
                                assoc_file_prodcatg = get_prodcatg(assoc_file.name)
                                if assoc_file_prodcatg.startswith("SCIENCE"):
                                    logger.info("Assoc file: %s is science so the category will change to ANCILLARY",
                                                assoc_file)
                                    with fits.open(assoc_file_destination, mode='update') as target_assoc_file:
                                        assoc_primary_header = target_assoc_file[0].header
                                        assoc_primary_header[PRODCATG] = assoc_file_prodcatg.replace("SCIENCE",
                                                                                                     "ANCILLARY")
                                        remove_asso_keywords(assoc_primary_header, 1)
                                adjusted_assoc_files.append(
                                    FitsFile(name=assoc_file_destination, category=assoc_file_prodcatg))
                                next_index += 1
                            else:
                                adjusted_assoc_files.append(assoc_file)

                        # add assoc reports
                        report_index = 1
                        for report in job.result.reports:
                            logger.info("Processing report: %s", report.report_name)
                            for panel in report.panels:
                                if should_add_panel(panel, main_file, configuration.associated_report_categories):
                                    logger.info("Adding report panel: %s to main file", panel)
                                    panel_path = os.path.join(job_dir, report.report_name, panel.file_name)
                                    panel_name = rename_panel(renamed_main_prefix, panel, report_index)
                                    set_keyword(main_file, primary_header, "ASSON", next_index, panel_name)
                                    set_keyword(main_file, primary_header, "ASSOC", next_index, panel.prodcatg)
                                    set_keyword(main_file, primary_header, "ASSOM", next_index, file_md5(panel_path))
                                    panel_output = os.path.join(output_location, panel_name)
                                    shutil.copy(panel_path, panel_output)
                                    next_index += 1
                                    report_index += 1

                        # add tar, but the associated files in tar should already have the prodcatg adjusted so checksums match
                        if configuration.tar_category is not None:
                            tar_result = build_tar(job_dir, main_file, adjusted_assoc_files, job.result.reports,
                                                   output_location, configuration, renaming_pattern,
                                                   category_renaming_pattern)
                            if tar_result:
                                tar_name, tar_checksum = tar_result
                                set_keyword(main_file, primary_header, "ASSON", next_index, tar_name)
                                set_keyword(main_file, primary_header, "ASSOC", next_index, configuration.tar_category)
                                set_keyword(main_file, primary_header, "ASSOM", next_index, tar_checksum)
                                next_index += 1

                        hdu_list.writeto(os.path.join(output_location, renamed_main_file), checksum=True)
    fix_provenance(old_to_new_prov_names, output_location)
    return output_location


def fix_provenance(old_to_new_prov_names: Dict[Pipefile, Filename], output_location: str):
    for filename in old_to_new_prov_names.values():
        filepath = os.path.join(output_location, filename)
        with fits.open(filepath) as hdu_list:
            primary_header = hdu_list[0].header
            idx = 1
            to_save = False
            while True:
                prov_key = f"PROV{idx}"
                if prov_key not in primary_header:
                    break
                old_prov = primary_header[prov_key]
                new_prov = old_to_new_prov_names.get(old_prov)
                if new_prov:
                    primary_header[prov_key] = new_prov
                    to_save = True
                    logger.info("Fixed %s from %s to %s in %s", prov_key, old_prov, new_prov, filename)
                idx += 1
            if to_save:
                hdu_list.writeto(filepath, checksum=True, overwrite=True)


def rename_panel(main_prefix, panel, report_index):
    _, ext = os.path.splitext(panel.file_name)
    if report_index == 1:
        return f"{main_prefix}{ext}"
    else:
        return f"{main_prefix}_{report_index}{ext}"


def set_keyword(main_file: str, header: Header, keyword: str, index: int, value: str):
    if index > 999:
        raise RuntimeError(f"Exceeded index 999 for {keyword}={value} in {main_file}")
    header[f"{keyword}{index}"] = value


def build_tar(job_dir: str, main_file: str, associated_files: List[FitsFile], reports: List[ReportEntry],
              output_location: str, configuration: Phase3ScienceFileConfiguration,
              renaming_pattern: str, category_renaming_pattern: Dict[str, str]) -> Optional[Tuple[str, str]]:
    renamed_main_file = rename(main_file, renaming_pattern, category_renaming_pattern)
    renamed_main_prefix, _ = os.path.splitext(os.path.basename(renamed_main_file))
    tar_name = f"{renamed_main_prefix}.tar"
    tar_path = os.path.join(output_location, tar_name)
    logger.info("Building tar: %s", tar_path)
    tar = tarfile.open(tar_path, "w")
    non_empty = False
    for assoc_file in associated_files:
        logger.info("Processing assoc file: %s for tar", assoc_file)
        if matches_catg_keywords(assoc_file.name, configuration.tar_fits_categories):
            logger.info("Adding assoc file: %s to tar", assoc_file)
            tar.add(assoc_file.name, rename(assoc_file.name, renaming_pattern, category_renaming_pattern))
            non_empty = True
    report_index = 1
    for report in reports:
        logger.info("Processing report: %s for tar", report.report_name)
        for panel in report.panels:
            if panel.prodcatg in configuration.tar_report_categories:
                logger.info("Adding report panel: %s to tar", panel)
                panel_name = rename_panel(renamed_main_prefix, panel, report_index)
                tar.add(os.path.join(job_dir, report.report_name, panel.file_name), panel_name)
                non_empty = True
                report_index += 1
    tar.close()
    if non_empty:
        return tar_name, file_md5(tar_path)
    else:
        os.remove(tar_path)
        return None


def file_md5(path) -> str:
    with open(path, "rb") as f:
        return hashlib.md5(f.read()).hexdigest()


OutputFiles = Set[FitsFile]
AssociatedFiles = Set[FitsFile]


def get_associated_files(job: JobDetails, repository: JobsRepository,
                         include_science: bool, recursive: bool) -> Tuple[OutputFiles, AssociatedFiles]:
    output_files = {f for f in job.result.output_files if (include_science or not is_science_file(f.name))}
    associated_files = set(job.configuration.associated_files)
    if job.configuration.input_files or not recursive:
        return output_files, associated_files
    else:
        input_jobs = [repository.get_job_details(UUID(job_id)) for job_id in job.configuration.input_job_ids]
        for parent_job in input_jobs:
            outputs, associated = get_associated_files(parent_job, repository, True, True)
            output_files.update(outputs)
            associated_files.update(associated)
        return output_files, associated_files


def verify_leaves(job_ids: List[UUID], repository: JobsRepository):
    """Verify that the provided job IDs are leaves in the job graph."""
    jobs = [repository.get_job_details(job_id) for job_id in job_ids]
    job_ids = set([str(job_id) for job_id in job_ids])
    for job in jobs:
        if job.configuration.input_job_ids:
            intersection = set(job.configuration.input_job_ids).intersection(job_ids)
            if intersection:
                msg = f"Jobs {intersection} are not leaves, they are main parents of {job.configuration.job_id}"
                raise RuntimeError(msg)


def is_science_file(path: str) -> bool:
    prodcatg_value = get_prodcatg(path)
    return prodcatg_value and prodcatg_value.startswith("SCIENCE")


def get_prodcatg(path: str) -> Optional[str]:
    keywords = FitsFileFactory.extract_provided_keywords(path, {PRODCATG})
    return keywords.get(PRODCATG, None)


def get_pipefile(path: str) -> Optional[str]:
    keywords = FitsFileFactory.extract_provided_keywords(path, {PIPEFILE})
    return keywords.get(PIPEFILE, None)


def get_categories(path: str) -> Set[Category]:
    keywords = FitsFileFactory.extract_provided_keywords(path, {PRODCATG, PRO_CATG})
    return {cat for cat in keywords.values() if cat is not None}


def get_recipe_input_files(path: str) -> set[str]:
    with fits.open(path) as hdu_list:
        header = hdu_list[0].header
        pattern = re.compile(r'^.*ESO PRO REC\d+ RAW\d+ NAME$')
        return {
            header[key]
            for key in header
            if pattern.match(key)
        }


def matches_catg_keywords(path: str, configured_values: Optional[List[str]]) -> bool:
    configured_values = set(configured_values) if configured_values else set()
    return any([keyword_value in configured_values for keyword_value in get_categories(path)])


def remove_asso_keywords(primary_header: Header, index: int):
    while True:
        try:
            primary_header.remove("ASSOC" + str(index), ignore_missing=True)
            primary_header.remove("ASSOM" + str(index), ignore_missing=True)
            primary_header.remove("ASSON" + str(index))
            index += 1
        except KeyError:
            break


def should_add_panel(panel: ReportEntryPanel, main_file: str, report_categories: List[str]) -> bool:
    category_to_be_added = panel.prodcatg in report_categories
    file_used_in_report = not panel.input_files or main_file in panel.input_files
    return category_to_be_added and file_used_in_report
