GRT Reference Implementation & Project Template¶
The GRT reference implementation uses a hydra + pytorch lightning-based stack on top of the NRDK; use this reference implementation to get started on a new project.
Tip
The included reference implementation can be run out-of-the-box:
- Obtain a copy of the I/Q-1M, and save it (or link it) to
nrdk/grt/data/
. - Create a virtual environment in
nrdk/grt
withuv sync
. - Run with
uv run train.py
; see the hydra config files innrdk/grt/config/
for options.
Quick Start¶
-
Create a new repository, and copy the contents of the
grt/
directory:example-project/ ├── config/ │ ├── aug/ | ... ├── grt/ │ ├── __init__.py | ... ├── pyproject.toml ├── train.py └── train_minimal.py
Tip
Don't forget to change the
name
,authors
, anddescription
! -
Set up the
nrdk
dependency.Required Extras
Make sure you include the
roverd
extra, which installs the following:roverd
: a dataloader for data collected by the red-rover systemxwr
: radar signal processing for TI mmWave radars
If using
uv
, uncomment one of the corresponding lines in the suppliedpyproject.toml
(and comment out the includednrdk = { path = "../" }
line):
Training Script¶
The GRT template includes reference training scripts which can be used for high level training and fine tuning control flow. You can use these scripts as-is, or modify them to suit your needs; where possible, stick to the same general structure to maintain compatibility.
Reference Training Script
"""GRT reference implementation training script."""
import logging
import os
from time import perf_counter
from typing import Any
import hydra
import torch
import yaml
from lightning.pytorch import callbacks
@hydra.main(version_base=None, config_path="./config", config_name="default")
def train(cfg):
"""Train a model using the GRT reference implementation."""
torch.set_float32_matmul_precision('high')
def _inst(path, *args, **kwargs):
return hydra.utils.instantiate(
cfg[path], _convert_="all", *args, **kwargs)
if cfg["meta"]["name"] is None or cfg["meta"]["version"] is None:
logging.error("Must set `meta.name` and `meta.version` in the config.")
return
transforms = _inst("transforms")
datamodule = _inst("datamodule", transforms=transforms)
lightningmodule = _inst("lightningmodule", transforms=transforms)
trainer = _inst("trainer")
if "base" in cfg:
lightningmodule.load_weights(
cfg['base']['path'], rename=cfg['base'].get('rename', {}))
start = perf_counter()
trainer.fit(
model=lightningmodule, datamodule=datamodule,
ckpt_path=cfg['meta']['resume'])
duration = perf_counter() - start
meta: dict[str, Any] = {"duration": duration}
for callback in trainer.callbacks:
if isinstance(callback, callbacks.ModelCheckpoint):
meta["best_k"] = {
os.path.basename(k): v.item()
for k, v in callback.best_k_models.items()}
meta["best"] = os.path.basename(callback.best_model_path)
break
meta_path = os.path.join(trainer.logger.log_dir, "checkpoints.yaml")
with open(meta_path, 'w') as f:
yaml.dump(meta, f, sort_keys=False)
if __name__ == "__main__":
train()
Minimal Training Script
"""GRT Reference implementation training script."""
import hydra
import torch
@hydra.main(version_base=None, config_path="./config", config_name="default")
def train(cfg):
"""Train a model using the GRT reference implementation."""
torch.set_float32_matmul_precision('medium')
def _inst(path, *args, **kwargs):
return hydra.utils.instantiate(
cfg[path], _convert_="all", *args, **kwargs)
transforms = _inst("transforms")
datamodule = _inst("datamodule", transforms=transforms)
lightningmodule = _inst("lightningmodule", transforms=transforms)
trainer = _inst("trainer")
trainer.fit(model=lightningmodule, datamodule=datamodule)
if __name__ == "__main__":
train()