GRT Reference Implementation & Project Template¶
The GRT reference implementation uses a hydra + pytorch lightning-based stack on top of the NRDK; use this reference implementation to get started on a new project.
Tip
The included reference implementation can be run out-of-the-box:
- Obtain a copy of the I/Q-1M, and save it (or link it) to
nrdk/grt/data/. - Create a virtual environment in
nrdk/grtwithuv sync. - Run with
uv run train.py; see the hydra config files innrdk/grt/config/for options.
Quick Start¶
-
Create a new repository, and copy the contents of the
grt/directory:example-project/ ├── config/ │ ├── aug/ | ... ├── grt/ │ ├── __init__.py | ... ├── pyproject.toml ├── train.py └── train_minimal.pyTip
Don't forget to change the
name,authors, anddescription! -
Set up the
nrdkdependency (nrdk[roverd] >= 0.1.5).Required Extras
Make sure you include the
roverdextra, which installs the following:roverd: a dataloader for data collected by the red-rover systemxwr: radar signal processing for TI mmWave radars
If using
uv, uncomment one of the corresponding lines in the suppliedpyproject.toml(and comment out the includednrdk = { path = "../" }line):
Training Script¶
The GRT template includes reference training scripts which can be used for high level training and fine tuning control flow. You can use these scripts as-is, or modify them to suit your needs; where possible, stick to the same general structure to maintain compatibility.
Reference Training Script
"""GRT Reference implementation training script."""
import logging
import os
from collections.abc import Mapping
from time import perf_counter
from typing import Any
import hydra
import torch
import yaml
from lightning.pytorch import callbacks
from omegaconf import DictConfig
from rich.logging import RichHandler
from nrdk.framework import Result
logger = logging.getLogger("train")
def _configure_logging(cfg: DictConfig) -> None:
log_level = cfg.meta.get("verbose", logging.INFO)
root = logging.getLogger()
root.setLevel(log_level)
root.handlers.clear()
rich_handler = RichHandler(markup=True)
rich_handler.setFormatter(logging.Formatter(
"[orange1]%(name)s:[/orange1] %(message)s"))
root.addHandler(rich_handler)
logger.debug(f"Configured with log level: {log_level}")
def _load_weights(lightningmodule, path: str, rename: Mapping = {}) -> None:
weights = Result(path).best if os.path.isdir(path) else path
lightningmodule.load_weights(weights, rename=rename)
def _get_best(trainer) -> dict[str, Any]:
for callback in trainer.callbacks:
if isinstance(callback, callbacks.ModelCheckpoint):
return {
"best_k": {
os.path.basename(k): v.item()
for k, v in callback.best_k_models.items()},
"best": os.path.basename(callback.best_model_path)
}
return {}
@hydra.main(version_base=None, config_path="./config", config_name="default")
def train(cfg: DictConfig) -> None:
"""Train a model using the GRT reference implementation."""
torch.set_float32_matmul_precision('high')
_configure_logging(cfg)
if cfg["meta"]["name"] is None or cfg["meta"]["version"] is None:
logger.error("Must set `meta.name` and `meta.version` in the config.")
return
def _inst(path, *args, **kwargs):
return hydra.utils.instantiate(
cfg[path], _convert_="all", *args, **kwargs)
transforms = _inst("transforms")
datamodule = _inst("datamodule", transforms=transforms)
lightningmodule = _inst("lightningmodule", transforms=transforms)
trainer = _inst("trainer")
if "base" in cfg:
_load_weights(lightningmodule, **cfg['base'])
start = perf_counter()
logger.info(
f"Start training @ {cfg["meta"]["results"]}/{cfg["meta"]["name"]}/"
f"{cfg["meta"]["version"]} [t={start:.3f}]")
trainer.fit(
model=lightningmodule, datamodule=datamodule,
ckpt_path=cfg['meta']['resume'])
duration = perf_counter() - start
logger.info(
f"Training completed in {duration / 60 / 60:.2f}h (={duration:.3f}s).")
meta: dict[str, Any] = {"duration": duration}
meta.update(_get_best(trainer))
logger.info(f"Best checkpoint: {meta.get('best')}")
meta_path = os.path.join(trainer.logger.log_dir, "checkpoints.yaml")
with open(meta_path, 'w') as f:
yaml.dump(meta, f, sort_keys=False)
if __name__ == "__main__":
train()
Minimal Training Script
"""GRT Reference implementation training script."""
import hydra
import torch
@hydra.main(version_base=None, config_path="./config", config_name="default")
def train(cfg):
"""Train a model using the GRT reference implementation."""
torch.set_float32_matmul_precision('medium')
def _inst(path, *args, **kwargs):
return hydra.utils.instantiate(
cfg[path], _convert_="all", *args, **kwargs)
transforms = _inst("transforms")
datamodule = _inst("datamodule", transforms=transforms)
lightningmodule = _inst("lightningmodule", transforms=transforms)
trainer = _inst("trainer")
trainer.fit(model=lightningmodule, datamodule=datamodule)
if __name__ == "__main__":
train()
Evaluation Script¶
Evaluate a trained model.
Supports three evaluation modes, in order of precedence:
- Enumerated traces: evaluate all traces specified by
--trace, relative to the--data-root. - Filtered evaluation: evaluate all traces in the configuration
(
datamodule/traces/test) that match the provided--filterregex. - Sample evaluation: evaluate a pseudo-random
--sampletaken from the test set specified in the configuration.
If none of --trace, --filter, or --sample are provided, defaults to
evaluating all traces specified in the configuration.
Tip
See Result for details about the expected
structure of the results directory.
Warning
Only supports using a single GPU; if multiple GPUs are available, use parallel evaluation instead.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str
|
path to results directory. |
required |
output
|
str | None
|
if specified, write results to this directory instead. |
None
|
sample
|
int | None
|
number of samples to evaluate. |
None
|
traces
|
list[str] | None
|
explicit list of traces to evaluate. |
None
|
filter
|
str | None
|
evaluate all traces matching this regex. |
None
|
data_root
|
str | None
|
root dataset directory; if |
None
|
device
|
str
|
device to use for evaluation. |
'cuda:0'
|
batch
|
int
|
batch size. |
32
|
workers
|
int
|
number of workers for data loading. |
32
|
prefetch
|
int
|
number of batches to prefetch per worker. |
2
|
Source code in grt/evaluate.py
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | |