Skip to content

NRDK CLI

Usage

The CLI tools use tyro:

  • Positional arguments ("required") are passed as positional command line arguments
  • Named arguments are passed as flagged command line arguments

nrdk export

Export model weights from a full-service checkpoint.

Usage

Take the best checkpoint in results/experiment/version, and export the model to results/experiment/version/weights.pth:

nrdk export results/experiment/version

The model is assumed to be created by NRDKLightningModule (via pytorch lightning), so has a state_dict attribute that contains the model weights, where each key has a leading .model prefix.

  • If the path points to a file, export that checkpoint.
  • If the path is a directory, the directory should have a checkpoints.yaml with a best key which specifies the best checkpoint in a checkpoints/ directory; the exported model is saved relative to the path.

Parameters:

Name Type Description Default
path str

path to the checkpoint.

required
output str

path to save the exported weights.

'weights.pth'
Source code in src/nrdk/_cli/export.py
def cli_export(
    path: str, /, output: str = "weights.pth"
) -> None:
    """Export model weights from a full-service checkpoint.

    !!! info "Usage"

        Take the best checkpoint in `results/experiment/version`, and export
        the model to `results/experiment/version/weights.pth`:
        ```sh
        nrdk export results/experiment/version
        ```

    The model is assumed to be created by
    [`NRDKLightningModule`][nrdk.framework.] (via pytorch lightning), so has a
    `state_dict` attribute that contains the model weights, where each key
    has a leading `.model` prefix.

    - If the `path` points to a file, export that checkpoint.
    - If the `path` is a directory, the directory should have a
        `checkpoints.yaml` with a `best` key which specifies the best
        checkpoint in a `checkpoints/` directory; the exported model is saved
        relative to the `path`.

    Args:
        path: path to the checkpoint.
        output: path to save the exported weights.
    """
    if os.path.isdir(path):
        output = os.path.join(path, output)

    if os.path.isdir(path):
        try:
            with open(os.path.join(path, "checkpoints.yaml"), 'r') as f:
                contents = yaml.safe_load(f)
            path = os.path.join(path, "checkpoints", contents['best'])
        except FileNotFoundError:
            raise FileNotFoundError(f"No 'checkpoints.yaml' found in {path}.")

    contents = torch.load(path, map_location='cpu')
    if 'state_dict' not in contents:
        raise ValueError(f"Checkpoint {path} does not contain 'state_dict'.")

    print(f"Exporting: {path}")
    print(f"--> {output}")
    pattern = re.compile(r"^model\.")
    state_dict = {
        pattern.sub("", k): v for k, v in contents['state_dict'].items()}
    torch.save(state_dict, output)

nrdk inspect

Inspect a pytorch / pytorch lightning checkpoint.

Usage

Inspect a representative (most recent) checkpoint:

nrdk inspect results/experiment/version

If the path points to a file, inspect that checkpoint; if it points to a directory, inspect the most recent checkpoint (by modification time).

Parameters:

Name Type Description Default
path str

path to checkpoint file.

required
depth int

maximum depth to print in the module/parameter tree.

2
weights_only bool

allow loading pytorch checkpoints containing custom objects. Note that this allows arbitrary code execution!

False
Source code in src/nrdk/_cli/inspect.py
def cli_inspect(
    path: str, /, depth: int = 2, weights_only: bool = False
) -> None:
    """Inspect a pytorch / pytorch lightning checkpoint.

    !!! info "Usage"

        Inspect a representative (most recent) checkpoint:
        ```sh
        nrdk inspect results/experiment/version
        ```

    If the `path` points to a file, inspect that checkpoint; if it points to a
    directory, inspect the most recent checkpoint (by modification time).

    Args:
        path: path to checkpoint file.
        depth: maximum depth to print in the module/parameter tree.
        weights_only: allow loading pytorch checkpoints containing custom
            objects. Note that this allows arbitrary code execution!
    """
    path = _get_model_path(path)
    contents = torch.load(path, map_location='cpu', weights_only=weights_only)

    if 'state_dict' in contents:
        contents = contents['state_dict']

    tree = {}
    for k, v in contents.items():
        name = k.split('.')
        subtree = tree
        for subpath in name[:-1]:
            subtree = subtree.setdefault(subpath, {})

        if isinstance(v, torch.Tensor):
            subtree[name[-1]] = np.prod(tuple(v.shape))

    tree = _collapse_singletons(tree)
    _print_tree(tree, 0, depth)

nrdk upgrade-config

Upgrade implementation references in hydra configs.

Usage

First test with a dry run:

nrdk upgrade-config <target> --path ./results --dry-run
If you're happy with what you see, you can then run the actual upgrade:
nrdk upgrade-config <target> <to> --path ./results

Danger

This is a potentially destructive operation! Always run with --dry-run first, and make sure that to does not overlap with any other existing implementations in your configs.

You can also use the upgrade-config tool to check for this overlap first:

nrdk upgrade-config <to> --path ./results --dry-run
# Shouldn't return any of the config files you are planning to upgrade

For each valid results directory in the specified path, search for all _target_ fields in the hydra config, and replace any occurrences of from with to.

Parameters:

Name Type Description Default
target str

full path name of the implementation to replace.

required
to str | None

full path name of the implementation to replace with.

None
dry_run bool

if True, only log the changes that would be made, and do not actually modify any files.

False
path str

path to search for results directories.

'.'
follow_symlinks bool

whether to follow symlinks when searching for results.

False
Source code in src/nrdk/_cli/upgrade.py
def cli_upgrade(
    target: str, /, to: str | None = None,
    dry_run: bool = False, path: str = ".", follow_symlinks: bool = False
) -> None:
    """Upgrade implementation references in hydra configs.

    !!! info "Usage"

        First test with a dry run:
        ```sh
        nrdk upgrade-config <target> --path ./results --dry-run
        ```
        If you're happy with what you see, you can then run the actual upgrade:
        ```sh
        nrdk upgrade-config <target> <to> --path ./results
        ```

    !!! danger

        This is a potentially destructive operation! Always run with
        `--dry-run` first, and make sure that `to` does not overlap with
        any other existing implementations in your configs.

        You can also use the `upgrade-config` tool to check for this overlap
        first:
        ```sh
        nrdk upgrade-config <to> --path ./results --dry-run
        # Shouldn't return any of the config files you are planning to upgrade
        ```

    For each valid [results directory][nrdk.framework.Result] in the specified
    `path`, search for all `_target_` fields in the hydra config, and replace
    any occurrences of `from` with `to`.

    Args:
        target: full path name of the implementation to replace.
        to: full path name of the implementation to replace with.
        dry_run: if `True`, only log the changes that would be made, and do not
            actually modify any files.
        path: path to search for results directories.
        follow_symlinks: whether to follow symlinks when searching for results.
    """
    pattern = re.compile(rf"_target_\s*:\s*{re.escape(target)}(?=\s|$)")
    results = Result.find(path, follow_symlinks=follow_symlinks)

    if dry_run:
        all_matches = {}
        for r in results:
            config_path = os.path.join(r, ".hydra", "config.yaml")
            if os.path.exists(config_path):
                with open(config_path, "r") as f:
                    config = f.read()

                matches = _search(config, pattern, context_size=2)
                for line_num, context in matches:
                    if context not in all_matches:
                        all_matches[context] = []
                    all_matches[context].append((config_path, line_num))

        for k, v in all_matches.items():
            print(
                f"Found {len(v)} occurrence(s) of '{target}' "
                f"with this context:")
            print(Panel(k))
            print(Columns(
                f"{os.path.relpath(config_path, path)}:{line_num}"
                for config_path, line_num in v))
            print()

    else:
        if to is None:
            raise ValueError("Must specify `to` when not doing a dry run.")

        for r in results:
            config_path = os.path.join(r, ".hydra", "config.yaml")
            if os.path.exists(config_path):
                with open(config_path, "r") as f:
                    config = f.read()

                n = re.findall(pattern, config)
                if n:
                    print(f"Upgrading {len(n)} occurrence(s): {config_path}")
                    new_config = re.sub(pattern, f"_target_: {to}", config)
                    with open(config_path, "w") as f:
                        f.write(new_config)

nrdk validate

Validate results directories.

Usage

nrdk validate <path> --follow_symlinks

For each valid results directory in the specified path, check that all expected files are present:

File Description
.hydra/config.yaml Hydra configuration used for the run.
checkpoints/last.ckpt Last model checkpoint saved during training.
eval/ Directory containing evaluation outputs.
checkpoints.yaml Checkpoint index; absence indicates a crashed run.
events.out.tfevents.* Tensorboard log files.

Parameters:

Name Type Description Default
path str

path to search for results directories.

required
follow_symlinks bool

whether to follow symlinks when searching for results.

False
show_all bool

show all results instead of just results with missing files.

False
Source code in src/nrdk/_cli/validate.py
def cli_validate(
    path: str, /, follow_symlinks: bool = False, show_all: bool = False,
) -> None:
    """Validate results directories.

    !!! info "Usage"

        ```sh
        nrdk validate <path> --follow_symlinks
        ```

    For each valid [results directory][nrdk.framework.Result] in the specified
    `path`, check that all expected files are present:

    | File                    | Description                                   |
    | ----------------------- | --------------------------------------------- |
    | `.hydra/config.yaml`    | Hydra configuration used for the run.         |
    | `checkpoints/last.ckpt` | Last model checkpoint saved during training.  |
    | `eval/`                 | Directory containing evaluation outputs.      |
    | `checkpoints.yaml` | Checkpoint index; absence indicates a crashed run. |
    | `events.out.tfevents.*` | Tensorboard log files.                        |

    Args:
        path: path to search for results directories.
        follow_symlinks: whether to follow symlinks when searching for results.
        show_all: show all results instead of just results with missing files.
    """
    results = Result.find(path, follow_symlinks=follow_symlinks, strict=False)

    _check_files = [
        ".hydra/config.yaml",
        "checkpoints/last.ckpt",
        "eval",
        "checkpoints.yaml",
    ]
    _status = {
        True: u'[green]\u2713[/green]',
        False: u'[bold red]\u2718[/bold red]',
    }

    missing = 0

    table = Table()
    table.add_column("path", justify="right", style="cyan")
    table.add_column("config.yaml", justify="left")
    table.add_column("last.ckpt", justify="left")
    table.add_column("eval", justify="left")
    table.add_column("checkpoints.yaml", justify="left")
    table.add_column("tfevents", justify="left")

    for r in results:
        row = [
            os.path.exists(os.path.join(r, file))
            for file in _check_files
        ] + [any(
            fname.startswith("events.out.tfevents.")
            for fname in os.listdir(r)
        )]
        if not all(row):
            missing += 1
        if show_all or not all(row):
            table.add_row(os.path.relpath(r, path), *[_status[x] for x in row])

    if missing > 0:
        print(
            f"Found {len(results)} results directories with {missing} "
            f"incomplete results.")
    else:
        print(f"All {len(results)} results directories are complete.")

    print(table)