Skip to content

Network Training

Utilities for learning parameter models via neural networks.

FeatureDataset dataclass

Dataset with per-sample features alongside simulation data.

Attributes:

Name Type Description
features Array

Feature matrix of shape [B, F].

controls Array

Control trajectories of shape [B, T, 2].

initial_states Array

Initial state vectors of shape [B, 9].

targets Array

Target trajectories of shape [B, T, 7].

make_network_step_fn

make_network_step_fn(
    model_apply,
    loss_fn,
    param_names,
    base_params,
    dynamics_fn,
    integrator_fn,
    optimizer,
    *,
    param_shapes=None,
)

Create a JIT'd step function for parameter networks.

Parameters:

Name Type Description Default
model_apply Callable[[Array, Array], Array]

Model apply function returning parameter vectors.

required
loss_fn Callable[[Array, Array], Array]

Loss function over simulations and targets.

required
param_names Sequence[str]

Names describing each parameter element.

required
base_params object

Baseline parameter object used for replacement.

required
dynamics_fn Callable[[Array, object], Array]

Dynamics function to integrate.

required
integrator_fn Callable[[Callable[[Array, object], Array], Array, object], Array]

Integrator applied to the dynamics function.

required
optimizer GradientTransformation

Optimizer for network parameters.

required
param_shapes Mapping[str, Sequence[int]]

Optional shapes for structured parameters.

None

Returns:

Type Description
Callable

JIT-compiled step function updating model parameters.

simulate_with_param_vectors

simulate_with_param_vectors(
    param_vectors,
    dataset,
    base_params,
    dynamics_fn,
    integrator_fn,
    param_names,
    param_shapes=None,
)

Simulate each trajectory with its own parameter vector.

Parameters:

Name Type Description Default
param_vectors Array

Parameter vectors per sample, shape [B, P].

required
dataset FeatureDataset

Batched features and trajectories.

required
base_params object

Baseline parameters supporting replace.

required
dynamics_fn Callable[[Array, object], Array]

Dynamics function to integrate.

required
integrator_fn Callable[[Callable[[Array, object], Array], Array, object], Array]

Integrator for the dynamics function.

required
param_names Sequence[str]

Names corresponding to entries in param_vectors.

required
param_shapes Mapping[str, Sequence[int]]

Optional target shapes for each parameter.

None

Returns:

Type Description
Array

Simulated trajectories with shape [B, T, state_dim].

summarize_predictions

summarize_predictions(
    model_apply,
    model_params,
    features,
    param_names,
    *,
    param_shapes=None,
)

Summarize model predictions by computing the mean parameter vector.

Parameters:

Name Type Description Default
model_apply Callable[[Array, Array], Array]

Model apply function.

required
model_params object

Trained model parameters.

required
features Array

Feature matrix used for inference.

required
param_names Sequence[str]

Names corresponding to the predicted vector entries.

required
param_shapes Mapping[str, Sequence[int]]

Optional shapes for structured parameters.

None

Returns:

Type Description
Dict[str, object]

Mean prediction mapped into structured numpy scalars/arrays.