Skip to content

Write a New Dynamics Model

This guide shows how to add a new vehicle dynamics module that plugs into the training stack.

Steps

  1. Create the module skeleton under src/pct/dynamics/<your_model>/__init__.py. Follow the pattern used by existing models (e.g., pct.dynamics.singletrack).
  2. Define parameter and state dataclasses (Flax struct.dataclass is used for JAX compatibility). Include bounds and any non-trainable fields.
  3. Implement the dynamics function that maps state+controls to derivatives. Keep it JIT-friendly and pure.
  4. Provide a loss function for training/evaluation (e.g., L1 on wrapped heading errors). Export it in __all__.
  5. Register the model spec in pct.dynamics.__init__.py by importing your module so it calls register_model_spec(...).
  6. Expose parameter shapes/bounds/defaults to support network training and optimizer constraints.

Minimal template

# src/pct/dynamics/my_model/__init__.py
from functools import partial
from flax import struct
import jax
import jax.numpy as jnp
import chex

@struct.dataclass
class Param:
    # Trainable or fixed parameters with defaults
    some_param: float = 1.0
    timestep: float = 0.01

@partial(jax.jit)
def dynamics(x_and_u: chex.Array, params: Param) -> chex.Array:
    # Compute derivatives; keep ops JIT-friendly
    ...
    return jnp.array([...], dtype=x_and_u.dtype)

@partial(jax.jit)
def loss_fn(output_states: chex.Array, target_states: chex.Array) -> chex.Array:
    return jnp.sum(jnp.abs(output_states - target_states))

# Export for mkdocstrings and importers
__all__ = ["Param", "dynamics", "loss_fn"]

Registering the model

Add an import in src/pct/dynamics/__init__.py so the module registers its ModelSpec. Inside your module, build and register the spec:

# inside my_model/__init__.py
from pct.dynamics import ModelSpec, register_model_spec

param_keys = ("some_param",)
param_shapes = {"some_param": ()}
param_bounds = [(-10.0, 10.0)]
default_params = {"some_param": 1.0}

register_model_spec(
    ModelSpec(
        name="my-model",
        param_keys=param_keys,
        param_shapes=param_shapes,
        param_bounds=param_bounds,
        default_params=default_params,
        state_indices=(0, 1, 2),  # choose states exposed to training loss
        dynamics_module=__name__,
        dynamics_fn=dynamics,
        loss_fn=loss_fn,
        build_base_params=lambda dt: Param(timestep=dt / 100.0),
        control_dim=2,
        state_names=("x", "y", "heading"),
        control_names=("u0", "u1"),
    )
)

Testing the integration

  • Import via the registry and run a forward dynamics call:
    from pct.dynamics import MODEL_SPECS
    spec = MODEL_SPECS["my-model"]
    params = spec.build_base_params(timestep_hundredths=5)
    d_next = spec.dynamics_fn(jnp.zeros(6), params)
    
  • Add a quick unit test to ensure shapes, bounds, and loss calls work.

Documentation

  • Add API docs under docs/reference/api/dynamics/<your_model>.md using mkdocstrings (mirror the existing dynamics pages).
  • Briefly mention the model in the dynamics overview page so it appears in the nav.