#! /usr/bin/python

import os
import re
import math
import copy

# ----------------------------------------------------------------------------------------------------

import GridSimTk
import UtilitiesSimTk
import LogSimTk 

fileName = '/home/SIMTK/friedrim/src/apbs/examples/test1/pot.dx'

eCharge             =  1.6021773e-19
Na                  =  6.0221367e+23
PotentialConstant   = eCharge*Na*1.0e-03
endOfLine           = '\n'
machineEpsilon      = 1.0e-05

# ----------------------------------------------------------------------------------------------------

class ApbsGridSimTk(GridSimTk.GridSimTk): # {
   """Apbs potential"""

   # -------------------------------------------------------------------------------------------------

   def __init__( self ): # {
      """Constructor for ApbsGridSimTk"""
     
      GridSimTk.GridSimTk.__init__( self )

   # } end of __init__(

   # -------------------------------------------------------------------------------------------------

   def getGridIndexGivenOffsets( self, xOffset, yOffset, zOffset ): # {
      """Get grid indiex given offsets:

           xIndex = int( ( x - xOrigin )/xDelta )
           ...
      """

      # validate offsets

      if xOffset < 0 or xOffset >= self._xGrid or yOffset < 0 or yOffset >= self._yGrid or zOffset < 0 or zOffset >= self._zGrid:
         return -1

      gridIndex = xOffset*self._xOffset + yOffset*self._yOffset + zOffset*self._zOffset

      return gridIndex

   # } end of getRawGridIndicesGivenPoint

   # -------------------------------------------------------------------------------------------------

   def getGridIndicesWithinPointAndRadius( self, point, radius ): # {
      """Get grid indices point within point and radius:
           ...
      """
      centerIndices = self.getRawGridIndicesGivenPoint( point )
      offsets       = self.getRawGridOffsetsGivenDistance( radius )

      logReference = self.getLogReference()
      logReference.info( 'getGridIndicesWithinPointAndRadius: centerIndices: ' + str( centerIndices ) )
      logReference.info( 'getGridIndicesWithinPointAndRadius: offsets: ' + str( offsets ) )

      gridList      = []
      for ii in range( centerIndices[0] - offsets[0], centerIndices[0] + offsets[0] + 1 ):
         if ii < 0 or ii >= self._xGrid:
            continue

         for jj in range( centerIndices[1] - offsets[1], centerIndices[1] + offsets[1] + 1 ):

            if jj < 0 or jj >= self._yGrid:
               continue

            for kk in range( centerIndices[2] - offsets[2], centerIndices[2] + offsets[2] + 1 ):

               if kk < 0 or kk >= self._zGrid:
                  continue

               gridIndex = self.getGridIndexGivenOffsets( ii, jj, kk )
               if gridIndex > -1:
                  gridList.append( gridIndex )

      return gridList

   # } end of getGridIndicesWithinPointAndRadius

   # -------------------------------------------------------------------------------------------------

   def readDx( self, dxFileName, hasHeader=1, onlyHeader=0 ): # {
      """A method to read APBS OpenDX-formatted data
   
       WARNING!  THIS PROGRAM ASSUMES YOU HAVE THE DATA ORDERED SUCH
       "THAT Z INCREASES MOST QUICKLY, Y INCREASES NEXT MOST QUICKLY,
       "AND X INCREASES MOST SLOWLY."
      """
    
      # This procedure reads in a dx grid
   
      # log info
   
      logReference = self.getLogReference()

      methodName   = 'ApbsGridSimTk::readDx';
      message      = methodName + '\n   reading file=<' + dxFileName + '>'
      logReference.info( message )
   
      # ----------------------------------------------------------------------------------------------
   
      # open file -- log error if file can not be opened and return None
   
      try:
         dxFile = open( dxFileName, "r" )
         self.setGridPotentialDxFileName( dxFileName, 1 )
      except IOError:
         logReference.error( methodName + ' file=<' + dxFileName + '> could not be opened.' );
         return None
    
      # ----------------------------------------------------------------------------------------------
   
      if hasHeader:

         # get title line and skip comment fields
      
         titleLine = dxFile.readline()
         logReference.info( 'Title: ' + titleLine )
         
         isComment = 1
         self._commentList = []
         while isComment:
            nextLine  = dxFile.readline()
            isComment = re.search( '^#', nextLine )
            if isComment:
               self._commentList.append( nextLine )
      
         nextLine = nextLine[:-1] 
         logReference.info( 'First noncomment line=<' + nextLine + '>' )
      
         # ----------------------------------------------------------------------------------------------
      
         # grid pts in each dimension
      
         lineTokens = nextLine.split()
         self._xGrid = int( lineTokens[-3] )
         self._yGrid = int( lineTokens[-2] )
         self._zGrid = int( lineTokens[-1] )
      
         message = 'No. grid pts each dimension: [ ' + str(self._xGrid) + ', ' + str(self._yGrid) + ', ' + str(self._zGrid) + ']'
         logReference.info( message  )
      
         self._xOffset   = self._yGrid*self._zGrid
         self._yOffset   = self._zGrid
         self._zOffset   = 1
   
         # ----------------------------------------------------------------------------------------------
      
         # origin
      
         nextLine   = dxFile.readline()
         nextLine   = nextLine[:-1] 
         lineTokens = nextLine.split()
      
         self._xOrigin    = float( lineTokens[-3] )
         self._yOrigin    = float( lineTokens[-2] )
         self._zOrigin    = float( lineTokens[-1] )
      
         message = 'Origin: [ ' + str(self._xOrigin) + ', ' + str(self._yOrigin) + ', ' + str(self._zOrigin) + ']'
         logReference.info( message  )
      
         # ----------------------------------------------------------------------------------------------
      
         # deltas
      
         deltas = [];
         for ii in range(3):
            nextLine   = dxFile.readline()
            nextLine   = nextLine[:-1] 
            lineTokens = nextLine.split()
            deltas.append( lineTokens[ii+1] )
      
         self._xDelta     = float( deltas[0] )
         self._yDelta     = float( deltas[1] )
         self._zDelta     = float( deltas[2] )
      
         message = 'Deltas: [ ' + str(self._xDelta) + ', ' + str(self._yDelta) + ', ' + str(self._zDelta) + ']'
         logReference.info( message  )
      
         self._xVec  = [ self._xDelta*self._xGrid, 0, 0 ]
         self._yVec  = [ self._yDelta*self._yGrid, 0, 0 ]
         self._zVec  = [ self._zDelta*self._zGrid, 0, 0 ]
      
         self.setTotalGridPoints( self._xGrid*self._yGrid*self._zGrid )

         # ----------------------------------------------------------------------------------------------
      
         # skip next two 'object' lines
      
         nextLine   = dxFile.readline()
         nextLine   = dxFile.readline()
      
         # ----------------------------------------------------------------------------------------------
      
         # if only header is desired, return 
         if onlyHeader:
            dxFile.close()
            return

      # read in grid values

      totalI3        = int( self.getTotalGridPoints()/3 )
      self._potentialList = []
      count          = 0
      logReference.info( 'Reading in ' + str( self.getTotalGridPoints()) + ' grid values' )
   
      while count < totalI3:
         nextLine   = dxFile.readline()
         nextLine   = nextLine[:-1] 
         lineTokens = nextLine.split()
         self._potentialList.append( float( lineTokens[0] ) )
         self._potentialList.append( float( lineTokens[1] ) )
         self._potentialList.append( float( lineTokens[2] ) )
         count     += 1
   
      count *= 3
   
      # get any leftover points
   
      totalI      = 3*totalI3
      if (self.getTotalGridPoints() - totalI) == 2: 
         nextLine   = dxFile.readline()
         nextLine   = nextLine[:-1] 
         lineTokens = nextLine.split()
         self._potentialList.append( float( lineTokens[0] ) )
         self._potentialList.append( float( lineTokens[1] ) )
         count += 2
      elif (self.getTotalGridPoints() - totalI) == 1:
         nextLine   = dxFile.readline()
         nextLine   = nextLine[:-1] 
         lineTokens = nextLine.split()
         self._potentialList.append( float( lineTokens[0] ) )
         count += 1
   
      dxFile.close()

      # ----------------------------------------------------------------------------------------------

      # log stats

      message = self.getBoxStats()
      logReference.info( message )

      # ----------------------------------------------------------------------------------------------
   
   # -------------------------------------------------------------------------------------------------

   # } end of readDx


   # -------------------------------------------------------------------------------------------------

   def copyDxHeader( self, selfToBeCopiesTo ): # {
      """Copy Dx file header from one object to another
      """
    
      selfToBeCopiesTo._xGrid   = self._xGrid;
      selfToBeCopiesTo._yGrid   = self._yGrid;
      selfToBeCopiesTo._zGrid   = self._zGrid;

      selfToBeCopiesTo._xOffset = self._xOffset;
      selfToBeCopiesTo._yOffset = self._yOffset;
      selfToBeCopiesTo._zOffset = self._zOffset;

      selfToBeCopiesTo._xOrigin = self._xOrigin;
      selfToBeCopiesTo._yOrigin = self._yOrigin;
      selfToBeCopiesTo._zOrigin = self._zOrigin;

      selfToBeCopiesTo._xDelta  = self._xDelta;
      selfToBeCopiesTo._yDelta  = self._yDelta;
      selfToBeCopiesTo._zDelta  = self._zDelta;

      selfToBeCopiesTo._xVec    = copy.deepcopy( self._xVec );
      selfToBeCopiesTo._yVec    = copy.deepcopy( self._yVec );
      selfToBeCopiesTo._zVec    = copy.deepcopy( self._zVec );

   # -------------------------------------------------------------------------------------------------

   def writeGridPotentialFile( self, gridPotentialFileName ): # {
      """A method to write ISIM a grid-formatted potential
   
      """
    
      # This procedure writes a ISIM poential grid
   
      # log info
   
      logReference = self.getLogReference()

      parsedPath   = UtilitiesSimTk.parseFileName( __file__ )
      methodName   = parsedPath[1] + '::writeGridPotentialFile';
      message      = methodName + '\n   writing file=<' + gridPotentialFileName + '>'
      logReference.info( message )
   
      # ----------------------------------------------------------------------------------------------
   
      # open file -- log error if file can not be opened and return None
   
      try:
         gridPotentialFile = open( gridPotentialFileName, "w" )
      except IOError:
         logReference.error( methodName + ' file=<' + gridPotentialileName + '> could not be opened.' );
         return None
    
      # ----------------------------------------------------------------------------------------------
   
      # Sample file: [x,y,z] v

      #  -0.02316    -0.00121    -0.88553    -111.29441
      #  -0.06941    -0.00364    -0.88310    -109.71898
      #  -0.11547    -0.00605    -0.87825    -108.14298
      #  -0.16121    -0.00845    -0.87100    -106.57134
        
      for ii in range ( self.getTotalGridPoints()):  
         potential    = self._potentialList[ii]
         coordinates  = self.getGridCoordinates(ii)
         potentialString  = '%13.7f %13.7f %13.7f %13.7f\n' % ( 
                            coordinates[0], coordinates[1], coordinates[2], potential )
         gridPotentialFile.write( potentialString )
         ii += 1
 
      # close file

      gridPotentialFile.close( )

      # ----------------------------------------------------------------------------------------------

      # log messsage

      message = 'Write grid-formatted file to ' + gridPotentialFileName + '\n'
      logReference.info( message )

      # ----------------------------------------------------------------------------------------------
   
   # -------------------------------------------------------------------------------------------------

   # } end of writeIsimGridPotentialDxFile
   # -------------------------------------------------------------------------------------------------

   def getGridCoordinates( self, gridIndex ): # {
      """A method to convert grid index to grid coordinates
   
      """
    
      # This procedure converts a grid index to grid coordinates
   
      # log info
   
      diagnostics = 0
      logReference = self.getLogReference()

      # ----------------------------------------------------------------------------------------------

      # validate input gridIndex

      methodName = 'getGridCoordinates: '
      if gridIndex < 0:
         message = methodName + 'grid index ' + str( gridIndex ) + ' is negative.'
         logReference.error( message )
         return None

      maxGridIndex = self._xGrid*self._yGrid*self._zGrid
      if gridIndex > maxGridIndex:
         message = methodName + 'grid index ' + str( gridIndex ) + ' is too big: max=' + str( maxGridIndex ) + \
                   '%4d %4 %4d ' % self._xGrid, self._yGrid, self._zGrid
         logReference.error( message )
         return None

      gridIndexOriginal  = gridIndex
      xOffset            = int( (gridIndex/self._xOffset) + machineEpsilon )
      gridIndex          = gridIndex - xOffset*self._xOffset

      yOffset            = int( (gridIndex /self._yOffset) + machineEpsilon )
      zOffset            = gridIndex - yOffset*self._yOffset
 
      xCoord             = self._xOrigin + xOffset*self._xDelta
      yCoord             = self._yOrigin + yOffset*self._yDelta
      zCoord             = self._zOrigin + zOffset*self._zDelta

      if diagnostics and (zOffset == 0 or zOffset == (self._zGrid - 1)):
         message   = 'GridIndex ' + str( gridIndexOriginal ) + ' Offsets: ' 
         message  += '[ %4d %4d %4d ]' % ( xOffset, yOffset, zOffset )
         message  += '[ %.3f %.3f %.3f ]' % ( xCoord, yCoord, zCoord )
         logReference.info( message )

      return [ xCoord, yCoord, zCoord, xOffset, yOffset, zOffset ]

   # } end of method getGridCoordinates

   # -------------------------------------------------------------------------------------------------

   def getGridIndicesxGivenCenterAndRadius( self, center, radius ): # {
      """A method to convert grid coordinates to grid index
   
      """
    
      # This procedure converts the grid index to grid coordinates
   
      # log info
   
      diagnostics = 0
      logReference = self.getLogReference()

      # ----------------------------------------------------------------------------------------------

      # validate input gridIndex

      methodName = 'getGridIndicesxGivenCenterAndRadius: '
      if gridIndex < 0:
         message = methodName + 'grid index ' + str( gridIndex ) + ' is negative.'
         logReference.error( message )
         return None

      maxGridIndex = self._xGrid*self._yGrid*self._zGrid
      if gridIndex > maxGridIndex:
         message = methodName + 'grid index ' + str( gridIndex ) + ' is too big: max=' + str( maxGridIndex ) + \
                   '%4d %4 %4d ' % self._xGrid, self._yGrid, self._zGrid
         logReference.error( message )
         return None

      gridIndexOriginal  = gridIndex
      xOffset            = int( gridIndex / self._xOffset )
      gridIndex          = gridIndex - xOffset*self._xOffset

      yOffset            = int( gridIndex / self._yOffset )
      zOffset            = gridIndex - yOffset*self._yOffset
 
      xCoord             = self._xOrigin + xOffset*self._xDelta
      yCoord             = self._yOrigin + yOffset*self._yDelta
      zCoord             = self._zOrigin + zOffset*self._zDelta

      if diagnostics and (zOffset == 0 or zOffset == (self._zGrid - 1)):
         message   = 'GridIndex ' + str( gridIndexOriginal ) + ' Offsets: ' 
         message  += '[ %4d %4d %4d ]' % ( xOffset, yOffset, zOffset )
         message  += '[ %.3f %.3f %.3f ]' % ( xCoord, yCoord, zCoord )
         logReference.info( message )

      return [ xCoord, yCoord, zCoord, xOffset, yOffset, zOffset ]

   # } end of method getGridIndicesxGivenCenterAndRadius

   # -------------------------------------------------------------------------------------------------

   def testGridConsistency( self ): # {
      """Test for grid consistency
   
      """
    
      # log info
   
      diagnostics = 1
      logReference = self.getLogReference()
      logReference.info( "testGridConsistency begin" )

      # ----------------------------------------------------------------------------------------------

      # validate input gridIndex

      methodName     = 'testGridConsistency: '
      errors         = 0
      maxPrintErrors = 10
      for ii in range( self.getTotalGridPoints() ):
         gridCoordinates = self.getGridCoordinates( ii )
         gridOffsets     = self.getRawGridIndicesGivenPoint( gridCoordinates )
         gridIndex       = self.getGridIndexGivenOffsets( gridOffsets[0], gridOffsets[1], gridOffsets[2] )
         if gridIndex != ii and errors < maxPrintErrors:
            errors += 1
            if diagnostics:
               message   = methodName + ' Error at gridIndex ' + str( ii ) + ' Calculated=' + str( gridIndex ) + ' Coord: ' 
               message  += '[ %.3f %.3f %.3f ]' % ( gridCoordinates[0], gridCoordinates[1], gridCoordinates[2] )
               message  += 'Offsets=[ %d %d %d ]' % ( gridCoordinates[3], gridCoordinates[4], gridCoordinates[5] )
               message  += '[ %4d %4d %4d ]' % ( gridOffsets[0], gridOffsets[1], gridOffsets[2] )
               logReference.info( message )

      logReference.info( "testGridConsistency end" )
      return errors

   # } end of method getGridCoordinates

# ----------------------------------------------------------------------------------------------------

# potential = ApbsGridSimTk()
# potential.readDx( fileName )
