// Java code for WCA dimer

import java.lang.Math;
import java.util.Random;

public class Simulation {

  Random random; // random number generator
  private Potential sampling_system;
  private Potential switching_system;
  private int natoms;
  private double mass; // particle mass, amu
  private double volume;
  private double box_edge_length;

  private double r0;
  private double w;

  private double equilibration_timestep;
  private double switching_timestep;

  private double temperature;
  private double collision_rate;
  private int equilibration_nsteps;
  private double kT;

  public Simulation(int natoms, double mass, double density, double sigma, double epsilon) {
    // this.equilibration_timestep = 1.0 * Units.femtoseconds; // timestep (s)
    this.temperature = 0.824 / (Constants.kB / epsilon) * Units.kelvin; // temperature, K
    System.out.printf("temperature = %.1f K\n", this.temperature);

    this.volume = natoms / density;
    this.box_edge_length = Math.pow(volume, (1.0/3.0));
    this.mass = mass;
    this.natoms = natoms;

    this.r0 = Math.pow(2.0, (1.0/6.0)) * sigma;
    this.w = 0.5 * this.r0;

    this.kT = Constants.kB * this.temperature;
    double barrier_height = 7.0 * this.kT;

    double lambda = 0.0;
    //this.sampling_system = new WCADimerBiasedInstantaneous(natoms, sigma, epsilon, this.box_edge_length, barrier_height, lambda);
    this.sampling_system = new WCADimer(natoms, sigma, epsilon, this.box_edge_length, barrier_height); 
    this.switching_system = new WCADimer(natoms, sigma, epsilon, this.box_edge_length, barrier_height);

    //this.sampling_system = new WCADimerVacuum(sigma, epsilon, barrier_height); 
    //this.switching_system = new WCADimerVacuum(sigma, epsilon, barrier_height);
    
    // Compute characteristic timescale.
    double tau = Math.sqrt(sigma*sigma * mass / epsilon);
    this.collision_rate = 1.0 / tau;
    System.out.printf("collision_rate = %.3f / ps\n", this.collision_rate * Units.picoseconds);
    this.equilibration_nsteps = 5000;
    this.equilibration_timestep = 0.001 * tau;
    this.switching_timestep = 0.010 * tau;
    System.out.printf("tau = %.3f ps; equilibration timestep = %.3f fs\n", tau / Units.picoseconds, this.equilibration_timestep / Units.femtoseconds);

  }

  /**
   * Compute stochastic noise terms for GHMC integration.
   *
   * @param temperature
   * @param timestep
   * @param collision_rate
   * @param alpha 
   * @param beta
   */
  private void ghmcterm(double temperature, double timestep, double collision_rate, double [] alpha, double [] beta) {
    double gamma = collision_rate * this.mass;
    double kT = Constants.kB * temperature;
    double sigma = Math.sqrt(2.0 * kT * gamma);
    for (int i = 0; i < this.natoms; i++) {
      for (int k = 0; k < 3; k++) {
        alpha[3*i+k] = (1.0 - (timestep/4.0)*collision_rate) / (1.0 + (timestep/4.0)*collision_rate);
        beta[3*i+k] = this.random.nextGaussian() * Math.sqrt(timestep/2.0)*sigma / (1.0 + (timestep/4.0)*collision_rate) / this.mass; 
      }
    }    
  }

  /**
   * Generalized hybrid Monte Carlo (GHMC) for sampling from NVT ensemble.
   *
   * @param nsteps number of steps to take
   * @return the fraction of steps that were accepted by the Monte Carlo scheme
   *
   * REFERENCES
   *                                                                                                                           
   * T. Lelievre, M. Rousset, and G. Stoltz, "Free Energy Computations:                                                    
   * A Mathematical Perspective." World Scientific, 2010. Algorithm 2.11.                                                  
   *                                                                                                                           
   * T. Lelievre, M. Rousset, and G. Stoltz, "Langevin dynamics with                                                       
   * constraints and computation of free energy differences." 2010.                                                        
   * Eqs. 61-63.                                                                                                           
   */
  private double ghmc(Potential system, double [] positions, double [] velocities, double temperature, double timestep, double collision_rate, int nsteps) {

    double kT = Constants.kB * temperature;

    double [] gradient = new double[3*this.natoms];

    double [] alpha = new double[3*this.natoms];
    double [] beta = new double[3*this.natoms];

    double potential_energy;
    double kinetic_energy;
    double total_energy;

    double [] accelerations = new double[3*this.natoms];

    double old_potential_energy, old_kinetic_energy, old_total_energy;
    double [] old_positions = new double[3*this.natoms];
    double [] old_velocities = new double[3*this.natoms];
    double [] old_accelerations = new double[3*this.natoms];

    // Compute initial acceleration.
    potential_energy = system.computeGradient(gradient, positions);
    for (int i = 0; i < this.natoms; i++) 
      for (int k = 0; k < 3; k++) 
        accelerations[3*i+k] = - gradient[3*i+k] / this.mass;

    // Compute kinetic energy.
    kinetic_energy = 0.0;
    for (int i = 0; i < this.natoms; i++) 
      for (int k = 0; k < 3; k++) 
        kinetic_energy += 0.5 * this.mass * (velocities[3*i+k]*velocities[3*i+k]);
    
    // Compute total energy.
    total_energy = kinetic_energy + potential_energy;

    // Main dynamics loop.
    int nsteps_accepted = 0;
    for (int step = 0; step < nsteps; step++) {

      // Velocity modification.
      ghmcterm(temperature, timestep, collision_rate, alpha, beta);
      for (int i = 0; i < this.natoms; i++) 
        for (int k = 0; k < 3; k++) 
          velocities[3*i+k] = velocities[3*i+k]*alpha[3*i+k] + beta[3*i+k];             

      // Compute kinetic energy.
      kinetic_energy = 0.0;
      for (int i = 0; i < this.natoms; i++) 
        for (int k = 0; k < 3; k++) 
          kinetic_energy += 0.5 * this.mass * (velocities[3*i+k]*velocities[3*i+k]);
      
      // Compute total energy.
      total_energy = kinetic_energy + potential_energy;
      
      //
      // Beginning of Metropolis step.
      //

      // Store old total energy, positions, gradient.
      old_total_energy = total_energy;
      for (int i = 0; i < this.natoms; i++) 
        for (int k = 0; k < 3; k++) {
          old_velocities[3*i+k] = velocities[3*i+k];
          old_positions[3*i+k] = positions[3*i+k];
          old_accelerations[3*i+k] = accelerations[3*i+k];
        }

      // First velocity half-kick.
      for (int i = 0; i < this.natoms; i++) 
        for (int k = 0; k < 3; k++) 
          velocities[3*i+k] += accelerations[3*i+k]*(timestep/2.0);
      
      // Position full-kick.
      for (int i = 0; i < this.natoms; i++) 
        for (int k = 0; k < 3; k++) 
          positions[3*i+k] += velocities[3*i+k] * timestep;

      // Update acceleration at new configuration.
      potential_energy = system.computeGradient(gradient, positions);
      for (int i = 0; i < this.natoms; i++) 
        for (int k = 0; k < 3; k++) 
          accelerations[3*i+k] = - gradient[3*i+k] / this.mass;      
      
      // Second velocity half-kick.
      for (int i = 0; i < this.natoms; i++) 
        for (int k = 0; k < 3; k++) 
          velocities[3*i+k] += accelerations[3*i+k]*(timestep/2.0);

      // Compute kinetic energy.
      kinetic_energy = 0.0;
      for (int i = 0; i < this.natoms; i++) 
        for (int k = 0; k < 3; k++) 
          kinetic_energy += 0.5 * this.mass * (velocities[3*i+k]*velocities[3*i+k]);

      // Compute total energy.
      total_energy = kinetic_energy + potential_energy;            
            
      // Accept or reject.
      double logP = - (total_energy - old_total_energy) / kT;
      if ((logP > 0.0) || (this.random.nextDouble() < Math.exp(logP))) {
        // Accept.

        nsteps_accepted++;

      } else {
        // Reject.

        // Restore old total energy, positions, gradient.
        total_energy = old_total_energy;
        for (int i = 0; i < this.natoms; i++) 
          for (int k = 0; k < 3; k++) {
            velocities[3*i+k] = old_velocities[3*i+k];
            positions[3*i+k] = old_positions[3*i+k];
            accelerations[3*i+k] = old_accelerations[3*i+k];
          }
      }
      
      // 
      // End of Metropolis step.
      //
                 
      // Velocity modification.
      ghmcterm(temperature, timestep, collision_rate, alpha, beta);
      for (int i = 0; i < this.natoms; i++) 
        for (int k = 0; k < 3; k++) 
          velocities[3*i+k] = velocities[3*i+k]*alpha[3*i+k] + beta[3*i+k];             
    }

    // Compute kinetic energy.
    kinetic_energy = 0.0;
    for (int i = 0; i < this.natoms; i++) 
      for (int k = 0; k < 3; k++) 
        kinetic_energy += 0.5 * this.mass * (velocities[3*i+k]*velocities[3*i+k]);

    // Compute total energy.
    total_energy = kinetic_energy + potential_energy;            
            
    // Compute acceptance rate.
    double acceptance_rate = (double)nsteps_accepted / (double)nsteps;

    // DEBUG
    System.out.printf("  GHMC acceptance rate %.3f%% over %d steps | kinetic %.3f kcal/mol | potential %.3f kcal/mol | total %.3f kcal/mol | temperature %.1f K\n", acceptance_rate * 100.0, nsteps, kinetic_energy / Units.kilocalories_per_mole, potential_energy / Units.kilocalories_per_mole, total_energy / Units.kilocalories_per_mole, kinetic_energy / (1.5 * this.natoms * Constants.kB));

    // Return fraction of steps accepted.
    return acceptance_rate;
  }

  private void minimize(Potential system, double [] positions) {
    // Simple energy minimization scheme.
    System.out.printf("Minimizing...\n");
    int natoms = this.natoms;
    double tolerance = 10.0 * Units.kilocalories_per_mole;
    double xtol = 0.01 * Units.angstroms;
    double [] gradient = new double[3*natoms];
    double potential_energy = sampling_system.computeGradient(gradient, positions);
    int minimization_steps = 2000;
    for (int iteration = 0; iteration < minimization_steps; iteration++) {
      //System.out.printf("iteration %d / %d\n", iteration, niterations);
      double gradient_norm = 0.0;
      for(int i = 0; i < 3*natoms; i++)
        gradient_norm += gradient[i]*gradient[i];
      double dx = 0.05 * Units.angstroms;
      gradient_norm = Math.sqrt(gradient_norm);
      for(int i = 0; i < 3*natoms; i++)
        positions[i] -= dx * gradient[i] / gradient_norm;
      potential_energy = system.computeGradient(gradient, positions);
      //System.out.printf("potential energy %.3f kcal/mol\n", potential_energy / Units.kilocalories_per_mole);      
    }
    System.out.printf("potential energy %.3f kcal/mol\n", potential_energy / Units.kilocalories_per_mole);      
  }

  private double dimer_separation(double [] positions, double [] rij) {
    // Compute minimum-image interparticle vector.
    for(int k = 0; k < 3; k++) {
      // Compute interparticle separation.
      double dr = positions[3*1+k] - positions[3*0+k];
      
      // Image into box.
      while (dr < - this.box_edge_length/2.)
        dr += this.box_edge_length;
      while (dr > this.box_edge_length/2.)
        dr -= this.box_edge_length;
      
      // Store interparticle separation.
      rij[k] = dr;
    }
    
    // Compute squared interparticle distance.
    double r2 = 0.0;
    for(int k = 0; k < 3; k++) 
        r2 += rij[k]*rij[k];      
    
    // Compute distance.
    double r = Math.sqrt(r2);

    return r;
  }

  /**
   * Perform nonequilibrium switching on system using velocity Verlet dynamics.
   *
   * @param system   the system to be switched
   * @param positions   positions[3*i+k] is the position of atom i dimension k
   * @param velocities   velocities[3*i+k] is the velocity of atom i dimension k
   * @param timestep   the timestep for integration
   * @param nsteps   the number of steps over which switching is to take place (switching is instantaneous if zero)
   * @param delta   the distance by which the interatomic distance between particles 0 and 1 is to be switched (positive or negative)
   * @return the Lechner work accumulated during integration
   */
  private double switching(Potential system, double [] positions, double [] velocities, double timestep, int nsteps, double delta) {
    double [] gradient = new double[3*this.natoms];
    double [] accelerations = new double[3*this.natoms];
    double potential_energy, kinetic_energy;

    // Compute distance.
    double [] rij = new double[3];
    double r = dimer_separation(positions, rij);
    System.out.printf("initial distance is %.1f A | delta = %+.1f A\n", r / Units.angstroms, delta / Units.angstroms);
    
    // Compute unit vector from i to j.
    double [] nij = new double[3];
    for(int k = 0; k < 3; k++) 
      nij[k] = rij[k] / r;    

    // Compute initial acceleration.
    potential_energy = system.computeGradient(gradient, positions);
    for (int i = 0; i < this.natoms; i++) 
      for (int k = 0; k < 3; k++) 
        accelerations[3*i+k] = - gradient[3*i+k] / this.mass;

    // Compute kinetic energy.
    kinetic_energy = 0.0;
    for (int i = 0; i < this.natoms; i++) 
      for (int k = 0; k < 3; k++) 
        kinetic_energy += 0.5 * this.mass * (velocities[3*i+k]*velocities[3*i+k]);
      
    // Compute initial total energy.
    double initial_total_energy = kinetic_energy + potential_energy;

    // Main dynamics loop.
    for (int step = 0; step < nsteps; step++) {

      //
      // Modify positions of atoms 0 and 1.
      //

      double dr = (delta / (double)nsteps); // distance to shift this step
      for (int k = 0; k < 3; k++) {
        positions[3*0+k] -= (dr/2.0) * nij[k];
        positions[3*1+k] += (dr/2.0) * nij[k];
      }      

      // Update acceleration.
      potential_energy = system.computeGradient(gradient, positions);
      for (int i = 0; i < this.natoms; i++) 
        for (int k = 0; k < 3; k++) 
          accelerations[3*i+k] = - gradient[3*i+k] / this.mass;

      //
      // Velocity Verlet step integrating all atoms but 0 and 1.
      //

      // First velocity half-kick.
      for (int i = 2; i < this.natoms; i++) 
        for (int k = 0; k < 3; k++) 
          velocities[3*i+k] += accelerations[3*i+k]*(timestep/2.0);
      
      // Position full-kick.
      for (int i = 2; i < this.natoms; i++) 
        for (int k = 0; k < 3; k++) 
          positions[3*i+k] += velocities[3*i+k] * timestep;

      // Update acceleration at new configuration.
      potential_energy = system.computeGradient(gradient, positions);
      for (int i = 0; i < this.natoms; i++) 
        for (int k = 0; k < 3; k++) 
          accelerations[3*i+k] = - gradient[3*i+k] / this.mass;      
      
      // Second velocity half-kick.
      for (int i = 2; i < this.natoms; i++) 
        for (int k = 0; k < 3; k++) 
          velocities[3*i+k] += accelerations[3*i+k]*(timestep/2.0);

    }

    // Perform instantaneous move if no steps were taken.
    if (nsteps == 0) {

      //
      // Modify positions of atoms 0 and 1.
      //

      for (int k = 0; k < 3; k++) {
        positions[3*0+k] -= (delta/2.0) * nij[k];
        positions[3*1+k] += (delta/2.0) * nij[k];
      }      

      // Update potential.
      potential_energy = system.computeGradient(gradient, positions);
    }


    // Compute kinetic energy.
    kinetic_energy = 0.0;
    for (int i = 0; i < this.natoms; i++) 
      for (int k = 0; k < 3; k++) 
        kinetic_energy += 0.5 * this.mass * (velocities[3*i+k]*velocities[3*i+k]);
    
    // Compute final total energy.
    double final_total_energy = kinetic_energy + potential_energy;                        

    // Return total work.
    double work = final_total_energy - initial_total_energy;
    return work;
  }
  
  public void run(int niterations) {
    // Initialize positions randomly.
    double [] positions = new double[3*this.natoms];
    this.random = new Random();
    for(int atom_index = 0; atom_index < natoms; atom_index++) 
      for(int k = 0; k < 3; k++)
        positions[3*atom_index+k] = this.random.nextDouble() * this.box_edge_length;

    // Move particles 0 and 1.
    for(int i = 0; i < 2; i++)
      for(int k = 0; k < 3; k++)
        positions[3*i+k] = 0.0;
    positions[3*1+0] = r0;

    // Minimize energy.
    minimize(sampling_system, positions);

    // Initialize velocities.
    double [] velocities = new double[3*this.natoms];
    for(int atom_index = 0; atom_index < natoms; atom_index++) 
      for(int k = 0; k < 3; k++)
        velocities[3*atom_index+k] = 0.0;        

    // Equilibrate.
    //    System.out.printf("Equilibrating...\n");
    //    for (int iteration = 0; iteration < 5; iteration++) 
    //      ghmc(this.sampling_system, positions, velocities, this.temperature, this.equilibration_timestep, this.collision_rate, this.equilibration_nsteps);

    double start_time, end_time, elapsed_time;
    
    double [] switching_positions = new double[3*this.natoms];
    double [] switching_velocities = new double[3*this.natoms];

    int max_nsteps_index = 13; // maximum number of steps is 1 << (max_nsteps_index-1)
    //int max_nsteps_index = 1; // DEBUG: Instantaneous switching only
    int [] nsteps_array = new int[max_nsteps_index];
    for (int nsteps_index = 0; nsteps_index < max_nsteps_index; nsteps_index++) 
      if (nsteps_index == 0) {
        nsteps_array[nsteps_index] = 0; // instantaneous MC
      } else {
        nsteps_array[nsteps_index] = (1 << (nsteps_index-1)); // NCMC
      }
    
    int naccepted = 0; // number of NCMC moves accepted
    for(int iteration = 0; iteration < niterations; iteration++) {
      System.out.printf("iteration %d / %d\n", iteration+1, niterations);
      start_time = (double)System.nanoTime() * 1e-9;            

      // Equilibrate.
      ghmc(this.sampling_system, positions, velocities, this.temperature, this.equilibration_timestep, this.collision_rate, this.equilibration_nsteps);
      
      // Report on difference with switching system.
      double du = switching_system.computePotential(positions) - sampling_system.computePotential(positions);
      System.out.printf("  switching system potential difference = %.3f kT\n", du / this.kT);

      // Choose switching perturbation.
      double [] rij = new double[3];
      double initial_distance = dimer_separation(positions, rij);
      double delta = 0.0;
      if (initial_distance < 1.5*r0) {
        delta = + r0;
      } else if ((initial_distance > 1.5*r0) && (initial_distance < 3.0*r0)) {
        delta = - r0;
      } else {
        delta = 0.0;
      }
      double final_distance = initial_distance + delta;
      System.out.printf("Proposing %.1f A -> %.1f A (barrier at %.1f A)", initial_distance / Units.angstroms, final_distance / Units.angstroms, (r0+w)/Units.angstroms);
            
      // Choose switching perturbation.
      //double delta = this.r0;
      //if (this.random.nextDouble() < 0.5)
      //delta = - delta;
      
      double [] saved_work = new double[max_nsteps_index]; // saved work
      double [] saved_log_Paccept = new double[max_nsteps_index]; // saved log acceptance probability

      for (int nsteps_index = 0; nsteps_index < max_nsteps_index; nsteps_index++) {
        int nsteps = nsteps_array[nsteps_index]; // number of switching steps
        System.out.printf(" Switching in %d steps...\n", nsteps);
        
        // Copy equilibrated positions and velocities.
        for(int i = 0; i < natoms; i++)
          for(int k = 0; k < 3; k++) {
            switching_velocities[3*i+k] = velocities[3*i+k];
            switching_positions[3*i+k] = positions[3*i+k];
          }
      
        // Switch.
        double work = switching(this.switching_system, switching_positions, switching_velocities, this.switching_timestep, nsteps, delta);
        double log_Paccept = -work/this.kT + 2.0*Math.log(final_distance/initial_distance);
        
        // Store.
        saved_work[nsteps_index] = work / this.kT;
        saved_log_Paccept[nsteps_index] = log_Paccept;
               
        System.out.printf("  work = %.3f kT, Paccept = %.3e\n", work / this.kT, Math.exp(log_Paccept) );
      }

      // Accept or reject based on slowest NCMC switching trial.
      double log_Paccept = saved_log_Paccept[max_nsteps_index-1];
      if ((log_Paccept > 0.0) || (random.nextDouble() < Math.exp(log_Paccept))) {
        // Accept.
        for(int i = 0; i < natoms; i++)
          for(int k = 0; k < 3; k++) {
            velocities[3*i+k] = switching_velocities[3*i+k];
            positions[3*i+k] = switching_positions[3*i+k];
          }        
        naccepted++;
        System.out.printf("accepted\n");
      } else {
        System.out.printf("rejected\n");
      }
            
      end_time = (double)System.nanoTime() * 1e-9;
      elapsed_time = (end_time - start_time);
      System.out.printf("%.3f s elapsed\n", elapsed_time);      

      // Write to file.
      System.out.printf("*");
      for (int nsteps_index = 0; nsteps_index < max_nsteps_index; nsteps_index++) 
        System.out.printf("%16.6e", saved_log_Paccept[nsteps_index]);
      System.out.printf("\n\n");
    }
  }

  public static void main(String [] args) {

    // PARAMETERS
    int natoms = 216;
    double mass = 39.9 * Units.amu; // amu
    double sigma = 3.4 * Units.angstroms; // angstroms
    double epsilon = 120.0 * Units.kelvin * Constants.kB; // J
    double density = 0.96 / Math.pow(sigma, 3); // number density
    int niterations = 10000; // number of iterations
    
    // Create a simulation object.
    System.out.printf("Initializing simulation...\n");
    Simulation simulation = new Simulation(natoms, mass, density, sigma, epsilon);
    System.out.printf("Initialized.\n");

    // Run simulation.
    System.out.printf("Running simulation...\n");
    simulation.run(niterations);
    System.out.printf("Simulation complete.\n");
  }

}