#!/usr/bin/env python
import argparse
from glob import glob
from msmbuilder.Serializer import Serializer
import os
import sys
from pprint import PrettyPrinter
import numpy as np
from euclid.leastsq import nnls, ols, ols_R, nnls_allcomb
from euclid.hist3d import hist3d
import matplotlib.pyplot as pp

pp = PrettyPrinter(indent=2)


def main():
    parser = argparse.ArgumentParser(description='''Run non-negative least
squares regresssion''')
    parser.add_argument('-d', '--directory', default='Observables',
                      help='Directory containing files with the drift under different metrics. Each file should end in the suffix ".h5"')
    parser.add_argument('-n', '--normalize', action='store_true', help='Normalize each metric by its standard deviation to put them on comparable unit scales')
    parser.add_argument('-v', '--visualize')
    parser.add_argument('-s', '--subsample', default=1, help='Subsample before running regression. Default None. Supply an integer > 1 if you want to.', type=int)
    args = parser.parse_args()
    if not os.path.isdir(args.directory):
        parser.error('Could not open directory %s' % args.directory)
    
    metric_fns = glob('%s/*.h5' % args.directory)
    num_metrics = len(metric_fns)
    if num_metrics == 0:
        parser.error('No metrics found')
    print 'Found %d metrics files: %s' % (num_metrics, str([os.path.split(fn)[1] for fn in metric_fns]))
    
    # make a list of each of the serializer objects
    metrics = [None for i in range(num_metrics)] 
    
    # compute the length of each row in the Data field of each
    # of the serializer objects
    for i, metric_fn in enumerate(metric_fns):
        s = Serializer.LoadFromHDF(metric_fn)
        n, m = s['Data'].shape
        s['row_length'] = np.zeros(n, dtype='int')
        assert len(s['Taus']) == n, 'The shapes of Tau and Data dont match'
        for j, tau in enumerate(s['Taus']):
            # find the true length of ith row in Data, since it might be right-padded with -1s
            # we should really save the row lengths in the h5 file instead
            minus_ones = np.nonzero(s['Data'][j] == -1)[0]
            if len(minus_ones) == 0:
                row_length = s['Data'].shape[1]
            else:
                row_length = minus_ones[0]
            s['row_length'][j] = row_length
            #print metric_fn, tau, row_length
        metrics[i] = s
    
    # make sure that all the metrics represent the same data
    for metric in metrics:
        assert metric['Data'].shape == metrics[0]['Data'].shape, 'Shapes need to be the same'
        assert np.all(metric['Taus'] == metrics[0]['Taus']), 'Taus need to be the same'
        assert np.all(metric['row_length'] == metrics[0]['row_length']), 'Row lengths need to be the same'
    
    # total number of data points
    num_pts = np.sum(metrics[0]['row_length'])
    
    # reshape all of this data into 2 arrays, output and predictor
    
    # This is tricky because each data point is in num_metric dimensions
    # and its value in each metric is in a different file. So you want to
    # be sure not to mangle the stuff and combine things that shouldn't be
    # combined
    # BE CAREFUL
    output = np.empty(num_pts)
    v_position = 0
    for i, tau in enumerate(metrics[0]['Taus']):
        row_length = metrics[0]['row_length'][i]
        output[v_position:v_position + row_length] = tau
        v_position += row_length
        
    # we want to predict the square root of the time seperation
    output = np.sqrt(output)
    
    # BE REALLY CAREFUL WITH THIS CODE
    predictors = np.empty((num_pts, num_metrics))
    for j, metric in enumerate(metrics):
        v_position = 0 # pointer that moves down vertically through the
        # rows in output
        for i, tau in enumerate(metric['Taus']):
            row_length = metric['row_length'][i]
            predictors[v_position:v_position + row_length, j] = metric['Data'][i, 0:row_length]
            v_position += row_length #increment pointer
        #print 'vp', v_position
    
    # uncomment these to test: if predictor and output have been
    # constructed correctly they should be the same
    #print output[-1]
    #print predictors[-1, :]
    #last_tau = len(metrics[0]['Taus']) - 1
    #for i in range(num_metrics):
    #    print metrics[i]['Data'][last_tau, metrics[i]['row_length'][last_tau] - 1]
    #    print '   ', metrics[i]['Taus'][last_tau]
    
    # Scale down each column in the predictor by its std
    if args.normalize:
        print 'Rescaled each metric by its standard deviation'
        std =  np.std(predictors, axis=0)
        for col in range(num_metrics):
            predictors[:, col] /= std[col]
        means = np.mean(predictors, axis=0)
    else:
        print 'No normalization applied'
    
    if args.subsample != 1:
        print 'Subsampling at freq: %s' % args.subsample
        ind = np.random.permutation(num_pts)[::args.subsample]
        ind.sort()
        num_pts = len(ind)
        predictors = np.copy(predictors[ind])
        output = np.copy(output[ind])
        
    
    print 'Regressing:'
    print 'Number of points: %d' % num_pts
    print 'Number of attributes: %d' % num_metrics
    
    for dtype in ['float32', 'float64']:
        print 'using type:', dtype
        typed_predictors = np.array(predictors, dtype=dtype)
        typed_output = np.array(output, dtype=dtype)
        print 'non-negative least-squares:'
        nnls_result = nnls_allcomb(typed_predictors, typed_output)
        pp.pprint(nnls_result)
        print 'ordinary least-squares:'
        ols_result = ols(typed_predictors, typed_output)
        pp.pprint(ols_result)
        print ''
    
if __name__ == '__main__':
  main()