Params¶
build_param_labels ¶
build_param_labels(spec)
Generate human-friendly labels for flattened parameters.
flat_vector_to_python ¶
flat_vector_to_python(vector, keys, shapes)
Convert a flat parameter vector back to a python dict.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vector
|
ndarray
|
Flattened parameter vector. |
required |
keys
|
Sequence[str]
|
Ordered parameter names. |
required |
shapes
|
Mapping[str, Sequence[int]]
|
Target shapes keyed by parameter name. |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, object]
|
Mapping from parameter name to Python scalars/lists. |
flatten_bounds ¶
flatten_bounds(keys, bounds, shapes)
Flatten per-parameter bounds into two aligned vectors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
keys
|
Sequence[str]
|
Parameter names. |
required |
bounds
|
Sequence[Sequence[object]]
|
Lower/upper bounds for each parameter. |
required |
shapes
|
Mapping[str, Sequence[int]]
|
Target shapes keyed by parameter name. |
required |
Returns:
| Type | Description |
|---|---|
Tuple[ndarray, ndarray]
|
Concatenated lower and upper bounds. |
infer_hidden_sizes_from_params ¶
infer_hidden_sizes_from_params(params_tree)
Infer dense layer widths from a flax parameter tree.
params_to_vector ¶
params_to_vector(param_obj, spec)
Flatten a parameter dataclass into a single vector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
param_obj
|
object
|
Dataclass containing parameter fields. |
required |
spec
|
ModelSpec - like
|
Object exposing |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Concatenated parameter vector. |