Write a New Dynamics Model¶
This guide shows how to add a new vehicle dynamics module that plugs into the training stack.
Steps¶
- Create the module skeleton under
src/pct/dynamics/<your_model>/__init__.py. Follow the pattern used by existing models (e.g.,pct.dynamics.singletrack). - Define parameter and state dataclasses (Flax
struct.dataclassis used for JAX compatibility). Include bounds and any non-trainable fields. - Implement the dynamics function that maps state+controls to derivatives. Keep it JIT-friendly and pure.
- Provide a loss function for training/evaluation (e.g., L1 on wrapped heading errors). Export it in
__all__. - Register the model spec in
pct.dynamics.__init__.pyby importing your module so it callsregister_model_spec(...). - Expose parameter shapes/bounds/defaults to support network training and optimizer constraints.
Minimal template¶
# src/pct/dynamics/my_model/__init__.py
from functools import partial
from flax import struct
import jax
import jax.numpy as jnp
import chex
@struct.dataclass
class Param:
# Trainable or fixed parameters with defaults
some_param: float = 1.0
timestep: float = 0.01
@partial(jax.jit)
def dynamics(x_and_u: chex.Array, params: Param) -> chex.Array:
# Compute derivatives; keep ops JIT-friendly
...
return jnp.array([...], dtype=x_and_u.dtype)
@partial(jax.jit)
def loss_fn(output_states: chex.Array, target_states: chex.Array) -> chex.Array:
return jnp.sum(jnp.abs(output_states - target_states))
# Export for mkdocstrings and importers
__all__ = ["Param", "dynamics", "loss_fn"]
Registering the model¶
Add an import in src/pct/dynamics/__init__.py so the module registers its ModelSpec. Inside your module, build and register the spec:
# inside my_model/__init__.py
from pct.dynamics import ModelSpec, register_model_spec
param_keys = ("some_param",)
param_shapes = {"some_param": ()}
param_bounds = [(-10.0, 10.0)]
default_params = {"some_param": 1.0}
register_model_spec(
ModelSpec(
name="my-model",
param_keys=param_keys,
param_shapes=param_shapes,
param_bounds=param_bounds,
default_params=default_params,
state_indices=(0, 1, 2), # choose states exposed to training loss
dynamics_module=__name__,
dynamics_fn=dynamics,
loss_fn=loss_fn,
build_base_params=lambda dt: Param(timestep=dt / 100.0),
control_dim=2,
state_names=("x", "y", "heading"),
control_names=("u0", "u1"),
)
)
Testing the integration¶
- Import via the registry and run a forward dynamics call:
from pct.dynamics import MODEL_SPECS spec = MODEL_SPECS["my-model"] params = spec.build_base_params(timestep_hundredths=5) d_next = spec.dynamics_fn(jnp.zeros(6), params) - Add a quick unit test to ensure shapes, bounds, and loss calls work.
Documentation¶
- Add API docs under
docs/reference/api/dynamics/<your_model>.mdusing mkdocstrings (mirror the existing dynamics pages). - Briefly mention the model in the dynamics overview page so it appears in the nav.