Skip to content

xwr.nn

Radar preprocessing toolkit for neural network training.

Warning

This submodule requires torch and torchvision to be installed, e.g. via the nn extra.

Tip

This submodule supports both np.ndarray and torch.Tensor inputs directly out-of-the box (with no additional overhead, since we already require torch to handle resizing).

When converting complex spectrum to real-valued representations, we can apply a range of different data augmentations. The supported data augmentations according to the abstract_dataloader.ext.augment conventions are:

Augmentation Key Description
azimuth_flip Flip along azimuth axis.
doppler_flip Flip along doppler axis.
range_scale Apply random range scale.
speed_scale Apply random speed scale.
radar_scale Radar magnitude scale factor.
radar_phase Phase shift across the frame.
Sample Hydra Configuration for abstract_dataloader.ext.augment
_target_: abstract_dataloader.ext.augment.Augmentations
azimuth_flip:
  _target_: abstract_dataloader.ext.augment.Bernoulli
  p: 0.5
doppler_flip:
  _target_: abstract_dataloader.ext.augment.Bernoulli
  p: 0.5
radar_scale:
  _target_: abstract_dataloader.ext.augment.TruncatedLogNormal
  std: 0.2
  clip: 2.0
radar_phase:
  _target_: abstract_dataloader.ext.augment.Uniform
  lower: -3.14159265
  upper: 3.14159265
range_scale:
  _target_: abstract_dataloader.ext.augment.Uniform
  lower: 1.0
  upper: 2.0
speed_scale:
  _target_: abstract_dataloader.ext.augment.TruncatedLogNormal
  std: 0.2
  clip: 2.0

xwr.nn.Magnitude

Bases: Representation

Real spectrum magnitude with phase discarded.

Parameters:

Name Type Description Default
scale float

scale factor to apply to the magnitude.

1e-06
transform Literal['log', 'sqrt', 'linear']

transformation to apply to the magnitude.

'sqrt'
eps float

small value to avoid log(0) in log transform.

1e-06
Source code in src/xwr/nn/representations.py
class Magnitude(Representation):
    """Real spectrum magnitude with phase discarded.

    Args:
        scale: scale factor to apply to the magnitude.
        transform: transformation to apply to the magnitude.
        eps: small value to avoid log(0) in log transform.
    """

    def __call__(
        self, spectrum: Complex64[TArray, "batch doppler el az rng"],
        aug: Mapping[str, Any] = {}
    ) -> Float32[TArray, "batch doppler el az rng 1"]:
        """Get spectrum amplitude.

        Type Parameters:
            - `TArray`: array type; `np.ndarray` or `torch.Tensor`.

        Args:
            spectrum: complex spectrum as output by one of the
                [`xwr.rsp.numpy`][xwr.rsp.numpy] classes.
            aug: augmentations to apply.

        Returns:
            Real 4D spectrum with a leading batch axis and trailing
                `[magnitude]` channel axis.
        """
        spectrum = self._flip(spectrum, aug)

        magnitude = backend.abs(spectrum)
        if aug.get("radar_scale", 1.0) != 1.0:
            magnitude *= aug["radar_scale"]
        magnitude = self._scale(magnitude)

        # Phase is unused; explicitly fetch it here to "touch" it
        _ = aug.get("radar_phase", 0.0)

        resized = resize(
            magnitude, range_scale=aug.get("range_scale", 1.0),
            speed_scale=aug.get("speed_scale", 1.0))

        return cast(TArray, resized[..., None])

__call__

__call__(
    spectrum: Complex64[TArray, "batch doppler el az rng"],
    aug: Mapping[str, Any] = {},
) -> Float32[TArray, "batch doppler el az rng 1"]

Get spectrum amplitude.

Type Parameters
  • TArray: array type; np.ndarray or torch.Tensor.

Parameters:

Name Type Description Default
spectrum Complex64[TArray, 'batch doppler el az rng']

complex spectrum as output by one of the xwr.rsp.numpy classes.

required
aug Mapping[str, Any]

augmentations to apply.

{}

Returns:

Type Description
Float32[TArray, 'batch doppler el az rng 1']

Real 4D spectrum with a leading batch axis and trailing [magnitude] channel axis.

Source code in src/xwr/nn/representations.py
def __call__(
    self, spectrum: Complex64[TArray, "batch doppler el az rng"],
    aug: Mapping[str, Any] = {}
) -> Float32[TArray, "batch doppler el az rng 1"]:
    """Get spectrum amplitude.

    Type Parameters:
        - `TArray`: array type; `np.ndarray` or `torch.Tensor`.

    Args:
        spectrum: complex spectrum as output by one of the
            [`xwr.rsp.numpy`][xwr.rsp.numpy] classes.
        aug: augmentations to apply.

    Returns:
        Real 4D spectrum with a leading batch axis and trailing
            `[magnitude]` channel axis.
    """
    spectrum = self._flip(spectrum, aug)

    magnitude = backend.abs(spectrum)
    if aug.get("radar_scale", 1.0) != 1.0:
        magnitude *= aug["radar_scale"]
    magnitude = self._scale(magnitude)

    # Phase is unused; explicitly fetch it here to "touch" it
    _ = aug.get("radar_phase", 0.0)

    resized = resize(
        magnitude, range_scale=aug.get("range_scale", 1.0),
        speed_scale=aug.get("speed_scale", 1.0))

    return cast(TArray, resized[..., None])

xwr.nn.PhaseAngle

Bases: Representation

Complex spectrum with magnitude and phase angle.

Parameters:

Name Type Description Default
scale float

scale factor to apply to the magnitude.

1e-06
transform Literal['log', 'sqrt', 'linear']

transformation to apply to the magnitude.

'sqrt'
eps float

small value to avoid log(0) in log transform.

1e-06
Source code in src/xwr/nn/representations.py
class PhaseAngle(Representation):
    """Complex spectrum with magnitude and phase angle.

    Args:
        scale: scale factor to apply to the magnitude.
        transform: transformation to apply to the magnitude.
        eps: small value to avoid log(0) in log transform.
    """

    def __call__(
        self, spectrum: Complex64[TArray, "batch doppler el az rng"],
        aug: Mapping[str, Any] = {}
    ) -> Float32[TArray, "batch doppler el az rng 2"]:
        """Get complex spectrum representation.

        Type Parameters:
            - `TArray`: array type; `np.ndarray` or `torch.Tensor`.

        Args:
            spectrum: complex spectrum as output by one of the
                [`xwr.rsp.numpy`][xwr.rsp.numpy] classes.
            aug: augmentations to apply.

        Returns:
            Complex 4D spectrum with a leading batch axis and trailing
                `[magnitude, phase]` channel axis.
        """
        spectrum = self._flip(spectrum, aug)

        magnitude = backend.abs(spectrum)
        phase = backend.angle(spectrum)

        if aug.get("radar_scale", 1.0) != 1.0:
            magnitude *= aug["radar_scale"]
        if aug.get("radar_phase", 0.0) != 0.0:
            phase += aug["radar_phase"]
        magnitude = self._scale(magnitude)

        range_scale = aug.get("range_scale", 1.0)
        speed_scale = aug.get("speed_scale", 1.0)
        return cast(TArray, backend.stack([
            resize(magnitude, range_scale=range_scale, speed_scale=speed_scale),
            resize(phase, range_scale, speed_scale) % (2 * np.pi)
        ], axis=-1))

__call__

__call__(
    spectrum: Complex64[TArray, "batch doppler el az rng"],
    aug: Mapping[str, Any] = {},
) -> Float32[TArray, "batch doppler el az rng 2"]

Get complex spectrum representation.

Type Parameters
  • TArray: array type; np.ndarray or torch.Tensor.

Parameters:

Name Type Description Default
spectrum Complex64[TArray, 'batch doppler el az rng']

complex spectrum as output by one of the xwr.rsp.numpy classes.

required
aug Mapping[str, Any]

augmentations to apply.

{}

Returns:

Type Description
Float32[TArray, 'batch doppler el az rng 2']

Complex 4D spectrum with a leading batch axis and trailing [magnitude, phase] channel axis.

Source code in src/xwr/nn/representations.py
def __call__(
    self, spectrum: Complex64[TArray, "batch doppler el az rng"],
    aug: Mapping[str, Any] = {}
) -> Float32[TArray, "batch doppler el az rng 2"]:
    """Get complex spectrum representation.

    Type Parameters:
        - `TArray`: array type; `np.ndarray` or `torch.Tensor`.

    Args:
        spectrum: complex spectrum as output by one of the
            [`xwr.rsp.numpy`][xwr.rsp.numpy] classes.
        aug: augmentations to apply.

    Returns:
        Complex 4D spectrum with a leading batch axis and trailing
            `[magnitude, phase]` channel axis.
    """
    spectrum = self._flip(spectrum, aug)

    magnitude = backend.abs(spectrum)
    phase = backend.angle(spectrum)

    if aug.get("radar_scale", 1.0) != 1.0:
        magnitude *= aug["radar_scale"]
    if aug.get("radar_phase", 0.0) != 0.0:
        phase += aug["radar_phase"]
    magnitude = self._scale(magnitude)

    range_scale = aug.get("range_scale", 1.0)
    speed_scale = aug.get("speed_scale", 1.0)
    return cast(TArray, backend.stack([
        resize(magnitude, range_scale=range_scale, speed_scale=speed_scale),
        resize(phase, range_scale, speed_scale) % (2 * np.pi)
    ], axis=-1))

xwr.nn.PhaseVec

Bases: Representation

Complex spectrum with magnitude and re/im phase vector.

Parameters:

Name Type Description Default
scale float

scale factor to apply to the magnitude.

1e-06
transform Literal['log', 'sqrt', 'linear']

transformation to apply to the magnitude.

'sqrt'
eps float

small value to avoid log(0) in log transform.

1e-06
Source code in src/xwr/nn/representations.py
class PhaseVec(Representation):
    """Complex spectrum with magnitude and re/im phase vector.

    Args:
        scale: scale factor to apply to the magnitude.
        transform: transformation to apply to the magnitude.
        eps: small value to avoid log(0) in log transform.
    """

    def __call__(
        self, spectrum: Complex64[TArray, "batch doppler el az rng"],
        aug: Mapping[str, Any] = {}
    ) -> Float32[TArray, "batch doppler el az rng 3"]:
        """Get amplitude spectrum.

        Type Parameters:
            - `TArray`: array type; `np.ndarray` or `torch.Tensor`.

        Args:
            spectrum: complex spectrum as output by one of the
                [`xwr.rsp.numpy`][xwr.rsp.numpy] classes.
            aug: augmentations to apply.

        Returns:
            Real 4D spectrum with a leading batch axis and trailing
                `[magnitude, re, im]` channel axis.
        """
        spectrum = self._flip(spectrum, aug)

        magnitude = backend.abs(spectrum)
        normed = spectrum / backend.maximum(magnitude, self.eps)
        if aug.get("radar_phase", 0.0) != 0.0:
            # aug["radar_phase"] is a python scalar, so we always use cmath
            normed *= cmath.exp(-1j * aug["radar_phase"])
        re = backend.real(normed)
        im = backend.imag(normed)

        if aug.get("radar_scale", 1.0) != 1.0:
            magnitude *= aug["radar_scale"]
        magnitude = self._scale(magnitude)

        range_scale = aug.get("range_scale", 1.0)
        speed_scale = aug.get("speed_scale", 1.0)
        return cast(TArray, backend.stack([
            resize(magnitude, range_scale=range_scale, speed_scale=speed_scale),
            resize(re, range_scale=range_scale, speed_scale=speed_scale),
            resize(im, range_scale=range_scale, speed_scale=speed_scale)
        ], axis=-1))

__call__

__call__(
    spectrum: Complex64[TArray, "batch doppler el az rng"],
    aug: Mapping[str, Any] = {},
) -> Float32[TArray, "batch doppler el az rng 3"]

Get amplitude spectrum.

Type Parameters
  • TArray: array type; np.ndarray or torch.Tensor.

Parameters:

Name Type Description Default
spectrum Complex64[TArray, 'batch doppler el az rng']

complex spectrum as output by one of the xwr.rsp.numpy classes.

required
aug Mapping[str, Any]

augmentations to apply.

{}

Returns:

Type Description
Float32[TArray, 'batch doppler el az rng 3']

Real 4D spectrum with a leading batch axis and trailing [magnitude, re, im] channel axis.

Source code in src/xwr/nn/representations.py
def __call__(
    self, spectrum: Complex64[TArray, "batch doppler el az rng"],
    aug: Mapping[str, Any] = {}
) -> Float32[TArray, "batch doppler el az rng 3"]:
    """Get amplitude spectrum.

    Type Parameters:
        - `TArray`: array type; `np.ndarray` or `torch.Tensor`.

    Args:
        spectrum: complex spectrum as output by one of the
            [`xwr.rsp.numpy`][xwr.rsp.numpy] classes.
        aug: augmentations to apply.

    Returns:
        Real 4D spectrum with a leading batch axis and trailing
            `[magnitude, re, im]` channel axis.
    """
    spectrum = self._flip(spectrum, aug)

    magnitude = backend.abs(spectrum)
    normed = spectrum / backend.maximum(magnitude, self.eps)
    if aug.get("radar_phase", 0.0) != 0.0:
        # aug["radar_phase"] is a python scalar, so we always use cmath
        normed *= cmath.exp(-1j * aug["radar_phase"])
    re = backend.real(normed)
    im = backend.imag(normed)

    if aug.get("radar_scale", 1.0) != 1.0:
        magnitude *= aug["radar_scale"]
    magnitude = self._scale(magnitude)

    range_scale = aug.get("range_scale", 1.0)
    speed_scale = aug.get("speed_scale", 1.0)
    return cast(TArray, backend.stack([
        resize(magnitude, range_scale=range_scale, speed_scale=speed_scale),
        resize(re, range_scale=range_scale, speed_scale=speed_scale),
        resize(im, range_scale=range_scale, speed_scale=speed_scale)
    ], axis=-1))

xwr.nn.Representation

Bases: ABC

Generic representation which maps complex spectrum to real channels.

Parameters:

Name Type Description Default
scale float

scale factor to apply to the magnitude.

1e-06
transform Literal['log', 'sqrt', 'linear']

transformation to apply to the magnitude.

'sqrt'
eps float

small value to avoid log(0) in log transform.

1e-06
Source code in src/xwr/nn/representations.py
class Representation(ABC):
    """Generic representation which maps complex spectrum to real channels.

    Args:
        scale: scale factor to apply to the magnitude.
        transform: transformation to apply to the magnitude.
        eps: small value to avoid log(0) in log transform.
    """

    def __init__(
        self, scale: float = 1e-6,
        transform: Literal["log", "sqrt", "linear"] = "sqrt",
        eps: float = 1e-6
    ) -> None:
        self.scale = scale
        self.eps = eps
        self._magnitude_transform = transform

    def _flip(
        self, spectrum: Shaped[TArray, "batch doppler el az rng"],
        aug: Mapping[str, Any] = {}
    ) -> Shaped[TArray, "batch doppler el az rng"]:
        if aug.get("azimuth_flip", False):
            spectrum = backend.flip(spectrum, axis=-2)
        if aug.get("doppler_flip", False):
            spectrum = backend.flip(spectrum, axis=-3)
        return spectrum

    def _scale(
        self, data: Float32[TArray, "..."]
    ) -> Float32[TArray, "..."]:
        data = cast(TArray, data * self.scale)
        if self._magnitude_transform == "log":
            return backend.log(cast(TArray, data + self.eps))
        elif self._magnitude_transform == "sqrt":
            return backend.sqrt(data)
        else:
            return data

    @abstractmethod
    def __call__(
        self, spectrum: Complex64[TArray, "batch doppler el az rng"],
        aug: Mapping[str, Any] = {}
    ) -> Float32[TArray, "batch doppler el az rng c"]:
        """Get spectrum representation.

        Type Parameters:
            - `TArray`: array type; `np.ndarray` or `torch.Tensor`.

        Args:
            spectrum: complex spectrum as output by one of the
                [`xwr.rsp.numpy`][xwr.rsp.numpy] classes.
            aug: dictionary of augmentations to apply.

        Returns:
            Real 4D spectrum with a leading batch axis and trailing channel
                axis.
        """
        ...

    def __repr__(self) -> str:  # noqa: D105
        return (
            f"{self.__class__.__name__}({self._magnitude_transform} * "
            f"{self.scale})")

__call__ abstractmethod

__call__(
    spectrum: Complex64[TArray, "batch doppler el az rng"],
    aug: Mapping[str, Any] = {},
) -> Float32[TArray, "batch doppler el az rng c"]

Get spectrum representation.

Type Parameters
  • TArray: array type; np.ndarray or torch.Tensor.

Parameters:

Name Type Description Default
spectrum Complex64[TArray, 'batch doppler el az rng']

complex spectrum as output by one of the xwr.rsp.numpy classes.

required
aug Mapping[str, Any]

dictionary of augmentations to apply.

{}

Returns:

Type Description
Float32[TArray, 'batch doppler el az rng c']

Real 4D spectrum with a leading batch axis and trailing channel axis.

Source code in src/xwr/nn/representations.py
@abstractmethod
def __call__(
    self, spectrum: Complex64[TArray, "batch doppler el az rng"],
    aug: Mapping[str, Any] = {}
) -> Float32[TArray, "batch doppler el az rng c"]:
    """Get spectrum representation.

    Type Parameters:
        - `TArray`: array type; `np.ndarray` or `torch.Tensor`.

    Args:
        spectrum: complex spectrum as output by one of the
            [`xwr.rsp.numpy`][xwr.rsp.numpy] classes.
        aug: dictionary of augmentations to apply.

    Returns:
        Real 4D spectrum with a leading batch axis and trailing channel
            axis.
    """
    ...

xwr.nn.resize

resize(
    spectrum: Float32[TArray, "T D *A R"],
    range_scale: float = 1.0,
    speed_scale: float = 1.0,
) -> Float32[TArray, "T D *A R"]

Resize range-Doppler spectrum.

Note

We use torchvision.transforms.Resize, which requires a round-trip through a (cpu) Tensor for numpy arrays. From some limited testing, this appears to be the most performant image resizing which supports antialiasing, with skimage.transform.resize being particularly slow.

Type Parameters
  • TArray: array type; np.ndarray or torch.Tensor.

Parameters:

Name Type Description Default
spectrum Float32[TArray, 'T D *A R']

input spectrum as a real channel; should be output by one of the xwr.rsp.numpy classes.

required
range_scale float

scale factor for the range dimension; crops if greater than 1.0, and zero-pads if less than 1.0.

1.0
speed_scale float

scale factor for the Doppler dimension; wraps if greater than 1.0, and zero-pads if less than 1.0.

1.0
Source code in src/xwr/nn/utils.py
def resize(
    spectrum: Float32[TArray, "T D *A R"],
    range_scale: float = 1.0, speed_scale: float = 1.0,
) -> Float32[TArray, "T D *A R"]:
    """Resize range-Doppler spectrum.

    !!! note

        We use `torchvision.transforms.Resize`, which requires a
        round-trip through a (cpu) `Tensor` for numpy arrays. From some limited
        testing, this appears to be the most performant image resizing which
        supports antialiasing, with `skimage.transform.resize` being
        particularly slow.

    Type Parameters:
        - `TArray`: array type; `np.ndarray` or `torch.Tensor`.

    Args:
        spectrum: input spectrum as a real channel; should be output by one of
            the [`xwr.rsp.numpy`][xwr.rsp.numpy] classes.
        range_scale: scale factor for the range dimension; crops if greater
            than 1.0, and zero-pads if less than 1.0.
        speed_scale: scale factor for the Doppler dimension; wraps if greater
            than 1.0, and zero-pads if less than 1.0.
    """
    T, Nd, *A, Nr = spectrum.shape
    range_out_dim = int(range_scale * Nr)
    speed_out_dim = 2 * (int(speed_scale * Nd) // 2)

    if range_out_dim != Nr or speed_out_dim != Nd:
        resized = _resize(spectrum, nd=speed_out_dim, nr=range_out_dim)

        # Upsample -> crop
        if range_out_dim >= Nr:
            resized = resized[..., :Nr]
        # Downsample -> zero pad far ranges (high indices)
        else:
            pad = _zeros(
                (*spectrum.shape[:-1], Nr - range_out_dim), like=resized)
            resized = backend.concatenate([resized, pad], axis=-1)

        # Upsample -> wrap
        if speed_out_dim > spectrum.shape[1]:
            resized = _wrap(resized, Nd)
        # Downsample -> zero pad high velocities (low and high indices)
        else:
            pad = _zeros(
                (T, (Nd - speed_out_dim) // 2, *spectrum.shape[2:]),
                like=resized)
            resized = backend.concatenate([pad, resized, pad], axis=1)

        spectrum = cast(TArray, resized)

    return spectrum