#!/usr/bin/env python -u

import time
import inspect
import sys
import argparse
import yaml
from msmbuilder.Project import Project, Serializer
import os
import numpy as np
from pprint import PrettyPrinter
from euclid import metrics
from euclid.drift import drift
from warnings import warn

CURDIR = os.path.abspath(os.curdir)

try:
    import mpi4py.MPI
    from deap import dtm
    PARALLEL_ENV = True if mpi4py.MPI.COMM_WORLD.size > 1 else False
    get_rank = lambda: dtm.getWorkerId() if 'workerId' in dtm._dtmObj.__dict__ else 0
except ImportError:
    PARALLEL_ENV = False
    get_rank = lambda: 0

def verbose_wrap(f):
    def g(*args):
        sys.stdout.write('# ')
        sys.stdout.flush()
        return f(*args)
    return g
def mymap(f, *args):
    if PARALLEL_ENV:
        return dtm.map(f, *args)
    #print len(args)
    #print args[0]
    return map(verbose_wrap(f), *args)

pp = PrettyPrinter(indent=2)  

def main():
    parser = argparse.ArgumentParser(description="""
    Compute the drift in a number of structure distance metrics
    (for a variety of lagtimes) accross an msmbuilder project.
    
    This code is written in parallel using mpi4py and the dtm
    task manager. It should scale to a large number of cores
    and/or nodes pretty well.
    """)
    parser.add_argument('input_file', type=file,
                        help='Input file in yaml format')
    parser.add_argument('output_dir', type=str, help='Directory to save output in')
    #parser.add_argument('-q', '--quiet', action='store_false', dest='verbose',
    #                    help='Print less output')
    parser.add_argument('-n', '--dry-run', action='store_true',
                        help="parse input file but don't actually perform calc.")
                        
    try:
        args = parser.parse_args()
        if not args.output_dir.startswith('~'):
            args.output_dir = os.path.join(CURDIR, args.output_dir)
    except IOError as e:
       parser.error(e)
       
    try:
        entries, projectinfo, tau = parse_yaml(args.input_file, args.output_dir)
    except yaml.error.YAMLError as e:
        parser.error('Input file %s was not problely formatted yaml.\n%s' % (args.input_file.name, e))
    except (IOError, KeyError, AssertionError) as e:
        raise
        
    # check for filename conflicts in output
    skip_indices = []
    for i, entry in enumerate(entries):
        if os.path.exists(entry['output_path']):
            print 'A metric file by the name "%s" already exists.' % entry['output_path']
            print 'skipping computation.'
            skip_indices.append(i)
    entries = [e for i, e in enumerate(entries) if i not in skip_indices]
    if len(skip_indices) > 0:
        print '\n%d metric file(s) found and thus will not be computed. Remaining computations are:' % len(skip_indices)
        pp.pprint(entries)
        
    # break out early for dry run
    if args.dry_run:
        return
        
    if not os.path.exists(args.output_dir):
        print 'Making directory %s' % args.output_dir
        os.mkdir(args.output_dir)
    
    if PARALLEL_ENV:
        run_entries(entries, projectinfo, tau)
    else:
        run_entries(entries, projectinfo, tau)


def parse_yaml(input_file, output_dir):
    def project_constructor(loader, node):
        path = loader.construct_scalar(node)
        project_rootdir = os.path.dirname(path)
        if project_rootdir.startswith('~') or project_rootdir.startswith("$HOME"):
            project_rootdir = os.path.expanduser(project_rootdir)
            path = os.path.expanduser(path)
        os.chdir(project_rootdir)
        projectinfo = Project.LoadFromHDF(path)
        return project_rootdir, projectinfo
        
    def array_constructor(loader, node):
        value = loader.construct_scalar(node)
        path, dtype = map(unicode.strip, value.split(','))
        return np.loadtxt(path, dtype=dtype)
        
    def metric_loader(loader, node):
        value = loader.construct_scalar(node)
        attr = getattr(metrics, value)
        return attr
    
    
    yaml.add_constructor('!Project', project_constructor)
    yaml.add_constructor('!array', array_constructor)
    yaml.add_constructor('!metric', metric_loader)
    
    
    document_stream = yaml.load_all(input_file)
    documents = [e for e in document_stream]
    if len(documents) != 2:
        raise RuntimeError('2 docs required in yaml')
        
    project_rootdir, projectinfo = documents[0]['projectinfo']
    os.path.expanduser(projectinfo['TrajFilePath'])
    if not projectinfo['TrajFilePath'].startswith('/'):
        projectinfo['TrajFilePath'] = os.path.join(project_rootdir, projectinfo['TrajFilePath'])
    
    tau = documents[0]['tau']
    entries = documents[1]
    for entry in entries:
        if not 'type' in entry:
            raise Exception('must have type')
        if not 'output_fn' in entry:
            raise Exception('must have output_fn')
        if not 'init_kwargs' in entry:
            entry['init_kwargs'] = {}
        entry['type'] = entry['type'](**entry['init_kwargs'])
        
        entry['output_path'] = os.path.join(output_dir, entry['output_fn'])
        
        
    pp.pprint(entries)
    
    return entries, projectinfo, tau
    
def run_entries(entries, projectinfo, tau):
    """Run each of the calculations"""

    n = len(entries)
    #mymap = dtm.map if PARALLEL_ENV else map
    mymap(_run_entry, entries, [(projectinfo)] * n, [(tau)] * n)


def _run_entry(entry, projectinfo, tau):
    'Run a single entry, in parallel over each of the trajectories'
    
    n = projectinfo['NumTrajs']
    trajectories = [projectinfo.LoadTraj(i) for i in range(n)]
    print 'Dispatching entry'
    pp.pprint(entry)
    out = mymap(drift, trajectories, [(tau)] * n,
                                     [(entry['type'])] * n)
    
    rows = np.sum(projectinfo['TrajLengths']) - n
    data = -1 * np.ones((len(tau), rows))
    for i in xrange(len(tau)):
        h_pointer = 0
        for j in xrange(n):
            to = projectinfo['TrajLengths'][j] - tau[i]
            #print 'i,j : %d, %d' % (i, j)
            #print h_pointer
            #print to
            #print data.shape
            #print data[i, h_pointer:to].shape
            #print out[j][i, 0:to].shape
            #print 'to', to
            #print np.count_nonzero(out[j][i, 0:to] == -1)
            data[i, h_pointer : h_pointer + to] = out[j][i, 0 : to]
            h_pointer += to
        #print 'row i final hpointer', h_pointer
    # save to disk
    print('Saving %s' % entry['output_fn'])
    entry['Data'] = data
    entry['Taus'] = tau
    entry['type'] = str(entry['type'])
    entry['init_kwargs'] = str(entry['init_kwargs'])
    print entry
    Serializer.Serializer(entry).SaveToHDF(entry['output_path'])
    
if __name__ == '__main__':
    if PARALLEL_ENV:
        dtm.start(main)
    else:
        start = time.time()
        main()
        end = time.time()
        print 'Total time: %s s' % (end - start)