import argparse
import json
import os.path
import subprocess
import sys
import time
from collections import Counter
from collections import defaultdict
from pprint import pprint
from typing import List

from edps import __version__
from edps.client.EDPSClient import EDPSClient
from edps.client.Error import RestError
from edps.client.FitsFile import FitsFile
from edps.client.GraphType import GraphType
from edps.client.ProcessingJobStatus import ProcessingJobStatus
from edps.client.ProcessingRequest import ProcessingRequest, ReportsConfiguration
from edps.client.RequestParameters import RequestParameters
from edps.client.RunReportsRequestDTO import RunReportsRequestDTO, ReportInput
from edps.client.WorkflowDTO import WorkflowDTO
from edps.client.monad import Either
from edps.config.configuration import AppConfig
from edps.scripts.shutdown import wait_until_cant_connect

LINE_UP = '\033[1A'
LINE_CLEAR = '\x1b[2K'

REPORT_TYPES = {
    'all': [ReportInput.RECIPE_INPUTS, ReportInput.RECIPE_OUTPUTS, ReportInput.RECIPE_INPUTS_OUTPUTS],
    'none': [],
    'raw': [ReportInput.RECIPE_INPUTS],
    'reduced': [ReportInput.RECIPE_OUTPUTS, ReportInput.RECIPE_INPUTS_OUTPUTS]
}


# fixme: this method is not used in the script, consider removing it
def wait_for_jobs(client: EDPSClient, job_ids: List[str]) -> None:
    jobs_to_check = job_ids
    print("Waiting for jobs to complete...")
    while True:
        if not jobs_to_check:
            break
        left_to_check = []
        failed = 0
        for job_id in jobs_to_check:
            result = client.get_job_details(job_id)
            if result.is_right():
                if result.get().status != ProcessingJobStatus.COMPLETED:
                    left_to_check.append(job_id)
                if result.get().status == ProcessingJobStatus.FAILED:
                    failed += 1
                print(job_id, result.get().status, result.get().output_files)
            else:
                print(result.get_left())
        jobs_to_check = left_to_check
        if failed == len(left_to_check):
            break
        time.sleep(5)


def monitor_jobs(client: EDPSClient, job_ids: List[str]):
    status_map = {"CREATED": "PENDING"}
    job_status = {job_id: "CREATED" for job_id in job_ids}
    start = time.time()
    elapsed = 0
    message = ""
    while True:
        jobs_to_check = [job_id for job_id, status in job_status.items() if status in ("CREATED", "RUNNING")]
        if len(jobs_to_check) == 0:
            break
        for job_id in jobs_to_check:
            result = client.get_job_status(job_id)
            if result.is_right():
                job_status[job_id] = result.get().status.value
            else:
                print(f'Failed to get job status {job_id}: {result.get_left()}')
                if not is_edps_running(client):
                    print("EDPS server has been shut down, exiting.")
                    return
        elapsed = time.strftime('%Hh%Mm%Ss', time.gmtime(time.time() - start))
        stats = Counter([status_map.get(status, status) for status in job_status.values()])
        failed = [job_id for job_id, status in job_status.items() if status == "FAILED"]
        failed_msg = f"Failed jobs: {', '.join(failed)}" if failed else ""
        message = f"{elapsed} {dict(stats)} {failed_msg}"
        print(message)
        time.sleep(5)
        print(LINE_UP, end=LINE_CLEAR)
    print(message)
    print(f"{len(job_ids)} jobs processed in {elapsed}")
    if "COMPLETED" in job_status.values():
        print("You can find reduced data, logs and quality control plots in the following directories:")
        print("=======================================================================================")
        jobs = [client.get_job_details(job_id).get() for job_id in job_ids]
        sorted_jobs = sorted(jobs, key=lambda job: job.completion_date)
        for job in sorted_jobs:
            if job.output_files:
                print(os.path.dirname(job.output_files[0].name))


def print_response(response):
    if response.is_right():
        print(response.get())


def print_json_response(response):
    if response.is_right():
        print(json.dumps(json.loads(response.get()), indent=2))


def print_targets(response: Either[WorkflowDTO, RestError]):
    if response.is_right():
        workflow = response.get()
        meta_targets = defaultdict(list)
        for task in workflow.tasks:
            for meta_target in task.meta_targets:
                meta_targets[meta_target].append(task.name)
        print(json.dumps(meta_targets, indent=2))


def print_files(files: List[FitsFile]):
    name_categories = defaultdict(list)
    for f in files:
        name_categories[f.name].append(f.category if f.category else 'NONE')
    for name, categories in name_categories.items():
        print(name, ' '.join(categories))


def is_edps_running(client: EDPSClient) -> bool:
    try:
        return client.get_edps_version().is_right()
    except Exception:
        return False


def is_valid_parameter_set(client: EDPSClient, workflow: str, parameter_set: str) -> bool:
    response = client.get_parameter_sets(workflow)
    if response.is_right():
        parameter_sets = [ps["name"] for ps in json.loads(response.get())]
        if parameter_set in parameter_sets:
            return True
        else:
            print(f"Parameter set '{parameter_set}' not found. Available parameter sets: {parameter_sets}")
            return False
    else:
        print(response.get_left())
        return False


def main():
    config = AppConfig()
    if not config.exists():
        config.create()
        sys.exit(0)

    parser = argparse.ArgumentParser(description=f"EDPS client version {__version__}")
    parser.add_argument("-H", "--host", help='default: %(default)s', default="localhost", type=str)
    parser.add_argument("-P", "--port", help='default: %(default)s', default=5000, type=int)
    parser.add_argument("-i", "--inputs", help='input files or directories', nargs='*')
    parser.add_argument("-t", "--targets", help='targets', nargs='*')
    parser.add_argument("-m", "--meta-targets", help='meta-targets', nargs='*')
    parser.add_argument("-w", "--workflow", help='e.g. "espresso.espresso_wkf"')
    parser.add_argument("-c", "--classify", help='classify input files', action='store_true')
    parser.add_argument("-od", "--organize-data", help='run data organization only', action='store_true')
    parser.add_argument("-f", "--flat", help='produce flat organization output', action='store_true')
    parser.add_argument("-g", "--graph", help='print workflow graph in DOT format', action='store_true')
    parser.add_argument("-g2", "--detailed-graph", help='print detailed workflow graph in DOT format',
                        action='store_true')
    parser.add_argument("-a", "--assocmap", help='print association map in MD format', action='store_true')
    parser.add_argument("-r", "--reset", help='reset given workflow', action='store_true')
    parser.add_argument("-x", "--expand-meta-targets", help='expand meta-targets', action='store_true')
    parser.add_argument("-d", "--default-parameters", help='get default parameters for task', metavar='TASK')
    parser.add_argument("-p", "--recipe-parameters", help='get recipe parameters for task', nargs='*',
                        metavar=('TASK', 'PARAMETER_SET'))
    parser.add_argument("-ps", "--parameter-sets", help='get parameter sets', action='store_true')
    parser.add_argument("-wp", "--workflow-param", help='workflow parameter', action='append', nargs=2,
                        metavar=('PARAMETER', 'VALUE'))
    parser.add_argument("-rp", "--recipe-param", help='recipe parameter', action='append', nargs=3,
                        metavar=('TASK', 'PARAMETER', 'VALUE'))
    parser.add_argument("-wps", "--workflow-parameter-set", help='workflow parameter set', default='science_parameters')
    parser.add_argument("-rps", "--recipe-parameter-set", help='recipe parameter set', default='science_parameters')
    parser.add_argument("-lt", "--list-targets", help='print workflow targets and meta-targets', action='store_true')
    parser.add_argument("-lw", "--list-workflows", help='print available workflows', action='store_true')
    parser.add_argument("-o", "--output-dir", help='Specify output directory for products')
    parser.add_argument("-shutdown", "--shutdown", help='Shutdown EDPS server.', action='store_true')
    parser.add_argument("-rt", "--report-type", help='data type for graphical reports (default: %(default)s)',
                        default='reduced', choices=['raw', 'reduced', 'all', 'none'])
    parser.add_argument("-rj", "--report-jobs", help='create graphical reports for given job IDs',
                        nargs='*')

    args = parser.parse_args()
    if not args.targets and not args.meta_targets:
        meta_targets = ['science']
    else:
        meta_targets = args.meta_targets
    if (args.graph or args.detailed_graph or args.assocmap or args.expand_meta_targets or args.reset or
        args.default_parameters or args.recipe_parameters or args.parameter_sets or
        args.list_targets or args.inputs) and not args.workflow:
        print("Required parameter '-w WORKFLOW' not found.")
        parser.print_usage()
        sys.exit(-1)
    client = EDPSClient(args.host, args.port)
    if not is_edps_running(client):
        if args.host != 'localhost':
            print(f"Failed to find EDPS at http://{args.host}:{args.port}")
            sys.exit(-1)
        else:
            print("Local instance of EDPS is not running and will be started in the background.")
            proc = subprocess.Popen(args=["edps-server"],
                                    start_new_session=True,
                                    stdout=subprocess.DEVNULL,
                                    stderr=subprocess.PIPE
                                    )
            content = b''
            while b'Uvicorn running' not in content:
                line = proc.stderr.readline()
                if not line:
                    print("Failed to start EDPS: \n" + content.decode('utf-8'))
                    sys.exit(-1)
                content += line
    response = None
    if args.graph:
        response = client.get_workflow_graph(args.workflow, graph_type=GraphType.SIMPLE)
        print_response(response)
    elif args.detailed_graph:
        response = client.get_workflow_graph(args.workflow, graph_type=GraphType.DETAILED)
        print_response(response)
    elif args.assocmap:
        response = client.get_assoc_map(args.workflow)
        print_response(response)
    elif args.expand_meta_targets:
        response = client.get_targets(args.workflow, args.targets, meta_targets)
        print_response(response)
    elif args.reset:
        response = client.reset_workflow(args.workflow)
        print_response(response)
    elif args.default_parameters:
        response = client.get_default_params(args.workflow, args.default_parameters)
        print_json_response(response)
    elif args.recipe_parameters:
        task = args.recipe_parameters[0]
        parameter_set = args.recipe_parameters[1] if len(args.recipe_parameters) > 1 else None
        response = client.get_recipe_params(args.workflow, task, parameter_set)
        print_json_response(response)
    elif args.parameter_sets:
        response = client.get_parameter_sets(args.workflow)
        print_json_response(response)
    elif args.list_targets:
        response = client.get_workflow(args.workflow)
        print_targets(response)
    elif args.list_workflows:
        response = client.list_workflows()
        print_response(response)
    elif args.report_jobs:
        request = RunReportsRequestDTO(job_ids=args.report_jobs, report_types=REPORT_TYPES[args.report_type])
        response = client.run_reports(request)
        print_response(response)
    elif args.inputs:
        if not (is_valid_parameter_set(client, args.workflow, args.workflow_parameter_set) and
                is_valid_parameter_set(client, args.workflow, args.recipe_parameter_set)):
            sys.exit(1)
        workflow_parameters = {name: value for name, value in args.workflow_param} if args.workflow_param else {}
        recipe_parameters = {}
        for task, name, value in (args.recipe_param or {}):
            task_recipe_params = recipe_parameters.get(task, {})
            task_recipe_params[name] = value
            recipe_parameters[task] = task_recipe_params
        targets = args.targets or []
        meta_targets = meta_targets or []
        package_base_dir = args.output_dir
        reports_config = ReportsConfiguration(task_names=['ALL'], report_types=REPORT_TYPES[args.report_type])
        parameters = RequestParameters(workflow_parameter_set=args.workflow_parameter_set,
                                       workflow_parameters=workflow_parameters,
                                       recipe_parameter_set=args.recipe_parameter_set,
                                       recipe_parameters=recipe_parameters)
        request = ProcessingRequest(inputs=args.inputs, targets=targets, meta_targets=meta_targets,
                                    workflow=args.workflow, parameters=parameters,
                                    package_base_dir=package_base_dir, reports_configuration=reports_config)
        if args.organize_data:
            response = client.submit_to_organise(request)
            if response.is_right():
                jobs = response.get()
                tmp = "[" + ",".join([job.model_dump_json() for job in jobs]) + "]"
                print(json.dumps(json.loads(tmp), indent=2))
        elif args.flat:
            response = client.submit_to_organise_flat(request)
            if response.is_right():
                print(json.dumps(json.loads(response.get().model_dump_json()), indent=2))
        elif args.classify:
            response = client.submit_to_classify(request)
            if response.is_right():
                print_files(response.get())
        else:
            response = client.submit_processing_request(request)
            if response.is_right():
                processing_response = response.get()
                summary = defaultdict(list)
                for job in processing_response.jobs:
                    summary[job.task_name].append(f"{client.endpoint}/jobs/{job.job_id}")
                print("The following jobs have been scheduled for processing.")
                print("You can use the URL associated to each job to retrieve detailed information")
                print("about the job using an http client (curl) or an internet browser (firefox).")
                print("===========================================================================")
                pprint(dict(summary))
                monitor_jobs(client, [job.job_id for job in processing_response.jobs])
    elif not args.shutdown:
        parser.print_help()

    if response and response.is_left():
        print(f"Request failed with: {response.get_left()}")
        sys.exit(1)

    if args.shutdown and is_edps_running(client):
        print("Shutting down EDPS...")
        client.shutdown_edps()
        wait_until_cant_connect(client)


if __name__ == '__main__':
    main()
