Skip to content

GRT Reference Implementation & Project Template

The GRT reference implementation uses a hydra + pytorch lightning-based stack on top of the NRDK; use this reference implementation to get started on a new project.

Tip

The included reference implementation can be run out-of-the-box:

  1. Obtain a copy of the I/Q-1M, and save it (or link it) to nrdk/grt/data/.
  2. Create a virtual environment in nrdk/grt with uv sync.
  3. Run with uv run train.py; see the hydra config files in nrdk/grt/config/ for options.

Quick Start

  1. Create a new repository, and copy the contents of the grt/ directory:

    example-project/
    ├── config/
    │   ├── aug/
    |   ...
    ├── grt/
    │   ├── __init__.py
    |   ...
    ├── pyproject.toml
    ├── train.py
    └── train_minimal.py
    

    Tip

    Don't forget to change the name, authors, and description!

  2. Set up the nrdk dependency (nrdk[roverd] >= 0.1.5).

    Required Extras

    Make sure you include the roverd extra, which installs the following:

    If using uv, uncomment one of the corresponding lines in the supplied pyproject.toml (and comment out the included nrdk = { path = "../" } line):

    [tool.uv.sources]
    nrdk = { git = "ssh://git@github.com/radarml/nrdk.git" }
    

    After git submodule add git@github.com:RadarML/nrdk.git:

    [tool.uv.sources]
    nrdk = { path = "./nrdk" }
    

Training Script

The GRT template includes reference training scripts which can be used for high level training and fine tuning control flow. You can use these scripts as-is, or modify them to suit your needs; where possible, stick to the same general structure to maintain compatibility.

Reference Training Script
grt/train.py
"""GRT Reference implementation training script."""

import logging
import os
from collections.abc import Mapping
from time import perf_counter
from typing import Any

import hydra
import torch
import yaml
from lightning.pytorch import callbacks
from omegaconf import DictConfig
from rich.logging import RichHandler

from nrdk.framework import Result

logger = logging.getLogger("train")

def _configure_logging(cfg: DictConfig) -> None:
    log_level = cfg.meta.get("verbose", logging.INFO)
    root = logging.getLogger()
    root.setLevel(log_level)
    root.handlers.clear()

    rich_handler = RichHandler(markup=True)
    rich_handler.setFormatter(logging.Formatter(
        "[orange1]%(name)s:[/orange1] %(message)s"))
    root.addHandler(rich_handler)

    logger.debug(f"Configured with log level: {log_level}")


def _load_weights(lightningmodule, path: str, rename: Mapping = {}) -> None:
    weights = Result(path).best if os.path.isdir(path) else path
    lightningmodule.load_weights(weights, rename=rename)


def _get_best(trainer) -> dict[str, Any]:
    for callback in trainer.callbacks:
        if isinstance(callback, callbacks.ModelCheckpoint):
            return {
                "best_k": {
                    os.path.basename(k): v.item()
                    for k, v in callback.best_k_models.items()},
                "best": os.path.basename(callback.best_model_path)
            }
    return {}


@hydra.main(version_base=None, config_path="./config", config_name="default")
def train(cfg: DictConfig) -> None:
    """Train a model using the GRT reference implementation."""
    torch.set_float32_matmul_precision('high')
    _configure_logging(cfg)

    if cfg["meta"]["name"] is None or cfg["meta"]["version"] is None:
        logger.error("Must set `meta.name` and `meta.version` in the config.")
        return

    def _inst(path, *args, **kwargs):
        return hydra.utils.instantiate(
            cfg[path], _convert_="all", *args, **kwargs)

    transforms = _inst("transforms")
    datamodule = _inst("datamodule", transforms=transforms)
    lightningmodule = _inst("lightningmodule", transforms=transforms)
    trainer = _inst("trainer")
    if "base" in cfg:
        _load_weights(lightningmodule, **cfg['base'])

    start = perf_counter()
    logger.info(
        f"Start training @ {cfg["meta"]["results"]}/{cfg["meta"]["name"]}/"
        f"{cfg["meta"]["version"]} [t={start:.3f}]")
    trainer.fit(
        model=lightningmodule, datamodule=datamodule,
        ckpt_path=cfg['meta']['resume'])
    duration = perf_counter() - start
    logger.info(
        f"Training completed in {duration / 60 / 60:.2f}h (={duration:.3f}s).")

    meta: dict[str, Any] = {"duration": duration}
    meta.update(_get_best(trainer))
    logger.info(f"Best checkpoint: {meta.get('best')}")
    meta_path = os.path.join(trainer.logger.log_dir, "checkpoints.yaml")
    with open(meta_path, 'w') as f:
        yaml.dump(meta, f, sort_keys=False)


if __name__ == "__main__":
    train()
Minimal Training Script
grt/train_minimal.py
"""GRT Reference implementation training script."""

import hydra
import torch


@hydra.main(version_base=None, config_path="./config", config_name="default")
def train(cfg):
    """Train a model using the GRT reference implementation."""
    torch.set_float32_matmul_precision('medium')

    def _inst(path, *args, **kwargs):
        return hydra.utils.instantiate(
            cfg[path], _convert_="all", *args, **kwargs)

    transforms = _inst("transforms")
    datamodule = _inst("datamodule", transforms=transforms)
    lightningmodule = _inst("lightningmodule", transforms=transforms)
    trainer = _inst("trainer")
    trainer.fit(model=lightningmodule, datamodule=datamodule)

if __name__ == "__main__":
    train()

Evaluation Script

Evaluate a trained model.

Supports three evaluation modes, in order of precedence:

  1. Enumerated traces: evaluate all traces specified by --trace, relative to the --data-root.
  2. Filtered evaluation: evaluate all traces in the configuration (datamodule/traces/test) that match the provided --filter regex.
  3. Sample evaluation: evaluate a pseudo-random --sample taken from the test set specified in the configuration.

If none of --trace, --filter, or --sample are provided, defaults to evaluating all traces specified in the configuration.

Tip

See Result for details about the expected structure of the results directory.

Warning

Only supports using a single GPU; if multiple GPUs are available, use parallel evaluation instead.

Parameters:

Name Type Description Default
path str

path to results directory.

required
output str | None

if specified, write results to this directory instead.

None
sample int | None

number of samples to evaluate.

None
traces list[str] | None

explicit list of traces to evaluate.

None
filter str | None

evaluate all traces matching this regex.

None
data_root str | None

root dataset directory; if None, use the path specified in meta/dataset in the config.

None
device str

device to use for evaluation.

'cuda:0'
batch int

batch size.

32
workers int

number of workers for data loading.

32
prefetch int

number of batches to prefetch per worker.

2
Source code in grt/evaluate.py
def evaluate(
    path: str, /, output: str | None = None, sample: int | None = None,
    traces: list[str] | None = None, filter: str | None = None,
    data_root: str | None = None,
    device: str = "cuda:0",
    batch: int = 32, workers: int = 32, prefetch: int = 2
) -> None:
    """Evaluate a trained model.

    Supports three evaluation modes, in order of precedence:

    1. Enumerated traces: evaluate all traces specified by `--trace`, relative
        to the `--data-root`.
    2. Filtered evaluation: evaluate all traces in the configuration
        (`datamodule/traces/test`) that match the provided `--filter` regex.
    3. Sample evaluation: evaluate a pseudo-random `--sample` taken from
        the test set specified in the configuration.

    If none of `--trace`, `--filter`, or `--sample` are provided, defaults to
    evaluating all traces specified in the configuration.

    !!! tip

        See [`Result`][nrdk.framework.Result] for details about the expected
        structure of the results directory.

    !!! warning

        Only supports using a single GPU; if multiple GPUs are available,
        use parallel evaluation instead.

    Args:
        path: path to results directory.
        output: if specified, write results to this directory instead.
        sample: number of samples to evaluate.
        traces: explicit list of traces to evaluate.
        filter: evaluate all traces matching this regex.
        data_root: root dataset directory; if `None`, use the path specified
            in `meta/dataset` in the config.
        device: device to use for evaluation.
        batch: batch size.
        workers: number of workers for data loading.
        prefetch: number of batches to prefetch per worker.
    """
    torch.set_float32_matmul_precision('high')

    result = Result(path)
    cfg = result.config()
    if sample is not None:
        cfg["datamodule"]["subsample"]["test"] = sample

    if output is None:
        output = os.path.join(path, "eval")

    if data_root is None:
        data_root = cfg["meta"]["dataset"]
        if data_root is None:
            raise ValueError(
                "`--data_root` must be specified if `meta/dataset` is not set "
                "in the config.")
    else:
        cfg["meta"]["dataset"] = data_root

    cfg["datamodule"]["batch_size"] = batch
    cfg["datamodule"]["num_workers"] = workers
    cfg["datamodule"]["prefetch_factor"] = prefetch
    cfg["lightningmodule"]["compile"] = False

    transforms = hydra.utils.instantiate(cfg["transforms"])
    lightningmodule = hydra.utils.instantiate(
        cfg["lightningmodule"], transforms=transforms).to(device)
    lightningmodule.load_weights(result.best)

    dataloaders = _get_dataloaders(
        cfg, data_root, transforms,
        traces=traces, filter=filter, sample=sample)

    def collect_metadata(y_true):
        return {
            f"meta/{k}/ts": getattr(v, "timestamps")
            for k, v in y_true.items() if hasattr(v, "timestamps")
        }

    for trace, dl_constructor in dataloaders.items():
        dataloader = dl_constructor()
        eval_stream = tqdm(
            Prefetch(lightningmodule.evaluate(
                dataloader, metadata=collect_metadata, device=device)),
            total=len(dataloader), desc=trace)

        output_container = DynamicSensor(
            os.path.join(output, trace), create=True, exist_ok=True)
        metrics = []
        outputs = {}
        for batch_metrics, vis in eval_stream:
            if len(outputs) == 0:
                for k, v in vis.items():
                    outputs[k] = Queue()
                    output_container.create(
                        k.split("/")[-1], meta={
                            "format": "lzmaf",
                            "type": f"{v.dtype.kind}{v.dtype.itemsize}",
                            "shape": v.shape[1:],
                            "desc": f"eval_render:{k}"
                        }
                    ).consume(outputs[k], thread=True)

            for k, v in vis.items():
                for sample in v:
                    outputs[k].put(sample)
            metrics.append(batch_metrics)

        for q in outputs.values():
            q.put(None)

        metrics = {
            k: np.concatenate([m[k] for m in metrics], axis=0)
            for k in metrics[0]}
        np.savez_compressed(
            os.path.join(output, trace, "metrics.npz"),
            **metrics, allow_pickle=False)

        output_container.create("ts", meta={
            "format": "raw", "type": "f8", "shape": (),
            "desc": "reference timestamps"}
        ).write(metrics["meta/spectrum/ts"])