options( warn = 1 )
library(KernSmooth)
library(RColorBrewer)
library(fBasics)


fileNames = c(paste('simulationsText','glucoseInfectedSimulation.txt', sep= .Platform$file.sep),
				paste('simulationsText','succinateInfectedSimulation.txt', sep= .Platform$file.sep), 
				paste('simulationsText','acetateInfectedSimulation.txt', sep= .Platform$file.sep),
				paste('simulationsText','tryptoneInfectedSimulation.txt', sep= .Platform$file.sep) )
condLabels <- c('G','S','A','T')
classFile <- paste('simulationsText','metClass_tryptoneInfectedSimulation.txt', sep= .Platform$file.sep) 

classPlotList <- c( 'Metabolism','Membrane Lipid Metabolism','Nucleotide Salvage Pathways','Purine and Pyrimidine Biosynthesis', 
	'Cell envelope biosynthesis','Amino Acid' , 'Cofactor and Prosthetic Group Biosynthesis','viralReactions') # , 
nClassPlot <- length(classPlotList)
			
zeroThresh <- 1e-12
		
			
# load in the differnt data sets
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -  
nSim <- length( fileNames )
hostList <- list()
for( i in 1:nSim){
	# read in data
	dataTableTemp <- read.table(fileNames[i], sep='\t', header = TRUE)
	hostTableTemp <- subset(dataTableTemp, hostFlag == 1) 
	# save just the fluxes (not concentrations, simulation variables)
	nameOfCols <- colnames(hostTableTemp)
	notFluxes <- c( grep("^CONC", nameOfCols, value=TRUE), 't', 'biomass', 'hostFlag','status','growth','aBIOMASS')
	justFluxes <- nameOfCols[!nameOfCols %in% notFluxes ]
	hostTableTemp <- hostTableTemp[justFluxes]
	hostList[i] <- list( hostTableTemp )
}

# find non-zero fluxes for each
nonzeroFluxes <- list()
for( i in 1:nSim){

	fluxMaxAbs <- apply(abs(hostList[[i]]), 2, max)
	logicalFluxesNonZero <- fluxMaxAbs > zeroThresh
	nonzeroFluxes[i] <- list(names(hostList[[i]][ ,logicalFluxesNonZero ]))
}
# get unique
fluxesAnyPairNonzero <- unique( unlist( nonzeroFluxes ))

classTable <- read.table(classFile, sep='\t', header = TRUE)
classTable <- as.matrix(classTable)
classTable <- rbind( classTable, colnames(classTable))
rownames(classTable) <- c('metClass','metSubClass','fluxName')
classTable <- classTable[, fluxesAnyPairNonzero]


# 
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 


mapsAll <- list()

dSame <- 1
dDiff <- 0
postscript(paste('figureOutput','figureSupCorrelations_8o7cm_width.eps', sep= .Platform$file.sep), width = 3.25, height = 3.25, 
horizontal = FALSE, onefile = FALSE, paper = "special", pointsize = 8)

colSame <- 'darkolivegreen3'
colTryp <- 'orangered1'
colMid <- 'cornflowerblue'
colDiag <- 'black'

colHighlight <- c(colSame,colMid,colTryp, colDiag)
layoutCol <- rbind( c(4,2,2,3), c(2,4,1,3), c(2,1,4,3),c(3,3,3,4) )

# layoutMat <- rbind( c(1,2,3), c(0,4,5), c(0,0,6) )
layoutMat <- rbind( c(1,2,3,4), c(5,6,7,8), c(9,10,11,12),c(13,14,15,16) )
par(ps=8,oma = c(3.5,3.5,0,0))
layout( layoutMat )



for( iRow in 1:(nSim)){
	for( jCol in (1):nSim ){
		 
		# for( iRow in 1:(nSim-1)){
		# 	for( jCol in (iRow+1):nSim ){
		
		fluxesAnyNonzero <- unique( c( nonzeroFluxes[[iRow]], nonzeroFluxes[[jCol]] ))
		nFlux <- length(fluxesAnyNonzero)
		
		x <- 0
		y <- 0
		for( iFlux in 1:nFlux ){
			
			media1 <- hostList[[iRow]][ , fluxesAnyNonzero[iFlux] ]
			media2 <- hostList[[jCol]][ , fluxesAnyNonzero[iFlux] ]
			
			if( (sd(media1) < zeroThresh) | (sd(media2) < zeroThresh) ){
				if( (sd(media1) < zeroThresh) & (sd(media2) < zeroThresh) ){
					# assign sameness if both nonvary
					xThis <- dSame	
				}
				else{
					
					if( iRow == jCol){ print('flux one nonzero tripped') }
					xThis <- dDiff
					
				}
			}
			else{
				xThis <- ( cor(media1,media2))
			}
			
			# get the y coords (dist to other fluxes in class)
			thisClass <- classTable[ 'metClass', fluxesAnyNonzero[iFlux] ]
			otherFluxesThisClass <- classTable[ 'metClass',  ] == thisClass
			otherFluxesThisClass[fluxesAnyNonzero[iFlux]] <- FALSE # don't double count this flux
			
			dataFluxes <- hostList[[iRow]][ , otherFluxesThisClass ]
			# dataFluxes <- hostList[[jCol]][ , otherFluxesThisClass ]

			zeroThresh <- 1e-12
			fluxStDevs <- apply(dataFluxes, 2, sd)
			logicalFluxesThatVary <- fluxStDevs > zeroThresh
			dataForDist <- as.matrix(dataFluxes[ ,logicalFluxesThatVary ])
			nConst <- (dim(dataFluxes)[2] - dim(dataForDist)[2] )
			
			if( !(sd(media1) < zeroThresh) ){
				
				
				yThis <- as.vector( ( cor(dataForDist, media1)) )

				# yThis <- as.vector( (1 - cor(dataForDist, media2)) )
				if( nConst > 0){ yThis <- c(yThis, rep(1, nConst))}
					
			}
			else{
					
				# yThis <- 0			
				# assign difness for fluxes that also do not vary
					yThis <- rep(dDiff,dim(dataForDist)[2])
					# assign sameness for fluxes that also do not vary
					if( nConst > 0){ yThis <- c(yThis, rep(dSame, nConst))}
				
			}
			
			nThis <- length( yThis )
			y <- c(y, yThis)
			x <- c(x, rep(xThis,nThis))
			

			
			# # add seccond media distances to the plot
			# if( !(sd(media2) < zeroThresh) ){
			# 	
			# 	# dataFluxes <- hostList[[iRow]][ , otherFluxesThisClass ]
			# 	dataFluxes <- hostList[[jCol]][ , otherFluxesThisClass ]
			# 
			# 	zeroThresh <- 1e-12
			# 	fluxStDevs <- apply(dataFluxes, 2, sd)
			# 	logicalFluxesThatVary <- fluxStDevs > zeroThresh
			# 	dataForDist <- as.matrix(dataFluxes[ ,logicalFluxesThatVary ])
			# 	# yThis <- as.vector( (1- cor(dataForDist, media1)) )
			# 
			# 	yThis <- as.vector( (1 - cor(dataForDist, media2)) )
			# 	nConst <- (dim(dataFluxes)[2] - dim(dataForDist)[2] )
			# 	if( nConst > 0){ yThis <- c(yThis, rep(1, nConst))}
			# 		
			# }
			# else{
			# 	yThis <- 0 # change this
			# }
			# 
			# nThis <- length( yThis )
			# y <- c(y, yThis)
			# x <- c(x, rep(xThis,nThis))

		}
		
		# x <- x[2:length(x)]
		# y <- y[2:length(y)]
		
		bw <- 0.25
		nBin <- 80
		distScale <- c(-1.5,1.5)
		distRange <- max(distScale)-min(distScale)
		ptSpace <- distRange/nBin
		ptRange <- c(  (min(distScale)+ptSpace/2), (max(distScale)-ptSpace/2) )
		
		# store the data so that an overall density scale can be used.
		nameThisMap <- paste(condLabels[iRow],'vs',condLabels[iRow],'and',condLabels[jCol], sep='')
		thisMap <- bkde2D( cbind(x,y), bandwidth = c(bw,bw), gridsize = c(nBin, nBin), range = list( x=ptRange,y=ptRange) )
		mapsAll[[ nameThisMap ]] <- thisMap$fhat
		# # # plot density by color
		# colors <- densCols(x,y)
		# plot(x,y, col=colors, pch=20, cex = 0.2, xlab = '', ylab = '')
		# # # # hdr.boxplot.2d(x,y)
		# title(sub = paste(condLabels[iRow],'and',condLabels[jCol], sep=" "), 
		# 		 xlab = 'distance flux distance between media', ylab = 'flux distance within class on media 1', cex = 0.5)
	}
}

# 
# 
nScale <- 8
colScaleAlt <- brewer.pal((nScale),"Greys")
# colScaleAlt <- colScaleAlt[2:length(colScaleAlt)]

allDensity <- unlist( mapsAll )
minAllDensity <- min(allDensity)
maxAllDensity <- max(allDensity)
b <- 1-(1-nScale)*minAllDensity/(minAllDensity-maxAllDensity)
m <- (1-nScale)/(minAllDensity-maxAllDensity)
	
sRect <- distRange/nBin	
yRectBot <- min(distScale) + (0:(nBin-1))*sRect + sRect/2
# xRectLeft <- min(distScale) + (0:(nBin-1))*sRect
# yRectTop <- min(distScale) + (1:(nBin))*sRect
# xRectRight <- min(distScale) + (1:(nBin))*sRect
	
pts <- gridVector(yRectBot,yRectBot)
	
	
# for( iRow in 1:(nSim-1)){
# 	for( jCol in (iRow+1):nSim ){
for( iRow in 1:(nSim)){
	for( jCol in (1):nSim ){	
		
		nameThisMap <- paste(condLabels[iRow],'vs',condLabels[iRow],'and',condLabels[jCol], sep='')
		namePlot <- paste(condLabels[iRow],' vs. ',condLabels[iRow],'-',condLabels[jCol], sep='')
		# 
		allDensity <- unlist( mapsAll[[ nameThisMap ]] )
		minAllDensity <- min(allDensity)
		maxAllDensity <- max(allDensity)
		b <- 1-(1-nScale)*minAllDensity/(minAllDensity-maxAllDensity)
		m <- (1-nScale)/(minAllDensity-maxAllDensity)
		
		thisMap <- round(mapsAll[[ nameThisMap ]]*m+b)
		# print(thisMap)
		par( mar= (c(0, 0, 1, 1)+0.1), xpd = NA, cex=1)
		plot(x=NULL, y=NULL, xlim = (distScale), ylim = (distScale),xlab = '', ylab = '', axes= FALSE)
		points(pts[[1]], pts[[2]], col = colScaleAlt[thisMap], pch = '.', cex = 1  )
		contour(x = yRectBot,y = yRectBot,z=thisMap, add = TRUE, drawlabels = FALSE, lwd = 0.25, nlevels = 8)
		
		rect(xleft = -1.5, ybottom = -1.5, xright = 1.5, ytop = 1.5, border = colHighlight[layoutCol[iRow,jCol]], col = NA, lwd = 2)
		# axis(1,at=c(-1,0,1),label=c('inversely\ncorrelated','uncorrelated','correlated'))
		par(ps=8, cex = 1)
		text(namePlot,x=1.5,y=1.55, adj=c(1,0), cex = 1)
		# title(sub = paste(condLabels[iRow],'and',condLabels[jCol], sep=" "), 
		# 		 xlab = 'distance flux distance between media', ylab = 'flux distance within class on media 1')
		par(ps=8, cex = 1)
		
		if((iRow == 4) && (jCol == 1)){
			lines(x = c(-1,-1,-1.15), y = c(-1.5,-2,-2))
			lines(x = c(-1.5,-1.75,-1.75), y = c(-1,-1,-1.65))
			text('Inversely\nCorrelated',x=-1.25,y=-1.75, adj=c(1,1), cex = 1)
			lines(x = c(0,0), y = c(-1.5,-2.2))
			text('Uncorrelated',x=0,y=-2.4, adj=c(0.25,1), cex = 1)
			lines(x = c(1,1), y = c(-1.5,-1.65))
			text('Correlated',x=1,y=-1.75, adj=c(0,1), cex = 1)
			
			lines(x = c(-1.5,-2.2), y = c(0,0))
			text('Uncorrelated',x=-2.4,y=-0, adj=c(0.25,0), cex = 1, srt = '90')
			lines(x = c(-1.5,-1.65), y = c(1,1))
			text('Correlated',x=-1.75,y=1, adj=c(0,0), cex = 1, srt = '90')	
		}
		
		if((iRow == 4) && (jCol == 4)){
			par(ps=8, cex = 1)
			text('Similarity of Flux Timecourse Between Media',x=1.5,y=-3.5, adj=c(1,1), cex = 1)
		}
		
		if((iRow == 1) && (jCol == 1)){
			par(ps=8, cex = 1)
			text('Similarity of Flux Timecourse Within Metabolic Class',x=-3.6,y=1.75, adj=c(1,0), cex = 1,srt = '90')
		}
		
	}
}

dev.off()