GRT Reference Implementation & Project Template¶
The GRT reference implementation uses a hydra + pytorch lightning-based stack on top of the NRDK; use this reference implementation to get started on a new project.
Tip
The included reference implementation can be run out-of-the-box:
- Obtain a copy of the I/Q-1M, and save it (or link it) to
nrdk/grt/data/. - Create a virtual environment in
nrdk/grtwithuv sync. - Run with
uv run train.py; see the hydra config files innrdk/grt/config/for options.
Pre-Trained Checkpoints¶
Pre-trained model checkpoints for the GRT reference implementation on the I/Q-1M dataset can also be found here.
With a single GPU, these checkpoints can be reproduced with the following:
The 3D polar occupancy base model is provided as the default configuration (i.e., sensors=[radar,lidar], transforms@transforms.sample=[radar,lidar3d], objective=lidar3d, model/decoder=lidar3d).
Tip
If you're not running in a "managed" environment (e.g., Slurm, LSF, AzureML), nq is a lightweight way to run jobs in a queue. Just sudo apt-get install -y nq, and run with nq uv run train.py ....
Quick Start¶
-
Create a new repository, and copy the contents of the
grt/directory:example-project/ ├── config/ │ ├── aug/ | ... ├── grt/ │ ├── __init__.py | ... ├── pyproject.toml ├── train.py └── train_minimal.pyTip
Don't forget to change the
name,authors, anddescription! -
Set up the
nrdkdependency (nrdk[roverd] >= 0.1.5).Required Extras
Make sure you include the
roverdextra, which installs the following:roverd: a dataloader for data collected by the red-rover systemxwr: radar signal processing for TI mmWave radars
If using
uv, uncomment one of the corresponding lines in the suppliedpyproject.toml(and comment out the includednrdk = { path = "../" }line):
Training Script¶
The GRT template includes reference training scripts which can be used for high level training and fine tuning control flow. You can use these scripts as-is, or modify them to suit your needs; where possible, stick to the same general structure to maintain compatibility.
Reference Training Script
"""GRT Reference implementation training script."""
import logging
import os
from collections.abc import Mapping
from time import perf_counter
from typing import Any
import hydra
import torch
import yaml
from lightning.pytorch import callbacks
from omegaconf import DictConfig
from nrdk.config import configure_rich_logging
from nrdk.framework import Result
logger = logging.getLogger("train")
def _load_weights(lightningmodule, path: str, rename: Mapping = {}) -> None:
weights = Result(path).best if os.path.isdir(path) else path
lightningmodule.load_weights(weights, rename=rename)
def _get_best(trainer) -> dict[str, Any]:
for callback in trainer.callbacks:
if isinstance(callback, callbacks.ModelCheckpoint):
return {
"best_k": {
os.path.basename(k): v.item()
for k, v in callback.best_k_models.items()},
"best": os.path.basename(callback.best_model_path)
}
return {}
def _autoscale_batch_size(batch: int) -> int:
n_gpus = torch.cuda.device_count()
if n_gpus > 1:
batch_new = batch // n_gpus
logger.info(
f"Auto-scaling batch size by n_gpus={n_gpus}: "
f"{batch} -> {batch_new}")
return batch_new
return batch
@hydra.main(version_base=None, config_path="./config", config_name="default")
def train(cfg: DictConfig) -> None:
"""Train a model using the GRT reference implementation."""
torch.set_float32_matmul_precision('high')
_log_level = configure_rich_logging(cfg.meta.get("verbose", logging.INFO))
logger.debug(f"Configured with log level: {_log_level}")
if cfg["meta"]["name"] is None or cfg["meta"]["version"] is None:
logger.error("Must set `meta.name` and `meta.version` in the config.")
return
cfg["datamodule"]["batch_size"] = _autoscale_batch_size(
cfg["datamodule"]["batch_size"])
def _inst(path, *args, **kwargs):
return hydra.utils.instantiate(
cfg[path], _convert_="all", *args, **kwargs)
transforms = _inst("transforms")
datamodule = _inst("datamodule", transforms=transforms)
lightningmodule = _inst("lightningmodule", transforms=transforms)
trainer = _inst("trainer")
if "base" in cfg:
_load_weights(lightningmodule, **cfg['base'])
if cfg["meta"]["compile"]:
lightningmodule.compile()
start = perf_counter()
logger.info(
f"Start training @ {cfg["meta"]["results"]}/{cfg["meta"]["name"]}/"
f"{cfg["meta"]["version"]} [t={start:.3f}]")
trainer.fit(
model=lightningmodule, datamodule=datamodule,
ckpt_path=cfg['meta']['resume'])
duration = perf_counter() - start
logger.info(
f"Training completed in {duration / 60 / 60:.2f}h (={duration:.3f}s).")
meta: dict[str, Any] = {"duration": duration}
meta.update(_get_best(trainer))
logger.info(f"Best checkpoint: {meta.get('best')}")
meta_path = os.path.join(trainer.logger.log_dir, "checkpoints.yaml")
with open(meta_path, 'w') as f:
yaml.dump(meta, f, sort_keys=False)
if __name__ == "__main__":
train()
Minimal Training Script
"""GRT Reference implementation training script."""
import hydra
import torch
@hydra.main(version_base=None, config_path="./config", config_name="default")
def train(cfg):
"""Train a model using the GRT reference implementation."""
torch.set_float32_matmul_precision('medium')
def _inst(path, *args, **kwargs):
return hydra.utils.instantiate(
cfg[path], _convert_="all", *args, **kwargs)
transforms = _inst("transforms")
datamodule = _inst("datamodule", transforms=transforms)
lightningmodule = _inst("lightningmodule", transforms=transforms)
trainer = _inst("trainer")
trainer.fit(model=lightningmodule, datamodule=datamodule)
if __name__ == "__main__":
train()
Compile the Model
You can invoke the pytorch JIT Compiler by setting meta.compile=True. Since the pytorch compiler is kind of janky and causes issues with type checkers, you will also need to set
Evaluation Script¶
Evaluate a trained model.
Supports three evaluation modes, in order of precedence:
- Enumerated traces: evaluate all traces specified by
--trace, relative to the--data-root. - Filtered evaluation: evaluate all traces in the configuration
(
datamodule/traces/test) that match the provided--filterregex. - Sample evaluation: evaluate a pseudo-random
--sampletaken from the test set specified in the configuration.
If none of --trace, --filter, or --sample are provided, defaults to
evaluating all traces specified in the configuration.
Tip
See Result for details about the expected
structure of the results directory.
Warning
Only supports using a single GPU; if multiple GPUs are available, use parallel evaluation instead.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str
|
path to results directory. |
required |
output
|
str | None
|
if specified, write results to this directory instead. |
None
|
sample
|
int | None
|
number of samples to evaluate. |
None
|
traces
|
list[str] | None
|
explicit list of traces to evaluate. |
None
|
filter
|
str | None
|
evaluate all traces matching this regex. |
None
|
data_root
|
str | None
|
root dataset directory; if |
None
|
device
|
str
|
device to use for evaluation. |
'cuda:0'
|
compile
|
bool
|
whether to compile the model using |
False
|
batch
|
int
|
batch size. |
32
|
workers
|
int
|
number of workers for data loading. |
32
|
prefetch
|
int
|
number of batches to prefetch per worker. |
2
|
verbose
|
int
|
logging verbosity level. |
INFO
|
Source code in grt/evaluate.py
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | |