#!/usr/bin/python2

#===============================================================================
errormsg="""
-----------------------------
     MERGERTREEPLOT.PY
-----------------------------

This python script creates a plot for the merger tree of a given clump/halo
as the root. By default, this script will find the directory where z = 0 and
looks for the clump/halo in that output_XXXXX directory.
You need to call this script from the directory where the output_XXXXX
directories are stored.



-------------
 Usage:
-------------
     treeplot.py [-options] <halo-ID>

     for fastest version, but possibly really bad image, don't use any option
     flags.
     If compiling the .tex file fails due to memory restrictions, try adding
     "\RequirePackage{luatex85}" as the first line and compile with lualatex.

     If you're trying to plot a really big tree, the resulting image might not
     be very good and require some manual tinkering in the plotting settings.
     In this case, I recommend using the -wb and -fb flags to write and read
     the tree, once it has been created, to speed up things (See debugging
     options for details).



-------------
 Options:
-------------

    -g, --galaxy            if creating particle plots (-pp flag), also plot
                            the current galaxies in the clumps as small stars
                            Galaxies with host clumps will be the same colour
                            as the clump, orphans will be black.

    -h, --help              print this help message

    -lc, --label-clumps     if creating particle plots (-pp flag), label the
                            clumps you plotted in the scatterplot

    -m, --movie             if creating particle plots (-pp flag), create
                            particle plots for a 'movie': Plot only x-y plane
                            instead of all 3 planes, annotate plot with only
                            snapshot number, use matplotlib.scatter for nicer
                            plot instead of mpl_scatter_density (this makes the
                            plotting slower).
                            If you want to disable axis labelling, you can do
                            so manually in the _tweak_particleplot(...)
                            function (look for "TODO" comment).
                            If you want to have a fixed axis rane, you can do
                            so manually in the _plot_particles(...) function
                            (look for "TODO" comment).

    -nc, --no-colors        don't plot lines in different colors, only thin
                            black lines.

    -njd,                   Don't draw dashed lines for jumpers that signify
    --no-jumper-diagonals   where the clumps appear to have merged into

    -nl, --no-labels        don't label the clumps in the tree plot
                            (these are not the same labels addressed with the
                            -lc flag)

    -pdf, --pdf             save image as a pdf file instead of a png file.
                            This file is not generated by by LaTeX.
                            If you only want a TeX file without compilation,
                            use -T or --tex instead.

    -p, --pretty            create an image of the tree that gives little
                            information, but looks pretty. This flag simply
                            activates the following flags:
                                --no-labels
                                --no-jumper-diagonals
                                --tex
                                --no-colors

    -pp, --plotparticles    also create a plot of particles at each output,
                            marking the clumps in the tree with their colour
                             WARNING:    Needs the unbinding_XXXXX.outYYYYY
                             WARNING:    files and the mpl-scatter-density
                             WARNING:    library for this function.
                             WARNING:    The current implementation is not
                             WARNING:    suitable to create plots of large
                             WARNING:    simulations.

    -s <output_nr>          don't start at output_XXXXX directory where z=0,
    --start_at <output_nr>  but at given output_<output_nr>.
                            <output_nr> does not need to be formatted, e.g.
                            5 is good enough, no need to type 00005.
                            Naturally, the <halo-ID> must be a halo/clump
                            in this output directory.

    -t, --use-time          use time instead of redshift for y label.
                            For  non-cosmo runs, you NEED to use this, unless
                            you want z=0 everywhere. (In case you forgot,
                            you will be warned.)
                            if you used the -pp flag, not using the -t flag
                            will be interpreted as plotting for a cosmo run
                            (=> periodic boundary conditions are assumed)

    -T, --tex               save image as a tikz file for LaTeX instead of a
                            png image. Image will be saved as .tex instead
                            of .png, the filename will be the same.
                             WARNING: Needs the matplotlib2tikz library for
                             WARNING: this function

    -v, --verbose           verbose: print more details of what you're doing





-----------------------
 Debugging Options:
-----------------------
In case you're trying to debug something or just tweak around plot parameters
to get a nicer image, some debug options are available:

    -wb, --write-backup     After reading in mergertree data and generating the
                            tree, dump the data using pickle to file so you can
                            skip that step next time you're launching the
                            script with the -fb flag.
                            This only stores the tree data, not any plotting
                            preferences you set, so you safely may change them
                            when reading the tree from this dump with -fb.

    -fb, --from-backup      instead of reading in data and recreating the tree,
                            read in backup files generated with the -wb flag.
                            This only loads the tree data, not any plotting
                            preferences you set, so you safely may change them
                            when reading the tree from this dump.

    -v, --verbose           verbose: print more details of what you're doing





---------------
More details
---------------

Each drawn branch will have a color pair: an inner and an outer color. This
is to help identifying them uniquely. If you nevertheless need/want to add
more colors, add them in class global_params.colorlist.
If a drawn clump doesn't have a direct progenitor, this will be signified by
a progenitor with clump ID 0. It's supposed to signify that the clump has
formed in the time between the last and this output. All other nodes will be
labelled by the drawn clump ID at that output step by default.
If a clump re-emerges more than one timestep later, there will be some clump
ID's missing in the branch as it skips a few output steps. Provided the
descenant that it appears to have been merged into and later re-emerges from
is in the tree, this will be symbolised by a dashed line. But it is possible
that such a descendant is not in this tree, so in some cases, there may be
gaps without the corresponding dashed lines.

By default, this script creates an image called
 "merger_tree_halo_<root_output_dir_nr>_<halo_nr>.png"

If you chose the --plot-particles flag, instead it will first make a directory
called
    "merger_tree_halo_<root_output_dir_nr>_<halo_nr>/
and save
    "particleplot_XXXXX.png"
images inside the directory.

If you furthermore also chose the --label-clumps flag, instead it will create
    "particleplot-with_labels_XXXXX.png"
images inside the directory, so the others won't be overwritten.



---------------
Requirements
---------------

This script was tested with python 2.7.12, matplotlib 1.5.1, numpy 1.14.2,
mpl_scatter_density 0.3, matplotlib2tikz 0.6.16 and fortranfile.py as it
is in the ramses repository as of 03/2018. mpl_scatter_density and
fortranfile.py are used to plot the particles, but the tree plot itself
doesn't require them.

mpl_scatter_density can be found on
https://github.com/astrofrog/mpl-scatter-density and/or installed via pip:
$ pip install --user mpl-scatter-density

matplotlib2tikz can be found on
https://github.com/nschloe/matplotlib2tikz and/or installed via pip:
$ pip install --user matplotlib2tikz
"""

#===============================================================================



#=======================
# Import libraries
#=======================
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
plt.ioff()

# set matplotlib font to TeX
mpl.rc('font', **{'family':'serif',
    'serif':['Computer Modern Roman'],
    'monospace': ['Computer Modern Typewriter']}
    )
mpl.rc('text', usetex=False)





#===================
class _branch_x:
#===================
    """
    This class makes objects to store in lists to
    dynamically allocate and track the x coordinate
    of branches, because in lists, objects are pointed to,
    not copied.
    """

    def __init__(self):
        self.x = 0
        self.ymin = 0
        self.ymax = 0
        return

    def set_x(self, x):
        self.x = x
        return

    def set_y(self,y):
        if self.ymax == 0:
            self.ymax = y
            self.ymin = y
        else:
            self.ymax = max(self.ymax, y)
            self.ymin = min(self.ymin, y)




#======================
class global_params:
#======================
    """
    An object to store all global parameters in, so you can only
    pass 1 object to functions that need it.

    """


    #======================
    def __init__(self):
    #======================
        """
        Initialises object.
        """

        self.from_backup    = False # whether to read in data from backup instead of generating new tree
        self.galaxies       = False # whether to draw little stars for galaxies of clumps
        self.halo           = 0     # the root halo for which to plot for
        self.label_clumps   = False # whether to label clumps in particle plots
        self.lastdir        = ''    # last output directory
        self.movie          = False # whether to create particle plots pretty for movies.
        self.ncpu           = 0     # how many processors were used for simulation run
        self.nocolors       = False # whether to use only black lines
        self.nojumperdraw   = False # whether to draw apparent mergers for jumpers
        self.noutput        = 0     # number of outputs to work with
        self.outputfilename = ''    # filename for finished tree image
        self.pdf            = False # whether to save a pdf instead of a png
        self.plotparticles  = False # whether to also create plots with particles
        self.prefix         = ''    # filename prefix
        self.start          = 0     # which output directory to start with
        self.tex            = False # whether to save figure as a tex file instead of a png file
        self.treeclumplabels= True  # whether to label clumps in tree plots
        self.use_t          = False # whether to use time instead of redshift for y-axis labels
        self.verbose        = False # whether to print more details of what this program is doing
        self.workdir        = ''    # current working directory
        self.write_backup   = False # whether to write backup dump




        # dictionnary of accepted keyword command line arguments
        self.accepted_args = {
            '-fb':                  self.set_read_backup,
            '--from-backup':        self.set_read_backup,
            '-g':                   self.set_galaxies,
            '--galaxy':             self.set_galaxies,
            '--galaxies':           self.set_galaxies,
            '-h':                   self.print_help,
            '--help':               self.print_help,
            '-lc':                  self.set_clumplabels,
            '--label-clumps' :      self.set_clumplabels,
            '-m':                   self.set_movie,
            '--movie':              self.set_movie,
            '-nc':                  self.set_nocolors,
            '--no-colors':          self.set_nocolors,
            '-njd':                 self.set_nojumperdraw,
            '--no-jumper-diagonals':self.set_nojumperdraw,
            '-nl':                  self.set_notreeclumplabels,
            '--no-labels':          self.set_notreeclumplabels,
            '-pdf':                 self.set_pdf,
            '--pdf':                self.set_pdf,
            '-p':                   self.set_pretty,
            '--pretty':             self.set_pretty,
            '-pp':                  self.set_plotparticles,
            '--plotparticles' :     self.set_plotparticles,
            '-s' :                  self.set_start,
            '--start_at':           self.set_start,
            '-T' :                  self.set_tex,
            '--tex' :               self.set_tex,
            '-t' :                  self.set_yaxis_labels,
            '--use-time' :          self.set_yaxis_labels,
            '-v' :                  self.set_verbose,
            '--verbose' :           self.set_verbose,
            '-wb':                  self.set_write_backup,
            '--write-backup':       self.set_write_backup
            }


        # Define a long list of colors for the branches
        self.colorlist=[
                'black',
                'red',
                'green',
                'gold',
                'cornflowerblue',
                'lime',
                'magenta',
                'orange',
                'mediumpurple',
                'deeppink',
                'lightgreen',
                'saddlebrown',
                'orchid',
                'mediumseagreen']

        return





    #=============================
    def read_cmdlineargs(self):
    #=============================
        """
        Reads in the command line arguments and stores them in the
        global_params object.
        """
        from sys import argv

        nargs = len(argv)
        i = 1 # first cmdlinearg is filename of this file, so skip it

        while i < nargs:
            arg  = argv[i]
            arg = arg.strip()
            if arg in self.accepted_args.keys():
                if arg in ['-s', '-start_at']:
                    startnr = argv[i+1]
                    try:
                        startnr = int(startnr)
                        self.accepted_args[arg](startnr)
                    except ValueError:
                        print('"'+argv[i+1]+'"',
                        "is not a valid start number for an output directory.")
                        print("use mergertreeplot.py -h or --help to print help message.")
                        quit()
                    i += 1
                else:
                    self.accepted_args[arg]()
            else:
                try:
                    self.halo = int(arg)
                except ValueError:
                    print("I didn't recognize the argument '", arg, "'")
                    print("use mergertreeplot.py -h or --help to print help message.")
                    quit()

            i+= 1


        # defensive programming
        if self.halo <= 0:
            print("No or wrong halo given. Halo ID must be > 0")
            print("use mergertreeplot.py -h or --help to print help message.")
            quit()


        return




    #==========================
    def get_output_info(self):
    #==========================
        """
        Read in the output info based on the files in the current
        working directory.
        Reads in last directory, ncpu, noutputs.
        """

        from os import getcwd
        from os import listdir

        self.workdir = getcwd()
        filelist = listdir(self.workdir)

        outputlist = []
        for filename in filelist:
            if filename.startswith('output_'):
                outputlist.append(filename)


        if len(outputlist)<1:
            print("I didn't find any output_XXXXX directories in current working directory.")
            print("Are you in the correct workdir?")
            print("use mergertreeplot.py -h or --help to print help message.")
            quit()


        outputlist.sort()

        self.lastdir = outputlist[-1]
        self.lastdirnr = int(self.lastdir[-5:])
        self.noutput = len(outputlist)

        if (self.start > 0):
            # check that directory exists
            startnrstr = str(self.start).zfill(5)
            if 'output_'+startnrstr not in outputlist:
                print("Didn't find specified starting directory output_"+startnrstr)
                print("use mergertreeplot.py -h or --help to print help message.")
                quit()



        # read ncpu from infofile in last output directory
        infofile = self.lastdir+'/'+'info_'+self.lastdir[-5:]+'.txt'
        f = open(infofile, 'r')
        ncpuline = f.readline()
        line = ncpuline.split()

        self.ncpu = int(line[-1])

        return




    #========================
    def print_help(self):
    #========================
        """
        Print error message. Used as class method for simplicity to
        fit in with other cmdline arg triggered actions.
        """
        print(errormsg)
        quit()
        return



    #========================
    def print_params(self):
    #========================
        """
        Prints the current parameters to screen.
        Parameters:
            none
        Returns:
            nothing
        """

        print(" ===================================================================")
        print(" ")
        print(" Started mergertreeplot.py")
        print(" Working parameters are:")
        print(" ")
        print('{0:45}{1:<10d}'.format(      " Halo:",                                   self.halo))
        print('{0:45}{1:<14}'.format(       " Last available directory:",               self.lastdir))
        if self.start > 0:
            print('{0:45}{1:<10d}'.format(  " You've given starting snapshot number",   self.start))
        print('{0:45}{1:<10d}'.format(      " Number of outputs:",                      self.noutput))
        print('{0:45}{1:<10d}'.format(      " Ncpus used for sim:",                     self.ncpu))
        print(" ")

        if self.use_t:
            labelname = 'time (this also assumes non-cosmo simulation)'
        else:
            labelname = 'redshift'

        print('{0:45}{1:}'.format(          " Clumps in tree plot will be labelled?",   self.treeclumplabels))
        print('{0:45}{1:<14}'.format(       " y-axis label will be:",                   labelname))

        if self.tex:
            form = 'tex'
        elif self.pdf:
            form = 'pdf'
        else:
            form = 'png'
        print('{0:45}{1:<4}'.format(        " Image format:",                           form))
        print('{0:45}{1:}'.format(          " Using colors for tree?",                  (not self.nocolors)))
        print('{0:45}{1:}'.format(          " Pseudomergers of jumpers will be drawn?", (not self.nojumperdraw)))


        print('{0:45}{1:}'.format(          " Particles will be plotted?",              self.plotparticles))
        if self.plotparticles:
            print('{0:45}{1:}'.format(      " Clumps in particle plots will be labelled?", self.label_clumps))
            print('{0:45}{1:}'.format(      " Formatting/annotating image for movie?",  self.movie))


        print(" ")
        print('{0:45}{1:}'.format(          " Dumping tree backup?",                    self.write_backup))
        print('{0:45}{1:}'.format(          " Reading data from tree backup?",          self.from_backup))


        print(" ")
        print(" ===================================================================")

        return



    #========================
    # Setter methods
    #========================

    def set_clumplabels(self):
        self.label_clumps = True
        return

    def set_galaxies(self):
        self.galaxies = True
        return

    def set_movie(self):
        self.movie = True
        return

    def set_notreeclumplabels(self):
        self.treeclumplabels = False
        return

    def set_nocolors(self):
        self.nocolors = True
        self.colorlist = ['black' for i in self.colorlist]
        return

    def set_nojumperdraw(self):
        self.nojumperdraw = True
        return

    def set_outputfilename(self, outnr):
        if self.plotparticles:
            self.set_prefix(outnr)

        # output filename for image
        self.outputfilename = self.prefix+"merger_tree_"+str(outnr).zfill(5)+"_halo_"+str(self.halo)
        return

    def set_pdf(self):
        self.pdf = True
        return

    def set_plotparticles(self):
        self.plotparticles = True
        return

    def set_prefix(self, outnr):
        import  os
        # get subdirectory name
        self.prefix = 'merger_tree_'+str(outnr).zfill(5)+'_halo_'+str(self.halo)+'/'
        # create subdirectory if it doesn't exist
        if not os.path.exists(self.prefix):
            os.makedirs(self.prefix)
        return

    def set_pretty(self):
        self.set_notreeclumplabels()
        self.set_nojumperdraw()
        self.set_tex()
        self.set_nocolors()
        return

    def set_read_backup(self):
        self.from_backup = True
        return

    def set_start(self, startnr):
        self.start = startnr
        return

    def set_tex(self):
        self.tex = True
        return

    def set_verbose(self):
        self.verbose = True
        return

    def set_write_backup(self):
        self.write_backup = True
        return

    def set_yaxis_labels(self):
        self.use_t = True
        return


params = global_params()



#==================
class Node:
#==================
    """
    A class for each node to be drawn.
    Parameters:
        id:         (sub)halo ID
        y:          y axis value for plot (time/redshift)
        desc_id:    the ID of the "main descendant". Needed to distinguish
                    jumps across multiple snapshots.

    """

    #-------------------------------------
    def __init__(self, ID, y, desc_id):
    #-------------------------------------
        """
        Arguments:
            id:         (sub)halo ID [any number]
            y:          y axis value for plot (time/redshift) [any number]
            desc_id:    the ID of the "main descendant". Needed to distinguish
                        jumps across multiple snapshots. [any number]

        """

        self.id = ID                # clump ID at timestep y
        self.main_desc_id = desc_id # id of "main descendant" to check for jumpers
        self.y = y                  # value for y axis: outputnumber where it was active clump
        self.x = None               # will be set later
        self.is_jumper = False      # whether this clump will re-emerge later

        self.progs = []             # list of progenitor nodes
        self.is_main_prog = []      # wether the i-th progenitor of the progs list is the main progenitor

        self.branches = []          # list of all branches at this node
        self.branches_level = 0     # how many times this branch will split
        self.branches_tot = 0       # total number of branches for this branch
                                    # Needed for nice plotting
        self.start_of_branch = None # Node at which the branch this clump is in starts
        self.walked = False         # whether this node has been walked over already


        return



    #-----------------------------------------------
    def add_progenitor(self, prog_node, is_main):
    #-----------------------------------------------
        """
        Adds a progenitor to this node's list.
        Arguments:
            prog_node:  node of the progenitor to add to list [class Node object]
            is_main:    whether this progenitor is the main progenitor of this node [boolean]
        """
        self.progs.append(prog_node)
        self.is_main_prog.append(is_main)
        return





#=====================================
def _draw_tree(node, ax, strategy):
#=====================================
    """
    Plots all connections of Node node to its progenitors (by calling
    _draw_tree_on_plot), then calls itself recursively.

    Arguments:
        node:       class Node object whose progenitors are to be plotted
        ax:         axis object of the plot
        strategy:   for what to adapt plot for

    returns:
        nothing
    """


    node.walked = True

    for i, prog in enumerate(node.progs):
        # call actual drawing function
        if prog.x.x != node.x.x : # straight lines are already drawn.
            _draw_tree_on_plot(node, prog, ax, strategy)
        # call yourself recursively for the main progenitor
        if (not prog.walked) or prog.is_jumper:
            _draw_tree(prog, ax, strategy)


    # when you're done, add labels
    if params.treeclumplabels:
        _draw_tree_labels(node, ax, strategy)


    return





#====================================================
def _draw_straight_lines(x_values, ax, strategy):
#====================================================
    """
    draws a straight line from start to end of branch of all branches.
    Parameters:
        x_values:   x values of branches. List of class _branch_x objects
        ax:         ax object on which to plot
        strategy:   for what to adapt plot for

    Returns:
        nothing
    """

    # Determine line-style
    # Dashed for mergers that will re-emerge later
    linestyle='-'
    alp = 1
    zorder = 1
    if strategy=='small':
        linewidth = 2
        outerlinewidth = 4
    elif strategy=='default':
        linewidth = 4
        outerlinewidth = 8
    elif strategy=='ultrathin':
        linewidth = 0.25
        outerlinewidth = 0.5


    for x in x_values:
        xx = [x.x, x.x]
        y = [x.ymin, x.ymax]

        # get line colors
        # color index is same as x!
        inner, outer = _get_plotcolors(x.x)
        myfacecolor = params.colorlist[inner]
        myedgecolor = params.colorlist[outer]

        #---------------
        # Plot the line
        #---------------

        if not params.nocolors:
            # plot outer line
            ax.plot(xx, y,
                    color = myedgecolor,
                    lw=outerlinewidth,
                    ls='-', # outer line is never dashed
                    alpha=alp,
                    zorder=zorder)

        # plot inner line
        ax.plot(xx, y,
                color = myfacecolor,
                lw=linewidth,
                ls=linestyle,
                zorder=zorder)



    return





#==================================================
def _draw_tree_on_plot(node, prog, ax, strategy):
#==================================================
    """
    The actual drawing function of the tree. Draws a line between Node node
    and its progenitor prog.
    If the progenitor re-imerges at a later timestep, draw a dotted line
    instead of a solid line to signify where it looks like it  merged into.

    Arguments:
        node:       class Node object of descendant
        prog:       class Node object of progenitor
        ax:         axis object of the plot
        strategy:   for what to adapt plot for

    returns:
        nothing
    """


    # get x and y values for plot
    x = [node.x.x, prog.x.x]
    y = [node.y, prog.y]


    # get line colors
    # color index is same as x!
    inner, outer = _get_plotcolors(prog.x.x)
    myfacecolor = params.colorlist[inner]
    myedgecolor = params.colorlist[outer]

    # Determine line-style
    # Dashed for mergers that will re-emerge later

    linestyle='-'
    alp = 1
    zorder = 1

    if strategy=='small':
        linewidth = 2
        outerlinewidth = 4
    elif strategy=='default':
        linewidth = 4
        outerlinewidth = 8
    elif strategy=='ultrathin':
        linewidth = 0.25
        outerlinewidth = 0.5

    # if jumper: make dashed line
    if  node.id != prog.main_desc_id:
        if params.nojumperdraw:
            return
        else:
            linestyle = '--'
            zorder = 0
            alp = 0.2

            if strategy=='small':
                linewidth = 1
                outerlinewidth = 2
            elif strategy=='default':
                linewidth = 3
                outerlinewidth = 4
            elif strategy=='ultrathin':
                linewidth = 0.1
                outerlinewidth = 0.17



    #---------------
    # Plot the line
    #---------------

    if not params.nocolors:
        # plot outer line
        ax.plot(x, y,
                color = myedgecolor,
                lw=outerlinewidth,
                ls='-', # outer line is never dashed
                alpha=alp,
                zorder=zorder)

    # plot inner line
    ax.plot(x, y,
            color = myfacecolor,
            lw=linewidth,
            ls=linestyle,
            zorder=zorder)



    return





#============================================
def _draw_tree_labels(node, ax, strategy):
#============================================
    """
    Draw a label for the node with the clump ID.

    Arguments:
        node:       class Node object of descendant
        ax:         axis object of the plot
        strategy:   how to adapt plot
        onlylabel:  whether to only draw nodes, no lines

    returns:
        nothing
    """


    if strategy=="small":
        textsize = 4
    elif strategy=="default":
        textsize = 6
    elif strategy=="ultrathin":
        textsize = 2

    #---------------
    # Annotation
    #---------------

    inner, outer = _get_plotcolors(node.x.x)
    myfacecolor = params.colorlist[inner]
    myedgecolor = params.colorlist[outer]

    # Annotate the dots with the clump ID

    if strategy=="small":
        bbox_props = dict(boxstyle="square,pad=0.05",
            fc=myfacecolor,
            ec=myedgecolor,
            lw=0.5,
            alpha=1)

    elif strategy=="default":
        bbox_props = dict(boxstyle="square,pad=0.15",
            fc=myfacecolor,
            ec=myedgecolor,
            lw=1.5,
            alpha=1)

    elif strategy=="ultrathin":
        bbox_props = dict(boxstyle="square,pad=0.03",
            fc=myfacecolor,
            ec=myedgecolor,
            lw=0.1,
            alpha=1)

    ax.text(node.x.x, node.y, str(node.id),
            size=textsize,
            bbox=bbox_props,
            horizontalalignment = 'center',
            verticalalignment = 'center',
            rotation=45,
            zorder=2)


    return





#==============================================
def dump_backup(tree, outputnrs, t):
#==============================================
    """
    Dump backup of tree, outputnrs, t and some params attributes
    so you can restart after tree-making

    Parameters:
        tree:       list of lists of nodes, of the tree
        outputnrs:  numpy array of snapshot nrs
        t:          numpy array of times/redshifts of snapshots

    Returns:
        Nothing
    """
    import pickle

    dumpfname = params.workdir+'/dump_mergertreeplot_halo-'+str(params.halo)+'.pickle'

    if params.verbose:
        print("Dumping backup to ", dumpfname)

    dumpfile = open(dumpfname,'w')
    pickle.dump(tree,                   dumpfile)
    pickle.dump(outputnrs,              dumpfile)
    pickle.dump(t,                      dumpfile)
    pickle.dump(params.outputfilename,  dumpfile)
    pickle.dump(params.start,           dumpfile)
    dumpfile.close()
    return





#==============================================
def _get_plotcolors(color_index):
#==============================================
    """
    Gets inner and outer plot line / scatterpoint colors.
    NEVER returns zero, zeroth index is black, reserved for special cases.
    (= for particles not in the tree)

    Parameters:
        color_index:    index for this color-pair

    Returns:
        inner, outer:   indices for inner / outer colors from colorlist
    """

    shortlen = len(params.colorlist) - 1
    multiples = color_index//shortlen
    inner = color_index - (multiples) * shortlen + 1
    while multiples >= shortlen:
        multiples = multiples//shortlen
    outer = -(multiples+1)
    if outer+shortlen+1 == inner:
        outer -= 1
        if outer+shortlen==0:
            outer -= 1

    return inner, outer





#=========================================
def _get_x(node, x_values, straight):
#=========================================
    """
    Assign x values for the plot to a node and its progenitors.

    Function summary:
    if straight = True:
        recursively descend down the line of main
        progenitors, then return to where you left off.
        Main progenitors inherit their descendants x object.
    else:
        if this node's progenitor hasn't been walked over already,
        assign a new x-value to it. Then call this function
        recursively for all appropriate progenitors of this node
        with straight=True
    For all appropriate progenitors:
        call this function recursively for every appropriate
        progenitor of this node with straight=False

    Arguments:
        node:       class Node object to check for
        x_values:   list of _branch_x objects
        straight:   whether to go straight first

    returns:
        x_values:   list of _branch_x objects

    """


    # Loop over progenitors
    for i, prog in enumerate(node.progs):

        if straight:
            # Only go down main progenitors
            if not prog.walked and node.is_main_prog[i]:
                # Prog inherits descendant's x
                prog.x = node.x
                prog.x.set_y(prog.y) # update min/max y
                prog.walked = True
                if prog.id > 0:
                    # call yourself recursively
                    x_values = _get_x(prog, x_values, straight=True)
                # when you're done, return! Need to go down branches
                # from top to bottom!
                return x_values


        else:
            if (not node.is_main_prog[i]) and (not prog.is_jumper):

                # if the progenitor doesn't have an x already, give it one
                if not prog.walked:
                    ids = (branch.id for branch in prog.start_of_branch.branches)
                    ids = list(ids)

                    bind = ids.index(prog.id)
                    new_x = _branch_x()
                    new_x.set_y(prog.y)

                    # figure out which way to go
                    x = x_values.index(node.x)
                    if bind%2==0: # go right
                        x_values = x_values[:x+1]+[new_x]+x_values[x+1:]
                        prog.x = x_values[x+1]
                    else: # go left
                        x_values = x_values[:x]+[new_x]+x_values[x:]
                        prog.x = x_values[x]

                    prog.walked = True

                # descend down branch along main progenitors only
                if prog.id > 0:
                    x_values = _get_x(prog, x_values, straight=True)


    # When you're done going down straight lines, now check for progenitor's progenitors
    # IMPORTANT: DO THIS ONLY AFTER PREVIOUS FOR LOOP IS DONE
    for i,prog in enumerate(node.progs):
        descend = (not prog.is_jumper) or node.is_main_prog[i]
        if prog.id > 0 and descend:
            x_values = _get_x(prog, x_values, straight=False)



    return x_values





#======================================================================================
def make_tree(progenitors, descendants, progenitor_outputnrs, outputnrs, t):
#======================================================================================
    """
    makes a tree out of read in lists.
    The tree is a list containing lists of class Node objects
    for each output step.

    parameters:
        progenitors:            list of numpy arrays of progenitors for each timestep
        descendants:            list of numpy arrays of descendants for each timestep
        progenitor_outputnrs:   list of numpy arrays of the output numbers when
                                any given progenitor was an active clump
        outputnrs:              numpy array of snapshot numbers
        t:                      numpy array of times/redshifts of snapshots

    returns:
        tree:                   list of lists of nodes constituting the tree
    """

    #------------------
    # Setup
    #------------------

    halo = params.halo
    lastdirnr = params.lastdirnr

    nout_present = len(progenitors)
    output_start = lastdirnr - nout_present

    if params.verbose:
        print("Creating tree.")
        print(" ")



    #------------------------------------
    # Find which output to start with
    #------------------------------------
    # if params.start != 0, then you have manually set start dir.
    startind = 0
    if params.start > 0:
        startind = outputnrs[outputnrs == params.start]
    else:
        if not params.use_t:
            # find output closest to z=0
            startind = np.argmin(np.absolute(t))
        else:
            print("Since you're using -t, I can't find the z=0 directory.")

        if params.verbose:
            print("Snapshot of the tree root is", outputnrs[startind])
    params.start = startind


    # set filename
    params.set_outputfilename(outputnrs[startind])

    #---------------------------------------------------------------------
    # Handle jumpers: Find out whether they'll eventually really merge
    # into the tree or exit later.
    #---------------------------------------------------------------------

    if params.verbose:
        print("Sorting out jumpers")

    jumpers_not_in_tree = [[] for i in range(nout_present+1)]

    for out in range(0, startind):
        for i in range(progenitors[out].shape[0]):
            if progenitors[out][i] < 0 : # found a jumper
                snapnr = progenitor_outputnrs[out][i]       # snapshot nr for progenitor
                ind = nout_present + output_start - snapnr  # index in tree / tree level
                jumpers_not_in_tree[ind].append(-progenitors[out][i])


    for out in range(startind,nout_present):
        for i in range(progenitors[out].shape[0]):
            if progenitors[out][i] < 0 : # found a jumper


                snapnr = progenitor_outputnrs[out][i]       # snapshot nr for progenitor
                ind = nout_present + output_start - snapnr  # index in tree / tree level

                current_id = descendants[out][i]


                if current_id in jumpers_not_in_tree[out]:
                    jumpers_not_in_tree[ind].append(-progenitors[out][i])

                else:
                    # go down tree until you find the root, or an already excluded jumper
                    out2 = out
                    reached_end = (out2 == startind) # don't loop if you're at root
                    while not reached_end:
                        parent_ind = progenitors[out2-1] == current_id

                        if not parent_ind.any(): # if there is not at least one True
                            break

                        current_id = np.abs(descendants[out2-1][parent_ind])

                        if current_id in jumpers_not_in_tree[out2-1]:
                            # just break prematurely, it will be added outside the while loop.
                            #  jumpers_not_in_tree[ind].append(-progenitors[out][i])
                            break
                        out2 -= 1
                        reached_end = (out2==startind)


                    # now check that you got the right root, or whether you made it far enough in time
                    if (current_id != halo) or (out2!=startind):
                        jumpers_not_in_tree[ind].append(-progenitors[out][i])




    #---------------------
    # initialise tree
    #---------------------

    # create empty list for each output
    tree = [[] for i in range(startind, nout_present+1)]

    # enter root
    rootnode = Node(halo, outputnrs[startind], 0)
    tree[0]=[rootnode]



    #---------------------
    # Make tree
    #---------------------

    print(" ")
    print("----------------------------------------------------------------------")

    # Loop over snapshots from startind until first snapshot containing tree data
    for out in range(startind, nout_present):
        found_one = False
        treeind = out - startind

        # for each branch of the tree at that snapshots:
        for branch in tree[treeind]:

            # find for which progenitor the descendant matches
            for i in range(progenitors[out].shape[0]) :
                # if progenitor is not a jumper that will re-emerge later:
                progid = abs(progenitors[out][i])           # progenitor ID

                if progid not in jumpers_not_in_tree[out+1]:
                    snapnr = progenitor_outputnrs[out][i]       # snapshot nr for progenitor
                    ind = nout_present + output_start - snapnr - startind # index in tree / tree level

                    if abs(descendants[out][i]) == branch.id:
                        found_one = True
                        is_main = descendants[out][i] == branch.id
                        # is a main progenitor if desc ID == branch ID
                        # is a merger if desc ID == -branch ID

                        # Check first if progenitor is already in list
                        prog_not_in_list = True

                        if (progid > 0):    # always add a new 0!
                            if (len(tree[ind]) > 0):
                                for j, candidate in enumerate(tree[ind]):
                                    if (candidate.id == progid):
                                        # Then this progenitor is already in the tree.
                                        branch.add_progenitor(tree[ind][j], is_main)
                                        prog_not_in_list = False
                                        break

                        if prog_not_in_list:
                            # create new node for progenitor
                            newnode = Node(progid, snapnr, branch.id)
                            if branch.y-snapnr > 1 :
                                newnode.is_jumper = True

                            # add new node to tree
                            #  print ind, nout_present, output_start, snapnr, progenitor_outputnrs[out][i], progid
                            tree[ind].append(newnode)

                            # link to its progenitor: you know in which outputnr it will be
                            # since it's the last added element, it's index will be len(tree at that outputnr) - 1
                            branch.add_progenitor(tree[ind][-1], is_main)


                        # Print informations
                        output_string = 'Adding progenitor '
                        output_string += '{0:7d}'.format(progid)+' '
                        output_string += 'for descendant '
                        output_string += '{0:7d}'.format(branch.id)+' '
                        output_string += "|| snapshot "
                        output_string += '{0:3d}'.format(snapnr)+' '
                        output_string += "->"
                        output_string += '{0:3d}'.format(branch.y)

                        if (is_main):
                            if branch.y-snapnr > 1 :
                                output_string += "   found jumper"
                            print(output_string)

                        else:
                            if branch.y-snapnr > 1 :
                                output_string += "   found jumper and merger"
                            else:
                                output_string += "   found merger"
                            print(output_string)

        # make new line after every snapshot read-in
        if found_one:
            print("----------------------------------------------------------------------")

    print("\n\n")


    # remove empty lists at the end of the tree:
    # clumps might not have been around since the first output

    list_is_empty = (len(tree[-1]) == 0)

    while list_is_empty:
        del tree[-1]
        list_is_empty = len(tree[-1]) == 0


    del jumpers_not_in_tree


    return tree





#===========================================================================
def _plot_particles(partdata, galaxydata, time, clumps_in_tree, colors, outnr):
#===========================================================================
    """
    This function plots the particles for this output. Is meant
    to be called for every output by the plot_treeparticles function.

    Parameters:
        partdata:       list containing numpy arrays of particle positions, clump
                        IDs and particle IDs. (If params.galaxies=False, idp=None)
        galaxydata:     list containing numpy arrays of galaxy positions, clump
                        IDs and particle IDs. (If params.galaxies=False, all=None)
        time:           time or redshift of output
        clumps_in_tree: list of clumps in tree at this output
        colors:         the colors of clumps in trees [list of integers]
        outnr:          current snapshot number

    Returns:
        nothing
    """

    import matplotlib.gridspec as gridspec
    import matplotlib.colors as mc
    import mpl_scatter_density


    if params.verbose:
        print("Creating particle figure")

    x = partdata[0]
    y = partdata[1]
    z = partdata[2]
    clumpid = partdata[3]


    #---------------------------
    # Set up figure
    #---------------------------

    if params.movie:
        fig = plt.figure(facecolor='white',figsize=(8,8), dpi=150)
        ax1 = fig.add_subplot(111)

    else:
        fig = plt.figure(facecolor='white', figsize=(15,6), dpi=150)

        # add subplots on fixed positions
        gs = gridspec.GridSpec(1, 3,
                           width_ratios=[1, 1, 1], height_ratios=[1],
                           left=0.04, bottom=0.13, right=0.98, top=0.9, wspace=0.12,)
        ax1 = fig.add_subplot(gs[0], aspect='equal', projection='scatter_density')
        ax2 = fig.add_subplot(gs[1], aspect='equal', projection='scatter_density')
        ax3 = fig.add_subplot(gs[2], aspect='equal', projection='scatter_density')



    #------------------------------
    # Setup plot region
    #------------------------------

    to_plot_all = np.array(x.shape, dtype='bool')
    to_plot_all = False

    if params.movie:
        colorarray = ["" for i in clumps_in_tree]
    else:
        colorarray = np.zeros(x.shape, dtype='int')


    for i,clump in enumerate(clumps_in_tree):
        is_this_clump = (clumpid == clump)
        to_plot_all = np.logical_or(to_plot_all,is_this_clump)
        fc, ec = _get_plotcolors(colors[i])
        if params.movie:
            colorarray[i] = fc
        else:
            colorarray[is_this_clump] = fc

    borders = _setup_plot_region(x, y, z, to_plot_all)
    xmin, xmax, ymin, ymax, zmin, zmax = borders


    # TODO: Comment out for general usage
    # If you want to manually set up axis limits, then do it here:
    #  xmin = 400
    #  xmax = 800
    #  ymin = 300
    #  ymax = 700
    #  zmin = 0         # don't forget to include big enough zmin/zmax!
    #  zmax = 1000
    #  borders = (xmin, xmax, ymin, ymax, zmin, zmax) # overwrite borders for later use


    # reset to_plot here!
    # plot only particles within the limits
    to_plot_all = ((x <= xmax) & (x >= xmin))
    to_plot_all = to_plot_all & ((y <= ymax) & (y >= ymin))
    to_plot_all = to_plot_all & ((z <= zmax) & (z >= zmin))





    #---------------------------
    # Actual plotting
    #---------------------------

    # first plot particles not in tree
    to_plot = to_plot_all & (clumpid==0)

    try:

        if params.movie:

            msize = 4
            alpha = 0.5
            lw = 0

            try:
                ax1.scatter(x[to_plot], y[to_plot],
                    c=params.colorlist[0],
                    zorder=0,
                    s=msize,
                    alpha=alpha,
                    lw=lw)
            except ValueError:
                # in case there were no particles to plot:
                print("No particles outside of tree were found to be plotted.")


            # now plot particles in tree
            for i, c in enumerate(clumps_in_tree):
                to_plot = to_plot_all & (clumpid==c)
                ax1.scatter(x[to_plot], y[to_plot],
                    c=params.colorlist[colorarray[i]],
                    zorder=1,
                    s=msize,
                    lw=lw,
                    alpha=0.7)


        else: # not movie

            # set up colormap
            bounds=np.linspace(0, len(params.colorlist), len(params.colorlist)+1)
            cmap=mc.ListedColormap(params.colorlist, name='My colormap')
            norm=mc.BoundaryNorm(bounds, len(params.colorlist))


            try:
                ax1.scatter_density(x[to_plot], y[to_plot], c=colorarray[to_plot],
                    cmap=cmap, norm=norm, dpi=150, alpha=1, zorder=0)
                ax2.scatter_density(y[to_plot], z[to_plot], c=colorarray[to_plot],
                    cmap=cmap, norm=norm, dpi=150, alpha=1, zorder=0)
                ax3.scatter_density(x[to_plot], z[to_plot], c=colorarray[to_plot],
                    cmap=cmap, norm=norm, dpi=150, alpha=1, zorder=0)
            except ValueError:
                # in case there were no particles to plot:
                print("No particles outside of tree were found to be plotted.")



            # now plot particles in tree
            for i, c in enumerate(clumps_in_tree):
                to_plot = to_plot_all & (clumpid==c)
                ax1.scatter_density(x[to_plot], y[to_plot], c=colorarray[to_plot],
                    cmap=cmap, norm=norm, dpi=72, alpha=1, zorder=1)
                ax2.scatter_density(y[to_plot], z[to_plot], c=colorarray[to_plot],
                    cmap=cmap, norm=norm, dpi=72, alpha=1, zorder=1)
                ax3.scatter_density(x[to_plot], z[to_plot], c=colorarray[to_plot],
                    cmap=cmap, norm=norm, dpi=72, alpha=1, zorder=1)


                if params.label_clumps:
                    #------------------------
                    # create clump labels
                    #------------------------
                    inner, outer = _get_plotcolors(colors[i])
                    fc = params.colorlist[inner]
                    ec = params.colorlist[outer]

                    bbox_props = dict(boxstyle="round,pad=0.15",
                            fc=fc,
                            ec=ec,
                            lw=1.5,
                            alpha=1)

                    xc = np.mean(x[to_plot])
                    yc = np.mean(y[to_plot])
                    zc = np.mean(z[to_plot])

                    ax1.text(xc, yc, str(c),
                        size=6,
                        bbox=bbox_props,
                        horizontalalignment='center',
                        verticalalignment='center',
                        )
                    ax2.text(yc, zc, str(c),
                        size=6,
                        bbox=bbox_props,
                        horizontalalignment='center',
                        verticalalignment='center',
                        )
                    ax3.text(xc, zc, str(c),
                        size=6,
                        bbox=bbox_props,
                        horizontalalignment='center',
                        verticalalignment='center',
                        )


    except ValueError:
        #in case there were no particles to plot:
        print("WARNING: ValueError while plotting clump", c, "raised.")
        print("WARNING: Maybe clump has no particles for some reason?")
        print("WARNING: Or are you playing around with the code?")
        acceptable_answer = False
        while not acceptable_answer:
            answer = raw_input("Do you still want to continue? (y/n) ")
            if (answer == 'y'):
                acceptable_answer = True
                break
            elif (answer == 'n'):
                print("Exiting.")
                quit()
            else:
                print("Please answer with 'y' or 'n'")


    if params.galaxies:
        _plot_galaxies(fig, galaxydata, clumps_in_tree, colors, borders)





    #-------------------------------
    # Tweak and save figure
    #-------------------------------

    _tweak_particleplot(fig, time, borders, outnr)

    _save_fig(fig, plotparticles=True)

    plt.close()

    return





#=============================================
def plot_tree(tree, yaxis_int, yaxis_phys):
#=============================================
    """
    The main function for plotting the tree.
    After some preparation, it distributes x-values to branches
    and then plots them.

    Arguments:
        tree:       list of lists of class Node objects, constituting the tree
        yaxis_int:  array for y axis ticks containing integers
                    (= output numbers)
        yaxis_phys: array for y axis ticks containing physical data
                    (= redshift or time)

    returns:
        nothing
    """


    #-----------------
    # Preparation
    #-----------------

    if params.verbose:
        print("Preparing to plot tree.")

    # First find the number of branches for each branch.
    tree[0][0].start_of_branch = tree[0][0]
    _walk_tree(tree[0][0], tree[0][0])

    # Now recursively sum up the number of branches
    # to know how much space to leave between them for plotting
    if params.verbose:
        print("Sorting branches. This might take a while.")
    _sum_branches(tree[0][0])

    # lastly, sort the root's branches
    _sort_branch(tree[0][0])

    # reset whether nodes have been walked for _get_x()
    for level in tree:
        for branch in level:
            branch.walked = False




    #--------------------------------
    # start distributing x values
    #--------------------------------

    if params.verbose:
        print("Assigning positions to nodes and branches.")

    # give initial values for root and borders for axes
    x_values = list([_branch_x()])
    tree[0][0].x = x_values[0]
    tree[0][0].x.set_y(tree[0][0].y)

    x_values = _get_x(tree[0][0], x_values, straight=True)
    x_values = _get_x(tree[0][0], x_values, straight=False)



    # once you're done, assign actual integer values for nodes
    for i,x in enumerate(x_values):
        x.set_x(i)

    # get x-axis borders for plot
    borders = [-len(x_values)/20, len(x_values)*(1+1/20)]





    #---------------------
    # Plot the tree
    #---------------------

    if params.verbose:
        print("Plotting tree with", tree[0][0].branches_tot+1, \
                "branches in total over", tree[0][0].y - tree[-1][0].y + 1, \
                "output steps.")
        print("Creating figure.")

    # create figure
    fig = plt.figure(facecolor='white')
    ax = fig.add_subplot(111)


    # reset whether nodes have been walked for _draw_tree()
    for level in tree:
        for branch in level:
            branch.walked = False

    # check whether you need to use small versions
    strategy = "default"
    if tree[0][0].branches_tot > 400:
        strategy = "small"
        if params.verbose :
            print("This plot apparently has a loooot of branches. Plotting everything smaller.")
    if (not params.treeclumplabels) and params.tex:
        strategy = "ultrathin"

    # draw the tree
    _draw_straight_lines(x_values, ax, strategy)
    _draw_tree(tree[0][0], ax, strategy)

    # Tweak the plot
    _tweak_treeplot(fig, yaxis_int, yaxis_phys, borders)

    # Save the figure
    _save_fig(fig, strategy)




    return





#=================================================
def plot_treeparticles(tree, yaxis_phys):
#=================================================
    """
    The main function to plot the particles currently in trees.
    Calls _plot_particles for every output in the tree.
    NOTE: This function assumes that the colors for the tree have
    already been assigned and takes the same color of the node
    for the particles / clump labels.

    Parameters:
        tree:       list of list of class Node objects containing the tree
        yaxis_phys: array for y axis ticks containing physical data
                    (= redshift or time)

    Returns:
        nothing
    """

    if params.verbose:
        print(" ")
        print(" ")
        print("======================================================================")
        print(" ")
        print("Started plotting tree particles.")
        print(" ")


    outnrs = range(params.lastdirnr-params.start, params.lastdirnr-params.start-len(tree), -1)

    # for every output in the tree:
    for i,out in enumerate(tree):

        # find which clumps are in the tree and their colors
        clumps_in_tree = []
        colors = []
        for branch in out:
            if branch.id != 0:
                clumps_in_tree.append(branch.id)
                colors.append(branch.x.x)

        # if there is work to be done
        if len(clumps_in_tree) > 0:

            outnr = outnrs[i]
            srcdir = 'output_'+str(outnr).zfill(5)

            # read in particles of this output
            x, y, z, clumpid = _read_particle_data(srcdir)

            if params.galaxies:
                xg, yg, zg, gid = _read_galaxy_data(srcdir)
            else:
                xg = None
                yg = None
                zg = None
                gid = None

            # get time/redshift
            time = yaxis_phys[i]

            # reset outputfilename
            if params.label_clumps:
                params.outputfilename = params.prefix+'particleplot-with-labels_'+str(outnr).zfill(5)
            else:
                params.outputfilename = params.prefix+'particleplot_'+str(outnr).zfill(5)

            # plot the particles
            partdata =   [x,  y,  z,  clumpid]
            galaxydata = [xg, yg, zg, gid    ]

            _plot_particles(partdata, galaxydata, time, clumps_in_tree, colors, outnr)

            del x, y, z, clumpid, clumps_in_tree, colors


    return






#======================================================================
def _plot_galaxies(fig,galaxydata,clumps_in_tree,colors,borders):
#======================================================================
    """
    Plots the galaxies.

    Parameters:

        fig:            matplotlib figure object
        galaxydata:     list of numpy arrays containing galaxy positions and IDs
        clumps_in_tree: list of clumps currently in tree
        colors:         x-value of clump in tree, used to determine its color
        borders:        list of plot borders (xmin, xmax, ...)

    Returns:
        Nothing
    """

    #---------------------------------
    # Non-Orphans
    #---------------------------------

    if params.verbose:
        print("Plotting galaxies with associated clumps.")

    xmin, xmax, ymin, ymax, zmin, zmax = borders

    if params.movie:
        ax = fig.axes[0]
    else:
        ax1, ax2, ax3 = fig.axes


    xg = galaxydata[0]
    yg = galaxydata[1]
    zg = galaxydata[2]
    galid = galaxydata[3]

    mask = galid>0

    msize = 200
    lw = 2
    ec = 'black'


    for i,g in enumerate(galid[mask]):
        for j,c in enumerate(clumps_in_tree):
            if g == c:
                color = params.colorlist[_get_plotcolors(colors[j])[0]]

                if params.movie:
                    ax.scatter(xg[mask][i],yg[mask][i],
                        marker="*",
                        s=msize,
                        facecolor=color,
                        edgecolor='black',
                        lw=lw,
                        zorder=4)
                else:
                    ax1.scatter(xg[mask][i],yg[mask][i],
                        marker="*",
                        s=msize,
                        facecolor=color,
                        edgecolor='black',
                        lw=lw,
                        zorder=4)
                    ax2.scatter(yg[mask][i],zg[mask][i],
                        marker="*",
                        s=msize,
                        facecolor=color,
                        edgecolor='black',
                        lw=lw,
                        zorder=4)
                    ax3.scatter(xg[mask][i],zg[mask][i],
                        marker="*",
                        s=msize,
                        facecolor=color,
                        edgecolor='black',
                        lw=lw,
                        zorder=4)

                break




    #-------------------------------
    # Orphans
    #-------------------------------


    # Find which orphans are supposed to be plotted
    to_plot = galid == 0
    to_plot = to_plot & ((xg <= xmax) & (xg >= xmin))
    to_plot = to_plot & ((yg <= ymax) & (yg >= ymin))
    to_plot = to_plot & ((zg <= zmax) & (zg >= zmin))

    if params.verbose:
        count = 0
        unique, counts = np.unique(to_plot, return_counts=True)
        for i,u in enumerate(unique):
            if u==True:
                count=counts[i]
                break
        print("Plotting ", count, "orphan galaxies.")


    if params.movie:

        ax.scatter(xg[to_plot],yg[to_plot],
            marker="*",
            s=msize,
            facecolor='black',
            edgecolor='black',
            lw=lw,
            zorder=4)
    else:
        ax1.scatter(xg[to_plot],yg[to_plot],
            marker="*",
            s=msize,
            facecolor='black',
            edgecolor='black',
            lw=lw,
            zorder=4)
        ax2.scatter(yg[to_plot],zg[to_plot],
            marker="*",
            s=msize,
            facecolor='black',
            edgecolor='black',
            lw=lw,
            zorder=4)
        ax3.scatter(xg[to_plot],zg[to_plot],
            marker="*",
            s=msize,
            facecolor='black',
            edgecolor='black',
            lw=lw,
            zorder=4)

    return






#==============================================
def read_backup():
#==============================================
    """
    Read backup dump of tree, outputnrs, t and some params attributes
    so you can restart after tree-making

    Returns:
        tree:       list of lists of nodes, of the tree
        outputnrs:  numpy array of snapshot nrs
        t:          numpy array of times/redshifts of snapshots
    """

    import pickle

    if params.verbose:
        print("Reading data from backup.")

    dumpfname = params.workdir+'/dump_mergertreeplot_halo-'+str(params.halo)+'.pickle'
    dumpfile = open(dumpfname,'r')
    tree                    = pickle.load(dumpfile)
    outputnrs               = pickle.load(dumpfile)
    t                       = pickle.load(dumpfile)
    params.outputfilename   = pickle.load(dumpfile)
    params.start            = pickle.load(dumpfile)
    dumpfile.close()

    return tree, outputnrs, t






#======================================
def _read_galaxy_data(srcdir):
#======================================
    """
    reads in galaxy data as written by the mergertree patch.
    NOTE: requires galaxies_XXXXX.txtYYYYY files.

    parameters:
        srcdir:         Path of output_XXXXX directory to work with

    returns:
        xg, yg, zg:     numpy arrays of x,y,z positions of galaxies
        galid:          numpy array of associated clump IDs of galaxies.
                        if element = 0, then galaxy is orphan.
    """

    import warnings
    import gc
    from os import listdir

    if params.verbose:
        print("Reading in galaxy data.")

    srcdirlist = listdir(srcdir)

    if 'galaxies_'+srcdir[-5:]+'.txt00001' not in srcdirlist:
        print("Couldn't find galaxies_"+srcdir[-5:]+".txt00001 in", srcdir)
        print("To plot particles, I require the galaxies output.")
        print("use mergertreeplot.py -h or --help to print help message.")
        quit()

    # create lists where to store stuff
    dir_template = 'output_'


    xlist = [0]*params.ncpu
    ylist = [0]*params.ncpu
    zlist = [0]*params.ncpu
    idlist = [0]*params.ncpu
    idplist = [0]*params.ncpu

    i = 0

    for cpu in range(params.ncpu):
        srcfile = srcdir+'/galaxies_'+srcdir[-5:]+'.txt'+str(cpu+1).zfill(5)

        data = np.atleast_2d(np.loadtxt(srcfile, usecols=[0,2,3,4], skiprows=1, dtype='float'))


        if data.shape[1] > 0 :

            idlist[i]  = data[:,0].astype('int')
            xlist[i]   = data[:,1]
            ylist[i]   = data[:,2]
            zlist[i]   = data[:,3]
            i+=1

    if i > 0:
        xg     = np.concatenate(xlist[:i])
        yg     = np.concatenate(ylist[:i])
        zg     = np.concatenate(zlist[:i])
        galid  = np.concatenate(idlist[:i])

    else:
        xg = None
        yg = None
        zg = None
        galid = None
        print("Didn't find any galaxy data in files. Setting params.galaxies = False and continuing")
        params.galaxies = False

    return xg, yg, zg, galid






#===================================
def read_mergertree_data():
#===================================
    """
    reads in mergertree data as written by the mergertree patch.
    NOTE: requires mergertree_XXXXX.txtYYYYY files.
    Reads in all available data, as snapshots past z=0 will also be necessary.
    The script will figure out where to start in make_tree(...)

    returns:
        progenitors, descendants:   lists of numpy arrays of progenitors and descendants,
                                    starting with the last output step.
        progenitor_outputnrs:       list of numpy arrays of the output number at which the
                                    progenitor is
        outputnrs :                 the output number at which descendants were taken from
        time :                      list of times correspondig to each output step

    """

    import warnings
    import gc

    noutput = params.noutput
    ncpu = params.ncpu
    use_t = params.use_t

    if params.verbose:
        print("Reading in mergertree data.")




    # create lists where to store stuff
    progenitors = [np.zeros(1, dtype='int') for i in range(noutput)]
    descendants = [np.zeros(1, dtype='int') for i in range(noutput)]
    progenitor_outputnrs = [np.zeros(1, dtype='int') for i in range(noutput)]
    outcount = 0

    startnr=params.lastdirnr
    outputnrs = np.array(range(startnr, startnr-noutput, -1))
    time = np.zeros(noutput)

    dir_template = 'output_'


    #---------------------------
    # Loop over directories
    #---------------------------

    for output in range(noutput):
        # loop through every output: Progenitor data only starts at output_00002,
        # but you'll need time/redshift data from output_00001!

        # Start with last directory (e.g. output_00060),
        # work your way to first directory (e.g. output_00001)
        dirnr =  str(startnr - output).zfill(5)
        srcdir = dir_template + dirnr

        if output < noutput-1: # don't try to read progenitor stuff from output_00001
            #------------------------------
            # Read in progenitor data
            #------------------------------

            # Stop early if you reach a directory that has no mergertree.txt* files
            # (Can happen if there are no halos in the simulation yet)
            try:

                fnames = ["".join([srcdir,'/mergertree_',srcdir[-5:],'.txt',str(cpu+1).zfill(5)]) for cpu in range(ncpu)]

                datalist = [np.zeros((1,3),dtype='int') for i in range(ncpu)]
                i = 0
                for f in fnames:
                    with warnings.catch_warnings():
                        warnings.filterwarnings('error') # treat warnings as errors so I can catch them
                        try:
                            datalist[i] = np.atleast_2d(np.loadtxt(f, dtype='int', skiprows=1, usecols=([0, 1, 2])))
                            i += 1
                        except Warning:
                            continue

                if i > 0:
                    fulldata = np.concatenate((datalist[:i]), axis=0)

                    descendants[output] = fulldata[:,0]
                    progenitors[output] = fulldata[:,1]
                    progenitor_outputnrs[output] = fulldata[:,2]

                    outcount = output # store the last output where you added stuff

            except IOError: # If file doesn't exist
                print("Didn't find any progenitor data in", srcdir)





        try:
            #-------------------------------------
            # get time, even for output_00001
            #-------------------------------------
            fileloc = srcdir+'/info_'+dirnr+'.txt'
            infofile = open(fileloc)
            for i in range(8):
                infofile.readline() # skip first 8 lines

            if not use_t:
                infofile.readline() # skip another line for redshift

            timeline = infofile.readline()
            timestring, equal, timeval = timeline.partition("=")
            timefloat = float(timeval)

            if not use_t:
                timefloat = 1.0/timefloat - 1

            time[output] = timefloat

        except IOError: # If file doesn't exist
            print("Didn't find any info data in ", srcdir)
            break


    # keep only entries that contain data
    if outcount > 1:
        descendants = descendants[:outcount+1]
        progenitors = progenitors[:outcount+1]
        progenitor_outputnrs = progenitor_outputnrs[:outcount+1]



    #----------------------------------------------------------
    # print warning if -t cmdline arg might've been forgotten
    #----------------------------------------------------------

    if (len(time)>1 and (not use_t)):
        if (time[0] == time[1]):
            print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
            print("WARNING: The first two elements in the physical y-axis data are the same.")
            print("WARNING: If you are trying to plot a non-cosmo simulation, you should use")
            print("WARNING: the -t flag. Otherwise, you'll always have z = 0 on the y-axis.")
            print("WARNING: use mergertreeplot.py -h or --help to print help message.")
            print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
            acceptable_answer = False
            while not acceptable_answer:
                answer = raw_input("Do you still want to continue? (y/n) ")
                if (answer == 'y'):
                    acceptable_answer = True
                    break
                elif (answer == 'n'):
                    print("Exiting.")
                    quit()
                else:
                    print("Please answer with 'y' or 'n'")

    # collect garbage
    gc.collect()

    return descendants, progenitors, progenitor_outputnrs, outputnrs, time






#========================================
def _read_particle_data(srcdir):
#========================================
    """
    Reads in the particle data from directory srcdir.
    NOTE: requires part_XXXXX.outYYYYY and unbinding_XXXXX.outYYYYY files

    Parameters:
        srcdir:     String of directory where to read data from

    returns:
        x,y,z:      numpy arrays of particle positions
        clumpid:    numpy arrays of particle clump IDs
    """

    from os import listdir
    import fortranfile as ff

    if params.verbose:
        print("Reading in particles of output", int(srcdir[-5:]))

    srcdirlist = listdir(srcdir)

    if 'unbinding_'+srcdir[-5:]+'.out00001' not in srcdirlist:
        print("Couldn't find unbinding_"+srcdir[-5:]+".out00001 in", srcdir)
        print("To plot particles, I require the unbinding output.")
        print("use mergertreeplot.py -h or --help to print help message.")
        quit()






    #-----------------------
    # First read headers
    #-----------------------
    nparts = np.zeros(params.ncpu, dtype='int')
    partfiles = [0]*params.ncpu

    for cpu in range(params.ncpu):
        srcfile = srcdir+'/part_'+srcdir[-5:]+'.out'+str(cpu+1).zfill(5)
        partfiles[cpu] = ff.FortranFile(srcfile)

        ncpu = partfiles[cpu].readInts()
        ndim = partfiles[cpu].readInts()
        nparts[cpu] = partfiles[cpu].readInts()
        localseed = partfiles[cpu].readInts()
        nstar_tot = partfiles[cpu].readInts()
        mstar_tot = partfiles[cpu].readReals('d')
        mstar_lost = partfiles[cpu].readReals('d')
        nsink = partfiles[cpu].readInts()

        del ncpu, ndim, localseed, nstar_tot, mstar_tot, mstar_lost, nsink



    #-------------------
    # Allocate arrays
    #-------------------
    nparttot = nparts.sum()
    x = np.zeros(nparttot, dtype='float')
    y = np.zeros(nparttot, dtype='float')
    z = np.zeros(nparttot, dtype='float')
    clumpid = np.zeros(nparttot, dtype='int')


    #----------------------
    # Read particle data
    #----------------------

    start_ind = np.zeros(params.ncpu, dtype='int')
    for cpu in range(params.ncpu-1):
        start_ind[cpu+1] = nparts[cpu] + start_ind[cpu]

    for cpu in range(params.ncpu):
        x[start_ind[cpu]:start_ind[cpu]+nparts[cpu]] = partfiles[cpu].readReals('d')
        y[start_ind[cpu]:start_ind[cpu]+nparts[cpu]] = partfiles[cpu].readReals('d')
        z[start_ind[cpu]:start_ind[cpu]+nparts[cpu]] = partfiles[cpu].readReals('d')

        unbfile = srcdir+'/unbinding_'+srcdir[-5:]+'.out'+str(cpu+1).zfill(5)
        unbffile = ff.FortranFile(unbfile)

        clumpid[start_ind[cpu]:start_ind[cpu]+nparts[cpu]] = unbffile.readInts()


    clumpid = np.absolute(clumpid)


    return x, y, z, clumpid






#====================================================================
def _save_fig(fig, strategy="default", plotparticles=False):
#====================================================================
    """
    Save figure as png.
    fig:            pyplot figure object
    strategy:       how to adapt plot
    plotparticles:  whether you're currently plotting particles or the tree

    returns:
        nothing

    """

    #--------------------------------------
    if params.tex and not plotparticles:
    #--------------------------------------
        from tikzplotlib import save as tikz_save

        # Create standalone tikz image
        fig_path = params.workdir+'/'+params.outputfilename + '.tex'
        if params.verbose:
            print("saving figure as "+fig_path)

        w,h = fig.get_size_inches()
        w /= 1.15
        h /= 1.25

        if strategy=="ultrathin":
            # use A4 size, with margins subtracted
            w =  24.7
            h =  16.0

        # things will appear smaller in LaTeX; use cm instead of inches and reduce size a bit
        tikz_save(fig_path, figure=fig, show_info=False,strict=True, figureheight=str(h)+'cm',figurewidth=str(w)+'cm')

        # tweak written file
        with file(fig_path, 'r') as original: data = original.read()
        data = _tweak_texfile(data, standalone=True)
        with file(fig_path, 'w') as modified: modified.write(data)



    #--------------------------------------
    else:
    #--------------------------------------

        if params.pdf and not plotparticles:
            fig_path = params.workdir+'/'+params.outputfilename + '.pdf'
            if params.verbose:
                print("saving figure as "+fig_path)
                print("This may take a while.")
            plt.savefig(fig_path, format='pdf')

        else:
            fig_path = params.workdir+'/'+params.outputfilename + '.png'
            if params.verbose:
                print("saving figure as "+fig_path)
                print("This may take a while.")

            dpi = 150
            plt.savefig(fig_path, format='png', facecolor=fig.get_facecolor(), transparent=False, dpi=dpi)



    print("Saved", fig_path)
    print(" ")
    plt.close()

    return






#======================================================
def _setup_plot_region(x, y, z, to_plot):
#======================================================
    """
    This function sets up the plot region:
        - find borders for the plot: min/max in x,y,z direction
        - account for periodic scenarios
    Parameters:
        x, y, z :   numpy arrays of x, y, z positions of particles
        to_plot :   numpy boolean array of particles which are to be plotted

    Returns:
        borders :   tuple of borders:  (xmin, xmax, ymin, ymax, zmin, zmax)
    """


    xmin = x[to_plot].min()
    xmax = x[to_plot].max()
    ymin = y[to_plot].min()
    ymax = y[to_plot].max()
    zmin = z[to_plot].min()
    zmax = z[to_plot].max()

    xc = (xmax + xmin)/2
    yc = (ymax + ymin)/2
    zc = (zmax + zmin)/2

    dx = xmax - xmin
    dy = ymax - ymin
    dz = zmax - zmin



    if not params.use_t:

        #----------------------------
        # Check for periodicity
        #----------------------------
        # if not use_t, then assume cosmo run
        # find out whether you need to shift for periodicity

        moved_x = False
        moved_y = False
        moved_z = False

        if dx > 0.5:
            if xc > 0.5:
                x[x>=0.5] -= 1
            else:
                x[x<0.5] += 1
            xmin = x[to_plot].min()
            xmax = x[to_plot].max()
            xc = (xmax + xmin)/2
            dx = xmax - xmin
            moved_x = True

        if dy > 0.5:
            if yc > 0.5:
                y[y>=0.5] -= 1
            else:
                y[y<0.5] += 1
            ymin = y[to_plot].min()
            ymax = y[to_plot].max()
            yc = (ymax + ymin)/2
            dy = ymax - ymin
            moved_y = True

        if dz > 0.5 or zmax>1 or zmin<0:
            if zc > 0.5:
                z[z>=0.5] -= 1
            else:
                z[z<0.5] += 1
            zmin = z[to_plot].min()
            zmax = z[to_plot].max()
            zc = (zmax + zmin)/2
            dz = zmax - zmin
            moved_z = True


        maxd=max(dx, dy, dz)
        xmax = xc + maxd
        xmin = xc - maxd
        ymax = yc + maxd
        ymin = yc - maxd
        zmax = zc + maxd
        zmin = zc - maxd

        if not moved_x:
            if xmax>1 or xmin<0:
                if xc > 0.5:
                    x[x>=0.5] -= 1
                else:
                    x[x<0.5] += 1
                xmin = x[to_plot].min()
                xmax = x[to_plot].max()
                xc = (xmax + xmin)/2
                dx = xmax - xmin

        if not moved_y:
            if ymax>1 or ymin<0:
                if yc > 0.5:
                    y[y>=0.5] -= 1
                else:
                    y[y<0.5] += 1
                ymin = y[to_plot].min()
                ymax = y[to_plot].max()
                yc = (ymax + ymin)/2
                dy = ymax - ymin

        if not moved_z:
            if zmax>1 or zmin<0:
                if zc > 0.5:
                    z[z>=0.5] -= 1
                else:
                    z[z<0.5] += 1
                zmin = z[to_plot].min()
                zmax = z[to_plot].max()
                zc = (zmax + zmin)/2
                dz = zmax - zmin



    # make all figures of equal size
    maxd=max(dx, dy, dz)
    xmax = xc + maxd
    xmin = xc - maxd
    ymax = yc + maxd
    ymin = yc - maxd
    zmax = zc + maxd
    zmin = zc - maxd


    borders = (xmin, xmax, ymin, ymax, zmin, zmax)


    return borders






#===========================
def _sort_branch(branch):
#===========================
    """
    Sort the list of branches of a given start of a branch (class Node object)
    by increasing time of appearance. The earliest branches (lowest on y-axis)
    come first. If there are multiple mergers at the same time, put the one
    with more own branches further out.

    Parameters:
        branch:     class Node objecs whose .branches list is to be sorted

    returns:
        nothing
    """

    import gc

    if len(branch.branches) > 0:
        branchlist = (bbranch for bbranch in branch.branches)
        branchlist = list(branchlist)

        nbranch = len(branchlist)
        times = np.zeros(nbranch, dtype='int')
        all_branches = np.zeros(nbranch, dtype='int')
        branchlist_id = np.zeros(nbranch, dtype='int')

        for b in range(nbranch):
            times[b]=branchlist[b].y
            all_branches[b]=branchlist[b].branches_tot
            branchlist_id[b] = branchlist[b].id

        # find out if you have multiple mergers at same timestep
        needs_branchcount_sorting = False

        for t in times:
            temp = times[times==t]
            if ((temp/temp).sum()>1):
                needs_branchcount_sorting = True
                break

        sort_ind = times.argsort()
        times = times[sort_ind]
        branchlist_id = branchlist_id[sort_ind]


        if needs_branchcount_sorting:

            all_branches = all_branches[sort_ind]

            # at least the next element is the same
            i = 0
            while i < (len(times)-1):
                if (times[i+1]==times[i]):
                    # if the following is the same:
                    startind = i
                    endind = i+1

                    # find where the same end
                    while times[startind]==times[endind]:
                        if endind == len(times)-1:
                            break
                        else:
                            endind += 1

                    # sort all of them by ascending branch_tot
                    branchtots = all_branches[startind:endind+1]
                    sort_ind2 = branchtots.argsort()

                    branchlist_id[startind:endind+1] = branchlist_id[startind+sort_ind2]


                    i = endind

                else:
                    i+=1


        # overwrite branches' branch list
        # work with branch.ids to not copy entire objects all the time
        for i in range(len(times)):
            for bbranch in branchlist:
                if bbranch.id == branchlist_id[i]:
                    branch.branches[i] = bbranch
                    break

        # release memory
        del branchlist
        del branchlist_id
        del times
        del all_branches
        del sort_ind

    gc.collect()


    return






#============================
def _sum_branches(node):
#============================
    """
    Recursively sum up the total number of branches of each "root"

    Arguments:
        node:   class Node object to check for

    returns:
        nothing
    """


    # each branch root has itself as a branch.
    # If the tree splits up somewhere else along the way,
    # then there must be > 1 branch in the root node.

    if len(node.branches) > 0:
        node.branches_level = 1
        for i, branch in enumerate(node.branches):
            # don't call yourself again
            _sum_branches(branch)
            node.branches_tot += branch.branches_tot
            node.branches_tot += 1 # count yourself too :)

            node.branches_level += branch.branches_level
            _sort_branch(branch)

    return






#==========================================================
def _tweak_particleplot(fig, time, borders, outnr):
#==========================================================
    """
    Tweak the particle plot. Set ticks, background color,
    label axes, get title, and some other cosmetics.

    Parameters:
        fig:            figure object
        time:           time or redshift of output
        borders:        tuple of plot edges: xmin, xmax, ..., zmin, zmax
        outnr:          current snapshot number

    Returns:
        nothing
    """

    if params.movie:
        ax1 = fig.axes[0]
    else:
        ax1, ax2, ax3 = fig.axes

    xmin, xmax, ymin, ymax, zmin, zmax = borders

    #-------------------
    # set axes limits
    #-------------------

    ax1.set_xlim([xmin, xmax])
    ax1.set_ylim([ymin, ymax])

    if not params.movie:
        ax2.set_xlim([ymin, ymax])
        ax2.set_ylim([zmin, zmax])

        ax3.set_xlim([xmin, xmax])
        ax3.set_ylim([zmin, zmax])



    #------------------------------------------
    # set tick params (especially digit size)
    #------------------------------------------

    if not params.movie:
        ax1.tick_params(axis='both', which='major', labelsize=8,top=5)
        ax2.tick_params(axis='both', which='major', labelsize=8,top=5)
        ax3.tick_params(axis='both', which='major', labelsize=8,top=5)
    # TODO: If you want to remove ticks for the particle plots in movie mode,
    # Do it here:
    #  else:
    #      ax1.get_yaxis().set_visible(False)
    #      ax1.get_xaxis().set_visible(False)




    #--------------
    # label axes
    #--------------

    if not params.movie:

        ax1.set_xlabel(r'x', labelpad=4, family='serif',size=12)
        ax1.set_ylabel(r'y', labelpad=4, family='serif',size=12)

        ax2.set_xlabel(r'y', labelpad=4, family='serif',size=12)
        ax2.set_ylabel(r'z', labelpad=4, family='serif',size=12)

        ax3.set_xlabel(r'x', labelpad=4, family='serif',size=12)
        ax3.set_ylabel(r'z', labelpad=4, family='serif',size=12)




        #--------------
        # Add title
        #--------------

        title = "Particles in tree at output "+str(int(params.outputfilename[-5:]))+"; "

        if params.use_t:
            title += "t = "
        else:
            title += "z = "


        # pad the time/redshift with zeros at the end
        num = str(round(time, 3))
        if (round(time, 3)) < 0:
            # one longer for negative z
            end = 7
        else:
            end = 6
        while len(num) < end:
            num += '0'
        title += num

        fig.suptitle(title, family='serif', size=18)



    else:
        plt.sca(ax1)
        plt.annotate(str(outnr),
            xy=(xmax*0.99, ymax*0.99),
            xycoords='data',
            fontsize=18,
            backgroundcolor='white',
            horizontalalignment='right',
            verticalalignment='top'
            )

        plt.tight_layout()







    #------------------------
    # Set background color
    #------------------------
    try:
        ax1.set_axis_bgcolor('aliceblue')
        if not params.movie:
            ax2.set_axis_bgcolor('aliceblue')
            ax3.set_axis_bgcolor('aliceblue')
    except AttributeError:
        ax1.set_facecolor('aliceblue')
        if not params.movie:
            ax2.set_facecolor('aliceblue')
            ax3.set_facecolor('aliceblue')

    return






#=========================================
def _tweak_texfile(data, standalone):
#=========================================
    """
    Tweaks the TeX-file around. Adds header, footer and
    other modifications.

    Parameters:
        data:       read-in tex file
        standalone: whether to add header and footer to create standalone tikz image
    Returns:
        data:   tweaked tex file to be saved
    """

    # Add header and footer to make standalone file
    header = "\\documentclass[crop,tikz]{standalone}%\n"
    header +="\\usepackage[utf8]{inputenc}\n"
    header +="\\usepackage{pgfplots}\n"
    header +="\\begin{document}\n"

    footer = "\n\\end{document}\n"


    # smaller tick labels
    before, keyword, after = data.partition('\\begin{tikzpicture}\n')
    add = "\n\\pgfplotsset{every tick label/.append style={font=\\tiny}}\n"
    # no arrows on axes
    add += "\\pgfplotsset{ every non boxed x axis/.append style={x axis line style=-},\n every non boxed y axis/.append style={y axis line style=-}}\n"
    add += "\\pgfplotsset{ compat=1.13 }\n\n"

    # find ytick labels in file
    yticks=""
    lines = after.split("\n")
    for l in lines:
        if "ytick=" in l:
            yticks = l
            break

    # no x labels (yes, do it twice), add explicit yticks for first axis
    after=after.replace("xtick pos=both,", "xmajorticks=false,\n"+yticks)
    after=after.replace("xtick pos=both,", "xmajorticks=false,\n")

    # remove axis on top, this screws with the grid
    after=after.replace("axis on top,\n", "")
    after=after.replace("axis on top,\n", "")


    # recombine the parts, add header and footer
    if standalone:
        data = header + before + keyword + add + after + footer
    else:
        data = before + keyword + add + after

    return data






#====================================================================
def _tweak_treeplot(fig, yaxis_int, yaxis_phys, borders):
#====================================================================
    """
    tweaks the plot. Removes x-axis ticks, labels y-axis
    ticks nicely, adds right y axis.

    parameters:
        fig:        pyplot figure object
        yaxis_int:  array for y axis ticks containing integers
                    (= output numbers)
        yaxis_phys: array for y axis ticks containing physical data
                    (= redshift or time)
        borders:    borders for x axis

    returns:
        nothing
    """

    #-----------------
    # preparation
    #-----------------

    if params.verbose:
        print("Tweaking the plot.")

    halo = params.halo
    snapnr = yaxis_int[params.start]


    if params.start > 0:
        # cut down y-axis length
        outputnr = yaxis_int[params.start-1:].tolist() + [yaxis_int[-1]-1] + [yaxis_int[-1]-2]
    else:
        # add additional yaxis points for padding
        outputnr = yaxis_int.tolist()
        outputnr = [yaxis_int[0]+1] + outputnr + [yaxis_int[-1]-1]

    noutput = len(outputnr)






    #------------------
    #  Prepare y axis
    #------------------

    # determine how many y axis ticks you want.
    # find step to go through loop

    nyticks_step = int(noutput/10+0.5)
    if nyticks_step < 2:
        nyticks_step = 2

    yticks = []
    yticks_labels = []      # explicitly add tick labels for tex files
    yticks_right = []
    yticks_right_labels=[]

    ind = 1
    while ind < noutput - 1:

        if ind % nyticks_step == 0:
            yticks.append(outputnr[ind])
            yticks_labels.append(str(outputnr[ind]))
            yticks_right.append(outputnr[ind])
            tstr = '{0:6.3f}'.format(yaxis_phys[ind-1+params.start])
            yticks_right_labels.append(tstr)

        ind += 1




    #---------------------
    # Set ticks and title
    #---------------------


    if params.tex and (not params.treeclumplabels):
        # assume only esthetical plot. Make no title, label only redshift.
        ax = fig.axes[0]
        ax.set_yticks(yticks_right)
        ax.set_ylim([outputnr[-1]-0.5, outputnr[0]+0.5])
        ax.set_yticklabels(yticks_right_labels)
        if params.use_t:
            ax.set_ylabel("t [code units]", size=14)
        else:
            ax.set_ylabel("redshift z", size=14)
        ax.tick_params(axis='both', labelsize=10)


    else:
        # left y ticks
        ax = fig.axes[0]
        ax.set_yticks(yticks)
        ax.set_ylim([outputnr[-1]-0.5, outputnr[0]+0.5])
        ax.set_ylabel('output number', size=14)
        ax.set_yticklabels(yticks_labels)
        ax.tick_params(axis='both', labelsize=10)

        # right y ticks
        ax2 = ax.twinx()
        ax2.set_yticks(yticks_right)
        ax2.set_ylim([outputnr[-1]-0.5, outputnr[0]+0.5])
        ax2.set_yticklabels(yticks_right_labels)
        if params.use_t:
            ax2.set_ylabel("t [code units]", size=14)
        else:
            ax2.set_ylabel("redshift z", size=14)
        ax2.tick_params(axis='both', labelsize=10)

        # title
        title = "Merger tree for halo "+str(halo)+" at output "+str(snapnr)
        ax.set_title(title, size=18, y=1.01)

        # add grid
        ax.grid(which='major')


    # x axis and ticks
    ax.set_xlim(borders)
    ax.set_xticks([]) # no x ticks




    #-------------------------
    # Other cosmetics
    #-------------------------



    # set figure size
    height = 0.3*len(yaxis_int)
    if height < 12:
        height = 12

    dx = borders[1]-borders[0] - 2
    width = 0.4*dx
    if width < 10:
        width = 10
    elif width > 100:
        width = 100

    fig.set_size_inches(width, height)


    # cleaner layout
    plt.tight_layout()

    return






#=================================
def _walk_tree(node, root):
#=================================
    """
    Walk the tree and count the branches. Add branches to the
    root/start of that branch.

    Arguments:
        node:   class Node object to check whether a new branch starts here
        root:   class Node object which is current source of the branch.

    returns:
        nothing
    """

    # mark node as walked.
    node.walked = True

    # First check out new branches to mark possible jumps
    # over multiple timesteps as "walked"

    for i, prog in enumerate(node.progs):
        # check only if node hasn't been walked over already:
        if not prog.walked:
            if not node.is_main_prog[i]:
                # a new branch starts!
                # this progenitor will be the root for the new branch.
                root.branches.append(prog)
                prog.start_of_branch = root

                _walk_tree(prog, prog)


    # then just resume where you left off
    for i, prog in enumerate(node.progs):
        if not prog.walked:
            if node.is_main_prog[i]:
                prog.start_of_branch = root
                _walk_tree(prog, root)

    return






#===============================================================================
#===============================================================================
#===============================================================================
#===============================================================================
#===============================================================================



#==============
def main():
#==============
    """
    Main function. Calls all the rest.
    """

    #-----------------------
    # Set up
    #-----------------------
    global params

    params.read_cmdlineargs()
    params.get_output_info()

    if params.verbose:
        params.print_params()



    #---------------------------
    # read in data, make tree
    #---------------------------
    if params.from_backup:
        tree, outputnrs, t = read_backup()

    else:
        descendants, progenitors, progenitor_outputnrs, outputnrs, t = read_mergertree_data()
        tree = make_tree(progenitors, descendants, progenitor_outputnrs, outputnrs, t)
        del progenitors
        del descendants
        del progenitor_outputnrs

        if params.write_backup:
            dump_backup(tree, outputnrs, t)


    #------------------
    # Plot the tree
    #------------------
    plot_tree(tree, outputnrs, t)

    #------------------------------
    # If needed: Plot particles
    #------------------------------
    if params.plotparticles:
        plot_treeparticles(tree, t)


    return






#===============================
if __name__ == "__main__":
#===============================

    main()

