Skip to content

NRDK CLI

Usage

The CLI tools use tyro:

  • Positional arguments ("required") are passed as positional command line arguments
  • Named arguments are passed as flagged command line arguments

nrdk export

Export model weights from a full-service checkpoint.

Usage

Take the best checkpoint in results/experiment/version, and export the model to results/experiment/version/weights.pth and config to results/experiment/version/model.yaml:

nrdk export results/experiment/version

The model is assumed to be created by NRDKLightningModule (via pytorch lightning), so has a state_dict attribute that contains the model weights, where each key has a leading .model prefix.

  • If the path points to a file, export that checkpoint.
  • If the path is a directory, the directory should have a checkpoints.yaml with a best key which specifies the best checkpoint in a checkpoints/ directory; the exported model and config are saved relative to the path.

Warning

If path is a file (i.e., a specific checkpoint), the model config will not be exported (since its path not defined).

Parameters:

Name Type Description Default
path str

path to the checkpoint.

required
output str

path to save the exported weights.

'weights.pth'
config str

path to save the exported model config.

'model.yaml'
Source code in src/nrdk/_cli/export.py
def cli_export(
    path: str, /, output: str = "weights.pth", config: str = "model.yaml"
) -> None:
    """Export model weights from a full-service checkpoint.

    !!! info "Usage"

        Take the best checkpoint in `results/experiment/version`, and export
        the model to `results/experiment/version/weights.pth` and config to
        `results/experiment/version/model.yaml`:
        ```sh
        nrdk export results/experiment/version
        ```

    The model is assumed to be created by
    [`NRDKLightningModule`][nrdk.framework.] (via pytorch lightning), so has a
    `state_dict` attribute that contains the model weights, where each key
    has a leading `.model` prefix.

    - If the `path` points to a file, export that checkpoint.
    - If the `path` is a directory, the directory should have a
        `checkpoints.yaml` with a `best` key which specifies the best
        checkpoint in a `checkpoints/` directory; the exported model and config
        are saved relative to the `path`.

    !!! warning

        If `path` is a file (i.e., a specific checkpoint), the model config
        will not be exported (since its path not defined).

    Args:
        path: path to the checkpoint.
        output: path to save the exported weights.
        config: path to save the exported model config.
    """
    if os.path.isdir(path):
        hydra_cfg = OmegaConf.load(os.path.join(path, ".hydra", "config.yaml"))
        resolved = OmegaConf.to_container(hydra_cfg, resolve=True)
        assert isinstance(resolved, dict)
        # Save to config file
        with open(os.path.join(path, config), 'w') as f:
            yaml.dump({
                "model": resolved["model"],
                "transforms": resolved["transforms"],
            }, f, default_flow_style=False, sort_keys=False)

        output = os.path.join(path, output)
        try:
            with open(os.path.join(path, "checkpoints.yaml"), 'r') as f:
                contents = yaml.safe_load(f)
            path = os.path.join(path, "checkpoints", contents['best'])
        except FileNotFoundError:
            raise FileNotFoundError(f"No 'checkpoints.yaml' found in {path}.")

    contents = torch.load(path, map_location='cpu')
    if 'state_dict' not in contents:
        raise ValueError(f"Checkpoint {path} does not contain 'state_dict'.")

    print(f"Exporting: {path}")
    print(f"--> {output}")
    pattern = re.compile(r"^(model\.)(_orig_mod\.)?")
    state_dict = {
        pattern.sub("", k): v for k, v in contents['state_dict'].items()}
    torch.save(state_dict, output)

nrdk inspect

Inspect a pytorch / pytorch lightning checkpoint.

Usage

Inspect a representative (most recent) checkpoint:

nrdk inspect results/experiment/version

If the path points to a file, inspect that checkpoint; if it points to a directory, inspect the most recent checkpoint (by modification time).

Parameters:

Name Type Description Default
path str

path to checkpoint file.

required
depth int

maximum depth to print in the module/parameter tree.

2
weights_only bool

allow loading pytorch checkpoints containing custom objects. Note that this allows arbitrary code execution!

False
Source code in src/nrdk/_cli/inspect.py
def cli_inspect(
    path: str, /, depth: int = 2, weights_only: bool = False
) -> None:
    """Inspect a pytorch / pytorch lightning checkpoint.

    !!! info "Usage"

        Inspect a representative (most recent) checkpoint:
        ```sh
        nrdk inspect results/experiment/version
        ```

    If the `path` points to a file, inspect that checkpoint; if it points to a
    directory, inspect the most recent checkpoint (by modification time).

    Args:
        path: path to checkpoint file.
        depth: maximum depth to print in the module/parameter tree.
        weights_only: allow loading pytorch checkpoints containing custom
            objects. Note that this allows arbitrary code execution!
    """
    path = _get_model_path(path)
    contents = torch.load(path, map_location='cpu', weights_only=weights_only)

    if 'state_dict' in contents:
        contents = contents['state_dict']

    tree = {}
    for k, v in contents.items():
        name = k.split('.')
        subtree = tree
        for subpath in name[:-1]:
            subtree = subtree.setdefault(subpath, {})

        if isinstance(v, torch.Tensor):
            subtree[name[-1]] = np.prod(tuple(v.shape))

    tree = _collapse_singletons(tree)
    _print_tree(tree, 0, depth)

nrdk upgrade-config

Upgrade implementation references in hydra configs.

Usage

First test with a dry run:

nrdk upgrade-config <target> --path ./results --dry-run
If you're happy with what you see, you can then run the actual upgrade:
nrdk upgrade-config <target> <to> --path ./results

Danger

This is a potentially destructive operation! Always run with --dry-run first, and make sure that to does not overlap with any other existing implementations in your configs.

You can also use the upgrade-config tool to check for this overlap first:

nrdk upgrade-config <to> --path ./results --dry-run
# Shouldn't return any of the config files you are planning to upgrade

For each valid results directory in the specified path, search for all _target_ fields in the hydra config, and replace any occurrences of from with to.

Parameters:

Name Type Description Default
target str

full path name of the implementation to replace.

required
to str | None

full path name of the implementation to replace with.

None
dry_run bool

if True, only log the changes that would be made, and do not actually modify any files.

False
path str

path to search for results directories.

'.'
follow_symlinks bool

whether to follow symlinks when searching for results.

False
Source code in src/nrdk/_cli/upgrade.py
def cli_upgrade(
    target: str, /, to: str | None = None,
    dry_run: bool = False, path: str = ".", follow_symlinks: bool = False
) -> None:
    """Upgrade implementation references in hydra configs.

    !!! info "Usage"

        First test with a dry run:
        ```sh
        nrdk upgrade-config <target> --path ./results --dry-run
        ```
        If you're happy with what you see, you can then run the actual upgrade:
        ```sh
        nrdk upgrade-config <target> <to> --path ./results
        ```

    !!! danger

        This is a potentially destructive operation! Always run with
        `--dry-run` first, and make sure that `to` does not overlap with
        any other existing implementations in your configs.

        You can also use the `upgrade-config` tool to check for this overlap
        first:
        ```sh
        nrdk upgrade-config <to> --path ./results --dry-run
        # Shouldn't return any of the config files you are planning to upgrade
        ```

    For each valid [results directory][nrdk.framework.Result] in the specified
    `path`, search for all `_target_` fields in the hydra config, and replace
    any occurrences of `from` with `to`.

    Args:
        target: full path name of the implementation to replace.
        to: full path name of the implementation to replace with.
        dry_run: if `True`, only log the changes that would be made, and do not
            actually modify any files.
        path: path to search for results directories.
        follow_symlinks: whether to follow symlinks when searching for results.
    """
    pattern = re.compile(rf"_target_\s*:\s*{re.escape(target)}(?=\s|$)")
    results = Result.find(path, follow_symlinks=follow_symlinks)

    if dry_run:
        all_matches = {}
        for r in results:
            config_path = os.path.join(r, ".hydra", "config.yaml")
            if os.path.exists(config_path):
                with open(config_path, "r") as f:
                    config = f.read()

                matches = _search(config, pattern, context_size=2)
                for line_num, context in matches:
                    if context not in all_matches:
                        all_matches[context] = []
                    all_matches[context].append((config_path, line_num))

        for k, v in all_matches.items():
            print(
                f"Found {len(v)} occurrence(s) of '{target}' "
                f"with this context:")
            print(Panel(k))
            print(Columns(
                f"{os.path.relpath(config_path, path)}:{line_num}"
                for config_path, line_num in v))
            print()

    else:
        if to is None:
            raise ValueError("Must specify `to` when not doing a dry run.")

        for r in results:
            config_path = os.path.join(r, ".hydra", "config.yaml")
            if os.path.exists(config_path):
                with open(config_path, "r") as f:
                    config = f.read()

                n = re.findall(pattern, config)
                if n:
                    print(f"Upgrading {len(n)} occurrence(s): {config_path}")
                    new_config = re.sub(pattern, f"_target_: {to}", config)
                    with open(config_path, "w") as f:
                        f.write(new_config)

nrdk validate

Validate results directories.

Usage

nrdk validate <path> --follow_symlinks

For each valid results directory in the specified path, check that all expected files are present:

File Description
.hydra/config.yaml Hydra configuration used for the run.
checkpoints/last.ckpt Last model checkpoint saved during training.
eval/ Directory containing evaluation outputs.
checkpoints.yaml Checkpoint index; absence indicates a crashed run.
events.out.tfevents.* Tensorboard log files.

Parameters:

Name Type Description Default
path str

path to search for results directories.

required
follow_symlinks bool

whether to follow symlinks when searching for results.

False
show_all bool

show all results instead of just results with missing files.

False
Source code in src/nrdk/_cli/validate.py
def cli_validate(
    path: str, /, follow_symlinks: bool = False, show_all: bool = False,
) -> None:
    """Validate results directories.

    !!! info "Usage"

        ```sh
        nrdk validate <path> --follow_symlinks
        ```

    For each valid [results directory][nrdk.framework.Result] in the specified
    `path`, check that all expected files are present:

    | File                    | Description                                   |
    | ----------------------- | --------------------------------------------- |
    | `.hydra/config.yaml`    | Hydra configuration used for the run.         |
    | `checkpoints/last.ckpt` | Last model checkpoint saved during training.  |
    | `eval/`                 | Directory containing evaluation outputs.      |
    | `checkpoints.yaml` | Checkpoint index; absence indicates a crashed run. |
    | `events.out.tfevents.*` | Tensorboard log files.                        |

    Args:
        path: path to search for results directories.
        follow_symlinks: whether to follow symlinks when searching for results.
        show_all: show all results instead of just results with missing files.
    """
    results = Result.find(path, follow_symlinks=follow_symlinks, strict=False)

    _check_files = [
        ".hydra/config.yaml",
        "checkpoints/last.ckpt",
        "eval",
        "checkpoints.yaml",
    ]
    _status = {
        True: u'[green]\u2713[/green]',
        False: u'[bold red]\u2718[/bold red]',
    }

    missing = 0

    table = Table()
    table.add_column("path", justify="right", style="cyan")
    table.add_column("config.yaml", justify="left")
    table.add_column("last.ckpt", justify="left")
    table.add_column("eval", justify="left")
    table.add_column("checkpoints.yaml", justify="left")
    table.add_column("tfevents", justify="left")

    for r in results:
        row = [
            os.path.exists(os.path.join(r, file))
            for file in _check_files
        ] + [any(
            fname.startswith("events.out.tfevents.")
            for fname in os.listdir(r)
        )]
        if not all(row):
            missing += 1
        if show_all or not all(row):
            table.add_row(os.path.relpath(r, path), *[_status[x] for x in row])

    if missing > 0:
        print(
            f"Found {len(results)} results directories with {missing} "
            f"incomplete results.")
    else:
        print(f"All {len(results)} results directories are complete.")

    print(table)