Skip to content

NRDK CLI

Usage

The CLI tools use tyro:

  • Positional arguments ("required") are passed as positional command line arguments
  • Named arguments are passed as flagged command line arguments

nrdk export

Export model weights from a full-service checkpoint.

Usage

Take the best checkpoint in results/experiment/version, and export the model to results/experiment/version/weights.pth:

nrdk export results/experiment/version

The model is assumed to be created by NRDKLightningModule (via pytorch lightning), so has a state_dict attribute that contains the model weights, where each key has a leading .model prefix.

  • If the path points to a file, export that checkpoint.
  • If the path is a directory, the directory should have a checkpoints.yaml with a best key which specifies the best checkpoint in a checkpoints/ directory; the exported model is saved relative to the path.

Parameters:

Name Type Description Default
path str

path to the checkpoint.

required
output str

path to save the exported weights.

'weights.pth'
Source code in src/nrdk/_cli/export.py
def cli_export(
    path: str, /, output: str = "weights.pth"
) -> None:
    """Export model weights from a full-service checkpoint.

    !!! info "Usage"

        Take the best checkpoint in `results/experiment/version`, and export
        the model to `results/experiment/version/weights.pth`:
        ```sh
        nrdk export results/experiment/version
        ```

    The model is assumed to be created by
    [`NRDKLightningModule`][nrdk.framework.] (via pytorch lightning), so has a
    `state_dict` attribute that contains the model weights, where each key
    has a leading `.model` prefix.

    - If the `path` points to a file, export that checkpoint.
    - If the `path` is a directory, the directory should have a
        `checkpoints.yaml` with a `best` key which specifies the best
        checkpoint in a `checkpoints/` directory; the exported model is saved
        relative to the `path`.

    Args:
        path: path to the checkpoint.
        output: path to save the exported weights.
    """
    if os.path.isdir(path):
        output = os.path.join(path, output)

    if os.path.isdir(path):
        try:
            with open(os.path.join(path, "checkpoints.yaml"), 'r') as f:
                contents = yaml.safe_load(f)
            path = os.path.join(path, "checkpoints", contents['best'])
        except FileNotFoundError:
            raise FileNotFoundError(f"No 'checkpoints.yaml' found in {path}.")

    contents = torch.load(path, map_location='cpu')
    if 'state_dict' not in contents:
        raise ValueError(f"Checkpoint {path} does not contain 'state_dict'.")

    print(f"Exporting: {path}")
    print(f"--> {output}")
    pattern = re.compile(r"^model\.")
    state_dict = {
        pattern.sub("", k): v for k, v in contents['state_dict'].items()}
    torch.save(state_dict, output)

nrdk inspect

Inspect a pytorch / pytorch lightning checkpoint.

Usage

Inspect a representative (most recent) checkpoint:

nrdk inspect results/experiment/version

If the path points to a file, inspect that checkpoint; if it points to a directory, inspect the most recent checkpoint (by modification time).

Parameters:

Name Type Description Default
path str

path to checkpoint file.

required
depth int

maximum depth to print in the module/parameter tree.

2
weights_only bool

allow loading pytorch checkpoints containing custom objects. Note that this allows arbitrary code execution!

False
Source code in src/nrdk/_cli/inspect.py
def cli_inspect(
    path: str, /, depth: int = 2, weights_only: bool = False
) -> None:
    """Inspect a pytorch / pytorch lightning checkpoint.

    !!! info "Usage"

        Inspect a representative (most recent) checkpoint:
        ```sh
        nrdk inspect results/experiment/version
        ```

    If the `path` points to a file, inspect that checkpoint; if it points to a
    directory, inspect the most recent checkpoint (by modification time).

    Args:
        path: path to checkpoint file.
        depth: maximum depth to print in the module/parameter tree.
        weights_only: allow loading pytorch checkpoints containing custom
            objects. Note that this allows arbitrary code execution!
    """
    path = _get_model_path(path)
    contents = torch.load(path, map_location='cpu', weights_only=weights_only)

    if 'state_dict' in contents:
        contents = contents['state_dict']

    tree = {}
    for k, v in contents.items():
        name = k.split('.')
        subtree = tree
        for subpath in name[:-1]:
            subtree = subtree.setdefault(subpath, {})

        if isinstance(v, torch.Tensor):
            subtree[name[-1]] = np.prod(tuple(v.shape))

    tree = _collapse_singletons(tree)
    _print_tree(tree, 0, depth)