/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */

package iterativeclustering;


import biocomp.moltools.discretization.TransitionMatrixSampling;
import biocomp.moltools.util.DoubleArrays;
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.colt.matrix.linalg.EigenvalueDecomposition;
import dataWrapper.DWListOfVectors;
import dataWrapper.DWMatrixOfScalars;
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.StringTokenizer;


/**
 *
 * @author jan-hendrikprinz
 */
public class TempTransProp
{
    private static final int PRIOR_TYPE_FLAT = 0;
    private static final int PRIOR_TYPE_ONECOUNT = 1;
    private static final int PRIOR_TYPE_ZERO = 2;

    public static void main( String[] args ) throws FileNotFoundException, IOException
    {
        String transitionMatrixFile = "files/temperature_transition_matrices";

        int numberOfTemperatures = 40;
        int numberOfStates = 6;
        int numberOfSegments = 500;

        boolean sampleTMatrices = true;
        boolean sampleRMatrices = false;

        int priorType = PRIOR_TYPE_ONECOUNT;

        Algebra myAlgebra = new Algebra();
        DoubleFactory2D myMatrixFactory = DoubleFactory2D.dense;

        String nextLine = "";
        StringTokenizer tok = null;

        FileReader inputFileTrajTemp = new FileReader( "files/temperature_transition_matrices" );
        BufferedReader inputStreamTemp = new BufferedReader( inputFileTrajTemp );

        double temperatureTransitionMatrix[][][] = new double[numberOfTemperatures][numberOfStates][numberOfStates];

        double[][] outputArray = new double[numberOfTemperatures][1 + numberOfStates * 2];

        for (int tt = 0; tt < numberOfTemperatures; tt++) {
            System.out.println( "Temperature : " + tt );
            for (int ii = 0; ii < numberOfStates; ii++) {
                double sum = 0.0;

                nextLine = inputStreamTemp.readLine();
                tok = new StringTokenizer( nextLine, " " );

                tok.nextToken();
                for (int jj = 0; jj < numberOfStates; jj++) {
                    temperatureTransitionMatrix[tt][ii][jj] = Double.valueOf( tok.nextToken() ) * numberOfSegments;
                    if ((priorType == PRIOR_TYPE_ONECOUNT) || (priorType == PRIOR_TYPE_ZERO)) {
                        // Remove 1 count from all
                        temperatureTransitionMatrix[tt][ii][jj]--;
                    }
                    if (priorType == PRIOR_TYPE_ONECOUNT) {
                        // Add a 1 over m on all diagonal
                        temperatureTransitionMatrix[tt][ii][jj]+= 1.0 / (double)numberOfStates;
                    }

                    sum += temperatureTransitionMatrix[tt][ii][jj];
                }
            }
        }

        for (int tt = 0; tt < numberOfTemperatures; tt += 1) {
            System.out.println( "Temperature : " + tt );
            DoubleArrays.print( temperatureTransitionMatrix[tt] );

            System.out.println( "open TM Sampling" );
            TransitionMatrixSampling TMSampling = new TransitionMatrixSampling( temperatureTransitionMatrix[tt], true, null );

            int numberOfIterations = 10000;

            double[][] stationaryDistributionList = new double[numberOfStates][numberOfIterations];
            double[][] spectrumList = new double[numberOfStates][numberOfIterations];
            double[][] rateMatrixList = new double[numberOfIterations][];
            double[][] tMatrixList = new double[numberOfIterations][];


            int rMNum = 0;
            int tMNum = 0;
            int numberOfMCSteps = 200;

            System.out.println( "Start Iterating ..." );
            //for (int ii = 0; ii < numberOfIterations; ii++) {
            while (((rMNum < numberOfIterations) && (sampleRMatrices)) || ((tMNum < numberOfIterations) && (sampleTMatrices))) {
//                TMSampling.saveState();
                for (int i = 1; i < numberOfMCSteps; i++) {
                    TMSampling.nextSample();
                }

                double[][] matrix = TMSampling.nextSample();
                //            System.out.println("Sample" + ii);
                matrix = TransitionMatrixSampling.normalize( matrix );
                //            DoubleArrays.print(matrix);
                DoubleMatrix2D tMatrix = new DenseDoubleMatrix2D( matrix );
                EigenvalueDecomposition evDecomp = new EigenvalueDecomposition( myAlgebra.transpose( tMatrix ) );

                DoubleMatrix1D eigenValues = evDecomp.getRealEigenvalues();
                DoubleMatrix1D eigenValuesSorted = eigenValues.viewSorted();
                DoubleMatrix2D eigenVectors = myAlgebra.transpose( evDecomp.getV() );
                DoubleMatrix1D eigenValuesLog = eigenValues.copy();

                double[][] tmEigenVectors = eigenVectors.toArray();
                double[] tmEigenValues = eigenValues.toArray();
                double[] stationaryDistribution = new double[numberOfStates];


                double largestEV = 0.0;
                boolean isPositive = true;
                for (int jj = 0; jj < numberOfStates; jj++) {
                    if (tmEigenValues[jj] > largestEV) {
                        stationaryDistribution = tmEigenVectors[jj];
                        largestEV = tmEigenValues[jj];
                    }
                    spectrumList[jj][rMNum] = eigenValuesSorted.get( jj );

                    if (spectrumList[jj][rMNum] < 0.0) {
                        isPositive = false;
                    }
                }

                double sum = 0.0;

                for (int jj = 0; jj < numberOfStates; jj++) {
                    sum += stationaryDistribution[jj];
                }

                for (int jj = 0; jj < numberOfStates; jj++) {
                    stationaryDistribution[jj] /= sum;
                    stationaryDistributionList[jj][rMNum] = stationaryDistribution[jj];
                }

                if (sampleTMatrices) {
                    tMatrixList[tMNum] = DoubleArrays.flatten( tMatrix.toArray() );
                    tMNum++;

                }

                if ((isPositive) && (sampleRMatrices)) {
                    for (int i = 0; i < eigenValues.size(); i++) {
                        eigenValuesLog.set( i, Math.log( eigenValues.get( i ) ) );
                    }

                    DoubleMatrix2D eigenVectorsInverse = myAlgebra.inverse( eigenVectors );
                    DoubleMatrix2D eigenvalueMatrix = myMatrixFactory.diagonal( eigenValuesLog );
                    DoubleMatrix2D temp1 = myAlgebra.mult( eigenVectorsInverse, eigenvalueMatrix );
                    DoubleMatrix2D temp2 = myAlgebra.mult( temp1, eigenVectors );

                    boolean isRateMatrix = true;

                    for (int i = 0; i < numberOfStates; i++) {
                        for (int j = 0; j < numberOfStates; j++) {
                            if (i != j) {
                                if (temp2.get( i, j ) < 0.0) {
                                    isRateMatrix = false;
                                }
                            }
                        }
                    }

                    if (isRateMatrix) {
                        rateMatrixList[rMNum] = DoubleArrays.flatten( temp2.toArray() );
                        rMNum++;
                        System.out.println( "Accept : " + rMNum );
                        //TMSampling.saveState();
                    }
                    else {
                        if (rMNum > 0) {
                            /*
                            TMSampling.restoreState();
                            rateMatrixList[rMNum] = DoubleArrays.copy(rateMatrixList[rMNum - 1]);
                            rMNum++;
                            System.out.println("Reject : " + rMNum);
                             * */
                        }
                    }
                }
            }

            for (int kk = 0; kk < numberOfStates; kk++) {
                outputArray[tt][0] = tt;
                outputArray[tt][1 + kk * 2] = DoubleArrays.mean( stationaryDistributionList[kk] );
                outputArray[tt][2 + kk * 2] = DoubleArrays.variance( stationaryDistributionList[kk] );
            }



            new DWListOfVectors( outputArray ).writeDataContainerAsTSV( "files/statDistSP" + Integer.toString( tt ) + ".tsv" );
            new DWMatrixOfScalars( spectrumList ).writeDataContainerAsTSV( "files/spectraListSP" + Integer.toString( tt ) + ".tsv" );
            if (sampleRMatrices) {
                new DWMatrixOfScalars( DoubleArrays.subarray( rateMatrixList, 0, rMNum ) ).writeDataContainerAsTSV( "files/rateMatrixListSP" + Integer.toString( tt ) + ".tsv" );
            }
            if (sampleTMatrices) {
                new DWMatrixOfScalars( DoubleArrays.subarray( tMatrixList, 0, tMNum ) ).writeDataContainerAsTSV( "files/tMatrixListSP" + Integer.toString( tt ) + ".tsv" );
            }

//            SimplePlotWindow sPW = new SimplePlotWindow(new SimplePlotListGraph(new DWListOfScalars(spectraList[0])));            


        }



    }
}
