"""
Output the computed Stark shift during a simulation.

DESCRIPTION

This version uses the difference between the forces of the complete system and a modified system
with zeroed charges to compute electric field vectors on each atom in the specified atom group.

The Stark vibrational shift is computed using the dot product with the normalized dipole moment of the atom group:

h c Delta \nu = - \alpha E . n

where \alpha is the Stark tuning rate (which must be specified by the user), E is the computed electric field impinging
on the atom group, and n = d / |d| is the normalized dipole moment of the group, with d the dipole moment.

Only NonbondedForce is currently supported.

Note that the corresponding electric field vectors include intramolecular contributions as well.

AUTHORS

original version by Kyle Beauchamp
modifications by John D. Chodera 2012-06-23

TODO

* Test HDF5 writing.
* Add ability to exclude certain molecules from contributing to the computed Stark field?  For example, waters hydrogen bonding to the Stark probe might be excluded.

"""
__author__ = "Kyle Beauchamp and John D. Chodera"
__version__ = "1.0"

import simtk.openmm as mm
import simtk.unit as units
import numpy
    
class StarkShiftReporter(object):
    """
    Report the electric field and position of specified atoms.
    
    NOTES

    This reporter requires transfer of coordinates and forces back and forth to the GPU each report interval.
    This can significantly slow down simulations at high reporting frequencies.

    EXAMPLES

    Run a simulation writing computed Stark shifts to a NetCDF file.
    
    >>> # Load a test system using the OpenMM app.
    >>> import simtk.openmm as openmm
    >>> import simtk.openmm.app as app
    >>> inpcrd = app.AmberInpcrdFile('bosutinib.crd')
    >>> prmtop = app.AmberPrmtopFile('bosutinib.prmtop')
    >>> system = prmtop.createSystem(nonbondedMethod=app.CutoffPeriodic, constraints=app.HBonds)
    >>> positions = inpcrd.getPositions()
    >>> topology = prmtop.topology
    >>> # Create a Simulation object using the system, positions, and topology.
    >>> temperature = 300.0 * units.kelvin
    >>> collision_rate = 9.1 / units.picosecond
    >>> timestep = 2.0 * units.femtosecond
    >>> integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)
    >>> simulation = app.Simulation(topology, system, integrator)
    >>> simulation.context.setPositions(positions)
    >>> # Add a Stark shift reporter.
    >>> atom_indices = [3-1, 19-1] # atoms defining Stark vibrational group;
    >>>                            # be sure to use OpenMM System numbering (starts at 0) rather than PDB numbering (starts at 1)
    >>> alpha = 0.87 * (units.centimeters**-1) / (units.mega*units.volts/units.centimeter) # linear Stark tuning rate of bosutinib
    >>> simulation.reporters.append( StarkShiftReporter(system, atom_indices, alpha, interval=10, filename='stark.nc', format='netcdf') )
    >>> # Run the simulation.
    >>> simulation.step(100)
    >>> # Clean up.
    >>> del simulation, integrator

    Write to a text file instead.
    
    >>> integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)
    >>> simulation = app.Simulation(topology, system, integrator)
    >>> simulation.context.setPositions(positions)
    >>> # Add a Stark shift reporter.
    >>> simulation.reporters.append( StarkShiftReporter(system, atom_indices, alpha, interval=10, filename='stark.txt', format='text') )
    >>> # Run the simulation.
    >>> simulation.step(100)
    >>> # Clean up.
    >>> del simulation, integrator

    Write to an HDF5 file.

    >>> integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)
    >>> simulation = app.Simulation(topology, system, integrator)
    >>> simulation.context.setPositions(positions)
    >>> # Add a Stark shift reporter.
    >>> simulation.reporters.append( StarkShiftReporter(system, atom_indices, alpha, interval=10, filename='stark.h5', format='hdf5') )
    >>> # Run the simulation.
    >>> simulation.step(100)
    >>> # Clean up.
    >>> del simulation, integrator

    """
    
    def __init__(self, system, atom_indices, alpha, interval=500, filename='stark.nc', format='netcdf', debug=False, append=False):
        """
        Create a EFieldReporter.
    
        ARGUMENTS
         - system (simtk.openmm.System) - the System object being simulated
         - atom_indices (list of int) - the atom indices defining the vibrational group for which the Stark shift is to be computed
         - alpha (simtk.unit.Quantity in units compatible with cm^{-1} / (MV/cm)) - linear Stark tuning rate

        OPTIONAL ARGUMENTS
         - interval (int) - the interval (in time steps) at which to write frames (default: 500)
         - filename (string) - the file to write coordinates and fields to (default: stark.nc)
         - format (string) - the format for writing data ['text', 'netcdf', 'hdf5'] (default: 'netcdf')
         - debug (boolean) - if True, debug output will be printed (default: False)
         - append (boolean) - if True, will append if file exists (default: False)

        """

        # Ensure at least two atom indices are specified (otherwise, there is no dipole moment).
        if (len(atom_indices) < 2):
            raise Exception("At least two atom indices must be used to define Stark vibrational group.  User specified: %s" % str(atom_indices))

        # Store arguments.
        self._name = 'StarkShiftReporter'
        self._atom_indices = atom_indices
        self._natoms = len(atom_indices) # number of atoms in Stark group
        self._alpha = alpha # linear Stark tuning rate
        self._reportInterval = interval # interval at which Efield is computed
        self._system = system
        self._filename = filename
        self._format = format
        self._debug = debug

        # Set defaults for unit reporting.
        self.time_unit = units.picoseconds
        self.position_unit = units.nanometer
        self.dipole_unit = units.elementary_charge * units.nanometer
        self.efield_unit = units.mega*units.volt/units.centimeter
        self.shift_unit = units.dimensionless / units.centimeter

        # Extract partial charges for specified atom indices in System.
        self._charges = self._extractPartialCharges(system, atom_indices) # self._charges[i,j] is the charge for atom atom_indices[i] replicated for j in range(0,3)

        # Extract mases for specified atom indices.
        self._masses = self._extractMasses(system, atom_indices) # self._masses[i,j] is the mass for atom atom_indices[i] replicated for j in range(0,3)
        
        # Set file mode
        import os.path
        file_mode = 'a' if (append and os.path.exists(filename)) else 'w'

        # Open file for writing.
        if format == 'text':
            if self._debug: print "%s: Opening text output file '%s'..." % (self._name, filename)
            self._txtfile = open(filename, file_mode)
            if file_mode == 'w':
                self._txtfile.write("#Timestep [%s] x y z [%s] Dx Dy Dz [%s] Ex Ey Ez [%s] shift [%s]\n" % (str(self.time_unit), str(self.position_unit), str(self.dipole_unit), str(self.efield_unit), str(self.shift_unit)))
            else:
                # TODO: Determine what sample we're resuming from.
                pass

        elif format == 'netcdf':
            if self._debug: print "%s: Opening NetCDF output file '%s'..." % (self._name, filename)

            import netCDF4 as netcdf
            ncfile = netcdf.Dataset(filename, file_mode, version='NETCDF4')
        
            if file_mode == 'w':
                print "Initializing NetCDF file..."

                # Create dimensions.
                ncfile.createDimension('sample', 0) # unlimited number of samples
                ncfile.createDimension('spatial', 3) # number of spatial dimensions
                ncfile.createDimension('atom', self._natoms) # number of spatial dimensions
                
                # Create variables.
                ncvar_indices = ncfile.createVariable('atom_indices', 'i', ('atom',))
                ncvar_time = ncfile.createVariable('time', 'f', ('sample', ))
                ncvar_position = ncfile.createVariable('position', 'f', ('sample','spatial'))
                ncvar_dipole = ncfile.createVariable('dipole', 'f', ('sample','spatial'))
                ncvar_efield = ncfile.createVariable('efield', 'f', ('sample','spatial'))
                ncvar_shift = ncfile.createVariable('shift', 'f', ('sample',))

                # Define units for variables.
                setattr(ncvar_indices, 'units', 'none')
                setattr(ncvar_time, 'units', str(self.time_unit))
                setattr(ncvar_position, 'units', str(self.position_unit))
                setattr(ncvar_dipole, 'units', str(self.dipole_unit))
                setattr(ncvar_efield, 'units', str(self.efield_unit))
                setattr(ncvar_shift, 'units', str(self.shift_unit))

                # TODO: Add machine-readable units using repr(units) instead of str(units).

                # Define long (human-readable) names for variables.
                setattr(ncvar_indices, "long_name", "atom_indices[index] is the System atom index of atom 'index' within the Stark vibrational group.")
                setattr(ncvar_time, "long_name", "time[sample] is simulation time of sample 'sample'.")
                setattr(ncvar_position, "long_name", "position[sample,spatial] is the spatial component 'spatial'  of the center of mass of the Stark vibrational group at sample 'sample'.")
                setattr(ncvar_dipole, "long_name", "dipole[sample,spatial] is the spatial component 'spatial' of the dipole moment of the Stark group at sample 'sample'.")
                setattr(ncvar_efield, "long_name", "efield[sample,spatial] is the spatial component 'spatial' of the mass-averaged electric field vector impinging on the Stark group for sample 'sample'.")
                setattr(ncvar_shift, "long_name", "shift[sample] is the computed Stark shift for sample 'sample'.")

                # Fill in atom indices.
                ncfile.variables['atom_indices'][:] = numpy.array(self._atom_indices[:])

                # Sync.
                ncfile.sync()

            self._ncfile = ncfile

        elif format == 'hdf5':
            # TODO: If resuming, don't overwrite, and determine which sample to start from.

            if self._debug: print "%s: Opening pytables (HDF5) output file '%s'..." % (self._name, filename)
            import tables
            filter = tables.Filters(complevel=9, complib='blosc', shuffle=True) # enable compression filter
            self._h5file = tables.File(filename, file_mode)
            self._data = self._h5file.createEArray("/", "Data", tables.Float32Atom(), (0,11), filters=filter) 

        else:
            raise Exception("Format '%s' not supported.  Choose one of ['netcdf', 'hdf5']." % format)                

        # Create modified version of system with only charges.
        self._modified_system = self._createZeroChargeSystem(system)

        # Create Context to use to evaluate forces.
        timestep = 1.0 * units.femtoseconds
        self._integrator = mm.VerletIntegrator(timestep)
        self._context = mm.Context(self._modified_system, self._integrator)

        return

    def _extractMasses(self, system, atom_indices=None):
        """
        Extract particle masses of specified group.

        ARGUMENTS
          system (simtk.openmm.System) - OpenMM system to extract charges from
        
        OPTIONAL ARGUMENTS
          atom_indices (list of int) - if specified, only masses for these atom indices will be extracted (default: None)

        RETURNS
          masses (simtk.unit.Quantity of Nx3 numpy array compatible with simtk.unit.amu) - masses[i,j] is the mass of particle atom_indices[i] for all j in range(0,3)

        """
        if atom_indices is None:
            atom_indices = range(0, system.getNumParticles())        
        nparticles = len(atom_indices)
        masses = units.Quantity(numpy.zeros([nparticles,3]), units.amu)
        for (store_index, particle_index) in enumerate(atom_indices):
            masses[store_index,:] = system.getParticleMass(particle_index)

        return masses

    def _extractPartialCharges(self, system, atom_indices=None):
        """
        Extract partial charges for specified atom indices.

        ARGUMENTS
          system (simtk.openmm.System) - OpenMM system to extract charges from
        
        OPTIONAL ARGUMENTS
          atom_indices (list of int) - if specified, only charges for these atom indices will be extracted (default: None)

        RETURNS
          charges (simtk.unit.Quantity of Nx3 numpy array compatible with simtk.unit.elementary_charge) - charges[i,j] is the charge of particle atom_indices[i] for all j in range(0,3)

        NOTES

        Currently, charges are only extracted from the first NonbondedForce object.

        """
        if atom_indices is None:
            atom_indices = range(0, system.getNumParticles())        
        nparticles = len(atom_indices)
        charges = units.Quantity(numpy.zeros([nparticles,3], numpy.float64), units.elementary_charge)

        # Find NonbondedForce.
        nonbonded_force = None
        for force_index in range(system.getNumForces()):
            force = system.getForce(force_index)
            if isinstance(force, mm.NonbondedForce):
                nonbonded_force = force
                break
        if nonbonded_force is None:
            raise Exception("%s: No NonbondedForce force object found in system." % (self._name))

        # Store partial charges.
        for (store_index, particle_index) in enumerate(atom_indices):
            [charge, sigma, epsilon] = force.getParticleParameters(particle_index)            
            charges[store_index,:] = charge

        return charges
        
    def _createZeroChargeSystem(self, reference_system):
        """
        Create a copy of the reference system containing no electrostatic components.

        ARGUMENTS

        reference_system (simtk.openmm.System) - the reference system

        NOTES

        CustomGBForce, CustomNonbondedForce, and CustomBondForce are also searched for 'charge' or 'chargeProd' parameters.
        
        """

        # Copy the System object.
        import copy
        system = copy.deepcopy(reference_system)

        # Set charges to zero.
        for force_index in range(system.getNumForces()):
            force = system.getForce(force_index)

            # TODO: Should we use 'isinstance(force, mm.NonbondedForce)' or 'force.__class__.__name__ == 'NobondedForce''?
            if isinstance(force, mm.NonbondedForce):
                # Zero charges.
                for particle_index in range(force.getNumParticles()):
                    [charge, sigma, epsilon] = force.getParticleParameters(particle_index)
                    force.setParticleParameters(particle_index, charge*0.0, sigma, epsilon)
                for exception_index in range(force.getNumExceptions()):
                    [iatom, jatom, chargeprod, sigma, epsilon] = force.getExceptionParameters(exception_index)
                    force.setExceptionParameters(exception_index, iatom, jatom, chargeprod*0.0, sigma, epsilon)

            elif isinstance(force, mm.GBSAOBCForce):
                # Zero charges.                
                for particle_index in range(force.getNumParticles()):
                    [charge, radius, scaling_factor] = force.getParticleParameters(particle_index)
                    force.setParticleParameters(particle_index, charge*0.0, radius, scaling_factor)

            elif isinstance(force, mm.CustomGBForce) or isinstance(force, mm.CustomNonbondedForce):
                # Get parameter names.
                parameter_names = [ force.getPerParticleParameterName(index) for index in range(force.getNumPerParticleParameters()) ]
                # Zero parameters named 'charge'.
                if 'charge' in parameter_names:
                    parameter_index = parameter_names.index('charge')
                    for particle_index in range(force.getNumParticles()):
                        parameters = force.getParticleParameters(particle_index)
                        parameters[parameter_index] *= 0.0
                        force.setParticleParameters(particle_index, parameters)
                    
            elif isinstance(force, mm.CustomBondForce):
                # Get parameter names.
                parameter_names = [ force.getPerParticleParameterName(index) for index in range(force.getNumPerParticleParameters()) ]
                # Zero parameters named 'chargeProd'.
                if 'chargeProd' in parameter_names:
                    parameter_index = parameter_names.index('chargeProd')
                    for particle_index in range(force.getNumBonds()):
                        parameters = force.getBondParameters(particle_index)
                        parameters[parameter_index] *= 0.0
                        force.setBondParameters(particle_index, parameters)

            else:                
                # Don't modify force.
                pass
        
        return system
    
    def describeNextReport(self, simulation):
        """
        Get information about the next report this object will generate.
        
        Parameters:
         - simulation (Simulation) The Simulation to generate a report for
        
         Returns: A five element tuple.  The first element is the number of steps until the
         next report.  The remaining elements specify whether that report will require
         positions, velocities, forces, and energies respectively.
         """
        steps = self._reportInterval - simulation.currentStep%self._reportInterval
        return (steps, True, False, True, False)
    
    def report(self, simulation, state):
        """
        Generate a report.
        
        Parameters:
         - simulation (Simulation) The Simulation to generate a report for
         - state (State) The current state of the simulation

        TODO:
        - only perform computation for requested atoms

        """

        if self._debug: print "%s: Generating report..." % (self._name)
            
        # Get current simulation time.
        current_time = simulation.integrator.getStepSize() * simulation.currentStep
        if self._debug: print "current_time = %s" % str(current_time)
        
        # Get positions.
        positions = state.getPositions(asNumpy=True)

        # Get forces on selected atoms
        if self._debug: print "%s: Computing original forces..." % (self._name)
        f0 = state.getForces(asNumpy=True)[self._atom_indices,:]
        if self._debug: print f0

        # Compute forces on selected atoms for charge-modified system.
        if self._debug: print "%s: Computing zero-charge forces..." % (self._name)
        self._context.setPositions(positions)
        state = self._context.getState(getForces=True)
        f1 = state.getForces(asNumpy=True)[self._atom_indices,:]
        if self._debug: print f1
        
        # Compute electric field on specified atoms from forces and charges.
        if self._debug: print "%s: Computing electric field vectors..." % (self._name)
        f_elec = f0 - f1 # electrostatic forces on all atoms
        if self._debug: print f_elec
        charges = self._charges
        efield = units.Quantity( (f_elec / f_elec.unit) / (charges / charges.unit), (f_elec.unit/charges.unit)) / units.AVOGADRO_CONSTANT_NA        
        efield = efield.in_units_of(self.efield_unit)
        if self._debug: print efield
        
        # Get positions on selected atoms.
        positions = positions[self._atom_indices,:]

        # Compute Stark group center of mass position.
        group_position = units.sum(positions * self._masses) / units.sum(self._masses) # units.sum operates along first dimension (atom indices)
        group_position = group_position.in_units_of(self.position_unit)

        # Compute group dipole moment about the center of mass.
        # Note that the dipole moment is independent of reference point only for net neutral groups; we choose the center of mass to eliminate this ambiguity.
        group_dipole = units.sum((positions - group_position) * self._charges) # units.sum operates along first dimension (atom indices)
        group_dipole = group_dipole.in_units_of(self.dipole_unit)

        # Compute unit vector aligned with group dipole.
        group_dipole_norm = units.sqrt(units.sum(group_dipole**2))
        group_dipole_unit_vector = group_dipole / group_dipole_norm

        # Compute mass-weighted electric field vector.
        group_efield = units.sum(efield * self._masses) / units.sum(self._masses) # units.sum operates along first dimension (atom indices)
        group_efield = group_efield.in_units_of(self.efield_unit)

        # Compute the effective linear Stark shift by dotting the effective electric field impinging on the group with the dipole unit vector.
        stark_shift = self._alpha * units.sum(group_efield * group_dipole_unit_vector)
        stark_shift = stark_shift.in_units_of(self.shift_unit)

        # Get current sample counter from simulation step counter.
        # This will automatically back up if we've resumed a simulation.
        sample = int(simulation.currentStep / self._reportInterval) - 1

        # Write snapshot data to file.
        if self._format == 'text':
            output_string="%e %e %e %e %e %e %e %e %e %e %e\n"%(current_time / self.time_unit, 
                                                                group_position[0]/self.position_unit, group_position[1]/self.position_unit, group_position[2]/self.position_unit,
                                                                group_dipole[0]/self.dipole_unit, group_dipole[1]/self.dipole_unit, group_dipole[2]/self.dipole_unit, 
                                                                group_efield[0]/self.efield_unit, group_efield[1]/self.efield_unit, group_efield[2]/self.efield_unit, 
                                                                stark_shift/self.shift_unit)
            self._txtfile.write(output_string)
            self._txtfile.flush()
            
        elif self._format == 'netcdf':
            self._ncfile.variables['time'][sample] = current_time / self.time_unit
            self._ncfile.variables['position'][sample,:] = group_position[:] / self.position_unit
            self._ncfile.variables['dipole'][sample,:] = group_dipole[:] / self.dipole_unit
            self._ncfile.variables['efield'][sample,:] = group_efield[:] / self.efield_unit
            self._ncfile.variables['shift'][sample] = stark_shift / self.shift_unit
            self._ncfile.sync()
            
        elif self._format == 'hdf5':
            # Build data line for HDF5 file.
            dataline = list()
            dataline.append(current_time/self.time_unit)
            for k in range(3): dataline.append(group_position[k]/self.position_unit)
            for k in range(3): dataline.append(group_dipole[k]/self.dipole_unit)
            for k in range(3): dataline.append(group_efield[k]/self.efield_unit)
            dataline.append(stark_shift/self.shift_unit)
            # Write dataline to HDF5 table.
            self._data.append([dataline])
            self._h5file.flush()
        
        else:
            raise Exception('self._format has changed to an invalid format.')

    def __del__(self):
        """
        Clean up safely.

        """

        # Close any open files.
        if hasattr(self, '_txtfile'):
            self._txtfile.close()
        if hasattr(self, '_ncfile'):
            self._ncfile.close()
        if hasattr(self, '_h5file'):
            self._h5file.close()

        # Clean up integrator and context.    
        del self._context, self._integrator
        del self._system


if __name__ == '__main__':
    import doctest
    doctest.testmod()
