from collections import defaultdict
from functools import lru_cache
from multiprocessing import Pool
from pathlib import Path
from typing import Dict, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
from cython_fortran_file import FortranFile as FF
from pynbody import units
from pynbody.array import SimArray
from scipy.stats import binned_statistic_dd
from tqdm.auto import tqdm

boltzmann_constant_cgs = (
    1.3806488e-16 * units.cm ** 2 * units.g / units.K / units.s ** 2
)
mass_hydrogen_cgs = 1.67373522e-24 * units.g

units = dict(
    position_x=("unit_l", "cm"),
    position_y=("unit_l", "cm"),
    position_z=("unit_l", "cm"),
    velocity_x=("unit_v", "cm s**-1"),
    velocity_y=("unit_v", "cm s**-1"),
    velocity_z=("unit_v", "cm s**-1"),
    density=("unit_d", "g cm**-3"),
    pressure=("unit_pressure", "g cm**-1 s**-2"),
    potential=("unit_ener", "erg"),
)


@lru_cache
def read_particle_file_descriptor(fname: Path) -> dict:
    with open(fname, "r") as f:
        version, _columns, *lines = (line.strip() for line in f.readlines())
        assert version == "# version:  1"

    fmt = {}
    for line in lines:
        _ivar, name, dtype = (_.strip() for _ in line.split(","))
        fmt[name] = dtype

    return fmt


@lru_cache
def read_info_file(iout: int, folder: Union[str, Path] = "movie1") -> Dict[str, float]:
    info_lines = (Path(folder) / f"info_{iout:05d}.txt").read_text().split("\n")

    info_data = {}
    for line in info_lines:
        # Stop after reading "ordering type"
        if line.startswith("ordering type"):
            break

        if line.strip() == "":
            continue
        key, rhs = (_.strip() for _ in line.split("="))
        rhs = float(rhs)

        info_data[key] = rhs

    info_data["unit_v"] = info_data["unit_l"] / info_data["unit_t"]
    info_data["unit_m"] = info_data["unit_d"] * info_data["unit_l"] ** 3
    info_data["unit_pressure"] = (
        info_data["unit_m"] / info_data["unit_l"] / info_data["unit_t"] ** 2
    )
    info_data["unit_ener"] = (
        info_data["unit_m"] * info_data["unit_l"] ** 2 / info_data["unit_t"] ** 2
    )

    return info_data


def read_particle_file(fname: Union[str, Path], imov: int) -> Tuple[dict, dict]:
    """Read a single frame from a particle file.

    Parameters
    ----------
    fname : str or Path
        The path to the particle file to read
    imov : int
        The movie frame to read

    Returns
    -------
    data : dict[str, np.ndarray]
        The particle data. The keys are the same as in `particle_file_descriptor.txt`
    info_data : dict
        Information about the simulation at the corresponding time, as parsed from the
        info_XXXXX.txt files.

    Notes
    -----
    If there are multiple entries matching the requested `imov`, this function returns the latest.
    """

    parent = Path(fname).parent

    fmt = read_particle_file_descriptor(parent / "particle_file_descriptor.txt")
    info_data = read_info_file(imov, parent)

    data = {}
    this_imov = -1
    with FF(str(fname)) as f:
        f.seek(0, 2)
        pos_end = f.tell()
        f.seek(0, 0)
        while f.tell() < pos_end:
            this_imov, _Npart = f.read_vector("i")
            if this_imov != imov:
                f.skip(11)
                continue

            for name, dtype in fmt.items():
                units_key = units.get(name, (None, None))[0]
                in_cgs = info_data.get(units_key, 1)
                tmp = f.read_vector(dtype) * in_cgs
                data[name] = tmp

    familytag = data.pop("familytag")

    # From pm/pm_commons.f90:int2part
    data["family"] = familytag // 256 - 128
    data["tag"] = familytag % 256 - 128

    return data, info_data


def _read_particle_single_arg(args):
    fname, imov = args
    return read_particle_file(fname, imov)


def read_particle(imov: int, folder: Union[str, Path] = "movie1", sort_ids=False) -> dict:
    """Read a single frame from all particle files.

    Parameters
    ----------
    imov : int
        The movie frame to read
    folder : str or Path, optional
        Path to the folder containing the particle files to read
    sort_ids : boolean
        If True, sort the particle data by their identity

    Returns
    -------
    data : dict[str, np.ndarray]
        The particle data. The keys are the same as in `particle_file_descriptor.txt`
    """
    part_files = sorted(Path(folder).glob("particle.bin?????"))

    all_dt = defaultdict(list)

    try:
        progress = tqdm(desc=f"Reading part. files #{imov}", total=len(part_files))
        with Pool() as p:
            for i, (dt, info_data) in enumerate(
                p.imap_unordered(
                    _read_particle_single_arg, [(fname, imov) for fname in part_files]
                )
            ):
                progress.update()
                for k, v in dt.items():
                    all_dt[k].append(v)
    finally:
        progress.close()

    for k, v in all_dt.items():
        u = units.get(k, (None, "1"))[1]
        val = SimArray(np.concatenate(v), u)
        all_dt[k] = val

    if sort_ids:
        order = np.argsort(all_dt["identity"])
        for k, v in all_dt.items():
            all_dt[k] = v[order]

    all_dt["properties"] = info_data

    l = info_data["boxlen"] * info_data["unit_l"]
    for k in "xyz":
        all_dt[f"position_{k}"] = all_dt[f"position_{k}"].in_units(f"{l} cm")

    all_dt["temperature"] = (
        all_dt["pressure"] / all_dt["density"] * mass_hydrogen_cgs / boltzmann_constant_cgs
    ).in_units("K")

    return all_dt
