Estimation¶
Estimation submodule providing models, metadata, and training utilities.
Notes
Contains parameter estimation model definitions, metadata helpers, and JAX/Flax training primitives for learning dynamics parameters.
DynamicsParameterMetadata
dataclass
¶
Metadata describing tunable parameter vectors for a dynamics model.
Attributes:
| Name | Type | Description |
|---|---|---|
name |
str
|
Identifier for the dynamics model. |
parameter_keys |
tuple of str
|
Ordered parameter names. |
parameter_bounds |
ndarray
|
Array of shape |
default_parameters |
Mapping[str, float]
|
Baseline parameter values. |
base_param_builder |
Callable[[int], object]
|
Factory creating parameter objects given timestep scaling. |
dynamics_fn |
Callable[[ndarray, Any], ndarray]
|
Dynamics function mapping state-control to derivatives. |
loss_fn |
Callable[..., ndarray]
|
Loss function compatible with simulated trajectories. |
loss_kwargs |
Mapping[str, Any] | None
|
Optional keyword arguments for the loss function. |
FeatureDataset
dataclass
¶
Dataset with per-sample features alongside simulation trajectories.
Attributes:
| Name | Type | Description |
|---|---|---|
features |
Array
|
Feature matrix of shape |
controls |
Array
|
Control trajectories of shape |
initial_states |
Array
|
Initial state vectors of shape |
targets |
Array
|
Target trajectories of shape |
ParameterMLP ¶
Bases: Module
Multi-layer perceptron that predicts dynamics parameter vectors.
Attributes:
| Name | Type | Description |
|---|---|---|
hidden_sizes |
Sequence[int]
|
Width of each hidden layer. |
parameter_bounds |
ndarray
|
Bounds array of shape |
available_dynamics_models ¶
available_dynamics_models()
Return the registered dynamics model identifiers.
Returns:
| Type | Description |
|---|---|
tuple of str
|
Sorted tuple of available model names. |
get_parameter_metadata ¶
get_parameter_metadata(model_name)
Retrieve metadata for the requested dynamics model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_name
|
str
|
Dynamics model identifier. |
required |
Returns:
| Type | Description |
|---|---|
DynamicsParameterMetadata
|
Metadata describing the requested model. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the model name is not registered. |
make_eval_loss_fn ¶
make_eval_loss_fn(
model_apply,
base_params,
dynamics_fn,
integrator_fn,
param_names,
loss_fn,
loss_kwargs=None,
*,
param_shapes=None,
)
Build a JIT-compiled evaluation loss function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_apply
|
Callable[[Any, Array], Array]
|
Model apply function. |
required |
base_params
|
object
|
Baseline parameter object used for replacement. |
required |
dynamics_fn
|
Callable[[Array, Any], Array]
|
Dynamics function to integrate. |
required |
integrator_fn
|
Callable[[Callable[[Array, Any], Array], Array, Any], Array]
|
Integrator applied to the dynamics function. |
required |
param_names
|
Sequence[str]
|
Names corresponding to entries in predicted vectors. |
required |
loss_fn
|
Callable[[Array, Array], Array]
|
Loss function over simulations and targets. |
required |
loss_kwargs
|
Mapping[str, Any]
|
Extra kwargs passed to the loss function. |
None
|
param_shapes
|
Mapping[str, Sequence[int]]
|
Optional shapes for structured parameters. |
None
|
Returns:
| Type | Description |
|---|---|
Callable[[Any, FeatureDataset], Array]
|
Evaluation loss function compiled with JAX. |
make_network_step_fn ¶
make_network_step_fn(
model_apply,
loss_fn,
param_names,
base_params,
dynamics_fn,
integrator_fn,
optimizer,
*,
loss_kwargs=None,
param_shapes=None,
)
Create a JIT'd step function for parameter networks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_apply
|
Callable[[Any, Array], Array]
|
Model apply function returning parameter vectors. |
required |
loss_fn
|
Callable[[Array, Array], Array]
|
Loss function over simulations and targets. |
required |
param_names
|
Sequence[str]
|
Names describing each parameter element. |
required |
base_params
|
object
|
Baseline parameter object used for replacement. |
required |
dynamics_fn
|
Callable[[Array, Any], Array]
|
Dynamics function to integrate. |
required |
integrator_fn
|
Callable[[Callable[[Array, Any], Array], Array, Any], Array]
|
Integrator applied to the dynamics function. |
required |
optimizer
|
GradientTransformation
|
Optimizer for network parameters. |
required |
loss_kwargs
|
Mapping[str, Any]
|
Extra kwargs passed to the loss function. |
None
|
param_shapes
|
Mapping[str, Sequence[int]]
|
Optional shapes for structured parameters. |
None
|
Returns:
| Type | Description |
|---|---|
Callable
|
JIT-compiled step function updating model parameters. |
simulate_with_param_vectors ¶
simulate_with_param_vectors(
param_vectors,
dataset,
base_params,
dynamics_fn,
integrator_fn,
param_names,
param_shapes=None,
)
Simulate each trajectory with its own parameter vector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
param_vectors
|
Array
|
Parameter vectors per sample, shape |
required |
dataset
|
FeatureDataset
|
Batched features and trajectories. |
required |
base_params
|
object
|
Baseline parameters supporting |
required |
dynamics_fn
|
Callable[[Array, Any], Array]
|
Dynamics function to integrate. |
required |
integrator_fn
|
Callable[[Callable[[Array, Any], Array], Array, Any], Array]
|
Integrator applied to the dynamics function. |
required |
param_names
|
Sequence[str]
|
Names corresponding to entries in |
required |
param_shapes
|
Mapping[str, Sequence[int]]
|
Optional target shapes for each parameter. |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
Simulated trajectories with shape |
summarize_predictions ¶
summarize_predictions(
model_apply,
model_params,
features,
param_names,
*,
param_shapes=None,
)
Summarize model predictions by computing the mean parameter vector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_apply
|
Callable[[Any, Array], Array]
|
Model apply function. |
required |
model_params
|
object
|
Trained model parameters. |
required |
features
|
Array
|
Feature matrix used for inference. |
required |
param_names
|
Sequence[str]
|
Names corresponding to the predicted vector entries. |
required |
param_shapes
|
Mapping[str, Sequence[int]]
|
Optional shapes for structured parameters. |
None
|
Returns:
| Type | Description |
|---|---|
Dict[str, object]
|
Mean prediction mapped into structured numpy scalars/arrays. |
summarize_predictions_chunked ¶
summarize_predictions_chunked(
model_apply,
model_params,
features,
param_names,
*,
chunk_size,
param_shapes=None,
)
Compute prediction mean without materializing the full tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_apply
|
Callable[[Any, Array], Array]
|
Model apply function. |
required |
model_params
|
object
|
Trained model parameters. |
required |
features
|
Array
|
Feature matrix used for inference. |
required |
param_names
|
Sequence[str]
|
Names corresponding to the predicted vector entries. |
required |
chunk_size
|
int
|
Batch size for chunked processing. |
required |
param_shapes
|
Mapping[str, Sequence[int]]
|
Optional shapes for structured parameters. |
None
|
Returns:
| Type | Description |
|---|---|
Dict[str, object]
|
Mean prediction mapped into structured numpy scalars/arrays. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Model definitions and helpers for parameter estimation networks.
Notes
Defines metadata for dynamics parameter vectors and small MLP heads for predicting parameters within bounded ranges.
DynamicsParameterMetadata
dataclass
¶
Metadata describing tunable parameter vectors for a dynamics model.
Attributes:
| Name | Type | Description |
|---|---|---|
name |
str
|
Identifier for the dynamics model. |
parameter_keys |
tuple of str
|
Ordered parameter names. |
parameter_bounds |
ndarray
|
Array of shape |
default_parameters |
Mapping[str, float]
|
Baseline parameter values. |
base_param_builder |
Callable[[int], object]
|
Factory creating parameter objects given timestep scaling. |
dynamics_fn |
Callable[[ndarray, Any], ndarray]
|
Dynamics function mapping state-control to derivatives. |
loss_fn |
Callable[..., ndarray]
|
Loss function compatible with simulated trajectories. |
loss_kwargs |
Mapping[str, Any] | None
|
Optional keyword arguments for the loss function. |
ParameterMLP ¶
Bases: Module
Multi-layer perceptron that predicts dynamics parameter vectors.
Attributes:
| Name | Type | Description |
|---|---|---|
hidden_sizes |
Sequence[int]
|
Width of each hidden layer. |
parameter_bounds |
ndarray
|
Bounds array of shape |
available_dynamics_models ¶
available_dynamics_models()
Return the registered dynamics model identifiers.
Returns:
| Type | Description |
|---|---|
tuple of str
|
Sorted tuple of available model names. |
get_parameter_metadata ¶
get_parameter_metadata(model_name)
Retrieve metadata for the requested dynamics model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_name
|
str
|
Dynamics model identifier. |
required |
Returns:
| Type | Description |
|---|---|
DynamicsParameterMetadata
|
Metadata describing the requested model. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the model name is not registered. |
Training primitives for parameter estimation networks.
FeatureDataset
dataclass
¶
Dataset with per-sample features alongside simulation trajectories.
Attributes:
| Name | Type | Description |
|---|---|---|
features |
Array
|
Feature matrix of shape |
controls |
Array
|
Control trajectories of shape |
initial_states |
Array
|
Initial state vectors of shape |
targets |
Array
|
Target trajectories of shape |
make_eval_loss_fn ¶
make_eval_loss_fn(
model_apply,
base_params,
dynamics_fn,
integrator_fn,
param_names,
loss_fn,
loss_kwargs=None,
*,
param_shapes=None,
)
Build a JIT-compiled evaluation loss function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_apply
|
Callable[[Any, Array], Array]
|
Model apply function. |
required |
base_params
|
object
|
Baseline parameter object used for replacement. |
required |
dynamics_fn
|
Callable[[Array, Any], Array]
|
Dynamics function to integrate. |
required |
integrator_fn
|
Callable[[Callable[[Array, Any], Array], Array, Any], Array]
|
Integrator applied to the dynamics function. |
required |
param_names
|
Sequence[str]
|
Names corresponding to entries in predicted vectors. |
required |
loss_fn
|
Callable[[Array, Array], Array]
|
Loss function over simulations and targets. |
required |
loss_kwargs
|
Mapping[str, Any]
|
Extra kwargs passed to the loss function. |
None
|
param_shapes
|
Mapping[str, Sequence[int]]
|
Optional shapes for structured parameters. |
None
|
Returns:
| Type | Description |
|---|---|
Callable[[Any, FeatureDataset], Array]
|
Evaluation loss function compiled with JAX. |
make_network_step_fn ¶
make_network_step_fn(
model_apply,
loss_fn,
param_names,
base_params,
dynamics_fn,
integrator_fn,
optimizer,
*,
loss_kwargs=None,
param_shapes=None,
)
Create a JIT'd step function for parameter networks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_apply
|
Callable[[Any, Array], Array]
|
Model apply function returning parameter vectors. |
required |
loss_fn
|
Callable[[Array, Array], Array]
|
Loss function over simulations and targets. |
required |
param_names
|
Sequence[str]
|
Names describing each parameter element. |
required |
base_params
|
object
|
Baseline parameter object used for replacement. |
required |
dynamics_fn
|
Callable[[Array, Any], Array]
|
Dynamics function to integrate. |
required |
integrator_fn
|
Callable[[Callable[[Array, Any], Array], Array, Any], Array]
|
Integrator applied to the dynamics function. |
required |
optimizer
|
GradientTransformation
|
Optimizer for network parameters. |
required |
loss_kwargs
|
Mapping[str, Any]
|
Extra kwargs passed to the loss function. |
None
|
param_shapes
|
Mapping[str, Sequence[int]]
|
Optional shapes for structured parameters. |
None
|
Returns:
| Type | Description |
|---|---|
Callable
|
JIT-compiled step function updating model parameters. |
simulate_with_param_vectors ¶
simulate_with_param_vectors(
param_vectors,
dataset,
base_params,
dynamics_fn,
integrator_fn,
param_names,
param_shapes=None,
)
Simulate each trajectory with its own parameter vector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
param_vectors
|
Array
|
Parameter vectors per sample, shape |
required |
dataset
|
FeatureDataset
|
Batched features and trajectories. |
required |
base_params
|
object
|
Baseline parameters supporting |
required |
dynamics_fn
|
Callable[[Array, Any], Array]
|
Dynamics function to integrate. |
required |
integrator_fn
|
Callable[[Callable[[Array, Any], Array], Array, Any], Array]
|
Integrator applied to the dynamics function. |
required |
param_names
|
Sequence[str]
|
Names corresponding to entries in |
required |
param_shapes
|
Mapping[str, Sequence[int]]
|
Optional target shapes for each parameter. |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
Simulated trajectories with shape |
summarize_predictions ¶
summarize_predictions(
model_apply,
model_params,
features,
param_names,
*,
param_shapes=None,
)
Summarize model predictions by computing the mean parameter vector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_apply
|
Callable[[Any, Array], Array]
|
Model apply function. |
required |
model_params
|
object
|
Trained model parameters. |
required |
features
|
Array
|
Feature matrix used for inference. |
required |
param_names
|
Sequence[str]
|
Names corresponding to the predicted vector entries. |
required |
param_shapes
|
Mapping[str, Sequence[int]]
|
Optional shapes for structured parameters. |
None
|
Returns:
| Type | Description |
|---|---|
Dict[str, object]
|
Mean prediction mapped into structured numpy scalars/arrays. |
summarize_predictions_chunked ¶
summarize_predictions_chunked(
model_apply,
model_params,
features,
param_names,
*,
chunk_size,
param_shapes=None,
)
Compute prediction mean without materializing the full tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_apply
|
Callable[[Any, Array], Array]
|
Model apply function. |
required |
model_params
|
object
|
Trained model parameters. |
required |
features
|
Array
|
Feature matrix used for inference. |
required |
param_names
|
Sequence[str]
|
Names corresponding to the predicted vector entries. |
required |
chunk_size
|
int
|
Batch size for chunked processing. |
required |
param_shapes
|
Mapping[str, Sequence[int]]
|
Optional shapes for structured parameters. |
None
|
Returns:
| Type | Description |
|---|---|
Dict[str, object]
|
Mean prediction mapped into structured numpy scalars/arrays. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |