import cProfile
import json
import logging
import os
import subprocess
import time
from datetime import datetime, timezone
from pathlib import Path
from pstats import SortKey
from typing import Dict, TypeVar, Callable, List, Any

from pydantic import BaseModel
from starlette.responses import JSONResponse

from edps.generator.constants import MJD_UNIX_EPOCH
from edps.metrics.meter_registry import MeterRegistry

T = TypeVar('T')


def as_json(obj) -> str:
    return json.dumps(obj, default=lambda x: x.__dict__, indent=2)


def merge_dicts(*args: Dict) -> Dict:
    """
    Given any number of dictionaries, shallow copy and merge into a new dict,
    precedence goes to key-value pairs in latter dictionaries.
    """
    result = {}
    for dictionary in args:
        result.update(dictionary)
    return result


def log_time() -> Callable[[Callable], T]:
    def decorator(func: Callable) -> Callable[[Any], T]:
        def wrapper(*args, **kwargs) -> T:
            logger = logging.getLogger("LogTime")
            level = logging.INFO
            result, _ = measure_time(lambda: func(*args, **kwargs), str(func), logger, level, args, kwargs)
            return result

        return wrapper

    return decorator


def timer(name: str, logger: logging.Logger = logging.getLogger("MetricsTimer"), level: int = logging.DEBUG) -> Callable[[Callable], T]:
    def decorator(func: Callable) -> Callable[[Any], T]:
        def wrapper(*args, **kwargs) -> T:
            result, ms_time = measure_time(lambda: func(*args, **kwargs), name, logger, level, args, kwargs)
            MeterRegistry.instance.get_meter(name).record(ms_time)
            return result

        return wrapper

    return decorator


def measure_time(action: Callable[[], T], name: str, logger: logging.Logger, level: int, args, kwargs) -> T:
    start_time = time.time()
    result = action()
    ms_time = (time.time() - start_time) * 1000
    formatted_process_time = '{0:.2f}'.format(ms_time)
    logger.log(level, f"{name} called with {args},{kwargs} completed_in {formatted_process_time}ms")
    return result, ms_time


def profiler(func: Callable) -> Callable[[], T]:
    def inner(*args, **kwargs) -> Callable[[], T]:
        return flat_profile(lambda: func(*args, **kwargs))

    return inner


def flat_profile(action: Callable[[], T]) -> T:
    with cProfile.Profile() as pr:
        result = action()
    pr.print_stats(SortKey.TIME)
    return result


class JSONResponseWithoutValidation(JSONResponse):
    def render(self, data) -> bytes:
        if isinstance(data, BaseModel):
            response = JSONResponse(content=data.model_dump())
        elif isinstance(data, List):
            response = JSONResponse(content=[obj.dict() for obj in data])
        else:
            response = JSONResponse(content=data)
        return response.body


def mjd_to_datetime_string(mjd: float, fmt: str) -> str:
    datetime_obj = datetime.fromtimestamp((mjd - MJD_UNIX_EPOCH) * 86400, tz=timezone.utc)
    return datetime_obj.strftime(fmt)


def is_lambda(fun: Callable) -> bool:
    return isinstance(fun, type(lambda: 0)) and fun.__name__ == (lambda: 0).__name__


def create_command_script(command: str, dirname: str, env: Dict[str, str]):
    env_for_script = " \\\n".join([f"'{key}'='{value}'" for key, value in env.items()])
    script = f"""\
#!/bin/sh
env -i \\
{env_for_script} \\
{command}
"""
    filepath = Path(dirname) / "cmdline.sh"
    try:
        with open(filepath, "w") as script_file:
            script_file.write(script)
        filepath.chmod(0o755)
    except Exception as e:
        logging.error(f"Error writing command script: {e}")


def run_command(command_parts: List[str], working_dir: str, stdout, stderr, env=None) -> int:
    command = " ".join(command_parts)
    env = env or dict(os.environ)
    create_command_script(command, working_dir, env)
    return subprocess.call(command, cwd=working_dir, stdout=stdout, stderr=stderr, shell=True, env=env)
