Skip to content

Batching

batchify

batchify(x, agent_list, num_actors)

Stack agent data into a batched array.

Parameters:

Name Type Description Default
x dict

Mapping from agent id to arrays.

required
agent_list Iterable

Ordered agent identifiers.

required
num_actors int

Number of actors to include.

required

Returns:

Type Description
ndarray

Stacked array of shape (num_actors, -1).

create_batched_params

create_batched_params(
    base_params, batch_size, **param_overrides
)

Create batched parameters from a base Param object.

Parameters:

Name Type Description Default
base_params Param

Base Param object with scalar values.

required
batch_size int

Number of samples in the batch.

required
**param_overrides

Parameter overrides as arrays of shape [batch_size].

{}

Returns:

Type Description
Param

Batched parameter object.

extract_param_at_index

extract_param_at_index(batched_params, index)

Extract parameters for a single sample from batched parameters.

Parameters:

Name Type Description Default
batched_params Param

Batched parameter object.

required
index int

Index to extract.

required

Returns:

Type Description
Param

Scalar parameter object for the given index.

is_batched_params

is_batched_params(params)

Check if a Param object contains batched parameters.

Parameters:

Name Type Description Default
params Param

Parameter object to test.

required

Returns:

Type Description
bool

True if any field is an ndarray, otherwise False.

unbatchify

unbatchify(x, agent_list, num_envs, num_actors)

Split a batched array back into a dict of agents.

Parameters:

Name Type Description Default
x ndarray

Batched array of shape (num_actors * num_envs, ...).

required
agent_list Iterable

Ordered agent identifiers.

required
num_envs int

Number of environments.

required
num_actors int

Number of actors.

required

Returns:

Type Description
dict

Mapping from agent id to unbatched arrays.