"""
Output the electric field vectors and positions of selected atoms during an OpenMM simulation.

DESCRIPTION

This version uses the difference between the forces of the complete system and a modified system
with zeroed charges to compute electric field vectors on each atom.  Only NonbondedForce is currently
supported.

Note that the corresponding electric field vectors include intramolecular contributions as well.

AUTHORS

original version by Kyle Beauchamp
modifications by John D. Chodera 2012-06-23

"""
__author__ = "Kyle Beauchamp and John D. Chodera"
__version__ = "1.0"

import simtk.openmm as mm
import simtk.unit as units
import numpy
    
class EFieldReporter(object):
    """
    Report the electric field and position of specified atoms.
    
    NOTES

    This reporter requires transfer of coordinates and forces back and forth to the GPU each report interval.
    This can significantly slow down simulations at high reporting frequencies.

    """
    
    def __init__(self, system, interval=500, filename='efield.nc', format='netcdf', atom_indices='all', debug=False):
        """
        Create a EFieldReporter.
    
        ARGUMENTS
         - system (simtk.openmm.System) - the System object being simulated
         - interval (int) - the interval (in time steps) at which to write frames (default: 500)
         - filename (string) - the file to write coordinates and fields to (default: efield.nc)
         - format (string) - the format for writing data ['text', 'netcdf', 'hdf5'] (default: 'netcdf')
         - atom_indices (list of int) - indices of atoms to compute E field for, or all atoms if set to 'all' (default: 'all')
         - debug (boolean) - if True, debug output will be printed

        """

        # Expand atom indices if requested.
        if atom_indices == 'all':
            atom_indices = range(0, system.getNumParticles())

        # TODO: Check arguments for validity.

        # Store arguments.
        self._name = 'EFieldReporter'
        self._reportInterval = interval # interval at which Efield is computed
        self._system = system
        self._filename = filename
        self._format = format
        self._atom_indices = atom_indices
        self._natoms = len(atom_indices)
        self._debug = debug
        self._efield_units = (units.mega*units.volts) / units.centimeters # units to write efield in

        # Extract partial charges.
        self._charges = self._extractPartialCharges(system)
        
        # Open file for writing.
        if format == 'text':
            if self._debug: print "%s: Opening text output file '%s'..." % (self._name, filename)
                        
            self._txtfile = open(filename, 'w')
            self._txtfile.write("#Timestep [ps] x y z [nm] Ex Ey Ez [N/C]\n")
        elif format == 'netcdf':
            if self._debug: print "%s: Opening NetCDF output file '%s'..." % (self._name, filename)
            
            import netCDF4 as netcdf
            ncfile = netcdf.Dataset(filename, 'w', version='NETCDF4')
        
            # Create dimensions.
            ncfile.createDimension('sample', 0) # unlimited number of samples
            ncfile.createDimension('atom', self._natoms) # number of atoms in system
            ncfile.createDimension('spatial', 3) # number of spatial dimensions

            # Create variables.
            ncvar_indices = ncfile.createVariable('atom_indices', 'i', ('atom',))
            ncvar_time = ncfile.createVariable('time', 'f', ('sample', ))
            ncvar_positions = ncfile.createVariable('position', 'f', ('sample', 'atom','spatial'))
            ncvar_efield = ncfile.createVariable('efield', 'f', ('sample', 'atom','spatial'))
        
            # Define units for variables.
            setattr(ncvar_indices, 'units', 'none')
            setattr(ncvar_time, 'units', 'ps')
            setattr(ncvar_positions, 'units', 'nm')
            setattr(ncvar_efield, 'units', 'megavolts/centimeter')

            # Define long (human-readable) names for variables.
            setattr(ncvar_indices, "long_name", "atom_indices[index] is the System atom index corresponding to the reported E field 'index'.")
            setattr(ncvar_time, "long_name", "timestep[sample] is simulation time of sample 'sample'.")
            setattr(ncvar_positions, "long_name", "positions[sample,atom,spatial] is position of coordinate 'spatial' of atom 'atom' at sample 'sample'.")
            setattr(ncvar_efield, "long_name", "efield[sample,atom,spatial] is 'spatial' component of electric field vector at atom 'atom' at sample 'sample'.")

            # Fill in atom indices.
            ncfile.variables['atom_indices'] = self._atom_indices[:]

            # Sync.
            ncfile.sync()

            self._ncfile = ncfile

        elif format == 'hdf5':
            if self._debug: print "%s: Opening pytables (HDF5) output file '%s'..." % (self._name, filename)
            
            import tables
            filter = tables.Filters(complevel=9, complib='blosc', shuffle=True)
            self._h5file = tables.File(filename, 'a')
            self._data = self._out.createEArray("/", "Data", tables.Float32Atom(), (0,7), filters=Filter)

        else:
            raise Exception("Format '%s' not supported.  Choose one of ['netcdf', 'hdf5']." % format)
        
        # Create modified version of system with only charges.
        self._modified_system = self._createZeroChargeSystem(system)

        # Create Context to use to evaluate forces.
        timestep = 1.0 * units.femtoseconds
        self._integrator = mm.VerletIntegrator(timestep)
        self._context = mm.Context(self._modified_system, self._integrator)

        # Zero sample counter.
        self._sample = 0

        return

    def _extractPartialCharges(self, system):
        """
        Extract partial charges.

        RETURNS
        - charges (simtk.unit.Quantity of numpy array of charges) partial atomic charges

        TODO
        - merge this with _createZeroChargeSystem?

        """
        nparticles = system.getNumParticles()
        charges = units.Quantity(numpy.zeros([nparticles], numpy.float64), units.elementary_charge)

        # Find NonbondedForce.
        nonbonded_force = None
        for force_index in range(system.getNumForces()):
            force = system.getForce(force_index)
            if isinstance(force, mm.NonbondedForce):
                nonbonded_force = force
                break
        if nonbonded_force is None:
            raise Exception("%s: No NonbondedForce force object found in system." % (self._name))
        
        for particle_index in range(nparticles):
            [charge, sigma, epsilon] = force.getParticleParameters(particle_index)            
            charges[particle_index] = charge

        # Modify shape of charges array.
        charges = units.Quantity(numpy.tile(charges / charges.unit, (3,1)).transpose(), charges.unit)

        return charges
        
    def _createZeroChargeSystem(self, reference_system):
        """
        Create a copy of the reference system containing no electrostatic components.

        ARGUMENTS

        reference_system (simtk.openmm.System) - the reference system

        TODO

        * Handle Custom*Force terms

        """

        # Copy the System object.
        import copy
        system = copy.deepcopy(reference_system)

        # Set charges to zero.
        for force_index in range(system.getNumForces()):
            force = system.getForce(force_index)

            if isinstance(force, mm.NonbondedForce):
                # Zero charges.
                for particle_index in range(force.getNumParticles()):
                    [charge, sigma, epsilon] = force.getParticleParameters(particle_index)
                    force.setParticleParameters(particle_index, charge*0.0, sigma, epsilon)
                for exception_index in range(force.getNumExceptions()):
                    [iatom, jatom, chargeprod, sigma, epsilon] = force.getExceptionParameters(exception_index)
                    force.setExceptionParameters(exception_index, iatom, jatom, chargeprod*0.0, sigma, epsilon)

            elif isinstance(force, mm.GBSAOBCForce):
                # Zero charges.                
                for particle_index in range(force.getNumParticles()):
                    [charge, radius, scaling_factor] = force.getParticleParameters(particle_index)
                    force.setParticleParameters(particle_index, charge*0.0, radius, scaling_factor)

            else:                
                # Don't modify force.
                pass
        
        return system
    
    def describeNextReport(self, simulation):
        """
        Get information about the next report this object will generate.
        
        Parameters:
         - simulation (Simulation) The Simulation to generate a report for
        
         Returns: A five element tuple.  The first element is the number of steps until the
         next report.  The remaining elements specify whether that report will require
         positions, velocities, forces, and energies respectively.
         """
        steps = self._reportInterval - simulation.currentStep%self._reportInterval
        return (steps, True, False, True, False)
    
    def report(self, simulation, state):
        """
        Generate a report.
        
        Parameters:
         - simulation (Simulation) The Simulation to generate a report for
         - state (State) The current state of the simulation

        TODO:
        - only perform computation for requested atoms

        """

        if self._debug: print "%s: Generating report..." % (self._name)
            
        # Get current simulation time.
        current_time = simulation.integrator.getStepSize() * simulation.currentStep
        if self._debug: print "current_time = %s" % str(current_time)
        
        # Get positions.
        positions = state.getPositions(asNumpy=True)

        # Get forces on selected atoms
        if self._debug: print "%s: Computing original forces..." % (self._name)
        f0 = state.getForces(asNumpy=True)
        if self._debug: print f0

        # Compute forces for charge-modified system.
        if self._debug: print "%s: Computing zero-charge forces..." % (self._name)
        self._context.setPositions(positions)
        state = self._context.getState(getForces=True)
        f1 = state.getForces(asNumpy=True)
        if self._debug: print f1
        
        # Compute electric field on specified atoms.
        if self._debug: print "%s: Computing electric field vectors..." % (self._name)
        f_elec = f0 - f1 # electrostatic forces on all atoms
        if self._debug: print f_elec
        charges = self._charges
        efield = units.Quantity( (f_elec / f_elec.unit) / (charges / charges.unit), (f_elec.unit/charges.unit)) / units.AVOGADRO_CONSTANT_NA
        if self._debug: print efield
        
        # Get current sample counter.
        sample = self._sample

        # Store positions and electric fields.
        if self._format == 'text':
            # TODO: FIXME
            output_string="%e %e %e %e %e %e %e\n"%(current_time, positions[0], positions[1], positions[2], efield[0], efield[1], efield[2])
            self._txtfile.write(output_string)
            self._txtfile.flush()
            pass
        elif self._format == 'netcdf':
            if self._debug: print current_time / units.picoseconds
            self._ncfile.variables['time'][sample] = current_time / units.picoseconds
            self._ncfile.variables['position'][sample,:,:] = positions[self._atom_indices,:] / units.nanometers
            self._ncfile.variables['efield'][sample,:,:] = efield[self._atom_indices,:] / self._efield_units
            self._ncfile.sync()
            pass
        elif self._format == 'hdf5':
            # TODO: FIXME
            self._data.append([[current_time, positions[0], positions[1], positions[2], efield[0], efield[1], efield[2]]])
            self._h5file.flush()
            pass
        else:
            raise Exception('self._format has changed to an invalid format.')

        # Increment sample counter.
        self._sample += 1

    def __del__(self):
        """
        Clean up safely.

        """

        if self._format == 'text':
            self._txtfile.close()
        elif self._format == 'netcdf':
            self._ncfile.close()
        elif self._format == 'hdf5':
            self._h5file.close()

        # Clean up integrator and context.    
        del self._context, self._integrator
        del self._system


            
