"""
Output simulation trajectory state data to a NetCDF file.

DESCRIPTION

This Reporter outputs state data from a simulation to a NetCDF file, roughly following the
Amber NetCDF trajectory convention (so that trajectory files can be read with AmberTools utilities)
but with additional information about energies, volumes, etc.

AUTHORS

John D. Chodera 2012-08-03

TODO

"""
__author__ = "John D. Chodera"
__version__ = "1.0"

import simtk.openmm as mm
import simtk.unit as units

import time
import numpy
    
class NetCDFReporter(object):
    """
    Periodically output state data from a simulation to a NetCDF file for use with analyzing or resuming simulations.
    
    NOTES

    This reporter requires transfer of coordinates and forces back and forth to the GPU each report interval.
    This can significantly slow down simulations at high reporting frequencies.
    The resulting output file complies with the AMBER NetCDF Convention 1.0: http://ambermd.org/netcdf/nctraj.html

    EXAMPLES

    Run a simulation, periodically writing state data to a NetCDF file.
    
    >>> # Load a test system using the OpenMM app.
    >>> import simtk.openmm as openmm
    >>> import simtk.openmm.app as app
    >>> inpcrd = app.AmberInpcrdFile('bosutinib.crd')
    >>> prmtop = app.AmberPrmtopFile('bosutinib.prmtop')
    >>> system = prmtop.createSystem(nonbondedMethod=app.CutoffPeriodic, constraints=app.HBonds)
    >>> positions = inpcrd.getPositions()
    >>> topology = prmtop.topology
    >>> # Create a Simulation object using the system, positions, and topology.
    >>> temperature = 300.0 * units.kelvin
    >>> collision_rate = 9.1 / units.picosecond
    >>> timestep = 2.0 * units.femtosecond
    >>> integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep)
    >>> simulation = app.Simulation(topology, system, integrator)
    >>> simulation.context.setPositions(positions)
    >>> # Add a NetCDF state data reporter to write positions but not velocities.
    >>> filename = 'trajectory.nc' # NetCDF output filename
    >>> reportInterval = 10 # write interval, in steps
    >>> simulation.reporters.append( NetCDFReporter(system, filename, reportInterval, writePositions=True, writeVelocities=False, writeEnergies=True) )
    >>> # Run the simulation.
    >>> simulation.step(100)
    >>> # Clean up.
    >>> del simulation, integrator

    """
    
    def __init__(self, system, filename, reportInterval, debug=False, append=False, writePositions=True, writeVelocities=False, writeForces=False, writeEnergies=True, title='', amberCompatible=True, enableCompression=False):
        """
        Create a NetCDFReporter.
    
        Parameters:
         - system (simtk.openmm.System) - the System about which information will be written
         - filename (string) - the filename of the file to write to
         - reportInterval (int) - the interval (in time steps) at which to write frames
         - debug (boolean=False) - if True, debug information is written to the terminal
         - append (boolean=False) - if True, the output file will be appended to, if it exists
         - writePositions (boolean=True) - if True, positions will be written to the file
         - writeVelocities (boolean=False) - if True, velociites will be written to the file
         - writeForces (boolean=False) - if True, forces will be written to the file
         - writeEnergies (boolean=True) - if True, energies will be written to the file
         - title (string='') - title for trajectory
         - amberCompatible (boolean=True) - if False, will use NETCDF4; if True, will use NETCDF3 64-bit for AMBER Trajectory Convention compatibility
         - enableCompression (boolean=False) - if True, enables automatic zlib compression of the NetCDF file if available

        """

        # Store arguments.
        self._name = 'NetCDFReporter'
        self._filename = filename
        self._reportInterval = reportInterval # interval at which Efield is computed
        self._debug = debug
        self._natoms = system.getNumParticles()
        self._writePositions = writePositions
        self._writeVelocities = writeVelocities
        self._writeForces = writeForces
        self._writeEnergies = writeEnergies

        # Set defaults for unit reporting.
        # These are AMBER defaults; changing them will break the AMBER NetCDF convention adherence.
        self.time_unit = units.picoseconds
        self.energy_unit = units.kilocalories_per_mole
        self.position_unit = units.angstrom
        self.angular_unit = units.degrees
        self.velocity_unit = units.angstrom / units.picosecond
        self.force_unit = units.kilojoules_per_mole / units.nanometer
        
        # Set file version.
        self._netcdf_format = 'NETCDF4' # Use NETCDF4 format by default
        if amberCompatible:
            self._netcdf_format = 'NETCDF3_64BIT' # AMBER Trajectory Convention requires NETCDF3 64-bit

        # Set file mode
        import os.path
        file_mode = 'a' if (append and os.path.exists(filename)) else 'w'

        # Open file for writing, attempting to find a reader that works.
        if self._debug: print "%s: Opening NetCDF output file '%s'..." % (self._name, filename)
        self._netcdf_wrapper = None 
        # Try netcdf4-python.
        try:
            import netCDF4 as netcdf                    
            ncfile = netcdf.Dataset(filename, file_mode, format=self._netcdf_format, zlib=enableCompression)
            self._netcdf_wrapper = 'netCDF4' # Store information about which wrapper is in use
        except Exception as e:
            if self._debug: print "netcdf4-python not found."
        # Try Scientific.IO.NetCDF
        # TODO: This interface does not work yet because of difficulties with variable assignment.
        #try:
        #    import Scientific.IO.NetCDF as netcdf
        #    ncfile = netcdf.NetCDFFile(filename, mode=file_mode, history=None) 
        #    self._netcdf_wrapper = 'Scientific.IO.NetCDF'
        #except Exception as e:
        #    if self._debug: print "Scientific.IO.NetCDF not found."
        # Check for success.
        if self._netcdf_wrapper is None:
            message = "No NetCDF Python wrapper could be imported.\n"
            message += "A NetCDF Python wrapper is required to use NetCDFReporter.\n"
            message += "Please install a NetCDF4 wrapper such as:\n"
            message += " * netcdf4-python - available at http://code.google.com/p/netcdf4-python/\n"
            message += "   or as part of the Enthought Python Distribution: http://www.enthought.com/products/epd.php"            
            raise Exception(message)            

        # Write NetCDF header.
        if file_mode == 'w':
            # Set global attributes to comply with AMBER NetCDF convention.
            setattr(ncfile, 'title', title)
            setattr(ncfile, 'application', 'OpenMM app')
            setattr(ncfile, 'program', 'NetCDFReporter');
            setattr(ncfile, 'programVersion', mm.__version__)
            setattr(ncfile, 'Conventions', 'AMBER')
            setattr(ncfile, 'ConventionVersion', '1.0')

            # Create dimensions.
            ncfile.createDimension('frame', 0) # unlimited number of samples
            ncfile.createDimension('spatial', 3) # number of spatial dimensions
            ncfile.createDimension('atom', self._natoms) # number of spatial dimensions
            ncfile.createDimension('cell_spatial', 3) # lengths that define size of unit cell
            ncfile.createDimension('cell_angular', 3) # angles that define shape of unit cell

            # Create timestamp variable for storing when frames were written.
            if self._netcdf_format == 'NETCDF4':
                ncfile.createVariable('timestamp', str, ('frame',))
            
            # Create label variables.
            if self._netcdf_format == 'NETCDF3_64BIT':
                ncfile.createDimension('label', 5) # TODO: There is an issue with defining label variables, so this code is commented out.
                ncvar = ncfile.createVariable('spatial', 'c', ('spatial',))
                ncvar[:] = ['x', 'y', 'z']
                ncvar = ncfile.createVariable('cell_spatial', 'c', ('spatial',))
                ncvar[:] = ['a', 'b', 'c']
                ncvar = ncfile.createVariable('cell_angular', 'c', ('cell_angular', 'label'))
                ncvar[0,0:5] = 'alpha'
                ncvar[1,0:4] = 'beta'
                ncvar[2,0:5] = 'gamma'
    
            # Create variables, defining units and human-readable names.
            ncvar = ncfile.createVariable('step', 'i4', ('frame', ))
            setattr(ncvar, "long_name", "step[sample] is number of steps that have elapsed at sample 'sample'.")
            setattr(ncvar, 'units', 'none')
            setattr(ncvar, 'simtk_units', 'dimensionless')

            ncvar = ncfile.createVariable('time', 'f4', ('frame', ))
            setattr(ncvar, "long_name", "time[sample] is simulation time of sample 'sample'.")
            setattr(ncvar, 'units', str(self.time_unit))
            setattr(ncvar, 'simtk_units', repr(self.time_unit))

            ncvar = ncfile.createVariable('box_vectors', 'f4', ('frame','spatial','spatial'))
            setattr(ncvar, "long_name", "box_vectors[i,j] is component j of box vector i")
            setattr(ncvar, 'units', str(self.position_unit))
            setattr(ncvar, 'simtk_units', repr(self.position_unit))
            
            ncvar = ncfile.createVariable('cell_lengths', 'f4', ('frame','spatial'))
            setattr(ncvar, 'long_name', "cell_lengths[i] is the simulation cell length along dimension i")
            setattr(ncvar, 'units', str(self.position_unit))
            setattr(ncvar, 'simtk_units', repr(self.position_unit))            

            ncvar = ncfile.createVariable('cell_angles', 'f4', ('frame','spatial'))
            setattr(ncvar, 'long_name', "cell_angles[i] is simulation cell angle alpha (i=0), beta (i=1), or gamma (i=2)")
            setattr(ncvar, 'units', str(self.angular_unit))
            setattr(ncvar, 'simtk_units', repr(self.angular_unit))
            
            if self._writePositions: 
                ncvar = ncfile.createVariable('coordinates', 'f4', ('frame','atom','spatial'))
                setattr(ncvar, 'long_name', "coordinates[frame,atom,spatial] is coordinate 'spatial' of atom 'atom' from simulation frame 'frame'")
                setattr(ncvar, 'units', str(self.position_unit))
                setattr(ncvar, 'simtk_units', repr(self.position_unit))
            if self._writeVelocities: 
                ncvar = ncfile.createVariable('velocities', 'f4', ('frame','atom','spatial'))
                setattr(ncvar, 'long_name', "velocities[frame,atom,spatial] is velocity component 'spatial' of atom 'atom' from simulation frame 'frame'")
                setattr(ncvar, 'units', str(self.velocity_unit))
                setattr(ncvar, 'simtk_units', repr(self.velocity_unit))
            if self._writeForces: 
                ncvar = ncfile.createVariable('forces', 'f4', ('frame','atom','spatial'))
                setattr(ncvar, 'long_name', "forces[frame,atom,spatial] is force component 'spatial' of atom 'atom' from simulation frame 'frame'")
                setattr(ncvar, 'units', str(self.force_unit))
                setattr(ncvar, 'simtk_units', repr(self.force_unit))
            if self._writeEnergies:
                ncvar = ncfile.createVariable('potential_energy', 'f4', ('frame', ))
                setattr(ncvar, 'long_name', "potential_energy[frame] is the potential energy from frame 'frame'")
                setattr(ncvar, 'units', str(self.energy_unit))
                setattr(ncvar, 'simtk_units', repr(self.energy_unit))

                ncvar = ncfile.createVariable('kinetic_energy', 'f4', ('frame', ))
                setattr(ncvar, 'long_name', "kinetic_energy[frame] is the kinetic energy from frame 'frame'")
                setattr(ncvar, 'units', str(self.energy_unit))
                setattr(ncvar, 'simtk_units', repr(self.energy_unit))

                ncvar = ncfile.createVariable('total_energy', 'f4', ('frame', ))
                setattr(ncvar, 'long_name', "potential_energy[frame] is the total (kinetic + potential) energy from frame 'frame'")
                setattr(ncvar, 'units', str(self.energy_unit))
                setattr(ncvar, 'simtk_units', repr(self.energy_unit))

            # Sync.
            ncfile.sync()

            # Zero sample counter.
            self._frame = 0
        else:
            # Determine what sample count to resume from.
            self._frame = len(ncfile.variables['step'])

            # TODO: Check that all data for last frame is consistent.

        self._ncfile = ncfile

        return

    def describeNextReport(self, simulation):
        """
        Get information about the next report this object will generate.
        
        Parameters:
         - simulation (Simulation) The Simulation to generate a report for
        
         Returns: A five element tuple.  The first element is the number of steps until the
         next report.  The remaining elements specify whether that report will require
         positions, velocities, forces, and energies respectively.
         """
        steps = self._reportInterval - simulation.currentStep%self._reportInterval
        return (steps, self._writePositions, self._writeVelocities, self._writeForces, self._writeEnergies)
    
    def report(self, simulation, state):
        """
        Generate a report.
        
        Parameters:
         - simulation (Simulation) The Simulation to generate a report for
         - state (State) The current state of the simulation

        TODO:
        - only perform computation for requested atoms

        """

        if self._debug: print "%s: Generating report..." % (self._name)

        ncfile = self._ncfile # NetCDF file handle
        frame = self._frame # current frame index

        # Store timestamp of report.
        if self._netcdf_format == 'NETCDF4':
            ncfile.variables['timestamp'][self._frame] = time.ctime()  
            
        # Store simulation step and time.
        current_time = simulation.integrator.getStepSize() * simulation.currentStep
        if self._debug: print "current_time = %s" % str(current_time)
        ncfile.variables['time'][frame] = current_time / self.time_unit
        ncfile.variables['step'][frame] = simulation.currentStep

        # Store box information.
        # TODO: Generalize calculation of cell_lengths and cell_angles based from box_vectors.
        box_vectors = state.getPeriodicBoxVectors(asNumpy=True)
        for i in range(3):                                                                                                                                                                                                      
            ncfile.variables['box_vectors'][frame,i,:] = box_vectors[i] / self.position_unit
            ncfile.variables['cell_lengths'][frame,i] = box_vectors[i][i] / self.position_unit
            ncfile.variables['cell_angles'][frame,i] = (90.0 * units.degrees / self.angular_unit) 
        
        if self._writePositions:
            positions = state.getPositions(asNumpy=True)
            ncfile.variables['coordinates'][frame,:,:] = positions[:,:] / self.position_unit
        
        if self._writeVelocities:
            velocities = state.getVelocities(asNumpy=True)
            ncfile.variables['velocities'][frame,:,:] = velocities[:,:] / self.velocity_unit

        if self._writeForces:
            forces = state.getForces(asNumpy=True)
            ncfile.variables['forces'][frame,:,:] = forces[:,:] / self.force_unit

        if self._writeEnergies:
            potential_energy = state.getPotentialEnergy() 
            kinetic_energy = state.getKineticEnergy()
            ncfile.variables['potential_energy'][frame] = potential_energy / self.energy_unit
            ncfile.variables['kinetic_energy'][frame] = kinetic_energy / self.energy_unit
            ncfile.variables['total_energy'][frame] = (potential_energy + kinetic_energy) / self.energy_unit

        # Make sure data is safely written to disk.
        ncfile.sync()

        # Increment frame counter.
        self._frame += 1

    def __del__(self):
        """
        Clean up safely.

        """

        # Close any open files.
        if hasattr(self, '_ncfile'):
            self._ncfile.close()

    @classmethod
    def getLastFrame(cls, filename):
        """
        ARGUMENTS
        
        filename (string) - filename of the NetCDF file to try to retrieve last frame from
        
        RETURNS
        
        positions - positions from last frame
        box_vectors - box vectors from last frame
        last_step (int) - last step index
        
        """

        position_unit = units.angstroms # TODO: Have this read from file.
        
        import netCDF4 as netcdf
        ncfile = netcdf.Dataset(filename, 'r')
        positions = units.Quantity(numpy.array(ncfile.variables['coordinates'][-1,:,:]), position_unit)
        box_vectors = units.Quantity(numpy.array(ncfile.variables['box_vectors'][-1,:,:]), position_unit)
        last_step = int(ncfile.variables['step'][-1])
        ncfile.close()
        return [positions, box_vectors, last_step]

if __name__ == '__main__':
    import doctest
    doctest.testmod()
