JAX Backend#

The JAX backend excels at automatic differentiation and batching. Use it for gradient-based optimization and research prototypes.

Key Features#

  • JIT Compilation – First call is slow, subsequent calls are fast

  • Automatic Differentiation – Compute gradients, Jacobians, Hessians

  • Native Batching – Process batches of configurations

  • GPU Support – Runs on GPU with proper JAX installation

Basic Usage#

import numpy as np
import jax.numpy as jnp
import adam
from adam.jax import KinDynComputations
from jax import jit, grad
import icub_models

# Load model
model_path = icub_models.get_model_file("iCubGazeboV2_5")
joints_name_list = [
    'torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
    'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
    'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch', 'l_hip_roll',
    'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll', 'r_hip_pitch',
    'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', 'r_ankle_roll'
]

kinDyn = KinDynComputations(model_path, joints_name_list)
kinDyn.set_frame_velocity_representation(adam.Representations.MIXED_REPRESENTATION)

# Define state
w_H_b = jnp.eye(4)
joints = jnp.ones(len(joints_name_list)) * 0.1

# Compute (slow, no compilation)
M = kinDyn.mass_matrix(w_H_b, joints)
print(f"Mass matrix shape: {M.shape}")

JIT Compilation#

Wrap your functions with @jit for speed:

from jax import jit

@jit
def compute(w_H_b, joints):
    M = kinDyn.mass_matrix(w_H_b, joints)
    J = kinDyn.jacobian('l_sole', w_H_b, joints)
    return M, J

# First call: slow (compilation)
M, J = compute(w_H_b, joints)

# Subsequent calls: fast (cached)
M, J = compute(w_H_b, joints)

Warning

Frame names must remain as strings (not traced). Wrap them in a closure:

# ✅ Correct
def make_fk_fn(frame_name):
    @jit
    def fk(w_H_b, joints):
        return kinDyn.forward_kinematics(frame_name, w_H_b, joints)
    return fk

fk_l_sole = make_fk_fn('l_sole')

Automatic Differentiation#

Compute gradients easily:

from jax import grad

# Gradient of mass matrix trace w.r.t. joint positions
def mass_matrix_trace(w_H_b, joints):
    M = kinDyn.mass_matrix(w_H_b, joints)
    return jnp.trace(M)

grad_fn = grad(mass_matrix_trace, argnums=1)  # Gradient w.r.t. joints
grad_joints = grad_fn(w_H_b, joints)
print(f"Gradient shape: {grad_joints.shape}")

Higher-order derivatives:

from jax import grad, hessian

hess_fn = hessian(mass_matrix_trace, argnums=1)
hess_joints = hess_fn(w_H_b, joints)
print(f"Hessian shape: {hess_joints.shape}")

Native Batching#

JAX automatically broadcasts batched operations:

# Batch size 1024
batch_size = 1024
w_H_b_batch = jnp.tile(jnp.eye(4), (batch_size, 1, 1))  # Shape: (1024, 4, 4)
joints_batch = jnp.tile(joints, (batch_size, 1))  # Shape: (1024, n_dof)

# Just pass batched tensors - JAX handles batching automatically
M_batch = kinDyn.mass_matrix(w_H_b_batch, joints_batch)  # Shape: (1024, 6+n, 6+n)
J_batch = kinDyn.jacobian('l_sole', w_H_b_batch, joints_batch)  # Shape: (1024, 6, 6+n)
print(f"Mass matrix shape: {M_batch.shape}")

Combine JIT with Native Batching:

# JIT for maximum speed
@jit
def jit_batched_compute(w_H_b_batch, joints_batch):
    M = kinDyn.mass_matrix(w_H_b_batch, joints_batch)
    J = kinDyn.jacobian('l_sole', w_H_b_batch, joints_batch)
    return M, J

# First call compiles, subsequent calls are fast
M_batch, J_batch = jit_batched_compute(w_H_b_batch, joints_batch)

Optimization Example#

Use gradients for optimization:

import optax  # pip install optax
from jax import grad, jit

def objective(joints):
    """Minimize mass matrix trace"""
    M = kinDyn.mass_matrix(w_H_b, joints)
    return jnp.trace(M)

# Setup optimizer
optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(joints)

# JIT the step
@jit
def step(joints, opt_state):
    loss, grads = jax.value_and_grad(objective)(joints)
    updates, opt_state = optimizer.update(grads, opt_state)
    joints = optax.apply_updates(joints, updates)
    return joints, opt_state, loss

# Optimize
for i in range(100):
    joints, opt_state, loss = step(joints, opt_state)
    if i % 10 == 0:
        print(f"Step {i}, Loss: {loss:.6f}")

Tips and Tricks#

Enable 64-bit precision (recommended for robustness):

from jax import config
config.update("jax_enable_x64", True)

Disable JIT temporarily for debugging:

from jax import config
config.update("jax_disable_jit", True)

Loading from MuJoCo#

Load models from MuJoCo and leverage JAX’s JIT and autodiff:

import mujoco
from robot_descriptions.loaders.mujoco import load_robot_description
from adam.jax import KinDynComputations
import jax.numpy as jnp
from jax import jit, grad

# Load MuJoCo model
mj_model = load_robot_description("g1_mj_description")

# Create KinDynComputations from MuJoCo model
kinDyn = KinDynComputations.from_mujoco_model(mj_model)
kinDyn.set_frame_velocity_representation(adam.Representations.MIXED_REPRESENTATION)

# Use with JIT and autodiff
@jit
def compute_mass_trace(w_H_b, joints):
    M = kinDyn.mass_matrix(w_H_b, joints)
    return jnp.trace(M)

w_H_b = jnp.eye(4)
joints = jnp.zeros(kinDyn.NDoF)

trace_val = compute_mass_trace(w_H_b, joints)
grad_fn = grad(compute_mass_trace, argnums=1)
grad_val = grad_fn(w_H_b, joints)

See MuJoCo Integration for more details on MuJoCo integration.

When to Use JAX#

Good for: - Gradient-based optimization - Computing Jacobians and Hessians - Processing batches with native batching - GPU acceleration

Not ideal for: - One-off computations (NumPy is faster) - Symbolic manipulation (use CasADi)

See also: Choosing a Backend