Skip to content

Estimation

Estimation submodule providing models, metadata, and training utilities.

Notes

Contains parameter estimation model definitions, metadata helpers, and JAX/Flax training primitives for learning dynamics parameters.

DynamicsParameterMetadata dataclass

Metadata describing tunable parameter vectors for a dynamics model.

Attributes:

Name Type Description
name str

Identifier for the dynamics model.

parameter_keys tuple of str

Ordered parameter names.

parameter_bounds ndarray

Array of shape [K, 2] with (min, max) bounds.

default_parameters Mapping[str, float]

Baseline parameter values.

base_param_builder Callable[[int], object]

Factory creating parameter objects given timestep scaling.

dynamics_fn Callable[[ndarray, Any], ndarray]

Dynamics function mapping state-control to derivatives.

loss_fn Callable[..., ndarray]

Loss function compatible with simulated trajectories.

loss_kwargs Mapping[str, Any] | None

Optional keyword arguments for the loss function.

FeatureDataset dataclass

Dataset with per-sample features alongside simulation trajectories.

Attributes:

Name Type Description
features Array

Feature matrix of shape [B, F].

controls Array

Control trajectories of shape [B, T, 2].

initial_states Array

Initial state vectors of shape [B, 9].

targets Array

Target trajectories of shape [B, T, 7].

ParameterMLP

Bases: Module

Multi-layer perceptron that predicts dynamics parameter vectors.

Attributes:

Name Type Description
hidden_sizes Sequence[int]

Width of each hidden layer.

parameter_bounds ndarray

Bounds array of shape [K, 2] used to scale outputs.

available_dynamics_models

available_dynamics_models()

Return the registered dynamics model identifiers.

Returns:

Type Description
tuple of str

Sorted tuple of available model names.

get_parameter_metadata

get_parameter_metadata(model_name)

Retrieve metadata for the requested dynamics model.

Parameters:

Name Type Description Default
model_name str

Dynamics model identifier.

required

Returns:

Type Description
DynamicsParameterMetadata

Metadata describing the requested model.

Raises:

Type Description
ValueError

If the model name is not registered.

make_eval_loss_fn

make_eval_loss_fn(
    model_apply,
    base_params,
    dynamics_fn,
    integrator_fn,
    param_names,
    loss_fn,
    loss_kwargs=None,
    *,
    param_shapes=None,
)

Build a JIT-compiled evaluation loss function.

Parameters:

Name Type Description Default
model_apply Callable[[Any, Array], Array]

Model apply function.

required
base_params object

Baseline parameter object used for replacement.

required
dynamics_fn Callable[[Array, Any], Array]

Dynamics function to integrate.

required
integrator_fn Callable[[Callable[[Array, Any], Array], Array, Any], Array]

Integrator applied to the dynamics function.

required
param_names Sequence[str]

Names corresponding to entries in predicted vectors.

required
loss_fn Callable[[Array, Array], Array]

Loss function over simulations and targets.

required
loss_kwargs Mapping[str, Any]

Extra kwargs passed to the loss function.

None
param_shapes Mapping[str, Sequence[int]]

Optional shapes for structured parameters.

None

Returns:

Type Description
Callable[[Any, FeatureDataset], Array]

Evaluation loss function compiled with JAX.

make_network_step_fn

make_network_step_fn(
    model_apply,
    loss_fn,
    param_names,
    base_params,
    dynamics_fn,
    integrator_fn,
    optimizer,
    *,
    loss_kwargs=None,
    param_shapes=None,
)

Create a JIT'd step function for parameter networks.

Parameters:

Name Type Description Default
model_apply Callable[[Any, Array], Array]

Model apply function returning parameter vectors.

required
loss_fn Callable[[Array, Array], Array]

Loss function over simulations and targets.

required
param_names Sequence[str]

Names describing each parameter element.

required
base_params object

Baseline parameter object used for replacement.

required
dynamics_fn Callable[[Array, Any], Array]

Dynamics function to integrate.

required
integrator_fn Callable[[Callable[[Array, Any], Array], Array, Any], Array]

Integrator applied to the dynamics function.

required
optimizer GradientTransformation

Optimizer for network parameters.

required
loss_kwargs Mapping[str, Any]

Extra kwargs passed to the loss function.

None
param_shapes Mapping[str, Sequence[int]]

Optional shapes for structured parameters.

None

Returns:

Type Description
Callable

JIT-compiled step function updating model parameters.

simulate_with_param_vectors

simulate_with_param_vectors(
    param_vectors,
    dataset,
    base_params,
    dynamics_fn,
    integrator_fn,
    param_names,
    param_shapes=None,
)

Simulate each trajectory with its own parameter vector.

Parameters:

Name Type Description Default
param_vectors Array

Parameter vectors per sample, shape [B, P].

required
dataset FeatureDataset

Batched features and trajectories.

required
base_params object

Baseline parameters supporting replace.

required
dynamics_fn Callable[[Array, Any], Array]

Dynamics function to integrate.

required
integrator_fn Callable[[Callable[[Array, Any], Array], Array, Any], Array]

Integrator applied to the dynamics function.

required
param_names Sequence[str]

Names corresponding to entries in param_vectors.

required
param_shapes Mapping[str, Sequence[int]]

Optional target shapes for each parameter.

None

Returns:

Type Description
Array

Simulated trajectories with shape [B, T, state_dim].

summarize_predictions

summarize_predictions(
    model_apply,
    model_params,
    features,
    param_names,
    *,
    param_shapes=None,
)

Summarize model predictions by computing the mean parameter vector.

Parameters:

Name Type Description Default
model_apply Callable[[Any, Array], Array]

Model apply function.

required
model_params object

Trained model parameters.

required
features Array

Feature matrix used for inference.

required
param_names Sequence[str]

Names corresponding to the predicted vector entries.

required
param_shapes Mapping[str, Sequence[int]]

Optional shapes for structured parameters.

None

Returns:

Type Description
Dict[str, object]

Mean prediction mapped into structured numpy scalars/arrays.

summarize_predictions_chunked

summarize_predictions_chunked(
    model_apply,
    model_params,
    features,
    param_names,
    *,
    chunk_size,
    param_shapes=None,
)

Compute prediction mean without materializing the full tensor.

Parameters:

Name Type Description Default
model_apply Callable[[Any, Array], Array]

Model apply function.

required
model_params object

Trained model parameters.

required
features Array

Feature matrix used for inference.

required
param_names Sequence[str]

Names corresponding to the predicted vector entries.

required
chunk_size int

Batch size for chunked processing.

required
param_shapes Mapping[str, Sequence[int]]

Optional shapes for structured parameters.

None

Returns:

Type Description
Dict[str, object]

Mean prediction mapped into structured numpy scalars/arrays.

Raises:

Type Description
ValueError

If features is empty.

Model definitions and helpers for parameter estimation networks.

Notes

Defines metadata for dynamics parameter vectors and small MLP heads for predicting parameters within bounded ranges.

DynamicsParameterMetadata dataclass

Metadata describing tunable parameter vectors for a dynamics model.

Attributes:

Name Type Description
name str

Identifier for the dynamics model.

parameter_keys tuple of str

Ordered parameter names.

parameter_bounds ndarray

Array of shape [K, 2] with (min, max) bounds.

default_parameters Mapping[str, float]

Baseline parameter values.

base_param_builder Callable[[int], object]

Factory creating parameter objects given timestep scaling.

dynamics_fn Callable[[ndarray, Any], ndarray]

Dynamics function mapping state-control to derivatives.

loss_fn Callable[..., ndarray]

Loss function compatible with simulated trajectories.

loss_kwargs Mapping[str, Any] | None

Optional keyword arguments for the loss function.

ParameterMLP

Bases: Module

Multi-layer perceptron that predicts dynamics parameter vectors.

Attributes:

Name Type Description
hidden_sizes Sequence[int]

Width of each hidden layer.

parameter_bounds ndarray

Bounds array of shape [K, 2] used to scale outputs.

available_dynamics_models

available_dynamics_models()

Return the registered dynamics model identifiers.

Returns:

Type Description
tuple of str

Sorted tuple of available model names.

get_parameter_metadata

get_parameter_metadata(model_name)

Retrieve metadata for the requested dynamics model.

Parameters:

Name Type Description Default
model_name str

Dynamics model identifier.

required

Returns:

Type Description
DynamicsParameterMetadata

Metadata describing the requested model.

Raises:

Type Description
ValueError

If the model name is not registered.

Training primitives for parameter estimation networks.

FeatureDataset dataclass

Dataset with per-sample features alongside simulation trajectories.

Attributes:

Name Type Description
features Array

Feature matrix of shape [B, F].

controls Array

Control trajectories of shape [B, T, 2].

initial_states Array

Initial state vectors of shape [B, 9].

targets Array

Target trajectories of shape [B, T, 7].

make_eval_loss_fn

make_eval_loss_fn(
    model_apply,
    base_params,
    dynamics_fn,
    integrator_fn,
    param_names,
    loss_fn,
    loss_kwargs=None,
    *,
    param_shapes=None,
)

Build a JIT-compiled evaluation loss function.

Parameters:

Name Type Description Default
model_apply Callable[[Any, Array], Array]

Model apply function.

required
base_params object

Baseline parameter object used for replacement.

required
dynamics_fn Callable[[Array, Any], Array]

Dynamics function to integrate.

required
integrator_fn Callable[[Callable[[Array, Any], Array], Array, Any], Array]

Integrator applied to the dynamics function.

required
param_names Sequence[str]

Names corresponding to entries in predicted vectors.

required
loss_fn Callable[[Array, Array], Array]

Loss function over simulations and targets.

required
loss_kwargs Mapping[str, Any]

Extra kwargs passed to the loss function.

None
param_shapes Mapping[str, Sequence[int]]

Optional shapes for structured parameters.

None

Returns:

Type Description
Callable[[Any, FeatureDataset], Array]

Evaluation loss function compiled with JAX.

make_network_step_fn

make_network_step_fn(
    model_apply,
    loss_fn,
    param_names,
    base_params,
    dynamics_fn,
    integrator_fn,
    optimizer,
    *,
    loss_kwargs=None,
    param_shapes=None,
)

Create a JIT'd step function for parameter networks.

Parameters:

Name Type Description Default
model_apply Callable[[Any, Array], Array]

Model apply function returning parameter vectors.

required
loss_fn Callable[[Array, Array], Array]

Loss function over simulations and targets.

required
param_names Sequence[str]

Names describing each parameter element.

required
base_params object

Baseline parameter object used for replacement.

required
dynamics_fn Callable[[Array, Any], Array]

Dynamics function to integrate.

required
integrator_fn Callable[[Callable[[Array, Any], Array], Array, Any], Array]

Integrator applied to the dynamics function.

required
optimizer GradientTransformation

Optimizer for network parameters.

required
loss_kwargs Mapping[str, Any]

Extra kwargs passed to the loss function.

None
param_shapes Mapping[str, Sequence[int]]

Optional shapes for structured parameters.

None

Returns:

Type Description
Callable

JIT-compiled step function updating model parameters.

simulate_with_param_vectors

simulate_with_param_vectors(
    param_vectors,
    dataset,
    base_params,
    dynamics_fn,
    integrator_fn,
    param_names,
    param_shapes=None,
)

Simulate each trajectory with its own parameter vector.

Parameters:

Name Type Description Default
param_vectors Array

Parameter vectors per sample, shape [B, P].

required
dataset FeatureDataset

Batched features and trajectories.

required
base_params object

Baseline parameters supporting replace.

required
dynamics_fn Callable[[Array, Any], Array]

Dynamics function to integrate.

required
integrator_fn Callable[[Callable[[Array, Any], Array], Array, Any], Array]

Integrator applied to the dynamics function.

required
param_names Sequence[str]

Names corresponding to entries in param_vectors.

required
param_shapes Mapping[str, Sequence[int]]

Optional target shapes for each parameter.

None

Returns:

Type Description
Array

Simulated trajectories with shape [B, T, state_dim].

summarize_predictions

summarize_predictions(
    model_apply,
    model_params,
    features,
    param_names,
    *,
    param_shapes=None,
)

Summarize model predictions by computing the mean parameter vector.

Parameters:

Name Type Description Default
model_apply Callable[[Any, Array], Array]

Model apply function.

required
model_params object

Trained model parameters.

required
features Array

Feature matrix used for inference.

required
param_names Sequence[str]

Names corresponding to the predicted vector entries.

required
param_shapes Mapping[str, Sequence[int]]

Optional shapes for structured parameters.

None

Returns:

Type Description
Dict[str, object]

Mean prediction mapped into structured numpy scalars/arrays.

summarize_predictions_chunked

summarize_predictions_chunked(
    model_apply,
    model_params,
    features,
    param_names,
    *,
    chunk_size,
    param_shapes=None,
)

Compute prediction mean without materializing the full tensor.

Parameters:

Name Type Description Default
model_apply Callable[[Any, Array], Array]

Model apply function.

required
model_params object

Trained model parameters.

required
features Array

Feature matrix used for inference.

required
param_names Sequence[str]

Names corresponding to the predicted vector entries.

required
chunk_size int

Batch size for chunked processing.

required
param_shapes Mapping[str, Sequence[int]]

Optional shapes for structured parameters.

None

Returns:

Type Description
Dict[str, object]

Mean prediction mapped into structured numpy scalars/arrays.

Raises:

Type Description
ValueError

If features is empty.