Skip to content

Parameter Tuning

Reusable utilities for tuning single-track vehicle parameters.

SimulationDataset dataclass

Container for batched simulation inputs.

Attributes:

Name Type Description
initial_states Array

Array of shape [B, state_dim_with_padding] containing starting states.

controls Array

Control trajectories of shape [B, T, control_dim].

targets Array

Target trajectories of shape [B, T, state_dim].

clamp_vector

clamp_vector(vector, param_names, bounds)

Project a parameter vector into provided box constraints.

Parameters:

Name Type Description Default
vector Array

Parameter vector to clamp.

required
param_names Sequence[str]

Names describing the order of parameters in vector.

required
bounds Mapping[str, Tuple[float, float]]

Lower and upper bounds keyed by parameter name.

required

Returns:

Type Description
Array

Clamped parameter vector with the same shape as vector.

create_optimizer

create_optimizer(name, learning_rate, momentum=0.0)

Factory for Optax optimizers used during tuning.

Parameters:

Name Type Description Default
name str

Optimizer name ("sgd", "adam", "adamw", or "rmsprop").

required
learning_rate float

Learning rate to apply.

required
momentum float

Momentum factor for SGD- or RMSProp-based optimizers.

0.0

Returns:

Type Description
GradientTransformation

Configured optimizer instance.

make_loss_and_grad_fn

make_loss_and_grad_fn(
    dataset,
    base_params,
    param_names,
    integrator_fn,
    dynamics_fn,
    loss_fn,
    per_element_loss_fn,
)

Build a loss function (and gradient) for parameter tuning.

Parameters:

Name Type Description Default
dataset SimulationDataset

Batched training trajectories.

required
base_params object

Baseline parameter object supporting replace.

required
param_names Sequence[str]

Parameter field names to tune.

required
integrator_fn Callable[[Callable, Array, object], Array]

Integrator used for rollouts.

required
dynamics_fn Callable[[Array, object], Array]

Dynamics function used for simulation.

required
loss_fn Callable[[Array, Array], Array]

Loss function over predicted and target trajectories.

required
per_element_loss_fn Callable[[Array, Array], Array]

Per-element loss for auxiliary logging.

required

Returns:

Type Description
Callable

JIT-compiled function returning loss and gradient for a parameter vector.

optimize_parameters

optimize_parameters(
    initial_vector,
    loss_and_grad_fn,
    steps,
    optimizer,
    param_names,
    bounds,
    initial_eval=None,
    progress_fn=None,
)

Optimize parameters while recording diagnostic history.

Parameters:

Name Type Description Default
initial_vector Array

Initial parameter vector.

required
loss_and_grad_fn Callable

Function returning loss/aux and gradient for a parameter vector.

required
steps int

Number of optimization iterations.

required
optimizer GradientTransformation

Optax optimizer configured for the problem.

required
param_names Sequence[str]

Ordered names corresponding to entries of the parameter vector.

required
bounds Mapping[str, Tuple[float, float]]

Lower and upper limits for each parameter.

required
initial_eval tuple

Optional precomputed ((loss, aux), grad) tuple.

None
progress_fn Callable[[int, float, ndarray], None]

Callback invoked each step with iteration, loss, and gradient.

None

Returns:

Type Description
Tuple[Array, Dict[str, ndarray]]

Best-found parameter vector and history containing loss, per-element loss, gradients, and best metrics.

params_from_vector

params_from_vector(base_params, vector, param_names)

Create a new Param object from a dense vector of tunable parameters.

Parameters:

Name Type Description Default
base_params object

Baseline parameter dataclass providing the replace method.

required
vector Array

Dense vector containing parameter values.

required
param_names Sequence[str]

Names of the fields corresponding to entries in vector.

required

Returns:

Type Description
object

A new parameter dataclass with updated values.

simulate_batch

simulate_batch(params, dataset, integrator_fn, dynamics_fn)

Roll out trajectories for a batch of initial states and controls.

Parameters:

Name Type Description Default
params object

Dynamics parameter object passed to integrator_fn and dynamics_fn.

required
dataset SimulationDataset

Batched trajectories of initial states, controls, and targets.

required
integrator_fn Callable[[Callable, Array, object], Array]

Integrator that advances the dynamics.

required
dynamics_fn Callable[[Array, object], Array]

Dynamics function mapping state-control to derivatives.

required

Returns:

Type Description
Array

Simulated trajectories with shape [B, T, state_dim].

vector_from_params

vector_from_params(param_obj, param_names)

Pack selected fields from a Param dataclass into a dense vector.

Parameters:

Name Type Description Default
param_obj object

A dataclass instance containing tunable parameters.

required
param_names Sequence[str]

Names of the fields to extract.

required

Returns:

Type Description
Array

Vector of selected parameters ordered according to param_names.