#!/usr/bin/env python

import numpy
import os
from scipy.io import FortranFile
from argparse import ArgumentParser
import subprocess
import matplotlib
# try to use agg backend as it allows to render movies without X11 connection
try:
    matplotlib.use('agg')
except ImportError:
    pass
from matplotlib import pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.collections import PatchCollection
from scipy import signal
import multiprocessing as mp
import time
import f90nml
import numpy as np

from numpy.polynomial.polynomial import polyfit


def a2z(scalefactor):
    z = 1. / scalefactor - 1.
    return z


def label(xy, text):
    y = xy[1] + 15 # shift y-value for label so that it's below the artist
    plt.text(xy[0], y, text, ha="center",  size=14, color='white')


def load_map(args, k, i, mapkind=None):
    if mapkind is None:
        kind = [item for item in args.kind.split(' ')]
        # define map path
        map_file = "%s/movie%d/%s_%05d.map" % (args.dir, int(args.proj), kind[k], i)
    else:
        map_file = "%s/movie%d/%s_%05d.map" % (args.dir, int(args.proj), mapkind, i)

    # read image data
    f = FortranFile(map_file)
    [t, dx, dy, dz] = f.read_reals(np.float64)
    [nx, ny] = f.read_ints(np.int32)
    dat = f.read_reals(np.float32)
    f.close()

    return dat


def load_namelist_info(args):

    if args.namelist == '':
        namelist = args.dir + '/output_00002/namelist.txt'
    else:
        # non-default namelist
        namelist = args.namelist

    nmlf = f90nml.read(namelist)

    cosmo = nmlf['RUN_PARAMS']['cosmo']
    try:
        boxlen = nmlf['AMR_PARAMS']['boxlen']
    except KeyError:
        if cosmo:
            boxlen = 1.
        else:
            raise KeyError("Box length not found in namelist for isolated run. Check your namelist.")

    try:
        proj_axis = nmlf['MOVIE_PARAMS']['proj_axis']
        nx = nmlf['MOVIE_PARAMS']['nw_frame']
        ny = nmlf['MOVIE_PARAMS']['nh_frame']
        max_iter = nmlf['MOVIE_PARAMS']['imovout']
    except KeyError:
        raise KeyError("Frame sizes in pixels (e.g. nw_frame), number of frames (imovout) or projections (proj_axis)  "
                       "could not be read from RAMSES namelist. They are necessary to generate images and movies "
                       "from frames. Check your namelist and its path.")

    return boxlen, proj_axis, nx, ny, max_iter, cosmo


def load_units(i, args):
    if type(args.proj) == str:
        proj_list = [int(item) for item in args.proj.split(' ')]
    else:
        proj_list = [args.proj]

    infof = open('{dir}/movie{proj}/info_{num:05d}.txt'.format(dir=args.dir, proj=proj_list[0], num=i))
    for j, line in enumerate(infof):
        if j == 15:
            unit_l = float(line.split()[2])
        if j == 16:
            unit_d = float(line.split()[2])
        if j == 17:
            unit_t = float(line.split()[2])
        if j > 18:
            break
    unit_m = unit_d * unit_l ** 3 / 1.9891000e33  # in MSun, matching amr/constants

    return unit_l, unit_d, unit_t, unit_m


def make_image(i, args, proj_list, proj_axis, nx, ny, boxlen, kind, cosmo, scale_l, cmin, cmax):

    fig = plt.figure(frameon=False)
    fig.set_size_inches(nx / 100, ny / 100)

    unit_l, unit_d, unit_t, unit_m = load_units(i, args)  # need to load units here due to cosmo

    for p in range(len(proj_list)):
        args.proj = proj_list[p]
        dat = load_map(args, p, i)

        infof = open("%s/movie%d/info_%05d.txt" % (args.dir, int(args.proj), i))
        for j, line in enumerate(infof):
            if cosmo:
                if j == 9:
                    scalefactor = float(line.split()[2])
                    redshift = a2z(scalefactor)
            else:
                if j == 8:
                    t = float(line.split()[2])
            if j > 9:
                break

        if kind[p] == 'dens':
            dat *= unit_d    # in g/cc
        if kind[p] in ["vx", "vy", "vz"]:
            dat *= (unit_l/unit_t)/1e5   # in km/s

        if args.outfile is None:
            outfile = "%s/pngs/%s_%05d.png" % (args.dir, kind[p], i/int(args.step)-int(args.fmin))
        else:
            outfile = args.outfile

        if args.logscale:
            dat = numpy.array(dat)
            if kind[p] == 'stars' or kind[p] == 'dm' :
                dat += 1e-12
        # Reshape data to 2d
        dat = dat.reshape(ny, nx)

        if kind[p] == 'stars' or kind[p] == 'dm':  # PSF convolution
            kernel = numpy.outer(signal.gaussian(100, 1), signal.gaussian(100, 1))
            dat = signal.fftconvolve(dat, kernel, mode='same')

        rawmin = numpy.amin(dat)
        rawmax = numpy.amax(dat)

        # Bounds
        if args.min is None:
            plotmin = rawmin
        else:
            plotmin = float(args.min)

        if args.max is None:
            plotmax = rawmax
        else:
            plotmax = float(args.max)

        # Log scale?
        if args.logscale and kind[p] not in ["vx", "vy", "vz"]:  # never logscale for velocities
            dat = numpy.log10(dat)
            rawmin = numpy.log10(rawmin)
            rawmax = numpy.log10(rawmax)
            plotmin = numpy.log10(plotmin)
            plotmax = numpy.log10(plotmax)

        # Auto-adjust dynamic range?
        if args.autorange:
            # Overrides any provided bounds
            NBINS = 200
            # Compute histogram
            (hist, bins) = numpy.histogram(dat, NBINS, (rawmin, rawmax), density=True)
            chist = numpy.cumsum(hist)
            chist = chist / numpy.amax(chist)
            # Compute black and white point
            clip_k = chist.searchsorted(0.15)
            plotmin = bins[clip_k]
            plotmax = rawmax

        if args.poly > 0:
            p_min = 0.
            p_max = 0.
            for d in range(args.poly+1):
                p_min += cmin[p][d]*i**d
                p_max += cmax[p][d]*i**d

            plotmin = p_min
            plotmax = p_max

        if kind[p] in ["vx", "vy", "vz"]:
            plotmax = max(abs(rawmin), rawmax)
            plotmin = -plotmax

        # Plotting

        ax = fig.add_subplot(1, 1, p+1)
        ax.axis([0, nx, 0, ny])
        fig.add_axes(ax)

        cmap = args.cmap_str
        im = ax.imshow(dat, interpolation='nearest', cmap=cmap,
                       vmin=plotmin, vmax=plotmax, aspect='auto')
        ax.tick_params(bottom='off', top='off', left='off', right='off')   # removes ticks
        ax.tick_params(labelbottom='off', labeltop='off', labelleft='off', labelright='off')   # removes ticks

        labels_color = 'w'

        if args.colorbar:
            cbar = plt.colorbar(im, orientation='horizontal', pad=0.01, shrink=0.98, aspect=50)
            cbar.solids.set_rasterized(True)
            cbar.outline.set_linewidth(3)
            bar_font_color = 'k'
            cbar.ax.tick_params(labelcolor=bar_font_color, labelsize=20, width=3, direction='in', length=8)

        if not args.clean_plot:
            patches = []

            # Add label to top right corner depending on which variable is plotted
            kind_string = " "
            if kind[p] == 'dens':
                kind_string = "Gas density"
            if kind[p] == 'temp':
                kind_string = "Gas temperature"
            if kind[p] == 'dm':
                kind_string = "Dark matter"
            if kind[p] == 'star':
                kind_string = "Stars"
            if kind[p] == 'var8':
                kind_string = "Oxygen metallicity"
            ax.text(0.85, 0.95, kind_string, verticalalignment='bottom', horizontalalignment='left',
                    transform=ax.transAxes, color=labels_color, fontsize=20)

            if cosmo:
                ax.text(0.05, 0.95, '$z={redshift:.2f}$'.format(redshift=redshift),
                        verticalalignment='bottom', horizontalalignment='left',
                        transform=ax.transAxes,
                        color=labels_color, fontsize=20)
            else:
                t *= unit_t/86400/365.25  # time in years
                if 1e3 <= t < 1e6:
                    scale_t = 1e3
                    t_unit = 'kyr'
                elif 1e6 <= t < 1e9:
                    scale_t = 1e6
                    t_unit = 'Myr'
                elif t > 1e9:
                    scale_t = 1e9
                    t_unit = 'Gyr'
                else:
                    scale_t = 1
                    t_unit = 'yr'

                ax.text(0.05, 0.95, '$%.1f$ $%s$' % (t/scale_t, t_unit),
                        verticalalignment='bottom', horizontalalignment='left',
                        transform=ax.transAxes,
                        color=labels_color, fontsize=18)

            collection = PatchCollection(patches, facecolor=labels_color)
            ax.add_collection(collection)

    # corrects window extent
    plt.subplots_adjust(left=0., bottom=0., right=1., top=1., wspace=0., hspace=0.)
    plt.savefig(outfile, dpi=100)
    plt.close(fig)

    return


def fit_min_max(args, p, max_iter):
    mins = numpy.array([])
    maxs = numpy.array([])

    kind = [item for item in args.kind.split(' ')]

    for i in range(int(args.fmin)+int(args.step), max_iter+1, int(args.step)):
        dat = load_map(args, p, i)
        unit_l, unit_d, unit_t, unit_m = load_units(i, args)

        if kind[p] == 'dens':
            dat *= unit_d	# in g/cc
        if kind[p] in ['vx', 'vy', 'vz']:
            dat *= (unit_l/unit_t)/1e5   # in km/s
        if kind[p] in ['stars', 'dm']:
            dat += 1e-12

        if args.logscale:
            mins = numpy.append(mins, numpy.log10(numpy.amin(dat)))
            maxs = numpy.append(maxs, numpy.log10(numpy.amax(dat)))
        else:
            mins = numpy.append(mins, numpy.amin(dat))
            maxs = numpy.append(maxs, numpy.amax(dat))

    ii = range(int(args.fmin)+int(args.step), max_iter+1, int(args.step))
    cmin = polyfit(ii, mins, args.poly)
    cmax = polyfit(ii, maxs, args.poly)

    return p, cmin, cmax


def main():

    # Parse command line arguments
    parser = ArgumentParser(description="Script to create RAMSES movies")
    parser.add_argument("-l", "--logscale", action="store_true", default=False,
                        help="use log color scaling [%(default)s]")
    parser.add_argument("-m", "--min", metavar="VALUE",
                        help='min value', default=None)
    parser.add_argument("-M", "--max", metavar="VALUE",
                        help='max value', default=None)
    parser.add_argument("-f", "--fmin", metavar="VALUE",
                        help='frame min value [%(default)d]', default=0, type=int)
    parser.add_argument("-F", "--fmax", metavar="VALUE",
                        help='frame max value [%(default)d]', default=-1, type=int)
    parser.add_argument("-d", "--dir",
                        help='map directory [current working dir]', default=os.environ['PWD'], metavar="VALUE")
    parser.add_argument("-p", "--proj", default='1', type=str,
                        help="projection index [%(default)s]")
    parser.add_argument("-s", "--step",
                        help="framing step [%(default)d]", default=1, type=int)
    parser.add_argument('-k', '--kind',
                        help="kind of plot [%(default)s]", default='dens')
    parser.add_argument('-a', '--autorange', action='store_true',
                        help='use automatic dynamic range (overrides min & max) [%(default)s]', default=False)
    parser.add_argument('--clean_plot', action='store_true',
                        help='do not annotate plot with bar and timestamp [%(default)s]', default=False)
    parser.add_argument('-c', '--colormap', dest="cmap_str", metavar='CMAP',
                        help='matplotlib color map to use [%(default)s]', default="bone")
    parser.add_argument('-b', '--barlen', metavar='VALUE',
                        help='length of the bar (specify unit!) [%(default)d]', default=5, type=float)
    parser.add_argument('-B', '--barlen_unit', metavar='VALUE',
                        help='unit of the bar length (AU/pc/kpc/Mpc) [%(default)s]', default='kpc')
    parser.add_argument("-o", "--output", dest="outfile", metavar="FILE",
                        help='output image file [<map_file>.png]', default=None)
    parser.add_argument('--nocolorbar', dest='colorbar', action='store_false',
                        help='add colorbar [%(default)s]', default=True)
    parser.add_argument('-n', '--ncpu', metavar="VALUE", type=int,
                        help='number of CPUs for multiprocessing [%(default)d]', default=1)
    parser.add_argument('-P', '--poly', metavar="VALUE", type=int,
                        help='polynomial degree for fitting min and max [off]', default=-1)
    parser.add_argument('-N', '--namelist', metavar="VALUE", type=str,
                        help='path to namelist, if empty take default [%(default)s]', default='')

    args = parser.parse_args()

    proj_list = [int(item) for item in args.proj.split(' ')]
    kind = [item for item in args.kind.split(' ')]

    if args.barlen_unit == 'pc':
        scale_l = 1e0
    elif args.barlen_unit == 'kpc':
        scale_l = 1e3
    elif args.barlen_unit == 'Mpc':
        scale_l = 1e6
    elif args.barlen_unit == 'AU':
        scale_l = 1./206264.806
    else:
        raise RuntimeError('Wrong length unit!')

    # load basic info once, instead of at each loop
    boxlen, proj_axis, nx, ny, max_iter, cosmo = load_namelist_info(args)

    if int(args.fmax) > 0:
        max_iter = int(args.fmax)
    else:
        from glob import glob
        args.fmin = min([filter(lambda x: x.isdigit(), y.split('/')[-1]) for y in glob('%s/movie1/info_*.txt' % args.dir)])
        max_iter = int(max([filter(lambda x: x.isdigit(), y.split('/')[-1]) for y in glob('%s/movie1/info_*.txt' % args.dir)]))

    # Progressbar imports
    try:
        from widgets import Percentage, Bar, ETA
        from progressbar import ProgressBar
        progressbar_avail = True
    except ImportError:
        progressbar_avail = False

    # for each projection fit mins and maxs with polynomial
    cmins = numpy.zeros(len(proj_list)*(args.poly+1)).reshape(len(proj_list),args.poly+1)
    cmaxs = numpy.zeros(len(proj_list)*(args.poly+1)).reshape(len(proj_list),args.poly+1)

    if args.poly > 0:
        if args.ncpu > 1:
            results = []
            pool = mp.Pool(processes=min(args.ncpu, len(proj_list)))

            results = [pool.apply_async(fit_min_max, args=(args, p, max_iter,)) for p in range(len(proj_list))]
            pool.close()
            pool.join()
            output = [p.get() for p in results]

            for p in range(len(output)):  # just for safety if executed not in order
                for d in range(len(proj_list)):
                    if output[p][0] == d:
                        cmins[d] = output[p][1]
                        cmaxs[d] = output[p][2]

        elif args.ncpu == 1:
            if progressbar_avail:
                widgets = ['Working...', Percentage(), Bar(marker='='), ETA()]
                pbar = ProgressBar(widgets=widgets, maxval=len(proj_list)).start()
            else:
                print('Working!')

            for d in range(len(proj_list)):
                _, cmins[d], cmaxs[d] = fit_min_max(args, d, max_iter)
                if progressbar_avail:
                    pbar.update(d+1)

            if progressbar_avail:
                pbar.finish()

        else:
            raise RuntimeError('Wrong number of CPUs! Exiting!')

        print('Polynomial coefficients fitted!')

    # creating images
    if progressbar_avail:
        widgets = ['Working...', Percentage(), Bar(marker='#'),ETA()]
        pbar = ProgressBar(widgets=widgets, maxval=max_iter+1).start()
    else:
        print('Working!')

    if not os.path.exists("%s/pngs/" % args.dir):
        os.makedirs("%s/pngs/" % args.dir)

    if args.ncpu > 1:
        results = []
        pool = mp.Pool(processes=args.ncpu)
        for i in range(int(args.fmin)+int(args.step), max_iter+1, int(args.step)):
            results.append(pool.apply_async(make_image, args=(i, args, proj_list, proj_axis, nx, ny, boxlen, kind, cosmo, scale_l, cmins, cmaxs,)))
        while True:
            inc_count = sum(1 for x in results if not x.ready())
            if inc_count == 0:
                break

            if progressbar_avail:
                pbar.update(max_iter+1-inc_count)
            time.sleep(.1)

        pool.close()
        pool.join()

    elif args.ncpu == 1:
        for i in range(int(args.fmin)+int(args.step), max_iter+1, int(args.step)):
            make_image(i, args, proj_list, proj_axis, nx, ny, boxlen, kind, cosmo, scale_l, cmins, cmaxs)
            if progressbar_avail:
                pbar.update(i)

    else:
        raise RuntimeError('Wrong number of CPUs! Exiting!')

    if progressbar_avail:
        pbar.finish()

    # movie name for montage
    frame = "{dir}/pngs/{kind}_%05d.png".format(dir=args.dir,kind=args.kind)
    mov = "{dir}/{kind}{proj}.mp4".format(dir=args.dir, kind=args.kind, proj=args.proj)

    print('Calling ffmpeg!')
    subprocess.call("ffmpeg -loglevel quiet -i {input} -y -vcodec h264 -pix_fmt yuv420p  -r 25 -qp 15 {output}".format(input=frame, output=mov), shell=True)
    print('Movie created! Cleaning up!')
    subprocess.call("rm {dir}/pngs -r".format(dir=args.dir), shell=True)
    subprocess.call("chmod a+r {mov}".format(mov=mov), shell=True)


if __name__ == '__main__':
    main()
