from numpy import *
from math import sqrt

# Input: expects Nx3 matrix of points
# Returns transformation matrix, R, and translation vector, t
# R = 3x3 rotation matrix
# t = 3x1 column vector
# Uses the svd algorithm in numpy.  

def SVD_3D(A, B):
    
    assert len(A) == len(B)

    N = A.shape[0]; # total points

    centroid_A = mean(A, axis=0)
    centroid_B = mean(B, axis=0)
    
    # center the points
    AA = A - tile(centroid_A, (N, 1))
    BB = B - tile(centroid_B, (N, 1))

    # dot is matrix multiplication for array
    H = transpose(AA) * BB

    U, S, Vt = linalg.svd(H)

    R = Vt.T * U.T

    # special reflection case
    if linalg.det(R) < 0:
       print "Multiply by -1 due to reflection"
       Vt[2,:] *= -1
       R = Vt.T * U.T

    t = -R*centroid_A.T + centroid_B.T


    return R, t

