/**
 * Time-series analysis helper functions for Matlab.
 *
 * @author John D. Chodera.
 */

import java.lang.Exception;

public class timeseries {

  /**
   * Compute transition counts for discrete trajectory.
   *
   * @param A_t   observable trajectory
   * @param max_tau   maximum lag time for accumulating statistics
   * @return the time-correlation function where C_t[tau] is the estimate for lag time tau
   */
  public static double [] compute_correlation_function(double [] A_t, int max_tau) {
    // Determine timeseries length.
    int T = A_t.length;

    // Allocate storage for time-correlation function.
    double [] C_t = new double[max_tau];
    for (int tau = 1; tau <= max_tau; tau++) 
      C_t[tau-1] = 0.0;
    
    // Compute time-correlation function.
    for (int tau = 1; tau <= max_tau; tau++) {
      for(int t0 = 0; t0 < T-tau; t0++)
        C_t[tau-1] += A_t[t0] * A_t[t0+tau];
      C_t[tau-1] /= (double)(T-tau);
    }
    
    return C_t;
  }

  /**
   */
  public static double [] diffusion(double [] x_t, double xbin, double bin_width, int max_tau) {
    // Determine timeseries length.
    int T = x_t.length;

    // Allocate storage for time-correlation function.
    double [] C_t = new double[max_tau];
    for (int tau = 1; tau <= max_tau; tau++) 
      C_t[tau-1] = 0.0;
    
    // Denominator.
    double [] D_t = new double[max_tau];
    for (int tau = 1; tau <= max_tau; tau++) 
      D_t[tau-1] = 0.0;

    // Compute time-correlation function.
    for(int t0 = 0; t0 < T; t0++) {
      if ((xbin-bin_width/2.0 <= x_t[t0]) && (x_t[t0] < xbin+bin_width/2.0)) {
        for (int tau = 1; (tau <= max_tau) && (t0+tau < T); tau++) {
          C_t[tau-1] += (x_t[t0+tau]-xbin)*(x_t[t0+tau]-xbin);
          D_t[tau-1] += 1.0;
        }
      }
    }

    for (int tau = 1; tau <= max_tau; tau++) 
      C_t[tau-1] /= D_t[tau-1];
    
    return C_t;
  }

  /**
   */
  public static double [] diffusion_pande(double [] x_t, double xbin, double bin_width, int max_tau, int tskip) {
    // Determine timeseries length.
    int T = x_t.length;

    // Compute offset.
    double mean = 0.0;
    for (int t = 0; t < T; t++)
      mean += x_t[t];
    double offset = mean;
    //double offset = x_t[0];
    //    for (int t = 0; t < T; t++)
    //      if (x_t[t] < offset)
    //        offset = x_t[t];
    //double offset = 0.0;

    // Allocate storage for time-correlation function.
    double [] C_t = new double[max_tau];
    for (int tau = 1; tau <= max_tau; tau++) 
      C_t[tau-1] = 0.0;
    
    // Denominator.
    double [] D_t = new double[max_tau];
    for (int tau = 1; tau <= max_tau; tau++) 
      D_t[tau-1] = 0.0;

    // Compute time-correlation function.
    for(int t0 = 0; t0 < T; t0 += tskip) {
      if ((xbin-bin_width/2.0 <= x_t[t0]) && (x_t[t0] < xbin+bin_width/2.0)) {
        for (int tau = 1; (tau <= max_tau) && (t0+tau < T); tau++) {
          C_t[tau-1] += (x_t[t0]-offset)*(x_t[t0+tau]-offset);
          D_t[tau-1] += 1.0;
        }
      }
    }
    
    for (int tau = 1; tau <= max_tau; tau++) 
      C_t[tau-1] /= D_t[tau-1];

    return C_t;
  }

  /**
   */
  public static double [] observed_splitting(double [] x_t, double [] bin_edges, double x_A, double x_B) throws Exception {
    // Determine timeseries length.
    int T = x_t.length;

    // Truncate trajectory length so all events end in commitment.
    while ( (x_A < x_t[T-1]) && (x_t[T-1] < x_B) && (T > 0) )
      T--;
    if (T == 0) 
      throw new Exception("No commitment events.");
    
    // Determine number of bins.
    int nbins = bin_edges.length - 1;

    // Accumulate commitment event statistics.
    double [] NA_i = new double[nbins];
    double [] NB_i = new double[nbins];
    int t = 0; // current marker
    int tcommit = 0;
    while (t < T) {
      if ( (x_t[t] <= x_A) || (x_B <= x_t[t]) ) {
        // Reposition current pointer.
        t++;
        tcommit = t;
      } else if ( (x_A < x_t[tcommit]) && (x_t[tcommit] < x_B) ) {
        // Reposition commit pointer.
        tcommit++;
      } else {
        // Determine current bin.
        int i = 0;
        while ((i < nbins) && ((x_t[t] < bin_edges[i]) || (x_t[t] >= bin_edges[i+1])))
          i++;
        // don't accumulate statistics if not within any bin
        if (i >= nbins)
          continue;

        // Determine commitment direction.        
        if (x_t[tcommit] < x_A)
          NA_i[i] += 1.0;
        else
          NB_i[i] += 1.0;                
        
        // Advance current marker.
        t++;
      }
    } 

    // Compute committor to A.
    double [] pA_i = new double[nbins];
    for (int i = 0; i < nbins; i++) {
      if (bin_edges[i+1] <= x_A)
        pA_i[i] = 1.0;
      else if (bin_edges[i] >= x_B)
        pA_i[i] = 0.0;
      else if (NA_i[i] + NB_i[i] > 0.0)
        pA_i[i] = NA_i[i] / (NA_i[i] + NB_i[i]);
    }

    return pA_i;
  }
    
    
}
