Training Loop Walkthrough¶
This tutorial walks through the training pipeline implemented in bin/optimize_network.py, showing how a run is configured, executed, and logged.
High-level flow¶
- Parse CLI arguments and seed RNG.
- Resolve the chosen dynamics model and flatten parameter bounds.
- Load and optionally cache training/validation datasets.
- Create a run directory for artifacts.
- Initialize logging (Weights & Biases optional).
- Build the parameter-predicting MLP and optimizer.
- Train with optional curriculum, validation, and learning-rate reduction.
- Summarize predictions, plot/save losses, and write checkpoints/best params.
Step-by-step¶
1) Configure and seed (parse_args, main, lines ~1050)¶
- CLI arguments set datasets, model name, feature mode, batch/epoch counts, LR schedule, curriculum, caching, logging, and analysis options.
- The PRNG key is seeded early to make shuffling and initialization reproducible.
def main(): args = parse_args() rng = jax.random.PRNGKey(args.seed) if args.summary_chunk_size <= 0: raise ValueError("--summary-chunk-size must be positive.")
2) Pick dynamics spec (MODEL_SPECS, lines ~1080)¶
- The script looks up
MODEL_SPECS[--model]to get parameter keys, shapes, bounds, loss and dynamics functions. See pct.dynamics.core for the registry and pct.dynamics.singletrack for the default model details. - MLP dynamics are patched with the requested hidden size before training.
- Bounds are flattened to constrain the network outputs.
try: spec = MODEL_SPECS[args.model] except KeyError as exc: available = ", ".join(sorted(MODEL_SPECS.keys())) raise ValueError(f"Unknown model '{args.model}'. Available models: {available}") from exc if spec.name == "mlp-dynamics": hidden_size = args.dynamics_hidden_size spec = dataclasses.replace( spec, param_shapes=mlp_param_shapes(hidden_size), param_bounds=mlp_param_bounds(hidden_size), default_params=mlp_default_params(hidden_size), ) flat_min_bounds, flat_max_bounds = flatten_bounds( spec.param_keys, spec.param_bounds, spec.full_shapes )
3) Load datasets (load_feature_dataset_group, lines ~1130)¶
- Training/validation folders are resolved relative to the repo root and checked for existence.
- Each dataset folder is converted to a
FeatureDataset(see pct.optimization.network_training) of features, controls, initial states, and targets; optional caching (--dataset-cache,--val-dataset-cache) avoids recomputation. - Feature mode controls whether inputs are concatenated with state contexts or not; sample limits and timestep subsampling (
--samples,--dt) are applied here.dataset = load_feature_dataset_group( dataset_paths, spec, dataset_cache, args.refresh_cache, args.samples, args.dt, args.feature_mode, "training", ) val_dataset = load_feature_dataset_group( val_dataset_paths, spec, val_dataset_cache, args.refresh_cache, args.val_samples, args.dt, args.feature_mode, "validation", )
4) Create run directory (lines ~1170)¶
- Artifacts live under
results/<model>/run_<timestamp>[_ablation_*]; a numeric suffix avoids collisions.results_root = REPO_ROOT / "results" / spec.name run_timestamp = time.strftime("%Y%m%d-%H%M%S") run_basename = f"run_{run_timestamp}{feature_suffix}" candidate_dir = results_root / run_basename while candidate_dir.exists(): candidate_dir = results_root / f"{run_basename}_{suffix:02d}" suffix += 1 candidate_dir.mkdir(parents=True, exist_ok=False) output_dir = candidate_dir.resolve()
5) Initialize logging (lines ~1205)¶
- If
--wandb-projectis provided andwandbis installed, a run is created with config metadata and metrics are declared (train/loss,val/loss, LR, etc.).if enable_wandb: init_kwargs = {"project": args.wandb_project, "config": config_payload} wandb_run = wandb.init(**init_kwargs) wandb_run.define_metric("train/loss", step_metric="step") wandb_run.define_metric("val/loss", step_metric="epoch")
6) Build model and optimizer (lines ~1230)¶
- Base dynamics params come from the spec; a
ParameterMLPmaps features to bounded parameter vectors. The training helpers live in pct.optimization.network_training. - Parameters are initialized with dummy input; AdamW is created at the initial LR.
- Training and eval loss functions are derived from the spec;
make_network_step_fnwires the model, loss, integrator (integrate_rk4), and optimizer.base_params = spec.build_base_params(args.dt) model = ParameterMLP( hidden_sizes=args.model_hidden, min_bounds=flat_min_bounds, max_bounds=flat_max_bounds, ) params = model.init(rng, jnp.zeros((1, dataset.features.shape[1]))) optimizer = optax.adamw(args.lr) opt_state = optimizer.init(params) step_fn = optnet.make_network_step_fn( model.apply, train_loss_fn, spec.param_keys, base_params, spec.dynamics_fn, integrate_rk4, optimizer, param_shapes=spec.full_shapes, ) eval_loss_fn = make_eval_loss_fn(model, base_params, spec)
7) Training loop with curriculum and validation (lines ~1270)¶
- Data is shuffled each epoch; mini-batches are sliced and optionally truncated in time if curriculum is active (
--curriculum-batches,--curriculum-timesteps). Mini-batch helpers are in pct.optimization.batching. step_fnperforms forward sim, loss, backprop, and optimizer update; losses and horizons are tracked.- Validation loss is computed each epoch when provided; the best params are tracked separately.
- LR plateau handling: patience counter triggers LR decay (
--lr-patience,--lr-factor,--min-lr) and recreates the optimizer/step function. - Metrics are printed and optionally logged to Weights & Biases every step/epoch.
for epoch in range(args.epochs): rng, perm_key = jax.random.split(rng) perm = jax.random.permutation(perm_key, num_samples) shuffled_features = dataset.features[perm] ... for start in range(0, num_samples, args.batch_size): batch_features = shuffled_features[start:end] batch_dataset = FeatureDataset(...slice...) effective_dataset = ( truncate_dataset_timesteps(batch_dataset, args.curriculum_timesteps) if curriculum_enabled and global_step < args.curriculum_batches else batch_dataset ) params, opt_state, loss, _ = step_fn( params, opt_state, batch_features, effective_dataset, ) loss_history.append(float(loss) / (end - start))# LR reduction on plateau if patience_counter >= args.lr_patience and current_lr > args.min_lr: current_lr = max(current_lr * args.lr_factor, args.min_lr) optimizer = optax.adam(current_lr) opt_state = optimizer.init(params) step_fn = optnet.make_network_step_fn(..., optimizer, ...)
8) Summaries, plots, and checkpoints (lines ~1370)¶
- Predictions are summarized over the training set in chunks to avoid OOM and printed. This uses utilities from pct.optimization.network_training.
- Analysis artifacts (scenario plots) are generated for train/val splits under
analysis/. - Loss curves are saved to PNG and JSON; model params, optimizer state, and (if validation) best params and predicted params are serialized to msgpack/JSON.
- Final predicted parameters are written to
pct_predicted_params.json(or--output); WandB run summaries include best losses and curriculum settings.predicted_params_mean = summarize_predictions_chunked( model.apply, params_for_prediction, dataset.features, spec.param_keys, chunk_size=args.summary_chunk_size, param_shapes=spec.full_shapes, ) generate_analysis_artifacts("training", dataset, spec, model, params_for_prediction, ...) plt.plot(loss_history); plt.savefig(output_dir / "pct_network_loss.png") with model_param_path.open("wb") as f: f.write(serialization.to_bytes(params)) with output_path.open("w") as f: json.dump(predicted_params_mean, f, indent=2)
Try it yourself¶
Run a small experiment and tail the logs:
UV_CACHE_DIR=.uv-cache-local uv run --project PCT python PCT/bin/optimize_network.py \
--dataset datasets/june_training_data/train_redp_1 \
--epochs 2 --batch-size 64 --lr 1e-3 --model singletrack-augmented
results/singletrack-augmented/run_*; open the loss PNG and analysis plots to inspect the run.