// Java code for WCA dimer

import java.lang.Math;

/**
 * WCA dimer system biased for instantaneous MC move.
 */
public class WCADimerBiasedInstantaneous implements Potential {

  private WCADimer system; // WCA dimer
  private int natoms; // number of atoms
  private double r0; // compact state separation
  private double h; // barrier height
  private double w; // second minimum is at r0 + 2*w
  private double dr; // instantaneous MC move along dimer bond vector

  private double lambda; // strength of coupling

  private double box_edge_length;
  private double half_box_edge;
  
  /*
   * Create a WCA dimer system.
   *
   * @param natoms the number of atoms in the system
   */
  public WCADimerBiasedInstantaneous(int natoms, double sigma, double epsilon, double box_edge_length, double barrier_height, double lambda) {
    // Create dimer system.
    this.system = new WCADimer(natoms, sigma, epsilon, box_edge_length, barrier_height);

    this.box_edge_length = box_edge_length;
    this.half_box_edge = this.box_edge_length / 2.0;

    // Store parameters for potential.
    this.natoms = natoms;
    this.r0 = Math.pow(2.0, (1.0/6.0)) * sigma;
    this.w = 0.5 * this.r0;
    this.h = barrier_height;

    this.lambda = lambda;
    
    // DEBUG
    this.dr = this.r0;
  }
    
  private double imagedSeparationVector(double [] rij, double [] positions, int i, int j) {
    // Compute squared distance.
    double r2 = 0.0;

    for(int k = 0; k < 3; k++) {
      // Compute interparticle separation.
      double dr = positions[3*j+k] - positions[3*i+k];

      // Image into box.
      while (dr < - this.half_box_edge)
        dr += this.box_edge_length;
      while (dr > this.half_box_edge)
        dr -= this.box_edge_length;

      // Store interparticle separation.
      rij[k] = dr;

      // Accumulate square distance.
      r2 += dr*dr;
    }
    
    return r2;
  }
    
  /**
   * Compute potential energy.
   *
   * U(x) = U0(x) - kT min { 1, exp(-w(x)) }
   *
   * where w(x) is work of driving bond separation.
   *
   * In dimensionless units:
   * 
   * u(x) = u0(x) - lambda ln min { 1, exp(-u1(x) + u0(x)) }
   *      = u0(x) - ln min { 1, exp[lambda (-u1(x)+u0(x)]) }
   *      = - ln [ exp(-u0(x)) min { 1, exp[lambda (-u1(x)+u0(x))] } ]
   *      = - ln min { exp(-u0(x)), exp(-lambda u1(x) + (1-lambda) u0(x)) }
   *      = max { u0(x), lambda u1(x) + (1-lambda) u0(x) }
   */
  public double computePotential(double [] positions) {
    // Compute potential at initial position.
    double u0 = this.system.computePotential(positions); 

    // Compute rij = r_j - r_i
    double [] rij = new double[3];
    double r2 = imagedSeparationVector(rij, positions, 0, 1);
    double r = Math.sqrt(r2);

    // Compute unit vector separation.
    double [] nij = new double[3];
    for(int k = 0; k < 3; k++)
      nij[k] = rij[k] / r;               

    // Perturb distance along ij vector.
    for(int k = 0; k < 3; k++) {
      positions[3*0+k] -= nij[k]*(this.dr/2.0);
      positions[3*1+k] += nij[k]*(this.dr/2.0);
    }
    
    // Compute potential for displaced vector.
    double u1 = this.system.computePotential(positions); 

    // Restore positions.
    for(int k = 0; k < 3; k++) {
      positions[3*0+k] += nij[k]*(this.dr/2.0);
      positions[3*1+k] -= nij[k]*(this.dr/2.0);
    }    

    // Compute potential.    
    double potential = Math.max(u0, this.lambda*u1 + (1.0-this.lambda)*u0);

    return potential;
  }

  public double computeGradient(double [] gradient, double [] positions) {
    double [] gradient0 = new double[3*this.natoms];
    double [] gradient1 = new double[3*this.natoms];

    double u0 = this.system.computeGradient(gradient0, positions); 

    // Compute rij = r_j - r_i
    double [] rij = new double[3];
    double r2 = imagedSeparationVector(rij, positions, 0, 1);
    double r = Math.sqrt(r2);

    // Compute unit vector separation.
    double [] nij = new double[3];
    for(int k = 0; k < 3; k++)
      nij[k] = rij[k] / r;               

    // Perturb distance along ij vector.
    for(int k = 0; k < 3; k++) {
      positions[3*0+k] -= nij[k]*(this.dr/2.0);
      positions[3*1+k] += nij[k]*(this.dr/2.0);
    }
    
    // Compute potential for displaced vector.
    double u1 = this.system.computeGradient(gradient1, positions); 

    // Restore positions.
    for(int k = 0; k < 3; k++) {
      positions[3*0+k] += nij[k]*(this.dr/2.0);
      positions[3*1+k] -= nij[k]*(this.dr/2.0);
    }    

    // Compute potential.    
    double potential = Math.max(u0, this.lambda*u1 + (1.0-this.lambda)*u0);

    // Compute gradient.
    if (u0 > this.lambda*u1 + (1.0-this.lambda)*u0) {
      for(int i = 0; i < 3*this.natoms; i++)
        gradient[i] = gradient0[i];
    } else {
      for(int i = 0; i < 3*this.natoms; i++)
        gradient[i] = this.lambda * gradient1[i] + (1.0-this.lambda)*gradient0[i];
    }
      
    return potential;
  }

}