Skip to content

GRT Reference Implementation & Project Template

The GRT reference implementation uses a hydra + pytorch lightning-based stack on top of the NRDK; use this reference implementation to get started on a new project.

Tip

The included reference implementation can be run out-of-the-box:

  1. Obtain a copy of the I/Q-1M, and save it (or link it) to nrdk/grt/data/.
  2. Create a virtual environment in nrdk/grt with uv sync.
  3. Run with uv run train.py; see the hydra config files in nrdk/grt/config/ for options.

Quick Start

  1. Create a new repository, and copy the contents of the grt/ directory:

    example-project/
    ├── config/
    │   ├── aug/
    |   ...
    ├── grt/
    │   ├── __init__.py
    |   ...
    ├── pyproject.toml
    ├── train.py
    └── train_minimal.py
    

    Tip

    Don't forget to change the name, authors, and description!

  2. Set up the nrdk dependency.

    Required Extras

    Make sure you include the roverd extra, which installs the following:

    If using uv, uncomment one of the corresponding lines in the supplied pyproject.toml (and comment out the included nrdk = { path = "../" } line):

    [tool.uv.sources]
    nrdk = { git = "ssh://git@github.com/radarml/nrdk.git" }
    

    After git submodule add git@github.com:RadarML/nrdk.git:

    [tool.uv.sources]
    nrdk = { path = "./nrdk" }
    

Training Script

The GRT template includes reference training scripts which can be used for high level training and fine tuning control flow. You can use these scripts as-is, or modify them to suit your needs; where possible, stick to the same general structure to maintain compatibility.

Reference Training Script
grt/train.py
"""GRT reference implementation training script."""

import logging
import os
from time import perf_counter
from typing import Any

import hydra
import torch
import yaml
from lightning.pytorch import callbacks


@hydra.main(version_base=None, config_path="./config", config_name="default")
def train(cfg):
    """Train a model using the GRT reference implementation."""
    torch.set_float32_matmul_precision('high')

    def _inst(path, *args, **kwargs):
        return hydra.utils.instantiate(
            cfg[path], _convert_="all", *args, **kwargs)

    if cfg["meta"]["name"] is None or cfg["meta"]["version"] is None:
        logging.error("Must set `meta.name` and `meta.version` in the config.")
        return

    transforms = _inst("transforms")
    datamodule = _inst("datamodule", transforms=transforms)
    lightningmodule = _inst("lightningmodule", transforms=transforms)
    trainer = _inst("trainer")

    if "base" in cfg:
        lightningmodule.load_weights(
            cfg['base']['path'], rename=cfg['base'].get('rename', {}))

    start = perf_counter()
    trainer.fit(
        model=lightningmodule, datamodule=datamodule,
        ckpt_path=cfg['meta']['resume'])
    duration = perf_counter() - start

    meta: dict[str, Any] = {"duration": duration}
    for callback in trainer.callbacks:
        if isinstance(callback, callbacks.ModelCheckpoint):
            meta["best_k"] = {
                os.path.basename(k): v.item()
                for k, v in callback.best_k_models.items()}
            meta["best"] = os.path.basename(callback.best_model_path)
            break

    meta_path = os.path.join(trainer.logger.log_dir, "checkpoints.yaml")
    with open(meta_path, 'w') as f:
        yaml.dump(meta, f, sort_keys=False)


if __name__ == "__main__":
    train()
Minimal Training Script
grt/train_minimal.py
"""GRT Reference implementation training script."""

import hydra
import torch


@hydra.main(version_base=None, config_path="./config", config_name="default")
def train(cfg):
    """Train a model using the GRT reference implementation."""
    torch.set_float32_matmul_precision('medium')

    def _inst(path, *args, **kwargs):
        return hydra.utils.instantiate(
            cfg[path], _convert_="all", *args, **kwargs)

    transforms = _inst("transforms")
    datamodule = _inst("datamodule", transforms=transforms)
    lightningmodule = _inst("lightningmodule", transforms=transforms)
    trainer = _inst("trainer")
    trainer.fit(model=lightningmodule, datamodule=datamodule)

if __name__ == "__main__":
    train()