

from dmexpv import dmexpv
import numpy as np
from numpy.linalg import norm

"""
*-----Arguments--------------------------------------------------------|
*
*     n      : (input) order of the principal matrix A.
*                      
*     m      : (input) maximum size for the Krylov basis.
*                      
*     t      : (input) time at wich the solution is needed (can be < 0).
*                      
*     v(n)   : (input) given operand vector.
*
*     w(n)   : (output) computed approximation of exp(t*A)*v.
*
*     tol    : (input/output) the requested acurracy tolerance on w. 
*              If on input tol=0.0d0 or tol is too small (tol.le.eps)
*              the internal value sqrt(eps) is used, and tol is set to
*              sqrt(eps) on output (`eps' denotes the machine epsilon).
*              (`Happy breakdown' is assumed if h(j+1,j) .le. anorm*tol)
*
*     anorm  : (input) an approximation of some norm of A.
*
*   wsp(lwsp): (workspace) lwsp .ge. n*(m+1)+n+(m+2)^2+4*(m+2)^2+ideg+1
*                                   +---------+-------+---------------+
*              (actually, ideg=6)        V        H     wsp for PADE
*                   
* iwsp(liwsp): (workspace) liwsp .ge. m+2
*
*     matvec : external subroutine for matrix-vector multiplication.
*              synopsis: matvec( x, y )
*                        double precision x(*), y(*)
*              computes: y(1:n) <- A*x(1:n)
*                        where A is the principal matrix.
*
*              IMPORTANT: DMEXPV requires the product y = Ax = Q'x, i.e.
*              the TRANSPOSE of the transition rate matrix.
*
*     itrace : (input) running mode. 0=silent, 1=print step-by-step info
*
*     iflag  : (output) exit flag.
*              <0 - bad input arguments 
*               0 - no problem
*               1 - maximum number of steps reached without convergence
*               2 - requested tolerance was too high
"""

class FnMatrix:

    def __init__(self, A):
        self.A = A

    def matvec(self, u):

        result = np.zeros(u.shape[0], dtype=u.dtype)

        for i in xrange(self.A.shape[1]):
            result[i] = self.A[0,i] * u[i]

        for j in xrange(1, self.A.shape[0]):
            for i in xrange( A.shape[1]-j ):
                result[i] += self.A[j,i] * u[i+j]
                result[i+j] += self.A[j,i] * u[i]

        return result


def krylov_expm(A, v, t, debug=1):
    """ A wrapper for the FORTRAN routine dmexpv, from expokit. Makes
    everything easy. 

    A : sparse matrix to take the exponent of
    v : vector on which to calcuate the action
    t : time
    """

    # set some basic parameters
    # source: http://sf.anu.edu.au/~mhk900/Python_Workshop/short.pdf
    n    = A.shape[0]                           # order of A
    m    = np.min([30, n-1])                  # Krylov basis size
 
    lwsp = max([ n*(m+1) + n*(m+2)**2 \
           + 4*(m+2)**2 + 7, 10 ])
    wsp  = np.zeros( lwsp, dtype='float64')   # fxn workspace
    liwsp = np.max([ m+2, 7 ])
    iwsp = np.zeros( liwsp )                      # fxn workspace

    w = np.zeros(n)
    dot_matvec = np.dot
    ret = 0 
    anorm = norm(A)

    itrace = int(debug)

    Am = FnMatrix(A)

    #dmexpv(m,t,v,w,tol,anorm,wsp,iwsp,matvec,itrace,iflag,n=len(v),lwsp=len(wsp),liwsp=len(iwsp),matvec_extra_args=())
    dmexpv(m,t,v,w,0.0001,anorm,wsp,iwsp,Am.matvec,itrace,ret)

    print "RETURN:", ret
    print "W:", w

    return w


A = np.ones((3,3))
A[0,0] = -2
A[1,1] = -2
A[2,2] = -2
v = np.zeros(3); v[0] = 1.0 

Am = FnMatrix(A)
print Am.matvec(v)


for t in range(5):
    w = krylov_expm(A, v, t, debug=1)
    print w



