#!/usr/bin/python

from ChangePointTools import * 
from cpDetect import *
from mpmath import log
from numpy import fromfile
from time import time
import cPickle
import sys


# 0. Choose some log odds cutoffs for splitting and lumping
log_odds_cutoff= log( 10, 10 )
#trajFileName = "data.txt" 
#trajFileName = "../simulation/traj.txt" 
trajFileName = sys.argv[1]
cpTxtFile = "changepoints.txt"
cpFileName = "changepoints.dat"
seglevelsFileName = "segment_level_trajectory.txt"
statesFileName= "states.dat"
statelevelsFileName = "state_partitions.txt"
statesTrajectory = "state_level_trajectory.txt"
stateLevelsFile = "state_levels.txt"

# get the trajectory
#print "loading trajectory"
datapoints = fromfile( trajFileName, dtype="int", sep="\n" ) 
print "loaded %d points" % len( datapoints )

"""
datapoints = []
for i in range(3): datapoints.append( 3 )
for i in range(3): datapoints.append( 10 )
datapoints = array(datapoints)
#print datapoints.sum()
#print datapoints.mean()
weights=[]
weights.append(  PoissonPointCalc( datapoints, 1 )[2] )
weights.append(  PoissonPointCalc( datapoints, 2 )[2] )
weights.append(  PoissonPointCalc( datapoints, 3 )[2] )
weights.append(  PoissonPointCalc( datapoints, 4 )[2] )
weights.append(  PoissonPointCalc( datapoints, 5 )[2] )
print weights
weights=array(weights)
print weights.mean()
print log( weights.mean(), 10 )
print
print "enough messing around"
"""

# 1. Locate all the change points in the trajectory.
#cpd = ChangePointDetector( datapoints,findPoissonChangePoint, log_odds_cutoff )
cpd = ChangePointDetector( datapoints,findGaussianChangePoint, log_odds_cutoff )

print "splitting"
cpd.split_init(verbose=False)
cpd.sort()
cpd.showall()
print "found %d change points" % cpd.nchangepoints()

"""
FILE=open(cpTxtFile, "w" )
for cp in cpd.changepoints: FILE.write( "%d\n" % cp )
FILE.close()
print "saving '%s'" % cpFileName
FILE=open( cpFileName, "w" )
p = cPickle.Pickler( FILE )
p.dump( cpd )
FILE.close()

# 2. Re-write the trajectory according to segment level rather than data level.
print "writing segment levels to '%s'" % seglevelsFileName
FILE=open( seglevelsFileName, "w" )
lastcp=0
for cp in cpd.changepoints :
        segment = cpd.data[ lastcp : cp ]
        N = len(segment)
        C = 1.0*segment.sum()
	lam = C/N
	for i in range(N): FILE.write( "%f\n" % lam )
        lastcp = cp

# get the last segment
segment = cpd.data[ lastcp : ]
N = len(segment)
C = 1.0*segment.sum()
lam = C/N
for i in range(N): FILE.write( "%d\n" % lam )
FILE.close()

# 3.  Lump the detected segments into states
# with similar emission statistics.
# build a trajectory of segments rather than of points
trajectory = loadSegments(cpFileName) # is this necessary?
print "sorting segments by MLE of average emission levels"
trajectory.lambdasort()

print "splitting segments into states of like emission"
stateFinder = StateDetector( trajectory, findPoissonState, log_odds_cutoff ) 

print "splitting"
stateFinder.split_init(verbose=False)
stateFinder.sort()
print "found %d partition values" % stateFinder.npartitions()

print "saving states to '%s'" % statesFileName 
FILE=open( statesFileName, "w" )
p = cPickle.Pickler( FILE )
p.dump( stateFinder )
FILE.close()

if stateFinder.npartitions() == 0: sys.exit()

# 4. Re-write the trajectory in terms of emission levels of states rather than segments.
# partitions look like "[2.42, 3.99, 5.11, 9.02]"; some data is below the smallest 
# partition and above the largest.  

# go through and assign the segments, tallying the N's and C's into the states. This
# then allows calculation of the average emission for each state (Ctot/Ntot)
statetrajectory=[]
pooledsegments=[]
for i in range( len( stateFinder.statePartitions ) +1 ) :
	pooledsegments.append( Segment() )

trajectory.timesort()
for segment in trajectory :
	i=0
	currpartition=stateFinder.statePartitions[i]
	while segment.lam >= stateFinder.statePartitions[i] :
		i = i+1
		try: currpartition=stateFinder.statePartitions[i]
		except IndexError: break 
	try: cutoff = stateFinder.statePartitions[i]
	except IndexError: cutoff = 100000
	print "segment with lam=", segment.lam, "added to state", i+1, "cutoff", cutoff
	statetrajectory.extend( [i] * segment.N )
	pooledsegments[i] += segment

FILE=open( statesTrajectory, "w" )
for state in statetrajectory : 
	FILE.write( "%f\n" % pooledsegments[state].lam )
FILE.close()

FILE2=open( stateLevelsFile, "w" )
print "state levels:"
for state in pooledsegments :
	print state.lam
	FILE2.write( "%f\n" % state.lam )
FILE2.close()

print; print "state partitions:"
for partition in stateFinder.statePartitions :
	print partition
"""
