Skip to content

Params

build_param_labels

build_param_labels(spec)

Generate human-friendly labels for flattened parameters.

flat_vector_to_python

flat_vector_to_python(vector, keys, shapes)

Convert a flat parameter vector back to a python dict.

Parameters:

Name Type Description Default
vector ndarray

Flattened parameter vector.

required
keys Sequence[str]

Ordered parameter names.

required
shapes Mapping[str, Sequence[int]]

Target shapes keyed by parameter name.

required

Returns:

Type Description
Dict[str, object]

Mapping from parameter name to Python scalars/lists.

flatten_bounds

flatten_bounds(keys, bounds, shapes)

Flatten per-parameter bounds into two aligned vectors.

Parameters:

Name Type Description Default
keys Sequence[str]

Parameter names.

required
bounds Sequence[Sequence[object]]

Lower/upper bounds for each parameter.

required
shapes Mapping[str, Sequence[int]]

Target shapes keyed by parameter name.

required

Returns:

Type Description
Tuple[ndarray, ndarray]

Concatenated lower and upper bounds.

infer_hidden_sizes_from_params

infer_hidden_sizes_from_params(params_tree)

Infer dense layer widths from a flax parameter tree.

params_to_vector

params_to_vector(param_obj, spec)

Flatten a parameter dataclass into a single vector.

Parameters:

Name Type Description Default
param_obj object

Dataclass containing parameter fields.

required
spec ModelSpec - like

Object exposing param_keys and optional shapes.

required

Returns:

Type Description
ndarray

Concatenated parameter vector.

shape_size

shape_size(shape)

Return the total element count for a shape (0-dim -> 1).

shape_tuple

shape_tuple(shape)

Normalize a possibly-None shape into a tuple of ints.