import warnings
warnings.filterwarnings(action='ignore',
                        message='.*has API version.*',
                        category=RuntimeWarning)

import Constants
import Vector3DBlock
import PARReader
import PSFReader
import PDBReader
import EigenvectorReader
import _mathutilities
import topologyutilities
import Topology
import _topologyutilities
import _Topology
import numpy
import sys

class Atom:
   """
   An atom in the system.
   """
   def __init__(self, n, s, rs, rn, an, at, c, m):
      self.number = n   #: Atom number
      self.seg_id = s   #: Segment identifier
      self.residue_sequence = rs  #: Residue sequence
      self.residue_name = rn     #: Residue name
      self.atom_name = an        #: Atom name
      self.atom_type = at        #: Atom type
      self.charge = c            #: Charge [e]
      self.mass = m              #: Mass [amu]

class Bond:
   """
   A two-atom bond.
   """
   def __init__(self, n, a1, a2):
      self.number = n    #: Bond number
      self.atom1 = a1    #: Atom index 1
      self.atom2 = a2    #: Atom index 2

class Angle:
   """
   A three-atom angle.
   """
   def __init__(self, n, a1, a2, a3):
      self.number = n    #: Bond number
      self.atom1 = a1    #: Atom index 1
      self.atom2 = a2    #: Atom index 2
      self.atom3 = a3    #: Atom index 3

class Dihedral:
   """
   A four-atom dihedral.
   """
   def __init__(self, n, a1, a2, a3, a4):
      self.number = n    #: Dihedral number
      self.atom1 = a1    #: Atom index 1
      self.atom2 = a2    #: Atom index 2
      self.atom3 = a3    #: Atom index 3
      self.atom4 = a4    #: Atom index 4

class Improper:
   """
   A four-atom improper.
   """
   def __init__(self, n, a1, a2, a3, a4):
      self.number = n    #: Bond number
      self.atom1 = a1    #: Atom index 1
      self.atom2 = a2    #: Atom index 2
      self.atom3 = a3    #: Atom index 3
      self.atom4 = a4    #: Atom index 4

class HDonor:
   """
   An H+ donor
   """
   def __init__(self, n, a1, a2):
      self.number = n    #: Donor number
      self.atom1 = a1    #: Atom index 1
      self.atom2 = a2    #: Atom index 2

class HAcceptor:
   """
   An H+ acceptor
   """
   def __init__(self, n, a1, a2):
      self.number = n    #: Acceptor number
      self.atom1 = a1    #: Atom index 1
      self.atom2 = a2    #: Atom index 2

      

class Physical:
   """
   Defines a physical system: positions, velocities, temperature,
   boundary conditions, etc.
   """
   def __init__(self):
      #####################################################################
      # USER-ACCESSIBLE STRUCTURES
      self.seed = 1234               #: Random number seed
      self.exclude = "1-4"           #: Exclusion Pairs
      self.cellsize = 6.5            #: Cell size
      self.bc = "Periodic"           #: Boundary conditions
      self.remcom = "yes"            #: Remove COM motion?
      self.remang = "yes"            #: Remove angular momentum?
      self.defaultCBV = True         #: Use default cell basis vectors (PBC)
      self.time = 0                  #: Current time
      self.cB1 = numpy.ndarray(3)    #: Cell basis vector 1
      self.cB1.fill(0)
      self.cB2 = numpy.ndarray(3)    #: Cell basis vector 2
      self.cB2.fill(0)
      self.cB3 = numpy.ndarray(3)    #: Cell basis vector 3
      self.cB3.fill(0)
      self.cO = numpy.ndarray(3)     #: Cell origin
      self.cO.fill(0)
      self.temperature = 300         #: Kelvin temperature
      self.masses = numpy.ndarray(0)      #: Diagonal mass matrix
      self.invmasses = numpy.ndarray(0)   #: Diagonal inverse mass matrix
      self.masssum = 0                    #: Sum over all atomic masses
      # NOTE: self.positions and self.velocities are also
      # available numpy arrays.
      #####################################################################

      ############################################
      # NOTE: USER SHOULD NOT TOUCH THESE!
      # THESE ARE NUMPY ARRAY WRAPPERS
      # USER SHOULD ACCESS THROUGH
      # self.positions and self.velocities
      self.__dict__['myTop'] = Topology.T_Periodic()
      self.__dict__['posvec'] = Vector3DBlock.Vector3DBlock()
      self.__dict__['velvec'] = Vector3DBlock.Vector3DBlock()
      self.__dict__['myPAR'] = PARReader.PAR()
      self.__dict__['myPSF'] = PSFReader.PSF()
      self.__dict__['myPDB'] = PDBReader.PDB()
      self.__dict__['myEig'] = EigenvectorReader.EigenvectorInfo()
      ############################################

      #self.positions = numpy.ndarray(0)   #: Atomic position vector
      #self.velocities = numpy.ndarray(0)  #: Atomic velocity vector

      self.dirty = 1   #: Dirty bit


   # Copy which avoids object assignment
   def copy(self):
      """
      Perform a deep copy, avoid reference assignment
      """
      retval = Physical()
      retval.seed = self.seed
      retval.exclude = self.exclude
      retval.cellsize = self.cellsize
      retval.bc = self.bc
      retval.remcom = self.remcom
      retval.remang = self.remang
      retval.defaultCBV = self.defaultCBV
      if (retval.bc == "Periodic"):
         retval.myTop = Topology.T_Periodic()
      else:
         retval.myTop = Topology.T_Vacuum()
      retval.cB1 = self.cB1.copy()
      retval.cB2 = self.cB2.copy()
      retval.cB3 = self.cB3.copy()
      retval.cO = self.cO.copy()
      retval.temperature = self.temperature
      retval.masses = self.masses.copy()
      retval.invmasses = self.invmasses.copy()
      retval.masssum = self.masssum
      retval.posvec.resize(self.posvec.size())
      for i in range(len(self.positions)):
         retval.positions[i] = self.positions[i]
      retval.velvec.resize(self.velvec.size())
      for i in range(len(self.velocities)):
         retval.velocities[i] = self.velocities[i]
      retval.myPAR = self.myPAR
      retval.myPSF = self.myPSF
      retval.myPDB = self.myPDB
      retval.myEig = self.myEig
      retval.build()
      return retval
      
   # SPECIAL ACCESSOR FOR self.positions or self.velocities
   # TO GET DATA FROM WRAPPERS
   def __getattr__(self, name):
      if (name == 'positions'):
         return self.__dict__['posvec'].getData()
      elif (name == 'velocities'):
         return self.__dict__['velvec'].getData()
      elif (name == 'time'):
         return self.myTop.time
      else:
         return self.__dict__[name]

   # SPECIAL ASSIGNMENT FOR self.positions or self.velocities
   # TO SET DATA IN WRAPPERS
   def __setattr__(self, name, val):
      firsttime = False
      if (not self.__dict__.has_key(name)):
         firsttime = True
      if (name == 'positions' and not firsttime):
         self.__dict__['posvec'].setData(val)
      elif (name == 'velocities' and not firsttime):
         self.__dict__['velvec'].setData(val)
      elif (name == 'bc'):
         self.__dict__['bc'] = val         
         if (val == "Vacuum"):
           self.myTop = Topology.T_Vacuum()
         else:
           self.myTop = Topology.T_Periodic()
         if (not firsttime):
            self.build()
      elif (name == 'time'):
         val /= Constants.invTimeFactor()
         self.myTop.time = val
      elif (name == 'seed' or
            name == 'exclude' or
            name == 'cellsize' or
            name == 'remcom' or
            name == 'remang'):
         self.__dict__[name] = val
         if (not firsttime):
            self.build()
      else:
         self.__dict__[name] = val

   # RESETS SIMULATION STATE TO DEFAULTS
   # ALSO RESETS TIME TO ZERO
   def reset(self):
      """
      Reset all member variables to default values
      """
      self.__dict__['seed'] = 1234               # Random number seed
      self.__dict__['exclude'] = "1-4"           # Exclusion Pairs
      self.__dict__['cellsize'] = 6.5            # Cell size
      self.__dict__['bc'] = "Periodic"           # Boundary conditions
      self.__dict__['remcom'] = "yes"            # Remove COM motion?
      self.__dict__['remang'] = "yes"            # Remove angular momentum?
      self.__dict__['defaultCBV'] = True         # Use default cell basis vectors (PBC)
      self.__dict__['myTop'] = Topology.T_Periodic()
      self.time = 0
      self.cB1.fill(0)
      self.cB2.fill(0)
      self.cB3.fill(0)
      self.cO.fill(0)
      self.__dict__['temperature'] = 300         # Kelvin temperature

   # SYSTEM PRESSURE.
   def pressure(self, forces):
      """
      Pressure of the system.

      @type forces: Forces
      @param forces: MDL Forces object

      @rtype: float
      @return: System pressure
      """
      return forces.energies.pressure(self.myTop.getVolume(self.posvec))

   # SYSTEM VOLUME (AA^3)
   def volume(self):
      """
      Volume of the system.

      @rtype: float
      @return: System volume
      """
      return self.myTop.getVolume(positions)


   # SIZE OF THE SYSTEM
   def N(self):
      """
      Number of atoms.

      @rtype: int
      @return: Number of atoms.
      """
      return self.numAtoms()
   
   def numAtoms(self):
      """
      Number of atoms.

      @rtype: int
      @return: Number of atoms.
      """
      return self.myPSF.numAtoms()

   # NUMBER OF TWO-ATOM BONDS.
   def numBonds(self):
      """
      Number of bonds.

      @rtype: int
      @return: Number of bonds
      """
      return self.myPSF.numBonds()

   # NUMBER OF THREE-ATOM ANGLES.
   def numAngles(self):
      """
      Number of angles

      @rtype: int
      @return: Number of angles
      """
      return self.myPSF.numAngles()

   # NUMBER OF FOUR-ATOM DIHEDRALS.
   def numDihedrals(self):
      """
      Number of dihedrals

      @rtype: int
      @return: Number of dihedrals
      """      
      return self.myPSF.numDihedrals()

   # NUMBER OF FOUR-ATOM IMPROPERS.
   def numImpropers(self):
      """
      Number of impropers

      @rtype: int
      @return: Number of impropers
      """
      return self.myPSF.numImpropers()

   # NUMBER OF H-BOND DONORS.
   def numDonors(self):
      """
      Number of hydrogen donors (for H+ bonding)

      @rtype: int
      @return: Number of hydrogen donors
      """
      return self.myPSF.numDonors()

   # NUMBER OF H-BOND ACCEPTORS.
   def numAcceptors(self):
      """
      Number of hydrogen acceptors (for H+ bonding)

      @rtype: int
      @return: Number of hydrogen acceptors
      """
      return self.myPSF.numAcceptors()

   # GET ATOM #(index)
   def atom(self, index):
      """
      Get an atom at the passed index.

      @type index: int
      @param index: Atom index (1 to N)

      @rtype: Atom
      @return: The atom at the passed index.
      """
      a = self.myPSF.getAtom(index-1)
      return Atom(a.number,a.seg_id,a.residue_sequence,a.residue_name,a.atom_name,a.atom_type,a.charge,a.mass)

   # GET BOND #(index)
   def bond(self, index):
      """
      Get a bond at the passed index.

      @type index: int
      @param index: Bond index

      @rtype: Bond
      @return: The bond at the passed index.
      """
      b = self.myPSF.getBond(index-1)
      return Bond(b.number, b.atom1, b.atom2)

   # GET ANGLE #(index)
   def ang(self, index):
      """
      Get an angle at the passed index.

      @type index: int
      @param index: Angle index

      @rtype: Angle
      @return: The angle at the passed index.
      """
      a = self.myPSF.getAngle(index-1)
      return Angle(a.number, a.atom1, a.atom2, a.atom3)

   # GET DIHEDRAL #(index)
   def dihedral(self, index):
      """
      Get a dihedral at the passed index.

      @type index: int
      @param index: Dihedral index

      @rtype: Dihedral
      @return: The dihedral at the passed index.
      """
      d = self.myPSF.getDihedral(index-1)
      return Dihedral(d.number, d.atom1, d.atom2, d.atom3, d.atom4)

   # GET IMPROPER #(index)
   def improper(self, index):
      """
      Get an improper at the passed index.

      @type index: int
      @param index: Improper index

      @rtype: Improper
      @return: The improper at the passed index.
      """
      i = self.myPSF.getImproper(index-1)
      return Improper(i.number, i.atom1, i.atom2, i.atom3, i.atom4)

   # GET DONOR #(index)
   def donor(self, index):
      """
      Get an H+ donor at the passed index.

      @type index: int
      @param index: H+ donor index

      @rtype: HDonor
      @return: The H+ donor at the passed index.
      """
      d = self.myPSF.getDonor(index-1)
      return HDonor(d.number, d.atom1, d.atom2)

   # GET ACCEPTOR #(index)
   def acceptor(self, index):
      """
      Get an H+ acceptor at the passed index.

      @type index: int
      @param index: H+ acceptor index

      @rtype: HAcceptor
      @return: The H+ acceptor at the passed index.
      """
      a = self.myPSF.getAcceptor(index-1)
      return HAcceptor(a.number, a.atom1, a.atom2)

   # GET THE MASS (AMU) OF A SPECIFIC ATOM
   def mass(self, atom):
      """
      Mass of an atom.

      @type atom: int
      @param atom: Atom index (1 to N)

      @rtype: float
      @return: Atom mass [amu]
      """
      return self.myPSF.getAtom(atom-1).mass

   def charge(self, atom):
      """
      Mass of an atom.

      @type atom: int
      @param atom: Atom index (1 to N)

      @rtype: float
      @return: Atom mass [amu]
      """
      return self.myPSF.getAtom(atom-1).charge

   # SYSTEM TEMPERATURE.
   def getTemperature(self):
      """
      System temperature (K)

      @rtype: float
      @return: Kelvin temperature
      """
      return topologyutilities.temperature(self.myTop, self.velvec)

   def angle(self, index):
      """
      Dihedral angle (rad) at passed index

      @type index: int
      @param index: Dihedral index
      
      @rtype: float
      @return: Dihedral angle in radians
      """
      myPhi = topologyutilities.computePhiDihedral(self.myTop, self.posvec, index-1)
      if (myPhi > numpy.pi):
         myPhi -= 2*numpy.pi
      elif (myPhi < -numpy.pi):
         myPhi += 2*numpy.pi
      return myPhi

   # CALCULATE BOND AND ANGLE HESSIANS.
   # EXPLICITLY USED IN MOLLY PROPAGATORS.
   # IF YOU ARE NOT DESIGNING A MOLLY PROPAGATOR,
   # YOU PROBABLY WON'T NEED THIS.
   def calculateHessians(self, mollypos, angleFilter):
      """
      Calculate bond and angle Hessians for MOLLY propagators

      @type mollypos: numpy.ndarray
      @param mollypos: Mollified positions

      @type angleFilter: ReducedHessAngleList
      @param angleFilter: List of angle Hessians
      """
      anglelist = list()
      ii = 0
      b1 = 0
      b2 = 0
      while (ii < self.numAngles()):
         a1 = self.angle(ii).atom1
         a2 = self.angle(ii).atom2
         a3 = self.angle(ii).atom3
         jj = 0
         while (jj < self.numBonds()):
            if (self.bond(jj).atom1 == a1 and
                self.bond(jj).atom2 == a2):
               b1 = jj
            else:
               b2 = jj
            jj = jj + 1
            anglelist.append([b1,b2])
            ii = ii + 1
                          
      while (ii < self.numAngles()):
         a1 = self.angle(ii).atom1
         a2 = self.angle(ii).atom2
         a3 = self.angle(ii).atom3
         theta0 = self.angle(ii).restAngle
         k_t = self.angle(ii).forceConstant
         angleFilter[ii].evaluate(mollypos[a1], mollypos[a2], mollypos[a3], k_t, theta0)
               
         r_0 = self.bond(anglelist[ii].bond1).restLength
         k = self.bond(anglelist[ii].bond1).springConstant
         bondHess12 = reducedHessBond.reducedHessbond(mollypos[a1],mollypos[a2], mollypos[a3], k, r_0)

         r_0 = self.bond(anglelist[ii].bond2).restLength
         k = self.bond(anglelist[ii].bond2).springConstant
         bondHess23 = reducedHessBond.reducedHessbond(mollypos[a1],mollypos[a2], mollypos[a3], k, r_0)
                    
         angleFilter[ii].accumulateTo(0,0,bondHess12)
         angleFilter[ii].accumulateTo(1,1,bondHess12)
         angleFilter[ii].accumulateNegTo(1,0,bondHess12)
         angleFilter[ii].accumulateNegTo(0,1,bondHess12)
                    
         angleFilter[ii].accumulateTo(0,0,bondHess23)
         angleFilter[ii].accumulateTo(1,1,bondHess23)
         angleFilter[ii].accumulateNegTo(1,0,bondHess23)
         angleFilter[ii].accumulateNegTo(0,1,bondHess23)


   def randomVelocity(self, T):
      _topologyutilities.randomVelocity(T, self.myTop, self.velvec, self.seed)
      
   def updateCOM_Momenta(self):
      """
      Update center of mass and angular momentum
      """
      _topologyutilities.buildMolecularCenterOfMass(self.posvec,self.myTop)
      _topologyutilities.buildMolecularMomentum(self.velvec,self.myTop)

   def build(self):
      """
      Build the physical data.
      """
      tm = -1
      if (hasattr(self, "myTop")):
          tm = self.myTop.time
      if (self.bc == "Periodic"):
          if (self.defaultCBV):
             # TEMPORARY STRUCTURES USED TO COMPUTE
             # BOUNDING BOX
             v1 = numpy.ndarray(3)
             v2 = numpy.ndarray(3)
             v3 = numpy.ndarray(3)
             v4 = numpy.ndarray(3)
             v1.fill(sys.maxint)
             v2.fill(-sys.maxint)
             # BOUNDING BOX
             i = 0
             while (i < numpy.size(self.positions)):
                if (self.positions[i] < v1[0]):
                   v1[0] = float(self.positions[i])
                if (self.positions[i] > v2[0]):
                   v2[0] = float(self.positions[i])
                if (self.positions[i+1] < v1[1]):
                   v1[1] = float(self.positions[i+1])
                if (self.positions[i+1] > v2[1]):
                   v2[1] = float(self.positions[i+1])
                if (self.positions[i+2] < v1[2]):
                   v1[2] = float(self.positions[i+2])
                if (self.positions[i+2] > v2[2]):
                   v2[2] = float(self.positions[i+2])
                i += 3
             v4.fill(Constants.periodicBoundaryTolerance()/2.0)
             v1 = v1 - v4
             v2 = v2 + v4
             v3 = v2 - v1
             self.cB1[0] = v3[0]
             self.cB2[1] = v3[1]
             self.cB3[2] = v3[2]
             self.cO = v1 + v3 * 0.5
             self.myTop.setBC(self.cB1[0],self.cB1[1],self.cB1[2],self.cB2[0],self.cB2[1],self.cB2[2],self.cB3[0],self.cB3[1],self.cB3[2],self.cO[0],self.cO[1],self.cO[2])
      self.myTop.setExclusion(self.exclude)
      if (self.myPSF.numAtoms() > 0 and hasattr(self.myPAR, 'readFlag')):
         _Topology.buildTopology(self.myTop, self.myPSF, self.myPAR, 0)
         if (numpy.size(self.velocities) == 0):
            aaa = _mathutilities.randomNumberFirst(self.seed, 1)
            _topologyutilities.randomVelocity(self.temperature, self.myTop, self.velvec, self.seed)
         if (self.remcom == "yes"):
           _topologyutilities.removeLinearMomentum(self.velvec, self.myTop)
         if (self.remang == "yes"):
           _topologyutilities.removeAngularMomentum(self.posvec, self.velvec, self.myTop)
      if (self.bc == "Periodic"):
           self.myTop.setCellSize(self.cellsize)

      # COMPUTE INV MASS MATRIX
      temp = list()
      ii = 0
      while ii < self.numAtoms()*3:
          temp.append(0.0)
          ii += 1
      ii = 0
      self.invmasses.resize(self.numAtoms()*3)
      while ii < self.numAtoms()*3:
          self.invmasses[ii] = 1.0 / self.myPSF.getAtom(ii/3).mass
          self.invmasses[ii+1] = 1.0 / self.myPSF.getAtom(ii/3).mass
          self.invmasses[ii+2] = 1.0 / self.myPSF.getAtom(ii/3).mass
          ii += 3

      # COMPUTE MASS MATRIX
      temp = list()
      ii = 0
      while ii < self.numAtoms()*3:
          temp.append(0.0)
          ii += 1
      ii = 0
      self.masses.resize(self.numAtoms()*3)
      while ii < self.numAtoms()*3:
          temp[ii] = self.myPSF.getAtom(ii/3).mass
          self.masses[ii] = temp[ii]
          temp[ii] = 0.0
          temp[ii+1] = self.myPSF.getAtom(ii/3).mass
          self.masses[ii+1] = temp[ii+1]
          temp[ii+1] = 0.0
          temp[ii+2] = self.myPSF.getAtom(ii/3).mass
          self.masses[ii+2] = temp[ii+2]
          temp[ii+2] = 0.0
          ii += 3


      # COMPUTE MASS SUM
      self.masssum = 0
      ii = 0
      while ii < self.numAtoms():
         self.masssum += self.myPSF.getAtom(ii).mass
         ii += 1

      # SET SELFICAL TIME
      if (tm != -1):
          self.myTop.time = tm

      self.dirty = 0  # clean now!
