import cPickle
import numpy
import sys
from mpmath import log, gamma, mpf, pi # arbitrary float precision!
from scipy import array, zeros

BASE=10
twoOverPi = 2.0 / pi
mpf1 = mpf(1.0)

def load_testdata( filename = "trajectory.dat" ):
	FILE = open( filename )
	u = cPickle.Unpickler( FILE )
	data = u.load()
	FILE.close()
	return data

def PoissonPointCalc( data, point ):
	# test whether point is a change point
	C = float(data.sum())
	N = mpf(len(data))
	if N<2: return -999.999

	data1=data[0:point]
	data2=data[point:]
	C1 = float( data1.sum() )
	N1 = mpf( len(data1) )
	C2 = float( data2.sum() )
	N2 = mpf( len(data2) )
	if N1<1 or N2<1: return -9999.9999

	denominator = pi * gamma(C) * N1**(C1+1) * N2**(C2+1) * ( (C1/N1)**2 + (C2/N2)**2 ) 
	numerator = 2.0 * gamma(C1+1) * gamma(C2+1) * N**C
	return log(numerator,BASE) - log(denominator,BASE)

def findPoissonChangePoint( data, log_odds_cutoff=log(10, BASE) ):
	# data is a list of counts in each time period, uniformly spaced
	C = float(data.sum())
	N = mpf(len(data))

	# can't split data less than 2 points!
	if N < 2: return None
	
	# the denominator (including both P(D|H1) and constant parts of P(D|H2) )
	denominator = gamma(C) * pi / ( 2 * N**C )

	# the numerator (trickier)
	# this needs to be averaged over the possible change points 
	weights = zeros(N,dtype=object)
	CA = 0
	CB = C
	for i in range(1,N) :
		# points up through i are in data set A; the rest are in B
		datapoint = data[i-1]	
		NA = mpf(i)   ; CA += datapoint
		NB = mpf(N-i) ; CB -= datapoint
	
		try:	fraction_num = gamma(CA+1.0) * gamma(CB+1.0)
		except ValueError: 
			print CA, CB
			raise
		fraction_den = NA**(CA+1) * NB**(CB+1) * ( (CA/NA)**2 + (CB/NB)**2 )
		weights[i-1] = mpf(fraction_num)/fraction_den

	numerator = weights.mean()
	#except ZeroDivisionError: return None

	if numerator==0:
		print "weights:", weights
		print "data:", data
		print "n points:", len(data)
		print "numerator = 0 (impossible result)"
		sys.exit()

	lognum= log( numerator, BASE )
	logden= log( denominator, BASE )
	logodds = lognum - logden
	print "num:",numerator, "log num:", lognum, "| denom:", denominator, "log denom:", logden, "|| log odds:", logodds 

	# if log odds exceeds the log odds cutoff, there is a cp; if not, there isn't
	if logodds < log_odds_cutoff : return None
	maxwtarg = weights.argmax()
	# the first point can't be the change point (same as "no change")
	if maxwtarg==0: return None
	return ( maxwtarg, logodds ) 

def findGaussianChangePoint( data, gammatable ):
	print "Warning! Old version of findGaussianChangePoint"
	N = len( data )
	if N<6 : return None # can't find a cp in data this small

	# the denominator. This is the easy part.
	denom = (pi**1.5) * mpf(( N*data.var() ))**( -N/2.0 + 0.5 ) * gammatable[N]

	# BEGIN weight calculation
	# the numerator. A little trickier.
	weights=[0,0,0] # the change cannot have occurred in the last 3 points
	data2=data**2

	#initialize
	dataA=data[0:3] ; dataA2=data2[0:3] ; NA = len(dataA)
	dataB=data[3:] ; dataB2=data2[3:] ;  NB = len(dataB)
	sumA=dataA.sum() ; sumsqA=dataA2.sum()
	sumB=dataB.sum()  ; sumsqB=dataB2.sum()

	# first data point--this could be done in the loop but it's okay here
	meanA=sumA/NA ; meansumsqA = sumsqA/NA ; meanA2 = meanA**2 ; sA2=meansumsqA-meanA2
	meanB=sumB/NB ; meansumsqB = sumsqB/NB ; meanB2 = meanB**2 ; sB2=meansumsqB-meanB2

	wnumf1 = mpf(NA)**(-0.5*NA + 0.5 ) * mpf(sA2)**(-0.5*NA + 1) * gammatable[NA]
	wnumf2 = mpf(NB)**(-0.5*NB + 0.5 ) * mpf(sB2)**(-0.5*NB + 1) * gammatable[NB]
	wdenom = (sA2 + sB2) * (meanA2*meanB2)
	weights.append( (wnumf1*wnumf2)/wdenom ) 

	for i in range( 3, N-3 ):
		NA += 1	; NB -= 1
		next = data[i]
		sumA += next	; sumB -= next
		nextsq = data2[i]
		sumsqA += nextsq; sumsqB -= nextsq
		meanA=sumA/NA ; meansumsqA = sumsqA/NA ; meanA2 = meanA**2 ; sA2=meansumsqA-meanA2
		meanB=sumB/NB ; meansumsqB = sumsqB/NB ; meanB2 = meanB**2 ; sB2=meansumsqB-meanB2
		wnumf1 = mpf(NA)**(-0.5*NA + 0.5 ) * mpf(sA2)**(-0.5*NA + 1) * gammatable[NA]
		wnumf2 = mpf(NB)**(-0.5*NB + 0.5 ) * mpf(sB2)**(-0.5*NB + 1) * gammatable[NB]
		wdenom = (sA2 + sB2) * (meanA2*meanB2)
		weights.append( (wnumf1*wnumf2)/wdenom) 
	weights.extend( [0,0] ) # the change cannot have occurred at the last 2 points
	weights=array(weights)
	# END weight calculation

	num = 2.0**2.5 * abs(data.mean()) * weights.mean()
	lognum=log(num, BASE)
	logden=log(den, BASE)
	logodds = lognum - logden
	print "num:", num, "log num:", lognum, "| denom:", denom, "log denom:", logden, "|| log odds:", logodds 
	
	# If there is a change point, then logodds will be greater than 0
	if logodds < 0 : return None
	return ( weights.argmax(), logodds ) 

class ChangePointDetector:
	def __init__( self, data, function, log_odds_cutoff=log(10, BASE) ):
		self.data = data
		self.datalen = len( self.data )
		self.function = function
		self.log_odds_cutoff = log_odds_cutoff
		self.changepoints = []
		self.logodds = {}
		self.niter = 0
		self.maxiter = 1000000 # just in case

	def nchangepoints( self ):
		return len( self.changepoints )

	def split_init( self, verbose=False ):
		self.split( 0, self.datalen, verbose )

	def split( self, start, end, verbose=False ):
		if self.niter > self.maxiter :
			print "Change point detection error: number of iterations exceeded"
			print "If this is the right result, you may need to increase"
			print "ChangePointDetector.maxiter (currently %d)" % self.maxiter
			return
		self.niter += 1
		if verbose:
			print "\nIteration %d" % self.niter
			print "Trying to split the segment:", self.data[start:end], "(data from %d to %d)" % ( start, end)
			print self.data[start:end]

		# try to find a change point in the data segment 
		try:
			result = self.function( self.data[ start: end ], self.log_odds_cutoff )
		except TypeError: 
			print "trying to test data from %d to %d failed" % ( start,end )
			#print self.data
			raise

		# otherwise, store the cp and call self.split on the two ends
		if result is not None :
			try: # fails if only one value is returned
				logodds = result[1]
				self.logodds[ start+result[0] ] = logodds
				result = start+result[0]
			except TypeError: # must mean it's one number?
				result += start

			if verbose: print "!! change point detected at %d !!" % result
			self.changepoints.append( result )
			self.split( start, result, verbose )
			self.split( result+1, end, verbose )

	def sort( self ): self.changepoints.sort()

	# display the change points
	def show( self ): print self.changepoints

	# show the change points along with the log odds
	def showall( self ):
		for i in range( len( self.changepoints ) ) :
			changepoint = self.changepoints[i]
			try: 
				logodds = self.logodds[ changepoint ]
			except KeyError:
				logodds = None
			print "%d (%f)" % ( changepoint, logodds )

	def largest_logodds( self ):
		return array( self.logodds.values() ).max()


# BEGIN stuff for lumping segments into states with similar emission levels

def findPoissonState( data, log_odds_cutoff=log(10, BASE) ):
	# data is a list of trajectory segments, in order of the maximum
	# likelihood estimate of the Poisson counts.

	# can't split data less than 2 points!
	nSegments  = len( data )
	if nSegments < 2: return None

	#print "\ndata:"
	#for segment in data : print segment 
	#print 

	oneStateSegments = Segment()
	for segment in data: oneStateSegments += segment
	N = mpf( oneStateSegments.N*1.0 )
	C = mpf( oneStateSegments.C*1.0 )

	# begin building the Bayes factor
	denominator = gamma(C)/mpf(N*1.0)**C
	
	# the numerator (trickier)
	# this needs to be averaged over the possible change points 
	weights = zeros(nSegments-1, dtype=object)
	secondState = oneStateSegments 
	firstState = Segment() # empty

	for i in range( 0,nSegments-1 ): 
		currentSegment = data[i]
		#print "transferring segment", currentSegment, "to first state"
		secondState -= currentSegment
		firstState += currentSegment

		try:	weights[i] = mpf1
		except IndexError :
			print weights
			print i
			raise	

		weights[i] *= firstState.ffactor() * secondState.ffactor()
		weights[i] /= firstState.lamsqr() + secondState.lamsqr()			

		#print "state1", firstState, "ffactor", firstState.ffactor(), "lamsqr", firstState.lamsqr()
		#print "state2", secondState, "ffactor", secondState.ffactor(), "lamsqr", secondState.lamsqr()
		#print weights[i]
		#print

	numerator = twoOverPi * weights.mean()

	if not numerator > 0:
		print "weights:", weights
		print "data:", data
		print "n points:", len(data)
		print "numerator is not > 0"
		return None

	lognum= log( numerator, BASE )
	logden= log( denominator, BASE )
	logodds = lognum - logden
	print "num:",numerator, "log num:", lognum, "| denom:", denominator, "log denom:", logden, "|| log odds:", logodds 

	if logodds < log_odds_cutoff : return None
	maxwtarg = weights.argmax() + 1
	if maxwtarg==0: 
		print "got maxwtarg 0"
		return None

	#print "corresponds to lambda index", maxwtarg
	#print data[maxwtarg]
	#sys.exit()

	return ( maxwtarg, logodds ) 

class StateDetector:
	def __init__( self, data, function, log_odds_cutoff=log(10, BASE) ):
		self.data = data
		self.datalen = len( self.data )
		self.function = function
		self.log_odds_cutoff=log_odds_cutoff

		self.statePartitions = []
		self.logodds = {}
		self.niter = 0
		self.maxiter = 1000000 # just in case

	def npartitions( self ):
		return len( self.statePartitions )

	def split_init( self, verbose=False ):
		self.split( 0, self.datalen, verbose )

	def split( self, start, end, verbose=False ):
		if self.niter > self.maxiter :
			print "number of iterations exceeded (%d)" % self.maxiter
			return None
		self.niter += 1
		if verbose:
			print "\nIteration %d" % self.niter
			print "Trying to detect a partition:", self.data[start:end], "(data from %d to %d)" % ( start, end)

		try:
			result = self.function( self.data[ start: end ], self.log_odds_cutoff )
		except TypeError: 
			print "trying to test data from %d to %d failed" % ( start,end )
			#print self.data
			raise

		if result is not None :
			"""
			try: # fails if only one value is returned
				logodds = result[1]
				self.logodds[ start+result[0] ] = logodds
				result = start+result[0]
			except TypeError: # must mean it's one number?
				result += start
			"""
			resultIndex, logodds=result
			resultIndex += start
			result = self.data[resultIndex].lamsqr()**.5
			self.logodds[result]=logodds
			print "!! change of state detected at lambda = ", result, "!!"
			self.statePartitions.append( result )
			self.split( start, resultIndex, verbose )
			self.split( resultIndex+1, end, verbose )

	def sort( self ): self.statePartitions.sort()

	# display the change points
	def show( self ): print self.statePartitions

	# show the change points along with the log odds
	def showall( self ):
		s = "{"
		for i in range( len( self.statePartitions ) ) :
			partition = self.statePartitions[i]
			try: 
				logodds = self.logodds[ partition ]
			except KeyError:
				logodds = None
			print "partition at", partition, "log odds", logodds
			s += "%f," % partition
		print s[:-1] + "}"

	def largest_logodds( self ):
		return array( self.logodds.values() ).max()

def loadSegments( filename, nToReturn = None ):
	from cPickle import Unpickler
	file=open( filename )
	u=Unpickler(file)
	cpd=u.load()

	if not nToReturn: nToReturn = len(cpd.changepoints)+10 
	segments=[]
	lastcp=0
	idx = 0
	for cp in cpd.changepoints :
		segment = cpd.data[ lastcp : cp ]
		N = len(segment)
		C = segment.sum()
		segments.append( Segment( N, C, idx ) )
		idx += 1		
		if idx >= nToReturn : break
		lastcp = cp

	# get the last segment
	segment = cpd.data[ lastcp : ]
	N = len(segment)
	C = segment.sum()
	segments.append( Segment( N, C, idx) )
	return Trajectory( segments )

class Segment :
	def __init__( self, N=0, C=0, idx=-1 ):
		self.N = N
		self.C = C
		self.idx = idx
		if N>0: self.lam = float(C)/float(N)
		else: self.lam = None
		self.state = None

	def __add__( self, other ):
		N = self.N + other.N
		C = self.C + other.C
		return Segment( N, C )

	def __sub__( self, other ):
		N = self.N - other.N
		C = self.C - other.C
		if N < 0 : raise NegativeSegmentLength
		return Segment( N, C )

	def __str__( self ):
		if self.N > 0 : 
			self.lam=float(self.C)/float(self.N)
			str = "N = %d, C = %d, lambda = %3.3f" % ( self.N, self.C, self.lam )
		else : str = "N = 0, C = %d" % self.C
		if self.idx>=0: str = ( "%d: " % self.idx ) + str
		return str
	
	def __repr__( self ): return "trajectory segment %s" % self

	def lamsqr( self ): return self.lam**2
	def ffactor( self ): return gamma(self.C+1.0) * mpf(self.N)**(-self.C-1)

class Trajectory( list ) :
	# use base class init
	"""
	def __init__( self, seglist ):
		self.seglist = seglist 
	"""

	def lambdasort( self ):
		self.sort( self.compareSegments )

	def compareSegments( self, seg1, seg2 ):
		if seg1.lam > seg2.lam : return 1
		elif seg1.lam == seg2.lam: return 0
		else: return -1

	def timesort( self ):
		self.sort( self.compareIndices )

	def compareIndices( self, seg1, seg2 ):
		if seg1.idx > seg2.idx: return 1
		elif seg1.idx == seg2.idx: return 0
		else: return -1

class NegativeSegmentLength( Exception ): pass
