Skip to content

GRT Reference Implementation & Project Template

The GRT reference implementation uses a hydra + pytorch lightning-based stack on top of the NRDK; use this reference implementation to get started on a new project.

Tip

The included reference implementation can be run out-of-the-box:

  1. Obtain a copy of the I/Q-1M, and save it (or link it) to nrdk/grt/data/.
  2. Create a virtual environment in nrdk/grt with uv sync.
  3. Run with uv run train.py; see the hydra config files in nrdk/grt/config/ for options.

Pre-Trained Checkpoints

Pre-trained model checkpoints for the GRT reference implementation on the I/Q-1M dataset can also be found here.

With a single GPU, these checkpoints can be reproduced with the following:

The 3D polar occupancy base model is provided as the default configuration (i.e., sensors=[radar,lidar], transforms@transforms.sample=[radar,lidar3d], objective=lidar3d, model/decoder=lidar3d).

uv run train.py meta.name=base meta.version=small size=small

uv run train.py meta.name=occ2d meta.version=small size=small \
    +base=occ3d_to_occ2d \
    sensors=[radar,lidar] \
    transforms@transforms.sample=[radar,lidar2d] \
    objective=lidar2d \
    model/decoder=lidar2d
uv run train.py meta.name=semseg meta.version=small size=small \
    +base=occ3d_to_semseg \
    sensors=[radar,semseg] \
    transforms@transforms.sample=[radar,semseg] \
    objective=semseg \
    model/decoder=semseg
uv run train.py meta.name=vel meta.version=small size=small \
    +base=occ3d_to_vel \
    sensors=[radar,pose] \
    transforms@transforms.sample=[radar,vel] \
    objective=vel \
    model/decoder=vel

Tip

If you're not running in a "managed" environment (e.g., Slurm, LSF, AzureML), nq is a lightweight way to run jobs in a queue. Just sudo apt-get install -y nq, and run with nq uv run train.py ....

Quick Start

  1. Create a new repository, and copy the contents of the grt/ directory:

    example-project/
    ├── config/
    │   ├── aug/
    |   ...
    ├── grt/
    │   ├── __init__.py
    |   ...
    ├── pyproject.toml
    ├── train.py
    └── train_minimal.py
    

    Tip

    Don't forget to change the name, authors, and description!

  2. Set up the nrdk dependency (nrdk[roverd] >= 0.1.5).

    Required Extras

    Make sure you include the roverd extra, which installs the following:

    If using uv, uncomment one of the corresponding lines in the supplied pyproject.toml (and comment out the included nrdk = { path = "../" } line):

    [tool.uv.sources]
    nrdk = { git = "ssh://git@github.com/radarml/nrdk.git" }
    

    After git submodule add git@github.com:RadarML/nrdk.git:

    [tool.uv.sources]
    nrdk = { path = "./nrdk" }
    

Training Script

The GRT template includes reference training scripts which can be used for high level training and fine tuning control flow. You can use these scripts as-is, or modify them to suit your needs; where possible, stick to the same general structure to maintain compatibility.

Reference Training Script
grt/train.py
"""GRT Reference implementation training script."""

import logging
import os
from collections.abc import Mapping
from time import perf_counter
from typing import Any

import hydra
import torch
import yaml
from lightning.pytorch import callbacks
from omegaconf import DictConfig

from nrdk.config import configure_rich_logging
from nrdk.framework import Result

logger = logging.getLogger("train")


def _load_weights(lightningmodule, path: str, rename: Mapping = {}) -> None:
    weights = Result(path).best if os.path.isdir(path) else path
    lightningmodule.load_weights(weights, rename=rename)


def _get_best(trainer) -> dict[str, Any]:
    for callback in trainer.callbacks:
        if isinstance(callback, callbacks.ModelCheckpoint):
            return {
                "best_k": {
                    os.path.basename(k): v.item()
                    for k, v in callback.best_k_models.items()},
                "best": os.path.basename(callback.best_model_path)
            }
    return {}


def _autoscale_batch_size(batch: int) -> int:
    n_gpus = torch.cuda.device_count()
    if n_gpus > 1:
        batch_new = batch // n_gpus
        logger.info(
            f"Auto-scaling batch size by n_gpus={n_gpus}: "
            f"{batch} -> {batch_new}")
        return batch_new
    return batch


@hydra.main(version_base=None, config_path="./config", config_name="default")
def train(cfg: DictConfig) -> None:
    """Train a model using the GRT reference implementation."""
    torch.set_float32_matmul_precision('high')

    _log_level = configure_rich_logging(cfg.meta.get("verbose", logging.INFO))
    logger.debug(f"Configured with log level: {_log_level}")

    if cfg["meta"]["name"] is None or cfg["meta"]["version"] is None:
        logger.error("Must set `meta.name` and `meta.version` in the config.")
        return
    cfg["datamodule"]["batch_size"] = _autoscale_batch_size(
        cfg["datamodule"]["batch_size"])

    def _inst(path, *args, **kwargs):
        return hydra.utils.instantiate(
            cfg[path], _convert_="all", *args, **kwargs)

    transforms = _inst("transforms")
    datamodule = _inst("datamodule", transforms=transforms)
    lightningmodule = _inst("lightningmodule", transforms=transforms)
    trainer = _inst("trainer")
    if "base" in cfg:
        _load_weights(lightningmodule, **cfg['base'])

    if cfg["meta"]["compile"]:
        lightningmodule.compile()

    start = perf_counter()
    logger.info(
        f"Start training @ {cfg["meta"]["results"]}/{cfg["meta"]["name"]}/"
        f"{cfg["meta"]["version"]} [t={start:.3f}]")
    trainer.fit(
        model=lightningmodule, datamodule=datamodule,
        ckpt_path=cfg['meta']['resume'])
    duration = perf_counter() - start
    logger.info(
        f"Training completed in {duration / 60 / 60:.2f}h (={duration:.3f}s).")

    meta: dict[str, Any] = {"duration": duration}
    meta.update(_get_best(trainer))
    logger.info(f"Best checkpoint: {meta.get('best')}")
    meta_path = os.path.join(trainer.logger.log_dir, "checkpoints.yaml")
    with open(meta_path, 'w') as f:
        yaml.dump(meta, f, sort_keys=False)


if __name__ == "__main__":
    train()
Minimal Training Script
grt/train_minimal.py
"""GRT Reference implementation training script."""

import hydra
import torch


@hydra.main(version_base=None, config_path="./config", config_name="default")
def train(cfg):
    """Train a model using the GRT reference implementation."""
    torch.set_float32_matmul_precision('medium')

    def _inst(path, *args, **kwargs):
        return hydra.utils.instantiate(
            cfg[path], _convert_="all", *args, **kwargs)

    transforms = _inst("transforms")
    datamodule = _inst("datamodule", transforms=transforms)
    lightningmodule = _inst("lightningmodule", transforms=transforms)
    trainer = _inst("trainer")
    trainer.fit(model=lightningmodule, datamodule=datamodule)

if __name__ == "__main__":
    train()

Compile the Model

You can invoke the pytorch JIT Compiler by setting meta.compile=True. Since the pytorch compiler is kind of janky and causes issues with type checkers, you will also need to set

export JAXTYPING_DISABLE=1

Evaluation Script

Evaluate a trained model.

Supports three evaluation modes, in order of precedence:

  1. Enumerated traces: evaluate all traces specified by --trace, relative to the --data-root.
  2. Filtered evaluation: evaluate all traces in the configuration (datamodule/traces/test) that match the provided --filter regex.
  3. Sample evaluation: evaluate a pseudo-random --sample taken from the test set specified in the configuration.

If none of --trace, --filter, or --sample are provided, defaults to evaluating all traces specified in the configuration.

Tip

See Result for details about the expected structure of the results directory.

Warning

Only supports using a single GPU; if multiple GPUs are available, use parallel evaluation instead.

Parameters:

Name Type Description Default
path str

path to results directory.

required
output str | None

if specified, write results to this directory instead.

None
sample int | None

number of samples to evaluate.

None
traces list[str] | None

explicit list of traces to evaluate.

None
filter str | None

evaluate all traces matching this regex.

None
data_root str | None

root dataset directory; if None, use the path specified in meta/dataset in the config.

None
device str

device to use for evaluation.

'cuda:0'
compile bool

whether to compile the model using torch.compile.

False
batch int

batch size.

32
workers int

number of workers for data loading.

32
prefetch int

number of batches to prefetch per worker.

2
verbose int

logging verbosity level.

INFO
Source code in grt/evaluate.py
def evaluate(
    path: str, /, output: str | None = None, sample: int | None = None,
    traces: list[str] | None = None, filter: str | None = None,
    data_root: str | None = None,
    device: str = "cuda:0", compile: bool = False,
    batch: int = 32, workers: int = 32, prefetch: int = 2,
    verbose: int = logging.INFO
) -> None:
    """Evaluate a trained model.

    Supports three evaluation modes, in order of precedence:

    1. Enumerated traces: evaluate all traces specified by `--trace`, relative
        to the `--data-root`.
    2. Filtered evaluation: evaluate all traces in the configuration
        (`datamodule/traces/test`) that match the provided `--filter` regex.
    3. Sample evaluation: evaluate a pseudo-random `--sample` taken from
        the test set specified in the configuration.

    If none of `--trace`, `--filter`, or `--sample` are provided, defaults to
    evaluating all traces specified in the configuration.

    !!! tip

        See [`Result`][nrdk.framework.Result] for details about the expected
        structure of the results directory.

    !!! warning

        Only supports using a single GPU; if multiple GPUs are available,
        use parallel evaluation instead.

    Args:
        path: path to results directory.
        output: if specified, write results to this directory instead.
        sample: number of samples to evaluate.
        traces: explicit list of traces to evaluate.
        filter: evaluate all traces matching this regex.
        data_root: root dataset directory; if `None`, use the path specified
            in `meta/dataset` in the config.
        device: device to use for evaluation.
        compile: whether to compile the model using `torch.compile`.
        batch: batch size.
        workers: number of workers for data loading.
        prefetch: number of batches to prefetch per worker.
        verbose: logging verbosity level.
    """
    torch.set_float32_matmul_precision('high')

    configure_rich_logging(verbose)
    logger.debug(f"Configured with log level: {verbose}")

    result = Result(path)
    cfg = result.config()
    if sample is not None:
        cfg["datamodule"]["subsample"]["test"] = sample

    if output is None:
        output = os.path.join(path, "eval")

    if data_root is None:
        data_root = cfg["meta"]["dataset"]
        if data_root is None:
            raise ValueError(
                "`--data_root` must be specified if `meta/dataset` is not set "
                "in the config.")
    else:
        cfg["meta"]["dataset"] = data_root

    cfg["datamodule"]["batch_size"] = batch
    cfg["datamodule"]["num_workers"] = workers
    cfg["datamodule"]["prefetch_factor"] = prefetch

    transforms = hydra.utils.instantiate(cfg["transforms"])
    lightningmodule = hydra.utils.instantiate(
        cfg["lightningmodule"], transforms=transforms).to(device)
    lightningmodule.load_weights(result.best)

    if compile:
        lightningmodule.compile()

    dataloaders = _get_dataloaders(
        cfg, data_root, transforms,
        traces=traces, filter=filter, sample=sample)

    start = time.perf_counter()
    for trace, dl_constructor in dataloaders.items():
        dataloader = dl_constructor()
        _evaluate_trace(dataloader, lightningmodule, device, trace, output)
    logger.info(f"Evaluation completed in {time.perf_counter() - start:.3f}s.")