#!/usr/bin/env python -u import time import inspect import sys import argparse import yaml from msmbuilder.Project import Project, Serializer import os import numpy as np from pprint import PrettyPrinter from euclid import metrics from euclid.drift import drift from warnings import warn CURDIR = os.path.abspath(os.curdir) try: import mpi4py.MPI from deap import dtm PARALLEL_ENV = True if mpi4py.MPI.COMM_WORLD.size > 1 else False get_rank = lambda: dtm.getWorkerId() if 'workerId' in dtm._dtmObj.__dict__ else 0 except ImportError: PARALLEL_ENV = False get_rank = lambda: 0 def verbose_wrap(f): def g(*args): sys.stdout.write('# ') sys.stdout.flush() return f(*args) return g def mymap(f, *args): if PARALLEL_ENV: return dtm.map(f, *args) #print len(args) #print args[0] return map(verbose_wrap(f), *args) pp = PrettyPrinter(indent=2) def main(): parser = argparse.ArgumentParser(description=""" Compute the drift in a number of structure distance metrics (for a variety of lagtimes) accross an msmbuilder project. This code is written in parallel using mpi4py and the dtm task manager. It should scale to a large number of cores and/or nodes pretty well. """) parser.add_argument('input_file', type=file, help='Input file in yaml format') parser.add_argument('output_dir', type=str, help='Directory to save output in') #parser.add_argument('-q', '--quiet', action='store_false', dest='verbose', # help='Print less output') parser.add_argument('-n', '--dry-run', action='store_true', help="parse input file but don't actually perform calc.") try: args = parser.parse_args() if not args.output_dir.startswith('~'): args.output_dir = os.path.join(CURDIR, args.output_dir) except IOError as e: parser.error(e) try: entries, projectinfo, tau = parse_yaml(args.input_file, args.output_dir) except yaml.error.YAMLError as e: parser.error('Input file %s was not problely formatted yaml.\n%s' % (args.input_file.name, e)) except (IOError, KeyError, AssertionError) as e: raise # check for filename conflicts in output skip_indices = [] for i, entry in enumerate(entries): if os.path.exists(entry['output_path']): print 'A metric file by the name "%s" already exists.' % entry['output_path'] print 'skipping computation.' skip_indices.append(i) entries = [e for i, e in enumerate(entries) if i not in skip_indices] if len(skip_indices) > 0: print '\n%d metric file(s) found and thus will not be computed. Remaining computations are:' % len(skip_indices) pp.pprint(entries) # break out early for dry run if args.dry_run: return if not os.path.exists(args.output_dir): print 'Making directory %s' % args.output_dir os.mkdir(args.output_dir) if PARALLEL_ENV: run_entries(entries, projectinfo, tau) else: run_entries(entries, projectinfo, tau) def parse_yaml(input_file, output_dir): def project_constructor(loader, node): path = loader.construct_scalar(node) project_rootdir = os.path.dirname(path) if project_rootdir.startswith('~') or project_rootdir.startswith("$HOME"): project_rootdir = os.path.expanduser(project_rootdir) path = os.path.expanduser(path) os.chdir(project_rootdir) projectinfo = Project.LoadFromHDF(path) return project_rootdir, projectinfo def array_constructor(loader, node): value = loader.construct_scalar(node) path, dtype = map(unicode.strip, value.split(',')) return np.loadtxt(path, dtype=dtype) def metric_loader(loader, node): value = loader.construct_scalar(node) attr = getattr(metrics, value) return attr yaml.add_constructor('!Project', project_constructor) yaml.add_constructor('!array', array_constructor) yaml.add_constructor('!metric', metric_loader) document_stream = yaml.load_all(input_file) documents = [e for e in document_stream] if len(documents) != 2: raise RuntimeError('2 docs required in yaml') project_rootdir, projectinfo = documents[0]['projectinfo'] os.path.expanduser(projectinfo['TrajFilePath']) if not projectinfo['TrajFilePath'].startswith('/'): projectinfo['TrajFilePath'] = os.path.join(project_rootdir, projectinfo['TrajFilePath']) tau = documents[0]['tau'] entries = documents[1] for entry in entries: if not 'type' in entry: raise Exception('must have type') if not 'output_fn' in entry: raise Exception('must have output_fn') if not 'init_kwargs' in entry: entry['init_kwargs'] = {} entry['type'] = entry['type'](**entry['init_kwargs']) entry['output_path'] = os.path.join(output_dir, entry['output_fn']) pp.pprint(entries) return entries, projectinfo, tau def run_entries(entries, projectinfo, tau): """Run each of the calculations""" n = len(entries) #mymap = dtm.map if PARALLEL_ENV else map mymap(_run_entry, entries, [(projectinfo)] * n, [(tau)] * n) def _run_entry(entry, projectinfo, tau): 'Run a single entry, in parallel over each of the trajectories' n = projectinfo['NumTrajs'] trajectories = [projectinfo.LoadTraj(i) for i in range(n)] print 'Dispatching entry' pp.pprint(entry) out = mymap(drift, trajectories, [(tau)] * n, [(entry['type'])] * n) rows = np.sum(projectinfo['TrajLengths']) - n data = -1 * np.ones((len(tau), rows)) for i in xrange(len(tau)): h_pointer = 0 for j in xrange(n): to = projectinfo['TrajLengths'][j] - tau[i] #print 'i,j : %d, %d' % (i, j) #print h_pointer #print to #print data.shape #print data[i, h_pointer:to].shape #print out[j][i, 0:to].shape #print 'to', to #print np.count_nonzero(out[j][i, 0:to] == -1) data[i, h_pointer : h_pointer + to] = out[j][i, 0 : to] h_pointer += to #print 'row i final hpointer', h_pointer # save to disk print('Saving %s' % entry['output_fn']) entry['Data'] = data entry['Taus'] = tau entry['type'] = str(entry['type']) entry['init_kwargs'] = str(entry['init_kwargs']) print entry Serializer.Serializer(entry).SaveToHDF(entry['output_path']) if __name__ == '__main__': if PARALLEL_ENV: dtm.start(main) else: start = time.time() main() end = time.time() print 'Total time: %s s' % (end - start)