/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */

package iterativeclustering;


import biocomp.moltools.discretization.TransitionMatrixSampling;
import biocomp.moltools.util.DoubleArrays;
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.colt.matrix.linalg.EigenvalueDecomposition;
import dataWrapper.DWListOfScalars;
import dataWrapper.DWListOfVectors;
import dataWrapper.DWMatrixOfScalars;
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.StringTokenizer;


/**
 *
 * @author jan-hendrikprinz
 */
public class TempTransPropSingleFile
{
    public static void main( String[] args ) throws FileNotFoundException, IOException
    {

        String dataPath = "/Users/jan-hendrikprinz/Studium/Projekte/dynamical-reweighting/datasets/alanine-dipeptide/shoot10ps302K/";
        String transitionMatrixFile = "transitionCountMatrix.dat";

        int numberOfStates = 6;

        boolean sampleTMatrices = true;
        boolean sampleRMatrices = false;

        Algebra myAlgebra = new Algebra();
        DoubleFactory2D myMatrixFactory = DoubleFactory2D.dense;

        String nextLine = "";
        StringTokenizer tok = null;

        FileReader inputFileTrajTemp = new FileReader( dataPath + transitionMatrixFile );
        BufferedReader inputStreamTemp = new BufferedReader( inputFileTrajTemp );

        double temperatureTransitionMatrix[][] = new double[numberOfStates][numberOfStates];

        double[] outputArray = new double[1 + numberOfStates * 2];

        for (int ii = 0; ii < numberOfStates; ii++) {
            double sum = 0.0;

            nextLine = inputStreamTemp.readLine();
            tok = new StringTokenizer( nextLine, "\t" );
            for (int jj = 0; jj < numberOfStates; jj++) {
                temperatureTransitionMatrix[ii][jj] = Double.valueOf( tok.nextToken() );
                sum += temperatureTransitionMatrix[ii][jj];
            }
        }

        DoubleArrays.print( temperatureTransitionMatrix );

        System.out.println( "open TM Sampling" );
        TransitionMatrixSampling TMSampling = new TransitionMatrixSampling( temperatureTransitionMatrix, true, null );

        int numberOfIterations = 10000;

        double[][] stationaryDistributionList = new double[numberOfStates][numberOfIterations];
        double[][] spectrumList = new double[numberOfStates][numberOfIterations];
        double[][] rateMatrixList = new double[numberOfIterations][];
        double[][] tMatrixList = new double[numberOfIterations][];


        int rMNum = 0;
        int tMNum = 0;
        int numberOfMCSteps = 200;

        System.out.println( "Start Iterating ..." );
        //for (int ii = 0; ii < numberOfIterations; ii++) {
        while (((rMNum < numberOfIterations) && (sampleRMatrices)) || ((tMNum < numberOfIterations) && (sampleTMatrices))) {
//                TMSampling.saveState();
            for (int i = 1; i < numberOfMCSteps; i++) {
                TMSampling.nextSample();
            }

            double[][] matrix = TMSampling.nextSample();
            //            System.out.println("Sample" + ii);
            matrix = TransitionMatrixSampling.normalize( matrix );
            //            DoubleArrays.print(matrix);
            DoubleMatrix2D tMatrix = new DenseDoubleMatrix2D( matrix );
            EigenvalueDecomposition evDecomp = new EigenvalueDecomposition( myAlgebra.transpose( tMatrix ) );

            DoubleMatrix1D eigenValues = evDecomp.getRealEigenvalues();
            DoubleMatrix1D eigenValuesSorted = eigenValues.viewSorted();
            DoubleMatrix2D eigenVectors = myAlgebra.transpose( evDecomp.getV() );
            DoubleMatrix1D eigenValuesLog = eigenValues.copy();

            double[][] tmEigenVectors = eigenVectors.toArray();
            double[] tmEigenValues = eigenValues.toArray();
            double[] stationaryDistribution = new double[numberOfStates];


            double largestEV = 0.0;
            boolean isPositive = true;
            for (int jj = 0; jj < numberOfStates; jj++) {
                if (tmEigenValues[jj] > largestEV) {
                    stationaryDistribution = tmEigenVectors[jj];
                    largestEV = tmEigenValues[jj];
                }
                spectrumList[jj][rMNum] = eigenValuesSorted.get( jj );

                if (spectrumList[jj][rMNum] < 0.0) {
                    isPositive = false;
                }
            }

            double sum = 0.0;

            for (int jj = 0; jj < numberOfStates; jj++) {
                sum += stationaryDistribution[jj];
            }

            for (int jj = 0; jj < numberOfStates; jj++) {
                stationaryDistribution[jj] /= sum;
                stationaryDistributionList[jj][rMNum] = stationaryDistribution[jj];
            }

            if (sampleTMatrices) {
                tMatrixList[tMNum] = DoubleArrays.flatten( tMatrix.toArray() );
                tMNum++;

            }

            if ((isPositive) && (sampleRMatrices)) {
                for (int i = 0; i < eigenValues.size(); i++) {
                    eigenValuesLog.set( i, Math.log( eigenValues.get( i ) ) );
                }

                DoubleMatrix2D eigenVectorsInverse = myAlgebra.inverse( eigenVectors );
                DoubleMatrix2D eigenvalueMatrix = myMatrixFactory.diagonal( eigenValuesLog );
                DoubleMatrix2D temp1 = myAlgebra.mult( eigenVectorsInverse, eigenvalueMatrix );
                DoubleMatrix2D temp2 = myAlgebra.mult( temp1, eigenVectors );

                boolean isRateMatrix = true;

                for (int i = 0; i < numberOfStates; i++) {
                    for (int j = 0; j < numberOfStates; j++) {
                        if (i != j) {
                            if (temp2.get( i, j ) < 0.0) {
                                isRateMatrix = false;
                            }
                        }
                    }
                }

                if (isRateMatrix) {
                    rateMatrixList[rMNum] = DoubleArrays.flatten( temp2.toArray() );
                    rMNum++;
                    System.out.println( "Accept : " + rMNum );
                //TMSampling.saveState();
                }
                else {
                    if (rMNum > 0) {
                        /*
                        TMSampling.restoreState();
                        rateMatrixList[rMNum] = DoubleArrays.copy(rateMatrixList[rMNum - 1]);
                        rMNum++;
                        System.out.println("Reject : " + rMNum);
                         * */
                    }
                }
            }
        }

        for (int kk = 0; kk < numberOfStates; kk++) {
            outputArray[0] = 0;
            outputArray[1 + kk * 2] = DoubleArrays.mean( stationaryDistributionList[kk] );
            outputArray[2 + kk * 2] = DoubleArrays.variance( stationaryDistributionList[kk] );
        }

        
        new DWListOfScalars( outputArray ).writeDataContainerAsTSV( dataPath + "referenceDist" + ".tsv" );
        new DWMatrixOfScalars( spectrumList ).writeDataContainerAsTSV( dataPath + "spectraList" + ".tsv" );
        if (sampleRMatrices) {
            new DWMatrixOfScalars( DoubleArrays.subarray( rateMatrixList, 0, rMNum ) ).writeDataContainerAsTSV( dataPath + "rateMatrixList" + ".tsv" );
        }
        if (sampleTMatrices) {
            new DWMatrixOfScalars( DoubleArrays.subarray( tMatrixList, 0, tMNum ) ).writeDataContainerAsTSV( dataPath + "tMatrixList" + ".tsv" );
        }

//            SimplePlotWindow sPW = new SimplePlotWindow(new SimplePlotListGraph(new DWListOfScalars(spectraList[0])));            



    }
}
