// Java code for WCA dimer

import java.lang.Math;
import java.lang.Exception;

public class WCAPotential implements Potential {

  private int natoms; //< the number of particles in the system 
  private double sigma; 
  private double epsilon;
  private double box_edge_length;

  private double rmin; //< cutoff at minimum potential
  private double cutoff;
  private double cutoff2; //< squared cutoff distance
  private double sigma2; //< sigma^2
  private double half_box_edge;
  private double cutoff_energy;
  
  /*
   * Create a WCA system.
   *
   * @param natoms the number of atoms in the system
   */
  public WCAPotential(int natoms, double sigma, double epsilon, double box_edge_length) {
    this.natoms = natoms;
    this.sigma = sigma;
    this.epsilon = epsilon;
    this.box_edge_length = box_edge_length;

    // Pre-compute squared cutoff.
    this.rmin = java.lang.Math.pow(2., (1./6.)) * sigma;  // distance of minimum of potential
    this.cutoff = this.rmin;  // cutoff 
    this.cutoff2 = Math.pow(this.cutoff, 2.0);
    this.sigma2 = sigma*sigma;
    this.half_box_edge = this.box_edge_length / 2.0;
    this.cutoff_energy = 4.0 * epsilon * (Math.pow(this.cutoff/sigma, -12.0) - Math.pow(this.cutoff/sigma, -6.0));
  }

  private double imagedSeparationVector(double [] rij, double [] positions, int i, int j) {
    // Compute squared distance.
    double r2 = 0.0;

    for(int k = 0; k < 3; k++) {
      // Compute interparticle separation.
      double dr = positions[3*j+k] - positions[3*i+k];

      // Image into box.
      while (dr < - this.half_box_edge) 
        dr += this.box_edge_length;
      while (dr > this.half_box_edge) 
        dr -= this.box_edge_length;

      // Store interparticle separation.
      rij[k] = dr;

      // Accumulate square distance.
      r2 += dr*dr;
    }
    
    return r2;
  }
    
  public double computePotential(double [] positions) {
    double potential = 0.0;
    double [] rij = new double[3];
    for(int i = 0; i < natoms; i++) 
      for(int j = 0; j < i; j++) {
        //if ((j==0) && (i==1)) continue; // DEBUG

        // Compute rij = r_j - r_i
        double r2 = imagedSeparationVector(rij, positions, i, j);
        // Compute interaction potential.
        if (r2 < cutoff2) {
          // 4 epsilon [ (r/sigma)^{-12} - (r/sigma)^{-6} ]
          double x2 = r2/sigma2;
          double x6 = x2*x2*x2;
          double x12 = x6*x6;
          potential += 4.0 * epsilon * (1.0/x12 - 1.0/x6) - this.cutoff_energy;
        }
      }
      
    return potential;
  }

  public double computeGradient(double [] gradient, double [] positions) {
    double potential = 0.0;
    double [] rij = new double[3];
    double [] nij = new double[3];

    for(int i = 0; i < natoms; i++) 
      for(int k = 0; k < 3; k++)
        gradient[3*i+k] = 0.0;
    
    for(int i = 0; i < natoms; i++)       
      for(int j = 0; j < i; j++) {
        //if ((j==0) && (i==1)) continue; // DEBUG

        // Compute rij = r_j - r_i
        double r2 = imagedSeparationVector(rij, positions, i, j);
        // Compute interaction potential.
        if (r2 < cutoff2) {
          // 4 epsilon [ (r/sigma)^{-12} - (r/sigma)^{-6} ]
          double x2 = r2/sigma2;
          double x6 = x2*x2*x2;
          double x12 = x6*x6;
          potential += 4.0 * epsilon * (1.0/x12 - 1.0/x6) - this.cutoff_energy;

          // Gradient contribution.
          double r = java.lang.Math.sqrt(r2);
          double x = (r/sigma);
          double x7 = x6*x;
          double x13 = x12*x;
          double dxdr = 1.0/sigma;
          for(int k = 0; k < 3; k++)
            nij[k] = rij[k] / r;               
          double dUdr = 4.0 * epsilon * (-12.0 / x13 + 6.0 / x7) * dxdr;
          for(int k = 0; k < 3; k++) {
            gradient[3*i+k] -= dUdr * nij[k];
            gradient[3*j+k] += dUdr * nij[k];
          }
        }
      }
      
    return potential;
  }

}