Network Training¶
Utilities for learning parameter models via neural networks.
FeatureDataset
dataclass
¶
Dataset with per-sample features alongside simulation data.
Attributes:
| Name | Type | Description |
|---|---|---|
features |
Array
|
Feature matrix of shape |
controls |
Array
|
Control trajectories of shape |
initial_states |
Array
|
Initial state vectors of shape |
targets |
Array
|
Target trajectories of shape |
make_network_step_fn ¶
make_network_step_fn(
model_apply,
loss_fn,
param_names,
base_params,
dynamics_fn,
integrator_fn,
optimizer,
*,
param_shapes=None,
)
Create a JIT'd step function for parameter networks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_apply
|
Callable[[Array, Array], Array]
|
Model apply function returning parameter vectors. |
required |
loss_fn
|
Callable[[Array, Array], Array]
|
Loss function over simulations and targets. |
required |
param_names
|
Sequence[str]
|
Names describing each parameter element. |
required |
base_params
|
object
|
Baseline parameter object used for replacement. |
required |
dynamics_fn
|
Callable[[Array, object], Array]
|
Dynamics function to integrate. |
required |
integrator_fn
|
Callable[[Callable[[Array, object], Array], Array, object], Array]
|
Integrator applied to the dynamics function. |
required |
optimizer
|
GradientTransformation
|
Optimizer for network parameters. |
required |
param_shapes
|
Mapping[str, Sequence[int]]
|
Optional shapes for structured parameters. |
None
|
Returns:
| Type | Description |
|---|---|
Callable
|
JIT-compiled step function updating model parameters. |
simulate_with_param_vectors ¶
simulate_with_param_vectors(
param_vectors,
dataset,
base_params,
dynamics_fn,
integrator_fn,
param_names,
param_shapes=None,
)
Simulate each trajectory with its own parameter vector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
param_vectors
|
Array
|
Parameter vectors per sample, shape |
required |
dataset
|
FeatureDataset
|
Batched features and trajectories. |
required |
base_params
|
object
|
Baseline parameters supporting |
required |
dynamics_fn
|
Callable[[Array, object], Array]
|
Dynamics function to integrate. |
required |
integrator_fn
|
Callable[[Callable[[Array, object], Array], Array, object], Array]
|
Integrator for the dynamics function. |
required |
param_names
|
Sequence[str]
|
Names corresponding to entries in |
required |
param_shapes
|
Mapping[str, Sequence[int]]
|
Optional target shapes for each parameter. |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
Simulated trajectories with shape |
summarize_predictions ¶
summarize_predictions(
model_apply,
model_params,
features,
param_names,
*,
param_shapes=None,
)
Summarize model predictions by computing the mean parameter vector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_apply
|
Callable[[Array, Array], Array]
|
Model apply function. |
required |
model_params
|
object
|
Trained model parameters. |
required |
features
|
Array
|
Feature matrix used for inference. |
required |
param_names
|
Sequence[str]
|
Names corresponding to the predicted vector entries. |
required |
param_shapes
|
Mapping[str, Sequence[int]]
|
Optional shapes for structured parameters. |
None
|
Returns:
| Type | Description |
|---|---|
Dict[str, object]
|
Mean prediction mapped into structured numpy scalars/arrays. |