import mathutilities
import ReducedHessAngle
from PropagatorFactory import *
import time
import Constants
import topologyutilities
import numpy


class Propagator:
   def __init__(self, phys, forces, io):
      #####################################################################################
      # USER-ACCESSIBLE STRUCTURES
      self.myPropagator = 0                  #: PROPAGATOR OBJECT (0 IF METHOD)
      self.myStep = 0                        #: CURRENT SIMULATION STEP
      self.myTimestep = 0                    #: PROPAGATOR TIMESTEP (fs)
      self.myLevel = 0                       #: CURRENT PROPAGATOR HIERARCHY LEVEL (0 IS STS)
      #####################################################################################

      phys.build()

      self.phys = phys
      self.forces = forces
      self.io = io
      io.phys = phys
      
   def reset(self):
      """
      Reset the state of the Propagator object.
      """
      self.myPropagator = 0
      self.myStep = 0
      self.myTimestep = 0
      self.myLevel = 0

   def isMDL(self, integ):
      """
      Determine whether or not a propagator has been coded in MDL.

      @type integ: STS/MTS
      @param integ: MDL propagator object (STS or MTS)

      @rtype: boolean
      @return: True if the passed propagator is coded using MDL.
      """
      return (hasattr(integ, "prerunmodifiers"))


   def addPreInitModifier(self, integ, modifier):
      """
      Add a modifier to execute before propagator initialization.

      @type integ: STS/MTS
      @param integ: MDL propagator object (STS or MTS)

      @type modifier: function
      @param modifier: Routine which alters propagator behavior.
      """
      integ.preinitmodifiers.append(modifier)


   def addPostInitModifier(self, integ, modifier):
      """
      Add a modifier to execute after propagator initialization.

      @type integ: STS/MTS
      @param integ: MDL propagator object (STS or MTS)

      @type modifier: function
      @param modifier: Routine which alters propagator behavior.
      """
      integ.postinitmodifiers.append(modifier)


   def addPreRunModifier(self, integ, modifier):
      """
      Add a modifier to execute before propagator execution.

      @type integ: STS/MTS
      @param integ: MDL propagator object (STS or MTS)

      @type modifier: function
      @param modifier: Routine which alters propagator behavior.
      """
      integ.prerunmodifiers.append(modifier)


   def addPostRunModifier(self, integ, modifier):
      """
      Add a modifier to execute after propagator execution.

      @type integ: STS/MTS
      @param integ: MDL propagator object (STS or MTS)

      @type modifier: function
      @param modifier: Routine which alters propagator behavior.
      """
      integ.postrunmodifiers.append(modifier)



   def addPreForceModifier(self, integ, modifier):
      """
      Add a modifier to execute before force calculation.

      @type integ: STS/MTS
      @param integ: MDL propagator object (STS or MTS)

      @type modifier: function
      @param modifier: Routine which alters propagator behavior.
      """
      integ.preforcemodifiers.append(modifier)


   def addPostForceModifier(self, integ, modifier):
      """
      Add a modifier to execute after force calculation.

      @type integ: STS/MTS
      @param integ: MDL propagator object (STS or MTS)

      @type modifier: function
      @param modifier: Routine which alters propagator behavior.
      """
      integ.postforcemodifiers.append(modifier)


   # RUN A LIST OF PASSED MODIFIERS ON THE PASSED PROPAGATOR
   def runModifiers(self, modifiers, phys, forces, prop, integ):
      """
      Run modifiers of a propagator

      @type modifier: list of functions
      @param modifier: A set of routines which alternates propagator behavior
      
      @type integ: object
      @param integ: MDL propagator object (STS or MTS)
      """
      #integ.postf
      for ii in range(0, modifiers.__len__()):
         modifiers[ii](phys, forces, prop, integ)

   def timestep(self, integ):
      """
      Return the timestep of a propagator, scaled accordingly

      @type integ: object
      @param integ: MDL propagator object (STS or MTS)

      @rtype: float
      @return: The timestep (dt) of a propagator
      """
      return integ.getTimestep() * Constants.invTimeFactor()

   def calculateForces(self, forces):
        """
        Calculate forces and update the atomic force vector.
      
        @type forces: Forces
        @param integ: MDL Forces object
        """
        for ii in range(0, self.myPropagator.preforcemodifiers.__len__()):
           self.myPropagator.preforcemodifiers[ii](self.phys, forces, self, self.myPropagator)
        self.myPropagator.calculateForces()
        forces.forcevec = self.myPropagator.myForces
        for ii in range(0, self.myPropagator.postforcemodifiers.__len__()):
           self.myPropagator.postforcemodifiers[ii](self.phys, forces, self, self.myPropagator)

   def initNext(self, phys, forces):
      """
      For multiple timestepping, initialize the next propagator
      in the chain.

      @type phys: Physical
      @param integ: MDL Physical object
      
      @type forces: Forces
      @param integ: MDL Forces object

      """
      tempI = self.myPropagator
      self.myLevel += 1
      setPropagator(self, phys, forces, self.myPropagator.next, levelswitch=True)
      self.myPropagator = tempI
      self.myLevel -= 1

   def runNext(self, phys, forces, cL):
      """
      For multiple timestepping, execute the next propagator
      in the chain.

      @type phys: Physical
      @param integ: MDL Physical object
      
      @type forces: Forces
      @param integ: MDL Forces object

      @type io: IO
      @param integ: MDL IO object

      @type cL: integer
      @param cL: Cycle length (number of times to execute the
                 inner propagator)
      """
      tempI = self.myPropagator
      self.myPropagator = self.myPropagator.next
      self.myLevel = self.myLevel + 1
      executePropagator(self, phys, forces, self.io, cL)
      self.myLevel = self.myLevel - 1
      self.myPropagator = tempI


   def finishNext(self, phys, forces, prop):
      """
      For multiple timestepping, finish the next propagator
      in the chain.
      """
      tempI = self.myPropagator
      self.myPropagator = self.myPropagator.next
      if (self.isMDL(self.myPropagator)):
         self.myPropagator.finish(phys, forces, prop)
      self.myPropagator = tempI

   # PROPAGATE THE SYSTEM
   # USE METHOD "name"
   # arg1 = NUMBER OF STEPS
   # arg2 = TIMESTEP
   # arg3 = ForceField STRUCTURE
   # args = OPTIONAL EXTRA ARGUMENTS AS TUPLES
   def propagate(self, scheme="Leapfrog", steps=0, cyclelength=-1, dt=0.1, forcefield=[], params={}):
       """
       Propagate the system.

       @type name: string
       @param name: Name of the propagator to use.
       
       @type steps: integer
       @param steps: Number of steps for execution.

       @type dt: float
       @param dt: Timestep.

       @type ff: ForceField
       @param ff: MDL ForceField object.

       @type *args: tuple
       @param *args: Extra parameters unique for this propagation scheme.
                     (This could be empty).

       """
       self.myTimestep = dt
       chain = ()
       if (cyclelength != -1):  # MTS
          if (str(type(cyclelength))[7:11] == 'list'): # LIST, MANY LEVELS
             levels = len(cyclelength) + 1
             outertime = cyclelength[0]
          else: # ONE CYCLELENGTH = 2 LEVELS OF PROPAGATION
             levels = 2
             outertime = cyclelength

          if (str(type(scheme))[7:11] == 'list'): # LIST, MANY LEVELS
             outerscheme = scheme[0]
          else: # ONE CYCLELENGTH = 2 LEVELS OF PROPAGATION
             outerscheme = scheme


          # THE NUMBER OF FORCEFIELDS PROVIDED MUST EQUAL THE NUMBER
          # OF PROPAGATION LEVELS
          if (len(forcefield) != levels):
             print "[MDL] Error in propagate(): ", levels, " levels of propagation with ", len(forcefield), " force fields."
          outerforcefield = forcefield[0]

          if (str(type(scheme))[7:11] != 'list'):
             chain += (params,)
          else:
             if (params.has_key(outerscheme)):
                 chain += (params[outerscheme],)
             else:
                 chain += ({},)
	  for i in range(1, levels):
             if (str(type(scheme))[7:11] == 'list' and i < len(scheme)):
                chain += (scheme[i],)
             if (str(type(cyclelength))[7:11] == 'list' and i < len(cyclelength)):
                chain += (cyclelength[i],)
             else:
                chain += (dt,)
	     chain += (forcefield[i],)
             if params.has_key(scheme[i]):
                chain += (params[scheme[i]],)
             else:
                chain += ({},)
       else: #STS
          outertime = dt
          outerscheme = scheme
          outerforcefield = forcefield
          chain += (params,)
       if (self.forces.dirty()):
          self.forces.build()
       if (self.io.dirty):
          self.io.build()
       if (propFactory.getType(outerscheme) == "method"):
          # Calculate the forces, store them in force.
          outerforcefield.calculateForces(self.phys, self.forces)
          self.phys.updateCOM_Momenta()
          self.io.run(self.phys, self.forces, 0, outertime)
          self.io.myProp = self
          for ii in range(1, steps+1):
             propFactory.create(outerscheme, self.phys, self.forces, self.io, 1, outertime*Constants.invTimeFactor(), outerforcefield, *chain)
             #if (ii % 1000 == 1):
             #   print (ii-1)/1000, " ps"
             self.phys.time = ii*outertime*Constants.invTimeFactor()
             self.io.run(self.phys, self.forces, ii, outertime)
             self.phys.updateCOM_Momenta()
       else: # Object
          setPropagator(self, self.phys, self.forces, propFactory.applyModifiers(propFactory.create(outerscheme, outertime, outerforcefield, *chain), outerscheme))
          shake = False
          if (params.has_key('shake') and params['shake'] == 'on'):
              shake = True
              shakeMod = self.myPropagator.createShakeModifier(0.00001, 30)
              self.myPropagator.adoptPostDriftOrNextModifier(shakeMod)
          rattle = False
          if (params.has_key('rattle') and params['rattle'] == 'on'):
              rattle = True
              rattleMod = self.myPropagator.createRattleModifier(0.02, 30)
              self.myPropagator.adoptPostDriftOrNextModifier(rattleMod)
          executePropagator(self, self.phys, self.forces, self.io, steps)
          if (shake):
             self.myPropagator.removeModifier(shakeMod)
          if (rattle):
             self.myPropagator.removeModifier(rattleMod)
       self.phys.updateCOM_Momenta()


           
