Skip to content

Training Loop Walkthrough

This tutorial walks through the training pipeline implemented in bin/optimize_network.py, showing how a run is configured, executed, and logged.

High-level flow

  1. Parse CLI arguments and seed RNG.
  2. Resolve the chosen dynamics model and flatten parameter bounds.
  3. Load and optionally cache training/validation datasets.
  4. Create a run directory for artifacts.
  5. Initialize logging (Weights & Biases optional).
  6. Build the parameter-predicting MLP and optimizer.
  7. Train with optional curriculum, validation, and learning-rate reduction.
  8. Summarize predictions, plot/save losses, and write checkpoints/best params.

Step-by-step

1) Configure and seed (parse_args, main, lines ~1050)

  • CLI arguments set datasets, model name, feature mode, batch/epoch counts, LR schedule, curriculum, caching, logging, and analysis options.
  • The PRNG key is seeded early to make shuffling and initialization reproducible.
    def main():
        args = parse_args()
        rng = jax.random.PRNGKey(args.seed)
        if args.summary_chunk_size <= 0:
            raise ValueError("--summary-chunk-size must be positive.")
    

2) Pick dynamics spec (MODEL_SPECS, lines ~1080)

  • The script looks up MODEL_SPECS[--model] to get parameter keys, shapes, bounds, loss and dynamics functions. See pct.dynamics.core for the registry and pct.dynamics.singletrack for the default model details.
  • MLP dynamics are patched with the requested hidden size before training.
  • Bounds are flattened to constrain the network outputs.
    try:
        spec = MODEL_SPECS[args.model]
    except KeyError as exc:
        available = ", ".join(sorted(MODEL_SPECS.keys()))
        raise ValueError(f"Unknown model '{args.model}'. Available models: {available}") from exc
    
    if spec.name == "mlp-dynamics":
        hidden_size = args.dynamics_hidden_size
        spec = dataclasses.replace(
            spec,
            param_shapes=mlp_param_shapes(hidden_size),
            param_bounds=mlp_param_bounds(hidden_size),
            default_params=mlp_default_params(hidden_size),
        )
    
    flat_min_bounds, flat_max_bounds = flatten_bounds(
        spec.param_keys, spec.param_bounds, spec.full_shapes
    )
    

3) Load datasets (load_feature_dataset_group, lines ~1130)

  • Training/validation folders are resolved relative to the repo root and checked for existence.
  • Each dataset folder is converted to a FeatureDataset (see pct.optimization.network_training) of features, controls, initial states, and targets; optional caching (--dataset-cache, --val-dataset-cache) avoids recomputation.
  • Feature mode controls whether inputs are concatenated with state contexts or not; sample limits and timestep subsampling (--samples, --dt) are applied here.
    dataset = load_feature_dataset_group(
        dataset_paths, spec, dataset_cache, args.refresh_cache,
        args.samples, args.dt, args.feature_mode, "training",
    )
    val_dataset = load_feature_dataset_group(
        val_dataset_paths, spec, val_dataset_cache, args.refresh_cache,
        args.val_samples, args.dt, args.feature_mode, "validation",
    )
    

4) Create run directory (lines ~1170)

  • Artifacts live under results/<model>/run_<timestamp>[_ablation_*]; a numeric suffix avoids collisions.
    results_root = REPO_ROOT / "results" / spec.name
    run_timestamp = time.strftime("%Y%m%d-%H%M%S")
    run_basename = f"run_{run_timestamp}{feature_suffix}"
    candidate_dir = results_root / run_basename
    while candidate_dir.exists():
        candidate_dir = results_root / f"{run_basename}_{suffix:02d}"
        suffix += 1
    candidate_dir.mkdir(parents=True, exist_ok=False)
    output_dir = candidate_dir.resolve()
    

5) Initialize logging (lines ~1205)

  • If --wandb-project is provided and wandb is installed, a run is created with config metadata and metrics are declared (train/loss, val/loss, LR, etc.).
    if enable_wandb:
        init_kwargs = {"project": args.wandb_project, "config": config_payload}
        wandb_run = wandb.init(**init_kwargs)
        wandb_run.define_metric("train/loss", step_metric="step")
        wandb_run.define_metric("val/loss", step_metric="epoch")
    

6) Build model and optimizer (lines ~1230)

  • Base dynamics params come from the spec; a ParameterMLP maps features to bounded parameter vectors. The training helpers live in pct.optimization.network_training.
  • Parameters are initialized with dummy input; AdamW is created at the initial LR.
  • Training and eval loss functions are derived from the spec; make_network_step_fn wires the model, loss, integrator (integrate_rk4), and optimizer.
    base_params = spec.build_base_params(args.dt)
    model = ParameterMLP(
        hidden_sizes=args.model_hidden,
        min_bounds=flat_min_bounds,
        max_bounds=flat_max_bounds,
    )
    params = model.init(rng, jnp.zeros((1, dataset.features.shape[1])))
    optimizer = optax.adamw(args.lr)
    opt_state = optimizer.init(params)
    
    step_fn = optnet.make_network_step_fn(
        model.apply, train_loss_fn, spec.param_keys, base_params,
        spec.dynamics_fn, integrate_rk4, optimizer, param_shapes=spec.full_shapes,
    )
    eval_loss_fn = make_eval_loss_fn(model, base_params, spec)
    

7) Training loop with curriculum and validation (lines ~1270)

  • Data is shuffled each epoch; mini-batches are sliced and optionally truncated in time if curriculum is active (--curriculum-batches, --curriculum-timesteps). Mini-batch helpers are in pct.optimization.batching.
  • step_fn performs forward sim, loss, backprop, and optimizer update; losses and horizons are tracked.
  • Validation loss is computed each epoch when provided; the best params are tracked separately.
  • LR plateau handling: patience counter triggers LR decay (--lr-patience, --lr-factor, --min-lr) and recreates the optimizer/step function.
  • Metrics are printed and optionally logged to Weights & Biases every step/epoch.
    for epoch in range(args.epochs):
        rng, perm_key = jax.random.split(rng)
        perm = jax.random.permutation(perm_key, num_samples)
        shuffled_features = dataset.features[perm]
        ...
        for start in range(0, num_samples, args.batch_size):
            batch_features = shuffled_features[start:end]
            batch_dataset = FeatureDataset(...slice...)
            effective_dataset = (
                truncate_dataset_timesteps(batch_dataset, args.curriculum_timesteps)
                if curriculum_enabled and global_step < args.curriculum_batches
                else batch_dataset
            )
            params, opt_state, loss, _ = step_fn(
                params, opt_state, batch_features, effective_dataset,
            )
            loss_history.append(float(loss) / (end - start))
    
    # LR reduction on plateau
    if patience_counter >= args.lr_patience and current_lr > args.min_lr:
        current_lr = max(current_lr * args.lr_factor, args.min_lr)
        optimizer = optax.adam(current_lr)
        opt_state = optimizer.init(params)
        step_fn = optnet.make_network_step_fn(..., optimizer, ...)
    

8) Summaries, plots, and checkpoints (lines ~1370)

  • Predictions are summarized over the training set in chunks to avoid OOM and printed. This uses utilities from pct.optimization.network_training.
  • Analysis artifacts (scenario plots) are generated for train/val splits under analysis/.
  • Loss curves are saved to PNG and JSON; model params, optimizer state, and (if validation) best params and predicted params are serialized to msgpack/JSON.
  • Final predicted parameters are written to pct_predicted_params.json (or --output); WandB run summaries include best losses and curriculum settings.
    predicted_params_mean = summarize_predictions_chunked(
        model.apply, params_for_prediction, dataset.features,
        spec.param_keys, chunk_size=args.summary_chunk_size,
        param_shapes=spec.full_shapes,
    )
    generate_analysis_artifacts("training", dataset, spec, model, params_for_prediction, ...)
    plt.plot(loss_history); plt.savefig(output_dir / "pct_network_loss.png")
    with model_param_path.open("wb") as f:
        f.write(serialization.to_bytes(params))
    with output_path.open("w") as f:
        json.dump(predicted_params_mean, f, indent=2)
    

Try it yourself

Run a small experiment and tail the logs:

UV_CACHE_DIR=.uv-cache-local uv run --project PCT python PCT/bin/optimize_network.py \
  --dataset datasets/june_training_data/train_redp_1 \
  --epochs 2 --batch-size 64 --lr 1e-3 --model singletrack-augmented
Artifacts will appear under results/singletrack-augmented/run_*; open the loss PNG and analysis plots to inspect the run.