from simbios.simtk import *
import simbios.std as std
import math
import inspect
import sys

def lineno():
    """Returns the current line number in our program."""
    return inspect.currentframe().f_back.f_lineno

# eliding boost.units expressions to make a more portable example --cmb
# pay no attention to these types and units.  They are hollow echo of
# the use of boost.units in an earlier version of this example
degrees = 3.14159/180.0
radians = 1.0
angstroms = 0.10
nanometers = 1.0
daltons = 1.0

mass_t = float
length_t = float
angle_t = float
location_t = Vec3

class ZeroFunction(Function.Constant):
    def __init__(self):
        print lineno(); sys.stdout.flush()
        super(ZeroFunction, self).__init__(0, 0)

class UnityFunction(Function.Constant):
    def __init__(self):
        print lineno(); sys.stdout.flush()
        super(ZeroFunction, self).__init__(1.0, 0)

class IdentityFunction(Function):
    def __init__(self):
        print lineno(); sys.stdout.flush()
        super(IdentityFunction, self).__init__()

    def calcValue(self, x):
        print lineno(); sys.stdout.flush()
        return x[0]

    def calcDerivative(self, derivComponents, x):
        print lineno(); sys.stdout.flush()
        if len(derivComponents) == 1:
            return 1.0
        else:
            return 0.0

    def getArgumentSize(self):
        print lineno(); sys.stdout.flush()
        return 1
    
    def getMaxDerivativeOrder(self):
        print lineno(); sys.stdout.flush()
        return 1000

# Difference between two components of a two component vector
# Used as a target function for equality coupling constraint
class DifferenceFunction(Function.Linear):
    def __init__(self):
        print lineno(); sys.stdout.flush()
        # ax -by +0c = 0 => 1, -1, 0
        super(DifferenceFunction, self).__init__([1.0, -1.0, 0.0])

#/ Function that when zero, ensures that three variables are equal
#/ f(x,y,z) = (x-y)^2 + (x-z)^2 + (y-z)^2
#/ df/dx = 2(x-y) + 2(x-z)
#/ df/dx^2 = 4
#/ df/dxdy = -2
class ThreeDifferencesFunction(Function):
    def __init__(self):
        print lineno(); sys.stdout.flush()
        super(ThreeDifferencesFunction, self).__init__()

    def calcValue(self, x):
        print lineno(); sys.stdout.flush()
        print x
        assert( 3 == len(x) )
        dxy = (x[0] - x[1])
        dxz = (x[0] - x[2])
        dyz = (x[1] - x[2])
        return dxy*dxy + dxz*dxz + dyz*dyz

    def calcDerivative(self, derivComponents, x):
        print lineno(); sys.stdout.flush()
        deriv = 0.0
        assert(3 == len(x))
        derivOrder = len(derivComponents)
        assert(1 <= derivOrder)
        if (derivOrder == 1):
            # too clever
            #    df/dx
            #  = 2(x-y) + 2(x-z) 
            #  = 2x - 2y + 2x - 2z
            #  = 4x - 2y - 2z
            #  = 6x - 2x - 2y - 2z
            deriv = 2 * (3 * x[derivComponents[0]] -x[0] -x[1] -x[2])
        elif (derivOrder == 2):
            if derivComponents[0] == derivComponents[1]:
                deriv = 4.0 # df/dx^2
            else:
                deriv = -2.0 # df/dxdy
        else:
            pass # all derivatives higher than two are zero
        return deriv

    def getArgumentSize(self):
        print lineno(); sys.stdout.flush()
        return 3
    
    def getMaxDerivativeOrder(self):
        print lineno(); sys.stdout.flush()
        return 1000

#/ Implements a simple functional relationship, y = amplitude * sin(x - phase)
class SinusoidFunction(Function):
    def __init__(self, amplitude=180.0*degrees, phase=0.0*degrees):
        print lineno(); sys.stdout.flush()
        super(SinusoidFunction, self).__init__()
        self.amplitude = amplitude
        self.phase = phase
    
    def calcValue(self, x):
        print lineno(); sys.stdout.flush()
        assert( 1 == len(x) )
        return self.amplitude*math.sin(x[0]*radians - self.phase)
    
    def calcDerivative(self, derivComponents, x):
        print lineno(); sys.stdout.flush()
        deriv = 0.0
        assert(1 == len(x))
        # Derivatives repeat after 4
        derivOrder = len(derivComponents) % 4
        # Derivatives 1, 5, 9, 13, ... are cos()
        if  1 == derivOrder:
            deriv = angle_t(self.amplitude*math.cos(x[0]*radians - self.phase))
        # Derivatives 2, 6, 10, 14, ... are -sin()
        elif 2 == derivOrder:
            deriv = angle_t(-self.amplitude*math.sin(x[0]*radians - self.phase))
        # Derivatives 3, 7, 11, 15, ... are -cos()
        elif 3 == derivOrder :
            deriv = angle_t(-self.amplitude*math.cos(x[0]*radians - self.phase))
        # Derivatives 0, 4, 8, 12, ... are sin()
        elif 0 == derivOrder:
            deriv = angle_t(self.amplitude*math.sin(x[0]*radians - self.phase))
        else:
            assert(False)
        return deriv
    
    def getArgumentSize(self):
        print lineno(); sys.stdout.flush()
        return 1
    
    def getMaxDerivativeOrder(self):
        print lineno(); sys.stdout.flush()
        return 1000

# The pupose of TestPinMobilizer is to prove that I have understood the
# basic syntax of Function-based mobilizers
class TestPinMobilizer(MobilizedBody.FunctionBased):
    def __init__(self, parent, inbFrame, body, outbFrame, direction=MobilizedBody.Forward):
        print lineno(); sys.stdout.flush()
        zfn = ZeroFunction()
        ifn = IdentityFunction()
        super(TestPinMobilizer, self).__init__(
            parent, inbFrame, body, outbFrame, 1, 
            [zfn, zfn, ifn, zfn, zfn, zfn], 
            [ [], [], [0], [], [], [] ],
            direction)

class TestPinMobilizer2(MobilizedBody.Pin):
    "Just in case the problem is using a python-defined mobilizer"
    def __init__(self, parent, inbFrame, body, outbFrame):
        super(TestPinMobilizer2, self).__init__(parent, inbFrame, body, outbFrame)

class PseudorotationMobilizer(MobilizedBody.FunctionBased):
    def __init__(self, parent, inbFrame, body, outbFrame, 
                 amplitude, phase, direction=MobilizedBody.Forward):
        print lineno(); sys.stdout.flush()
        zfn = ZeroFunction()
        sfn = SinusoidFunction(amplitude, phase)
        super(PseudorotationMobilizer, self).__init__(parent,
                        inbFrame,
                        body,
                        outbFrame,
                        1,
                        [zfn, zfn, sfn, zfn, zfn, zfn],
                        [ [], [], [0], [], [], [] ],
                        direction
                        )

def testRiboseMobilizer():
    print lineno(); sys.stdout.flush()
    system = MultibodySystem()
    matter = SimbodyMatterSubsystem(system)
    decorations = DecorationSubsystem(system)
    matter.setShowDefaultGeometry(False)
    # Put some hastily chosen mass there (doesn't help)
    rigidBody = Body.Rigid()
    rigidBody.setDefaultRigidBodyMassProperties(MassProperties(
        mass_t(20.0*daltons),
        location_t(Vec3(0,0,0)*nanometers),
        Inertia(20.0)
        ))
    # One body anchored at C4 atom, 
    c4Body = MobilizedBody.Weld( 
        matter.updGround(), 
        Rotation(-120*degrees, XAxis),
        rigidBody,
        Transform())
    # sphere for C4 atom
    decorations.addBodyFixedDecoration(
        c4Body.getMobilizedBodyIndex(), 
        Transform(),
        DecorativeSphere( length_t(0.5*angstroms) )
    )
    # sphere for C5 atom
    decorations.addBodyFixedDecoration(
        c4Body.getMobilizedBodyIndex(), 
        location_t(Vec3(-1.0,-1.0,0.5)*angstroms),
        DecorativeSphere( length_t(0.5*angstroms) )
    )
    decorations.addRubberBandLine(
        c4Body.getMobilizedBodyIndex(),
        Vec3(0),
        c4Body.getMobilizedBodyIndex(),
        location_t(Vec3(-1.0,-1.0,0.5)*angstroms),
        DecorativeLine().setColor(Vec3(0,0,0)).setLineThickness(6))

    # bC3BodyType = "Pin"
    bC3BodyType = "TestPin"
    # bC3BodyType = "TestPin2"
        
    if bC3BodyType == "Pin":
        # One body anchored at C3 atom -- works
        # Pin version
        c3Body = MobilizedBody.Pin( 
            c4Body, 
            Transform(),
            rigidBody,
            Transform(location_t(Vec3(0,0,1.5)*angstroms))
            )
    elif bC3BodyType == "TestPin":
        # Function based pin version -- works
        c3Body = TestPinMobilizer( 
            c4Body, 
            Transform(),
            rigidBody,
            Transform(location_t(Vec3(0,0,1.5)*angstroms))
            )
    elif bC3BodyType == "TestPin2":
        # Function based pin version -- works
        c3Body = TestPinMobilizer2( 
            c4Body, 
            Transform(),
            rigidBody,
            Transform(location_t(Vec3(0,0,1.5)*angstroms))
            )
    elif bC3BodyType == "RiboseMobilizer":
        c3Body = PseudorotationMobilizer( 
            c4Body, 
            Transform(),
            rigidBody,
            Transform(location_t(Vec3(0,0,1.5)*angstroms)),
            angle_t(36.4*degrees), # amplitude
            angle_t(-161.8*degrees) # phase
            )
        c2Body = PseudorotationMobilizer( 
            c3Body, 
            Rotation( angle_t(-80*degrees), YAxis ),
            rigidBody,
            Transform(location_t(Vec3(0,0,1.5)*angstroms)),
            angle_t(35.8*degrees), # amplitude
            angle_t(-91.3*degrees) # phase
            )
        # sphere for C2 atom
        decorations.addBodyFixedDecoration(
            c2Body.getMobilizedBodyIndex(), 
            Transform(),
            DecorativeSphere( length_t(0.5*angstroms) )
        )
        # sphere for O2 atom
        decorations.addBodyFixedDecoration(
            c2Body.getMobilizedBodyIndex(), 
            location_t(Vec3(-1.0,1.0,-0.5)*angstroms),
            DecorativeSphere( length_t(0.5*angstroms) ).setColor(Vec3(1,0,0))
        )
        decorations.addRubberBandLine(
            c2Body.getMobilizedBodyIndex(),
            Vec3(0),
            c2Body.getMobilizedBodyIndex(),
            location_t(Vec3(-1.0,1.0,-0.5)*angstroms),
            DecorativeLine().setColor(Vec3(0,0,0)).setLineThickness(6))
        decorations.addRubberBandLine(
            c3Body.getMobilizedBodyIndex(),
            Vec3(0),
            c2Body.getMobilizedBodyIndex(),
            Vec3(0),
            DecorativeLine().setColor(Vec3(0,0,0)).setLineThickness(6))
        c1Body = PseudorotationMobilizer( 
            c2Body, 
            Rotation( angle_t(-80*degrees), YAxis ),
            rigidBody,
            Transform(location_t(Vec3(0,0,1.5)*angstroms)),
            angle_t(37.6*degrees), # amplitude
            angle_t(52.8*degrees) # phase
            )
        # sphere for C1 atom
        decorations.addBodyFixedDecoration(
            c1Body.getMobilizedBodyIndex(), 
            Transform(),
            DecorativeSphere( length_t(0.5*angstroms) )
        )
        # sphere for N1 atom
        decorations.addBodyFixedDecoration(
            c1Body.getMobilizedBodyIndex(), 
            location_t(Vec3(-1.0,-1.0,-0.5)*angstroms),
            DecorativeSphere( length_t(0.5*angstroms) ).setColor(Vec3(0,0,1))
        )
        # sphere for O4 atom
        decorations.addBodyFixedDecoration(
            c1Body.getMobilizedBodyIndex(), 
            location_t(Vec3(1.0,0,-0.5)*angstroms),
            DecorativeSphere( length_t(0.5*angstroms) ).setColor(Vec3(1,0,0))
        )
        decorations.addRubberBandLine(
            c2Body.getMobilizedBodyIndex(),
            Vec3(0),
            c1Body.getMobilizedBodyIndex(),
            Vec3(0),
            DecorativeLine().setColor(Vec3(0,0,0)).setLineThickness(6))
        decorations.addRubberBandLine(
            c1Body.getMobilizedBodyIndex(),
            Vec3(0),
            c1Body.getMobilizedBodyIndex(),
            location_t(Vec3(1.0,0,-0.5)*angstroms),
            DecorativeLine().setColor(Vec3(0,0,0)).setLineThickness(6))
        decorations.addRubberBandLine(
            c1Body.getMobilizedBodyIndex(),
            Vec3(0),
            c1Body.getMobilizedBodyIndex(),
            location_t(Vec3(-1.0,-1.0,-0.5)*angstroms),
            DecorativeLine().setColor(Vec3(0,0,0)).setLineThickness(6))
        decorations.addRubberBandLine(
            c4Body.getMobilizedBodyIndex(),
            Vec3(0),
            c1Body.getMobilizedBodyIndex(),
            location_t(Vec3(1.0,0,-0.5)*angstroms),
            DecorativeLine().setColor(Vec3(0,0,0)).setLineThickness(6))
        # Two constraint way works one constraint way does not
        numConstraints = 0
        if 2 == numConstraints: 
            # Constraints to make three generalized coordinates identical
            c32bodies = (c3Body.getMobilizedBodyIndex(),
                         c2Body.getMobilizedBodyIndex())
            coordinates = (MobilizerQIndex(0),) * 2
            differenceFunction1 = DifferenceFunction()
            Constraint.CoordinateCoupler(matter, differenceFunction1, c32bodies, coordinates)
            c21bodies = (c2Body.getMobilizedBodyIndex(), c1Body.getMobilizedBodyIndex())
            differenceFunction2 = DifferenceFunction()
            coupler2 = Constraint.CoordinateCoupler(matter, differenceFunction2, c21bodies, coordinates)
        if 1 == numConstraints:  # trying to get single constraint way to work
            # Try one constraint for all three mobilizers
            c123Bodies = (c1Body.getMobilizedBodyIndex(),
                          c2Body.getMobilizedBodyIndex(),
                          c3Body.getMobilizedBodyIndex())
            coords3 = (MobilizerQIndex(0),) * 3
            print c123Bodies, coords3
            threeDifferencesFunction = ThreeDifferencesFunction() # must be created out of constructor, below
            coupler = Constraint.CoordinateCoupler(matter, threeDifferencesFunction, c123Bodies, coords3)
            
    # sphere for C3 atom
    decorations.addBodyFixedDecoration(
        c3Body.getMobilizedBodyIndex(), 
        Transform(),
        DecorativeSphere( length_t(0.5*angstroms) )
    )
    # sphere for O3 atom
    decorations.addBodyFixedDecoration(
        c3Body.getMobilizedBodyIndex(), 
        location_t(Vec3(-1.0,1.0,-0.5)*angstroms),
        DecorativeSphere( length_t(0.5*angstroms) ).setColor(Vec3(1,0,0))
    )
    decorations.addRubberBandLine(
        c3Body.getMobilizedBodyIndex(),
        Vec3(0),
        c3Body.getMobilizedBodyIndex(),
        location_t(Vec3(-1.0,1.0,-0.5)*angstroms),
        DecorativeLine().setColor(Vec3(0,0,0)).setLineThickness(6))
    decorations.addRubberBandLine(
        c4Body.getMobilizedBodyIndex(),
        Vec3(0),
        c3Body.getMobilizedBodyIndex(),
        Vec3(0),
        DecorativeLine().setColor(Vec3(0,0,0)).setLineThickness(6))

    # Prescribed motion
    Constraint.ConstantSpeed(c3Body, 0.5)
            
    system.updDefaultSubsystem().addEventReporter( VTKEventReporter(system, 0.10) )
    system.realizeTopology()
    state = system.getDefaultState()
    # Simulate it.
    integ = VerletIntegrator(system)
    #integ = RungeKuttaMersonIntegrator(system)
    
    ts = TimeStepper(system, integ)
    print lineno(); sys.stdout.flush()
    ts.initialize(state) # dies here
    print lineno(); sys.stdout.flush()
    ts.stepTo(50.0)

if __name__ == "__main__":
    testRiboseMobilizer()
