Skip to content

xwr.rsp.jax

Radar Signal Processing in Jax.

Info

In addition to mirroring the functionality of xwr.rsp.numpy, this module also provides a range of point cloud processing algorithms.

Warning

This module is not automatically imported; you will need to explicitly import it:

from xwr.rsp import jax as rsp

Since jax is not declared as a required dependency, you will also need to install jax yourself (or install the jax extra with pip install xwr[jax]).

xwr.rsp.jax.AWR1642Boost

Bases: RSPJax

Radar Signal Processing for the AWR1642 or AWR1843 with TX2 disabled.

Antenna Array

The TI AWR1642Boost (or AWR1843Boost with TX2 disabled) has a 1x8 linear MIMO array:

1-1 1-2 1-3 1-4 2-1 2-2 2-3 2-4

Parameters:

Name Type Description Default
window bool | dict[Literal['range', 'doppler', 'azimuth', 'elevation'], bool]

whether to apply a hanning window. If bool, the same option is applied to all axes. If dict, specify per axis with keys "range", "doppler", "azimuth", and "elevation".

False
size dict[Literal['range', 'doppler', 'azimuth', 'elevation'], int]

target size for each axis after zero-padding, specified by axis. If an axis is not spacified, it is not padded.

{}
Source code in src/xwr/rsp/jax/rsp.py
class AWR1642Boost(RSPJax):
    """Radar Signal Processing for the AWR1642 or AWR1843 with TX2 disabled.

    !!! info "Antenna Array"

        The TI AWR1642Boost (or AWR1843Boost with TX2 disabled) has a
        1x8 linear MIMO array:
        ```
        1-1 1-2 1-3 1-4 2-1 2-2 2-3 2-4
        ```

    Args:
        window: whether to apply a hanning window. If `bool`, the same option
            is applied to all axes. If `dict`, specify per axis with keys
            "range", "doppler", "azimuth", and "elevation".
        size: target size for each axis after zero-padding, specified by axis.
            If an axis is not spacified, it is not padded.
    """

    def mimo_virtual_array(
        self, rd: Complex64[Array, "#batch doppler tx rx range"]
    ) -> Complex64[Array, "#batch doppler el az range"]:
        batch, doppler, tx, rx, range = rd.shape
        # 1843Boost cast as 1642Boost
        if tx == 3:
            if rx != 4:
                raise ValueError(
                    f"Expected (tx, rx)=3x4 in 1843Boost -> 1642Boost "
                    f"emulation, got tx={tx} and rx={rx}.")
            rd = rd[:, :, [0, 2], :, :]
        else:
            if tx != 2 or rx != 4:
                raise ValueError(
                    f"Expected (tx, rx)=2x4, got tx={tx} and rx={rx}.")

        return rd.reshape(batch, doppler, 1, -1, range)

xwr.rsp.jax.AWR1843AOP

Bases: RSPJax

Radar Signal Processing for AWR1843AOP.

Antenna Array

In the TI AWR1843AOP, the MIMO virtual array is arranged in a 2D grid:

1-1 2-1 3-1   ^
1-2 2-2 3-2   | Up
1-3 2-3 3-3
1-4 2-4 3-4 (TX-RX pairs)

Parameters:

Name Type Description Default
window bool | dict[Literal['range', 'doppler', 'azimuth', 'elevation'], bool]

whether to apply a hanning window. If bool, the same option is applied to all axes. If dict, specify per axis with keys "range", "doppler", "azimuth", and "elevation".

False
size dict[Literal['range', 'doppler', 'azimuth', 'elevation'], int]

target size for each axis after zero-padding, specified by axis. If an axis is not spacified, it is not padded.

{}
Source code in src/xwr/rsp/jax/rsp.py
class AWR1843AOP(RSPJax):
    """Radar Signal Processing for AWR1843AOP.

    !!! info "Antenna Array"

        In the TI AWR1843AOP, the MIMO virtual array is arranged in a 2D grid:
            ```
            1-1 2-1 3-1   ^
            1-2 2-2 3-2   | Up
            1-3 2-3 3-3
            1-4 2-4 3-4 (TX-RX pairs)
            ```

    Args:
        window: whether to apply a hanning window. If `bool`, the same option
            is applied to all axes. If `dict`, specify per axis with keys
            "range", "doppler", "azimuth", and "elevation".
        size: target size for each axis after zero-padding, specified by axis.
            If an axis is not spacified, it is not padded.
    """

    def mimo_virtual_array(
        self, rd: Complex64[Array, "#batch doppler tx rx range"]
    ) -> Complex64[Array, "#batch doppler el az range"]:
        _, _, tx, rx, _ = rd.shape
        if tx != 3 or rx != 4:
            raise ValueError(
                f"Expected (tx, rx)=3x4, got tx={tx} and rx={rx}.")

        return jnp.swapaxes(rd, 2, 3)

xwr.rsp.jax.AWR1843Boost

Bases: RSPJax

Radar Signal Processing for AWR1843Boost.

Antenna Array

In the TI AWR1843Boost, the MIMO virtual array has resolution 2x8, with a single 1/2-wavelength elevated middle antenna element:

TX-RX:  2-1 2-2 2-3 2-4           ^
1-1 1-2 1-3 1-4 3-1 3-2 3-3 3-4   | Up

Parameters:

Name Type Description Default
window bool | dict[Literal['range', 'doppler', 'azimuth', 'elevation'], bool]

whether to apply a hanning window. If bool, the same option is applied to all axes. If dict, specify per axis with keys "range", "doppler", "azimuth", and "elevation".

False
size dict[Literal['range', 'doppler', 'azimuth', 'elevation'], int]

target size for each axis after zero-padding, specified by axis. If an axis is not spacified, it is not padded.

{}
Source code in src/xwr/rsp/jax/rsp.py
class AWR1843Boost(RSPJax):
    """Radar Signal Processing for AWR1843Boost.

    !!! info "Antenna Array"

        In the TI AWR1843Boost, the MIMO virtual array has resolution 2x8, with
        a single 1/2-wavelength elevated middle antenna element:
        ```
        TX-RX:  2-1 2-2 2-3 2-4           ^
        1-1 1-2 1-3 1-4 3-1 3-2 3-3 3-4   | Up
        ```

    Args:
        window: whether to apply a hanning window. If `bool`, the same option
            is applied to all axes. If `dict`, specify per axis with keys
            "range", "doppler", "azimuth", and "elevation".
        size: target size for each axis after zero-padding, specified by axis.
            If an axis is not spacified, it is not padded.
    """

    def mimo_virtual_array(
        self, rd: Complex64[Array, "#batch doppler tx rx range"]
    ) -> Complex64[Array, "#batch doppler el az range"]:
        batch, doppler, tx, rx, range = rd.shape
        if tx != 3 or rx != 4:
            raise ValueError(
                f"Expected (tx, rx)=3x4, got tx={tx} and rx={rx}.")

        mimo = jnp.zeros(
            (batch, doppler, 2, 8, range), dtype=jnp.complex64
        ).at[:, :, 0, 2:6, :].set(rd[:, :, 1, :, :]
        ).at[:, :, 1, 0:4, :].set(rd[:, :, 0, :, :]
        ).at[:, :, 1, 4:8, :].set(rd[:, :, 2, :, :])
        return mimo

    def elevation_aoa(
        self, iq: Complex64[Array, "batch slow tx rx fast"]
        | Int16[Array, "batch slow tx rx fast*2"]
    ) -> Float32[Array, "batch doppler range"]:
        """Estimate elevation angle of arrival (AoA).

        Args:
            iq: raw IQ data.

        Returns:
            Estimated elevation angle of arrival (AoA) in radians for each
                range-Doppler bin.
        """
        iq = iq_from_iiqq(iq)
        rd = self.doppler_range(iq)
        mimo = self.mimo_virtual_array(rd)[:, :, :, 2:-2]

        angle = jnp.angle(mimo)
        phase_diff: Float32[Array, "batch doppler range"] = jnp.median(
            angle[:, :, 0] - angle[:, :, 1], axis=3)
        el_angle = jnp.arcsin((phase_diff / jnp.pi + 1) % 2 - 1)
        return el_angle

elevation_aoa

elevation_aoa(
    iq: Complex64[Array, "batch slow tx rx fast"]
    | Int16[Array, "batch slow tx rx fast*2"],
) -> Float32[Array, "batch doppler range"]

Estimate elevation angle of arrival (AoA).

Parameters:

Name Type Description Default
iq Complex64[Array, 'batch slow tx rx fast'] | Int16[Array, 'batch slow tx rx fast*2']

raw IQ data.

required

Returns:

Type Description
Float32[Array, 'batch doppler range']

Estimated elevation angle of arrival (AoA) in radians for each range-Doppler bin.

Source code in src/xwr/rsp/jax/rsp.py
def elevation_aoa(
    self, iq: Complex64[Array, "batch slow tx rx fast"]
    | Int16[Array, "batch slow tx rx fast*2"]
) -> Float32[Array, "batch doppler range"]:
    """Estimate elevation angle of arrival (AoA).

    Args:
        iq: raw IQ data.

    Returns:
        Estimated elevation angle of arrival (AoA) in radians for each
            range-Doppler bin.
    """
    iq = iq_from_iiqq(iq)
    rd = self.doppler_range(iq)
    mimo = self.mimo_virtual_array(rd)[:, :, :, 2:-2]

    angle = jnp.angle(mimo)
    phase_diff: Float32[Array, "batch doppler range"] = jnp.median(
        angle[:, :, 0] - angle[:, :, 1], axis=3)
    el_angle = jnp.arcsin((phase_diff / jnp.pi + 1) % 2 - 1)
    return el_angle

xwr.rsp.jax.CFAR

Cell-averaging CFAR.

Expects a 2d input, with the guard and window sizes corresponding to the respective input axes.

    ┌─────────────────┐ ▲ window[0]
    │    ┌───────┐    │ │
    │    │  ┌─┐  │    │ ▼
    │    │  └─┘  │    │ ▲ guard[0]
    │    └───────┘    │ ▼
    └─────────────────┘
guard[1] ◄──► ◄───────► window[1]

Note

The user is responsible for applying the desired thresholding. For example, when using a gaussian model, the threshold should be calculated using an inverse normal CDF (e.g. scipy.stats.norm.isf):

cfar = CFAR(guard=(2, 2), window=(4, 4))
thresholds = cfar(image)
mask = (thresholds > scipy.stats.norm.isf(0.01))

Parameters:

Name Type Description Default
guard tuple[int, int]

size of guard cells (excluded from noise estimation).

(2, 2)
window tuple[int, int]

total CFAR window size.

(4, 4)
Source code in src/xwr/rsp/jax/spectrum.py
class CFAR:
    """Cell-averaging CFAR.

    Expects a 2d input, with the `guard` and `window` sizes corresponding to
    the respective input axes.

    ```
        ┌─────────────────┐ ▲ window[0]
        │    ┌───────┐    │ │
        │    │  ┌─┐  │    │ ▼
        │    │  └─┘  │    │ ▲ guard[0]
        │    └───────┘    │ ▼
        └─────────────────┘
    guard[1] ◄──► ◄───────► window[1]
    ```

    !!! note

        The user is responsible for applying the desired thresholding.
        For example, when using a gaussian model, the threshold should be
        calculated using an inverse normal CDF (e.g. `scipy.stats.norm.isf`):

        ```python
        cfar = CFAR(guard=(2, 2), window=(4, 4))
        thresholds = cfar(image)
        mask = (thresholds > scipy.stats.norm.isf(0.01))
        ```

    Args:
        guard: size of guard cells (excluded from noise estimation).
        window: total CFAR window size.
    """

    def __init__(
        self, guard: tuple[int, int] = (2, 2), window: tuple[int, int] = (4, 4)
    ) -> None:
        w0, w1 = window
        g0, g1 = guard

        mask = np.ones((2 * w0 + 1, 2 * w1 + 1), dtype=np.float32)
        mask[w0 - g0 : w0 + g0 + 1, w1 - g1 : w1 + g1 + 1] = 0.0
        self.mask: Array = jnp.array(mask)

    def __call__(self, x: Float[Array, "d r ..."]) -> Float[Array, "d r"]:
        """Get CFAR thresholds.

        !!! note

            Boundary cells are zero-padded.

        Args:
            x: input. If more than 2 axes are present, the additional axes
                are averaged before running CFAR.

        Returns:
            CFAR threshold values for this input.
        """
        # Collapse additional axes if required
        if x.ndim > 2:
            x = jnp.mean(x.reshape(x.shape[0], x.shape[1], -1), -1)

        # Jax currently only supports 'fill', but this should be changed to
        # 'wrap' if they ever decide to add support.
        valid = convolve2d(jnp.ones_like(x), self.mask, mode="same")
        mu = convolve2d(x, self.mask, mode="same") / valid
        second_moment = convolve2d(x**2, self.mask, mode="same") / valid
        sigma = jnp.sqrt(second_moment - mu**2)

        return (x - mu) / sigma

__call__

__call__(x: Float[Array, 'd r ...']) -> Float[Array, 'd r']

Get CFAR thresholds.

Note

Boundary cells are zero-padded.

Parameters:

Name Type Description Default
x Float[Array, 'd r ...']

input. If more than 2 axes are present, the additional axes are averaged before running CFAR.

required

Returns:

Type Description
Float[Array, 'd r']

CFAR threshold values for this input.

Source code in src/xwr/rsp/jax/spectrum.py
def __call__(self, x: Float[Array, "d r ..."]) -> Float[Array, "d r"]:
    """Get CFAR thresholds.

    !!! note

        Boundary cells are zero-padded.

    Args:
        x: input. If more than 2 axes are present, the additional axes
            are averaged before running CFAR.

    Returns:
        CFAR threshold values for this input.
    """
    # Collapse additional axes if required
    if x.ndim > 2:
        x = jnp.mean(x.reshape(x.shape[0], x.shape[1], -1), -1)

    # Jax currently only supports 'fill', but this should be changed to
    # 'wrap' if they ever decide to add support.
    valid = convolve2d(jnp.ones_like(x), self.mask, mode="same")
    mu = convolve2d(x, self.mask, mode="same") / valid
    second_moment = convolve2d(x**2, self.mask, mode="same") / valid
    sigma = jnp.sqrt(second_moment - mu**2)

    return (x - mu) / sigma

xwr.rsp.jax.CFARCASO

Cell-averaging Smallest of CFAR.

Expects a 2d input, with the guard and window sizes corresponding to the respective input axes.

Info

Instead of the 2D kernel used in Cell-averaging CFAR, CASO uses a separate 1D kernel for the range and doppler axes; detection occurs if the SNR exceeds the specified threshold on either axis.

           ┌─┐       ▲ window[0]
           │ │       │
           ├─┤       ▼
     ┌───┬─┼─┼─┬───┐
     └───┴─┼─┼─┴───┘ ▲ guard[0]
           ├─┤       ▼
           │ │
           └─┘
guard[1] ◄─►   ◄───► window[1]

Parameters:

Name Type Description Default
train_window Sequence[int]

training window size for (range, doppler).

(8, 4)
guard_window Sequence[int]

guard window size for (range, doppler).

(8, 0)
snr_thresh Sequence[float]

signal to noise ratio threshold for (range, doppler).

(5.0, 3.0)
discard_range Sequence[int]

range bins (close, far) to discard around DC.

(10, 20)
Source code in src/xwr/rsp/jax/spectrum.py
class CFARCASO:
    """Cell-averaging Smallest of CFAR.

    Expects a 2d input, with the `guard` and `window` sizes corresponding to
    the respective input axes.

    !!! info

        Instead of the 2D kernel used in Cell-averaging CFAR, CASO uses
        a separate 1D kernel for the range and doppler axes; detection occurs
        if the SNR exceeds the specified threshold on either axis.

    ```
               ┌─┐       ▲ window[0]
               │ │       │
               ├─┤       ▼
         ┌───┬─┼─┼─┬───┐
         └───┴─┼─┼─┴───┘ ▲ guard[0]
               ├─┤       ▼
               │ │
               └─┘
    guard[1] ◄─►   ◄───► window[1]
    ```

    Args:
        train_window: training window size for (range, doppler).
        guard_window: guard window size for (range, doppler).
        snr_thresh: signal to noise ratio threshold for (range, doppler).
        discard_range: range bins (close, far) to discard around DC.
    """

    def __init__(
        self,
        train_window: Sequence[int] = (8, 4),
        guard_window: Sequence[int] = (8, 0),
        snr_thresh: Sequence[float] = (5.0, 3.0),
        discard_range: Sequence[int] = (10, 20),
    ):
        if len(train_window) != 2:
            raise ValueError(f"Train window {train_window} must be length 2.")
        if len(guard_window) != 2:
            raise ValueError(f"Guard window {guard_window} must be length 2.")
        if len(discard_range) != 2:
            raise ValueError(
                f"Discard range {discard_range} must be length 2.")
        if len(snr_thresh) != 2:
            raise ValueError(f"SNR thresh {snr_thresh} must be length 2.")

        # discard detect object around DC
        self.discard_r = discard_range
        self.snr_r, self.snr_d = snr_thresh

        self.pad_r = train_window[0] + guard_window[0]
        self.pad_d = train_window[1] + guard_window[1]

        # caso
        def make_caso_kernels(train, pad):
            ker = np.zeros((2 * pad + 1), dtype=np.float32)
            ker_a, ker_b = ker.copy(), ker.copy()
            ker_a[:train], ker_b[-train:] = 1, 1
            ker_a /= ker_a.sum()
            ker_b /= ker_b.sum()
            return jnp.asarray(ker_a), jnp.asarray(ker_b)

        self.r_ker_a, self.r_ker_b = make_caso_kernels(
            train_window[0], self.pad_r)
        self.d_ker_a, self.d_ker_b = make_caso_kernels(
            train_window[1], self.pad_d)

    @staticmethod
    def _caso(
        signal: Float32[Array, "n"],
        ker_a: Float32[Array, "w"],
        ker_b: Float32[Array, "w"],
        snr: float,
        pad: int,
    ) -> tuple[Bool[Array, "m"], Float32[Array, "m"]]:
        """1D CFAR CASO.

        Args:
            signal: 1D frequency spectrum.
            ker_a: one side cfar kernel.
            ker_b: one side cfar kernel.
            snr: signal to noise ratio threshold.
            pad: padding number of the input signal.

        Returns:
            detection mask.
            noise level.
        """
        cor_a = jnp.correlate(signal, ker_a, mode="valid")
        cor_b = jnp.correlate(signal, ker_b, mode="valid")
        noise = jnp.minimum(cor_a, cor_b)
        detect = signal[pad:-pad] > snr * noise
        return detect, noise

    def __call__(
        self, signal_cube: Float32[Array, "doppler Rx Tx range"]
    ) -> tuple[
        Bool[Array, "range doppler"],
        Float32[Array, "range doppler"],
        Float32[Array, "range doppler"],
    ]:
        """Run 2D CFAR CASO.

        Args:
            signal_cube: post range doppler FFT radar cube in amplitude.

        Returns:
            cfar detected object mask.
            range doppler spectrum for detection.
            signal to noise ratio.
        """
        signal_cube = signal_cube.transpose(3, 0, 1, 2)
        s_r, s_d, _, _ = signal_cube.shape
        range_dopp = signal_cube.reshape(s_r, s_d, -1)

        # non-coherent signal combination along the antenna array
        signal = jnp.sum(range_dopp**2, axis=-1) + 1
        sig_discard = signal[self.discard_r[0] : -self.discard_r[1]]
        sig_pad_r = jnp.concat(
            (
                sig_discard[: self.pad_r],
                sig_discard,
                sig_discard[-self.pad_r :],
            ),
            axis=0,
        )
        sig_pad_d = jnp.pad(
            signal, ((0, 0), (self.pad_d, self.pad_d)), mode="wrap"
        )

        # detection
        detect_r, noise = jax.vmap(
            self._caso, in_axes=(1, None, None, None, None)
        )(sig_pad_r, self.r_ker_a, self.r_ker_b, self.snr_r, self.pad_r)
        detect_r, noise = detect_r.swapaxes(0, 1), noise.swapaxes(0, 1)
        detect_r = jnp.pad(
            detect_r, ((self.discard_r[0], self.discard_r[1]), (0, 0))
        )
        noise = jnp.pad(
            noise,
            ((self.discard_r[0], self.discard_r[1]), (0, 0)),
            constant_values=1,
        )
        detect_d, _ = jax.vmap(self._caso, in_axes=(0, None, None, None, None))(
            sig_pad_d, self.d_ker_a, self.d_ker_b, self.snr_d, self.pad_d
        )

        snr = signal / noise
        obj_mask = jnp.logical_and(detect_r, detect_d)

        return obj_mask, signal, snr

__call__

__call__(
    signal_cube: Float32[Array, "doppler Rx Tx range"],
) -> tuple[
    Bool[Array, "range doppler"],
    Float32[Array, "range doppler"],
    Float32[Array, "range doppler"],
]

Run 2D CFAR CASO.

Parameters:

Name Type Description Default
signal_cube Float32[Array, 'doppler Rx Tx range']

post range doppler FFT radar cube in amplitude.

required

Returns:

Type Description
Bool[Array, 'range doppler']

cfar detected object mask.

Float32[Array, 'range doppler']

range doppler spectrum for detection.

Float32[Array, 'range doppler']

signal to noise ratio.

Source code in src/xwr/rsp/jax/spectrum.py
def __call__(
    self, signal_cube: Float32[Array, "doppler Rx Tx range"]
) -> tuple[
    Bool[Array, "range doppler"],
    Float32[Array, "range doppler"],
    Float32[Array, "range doppler"],
]:
    """Run 2D CFAR CASO.

    Args:
        signal_cube: post range doppler FFT radar cube in amplitude.

    Returns:
        cfar detected object mask.
        range doppler spectrum for detection.
        signal to noise ratio.
    """
    signal_cube = signal_cube.transpose(3, 0, 1, 2)
    s_r, s_d, _, _ = signal_cube.shape
    range_dopp = signal_cube.reshape(s_r, s_d, -1)

    # non-coherent signal combination along the antenna array
    signal = jnp.sum(range_dopp**2, axis=-1) + 1
    sig_discard = signal[self.discard_r[0] : -self.discard_r[1]]
    sig_pad_r = jnp.concat(
        (
            sig_discard[: self.pad_r],
            sig_discard,
            sig_discard[-self.pad_r :],
        ),
        axis=0,
    )
    sig_pad_d = jnp.pad(
        signal, ((0, 0), (self.pad_d, self.pad_d)), mode="wrap"
    )

    # detection
    detect_r, noise = jax.vmap(
        self._caso, in_axes=(1, None, None, None, None)
    )(sig_pad_r, self.r_ker_a, self.r_ker_b, self.snr_r, self.pad_r)
    detect_r, noise = detect_r.swapaxes(0, 1), noise.swapaxes(0, 1)
    detect_r = jnp.pad(
        detect_r, ((self.discard_r[0], self.discard_r[1]), (0, 0))
    )
    noise = jnp.pad(
        noise,
        ((self.discard_r[0], self.discard_r[1]), (0, 0)),
        constant_values=1,
    )
    detect_d, _ = jax.vmap(self._caso, in_axes=(0, None, None, None, None))(
        sig_pad_d, self.d_ker_a, self.d_ker_b, self.snr_d, self.pad_d
    )

    snr = signal / noise
    obj_mask = jnp.logical_and(detect_r, detect_d)

    return obj_mask, signal, snr

xwr.rsp.jax.CalibratedSpectrum

Bases: Generic[TRSP]

Radar processing with zero-doppler calibration.

Zero Doppler Calibration

Due to the antenna geometry and radar returns from the data collection rig which is mounted rigidly to the radar, the radar spectrum has a substantial constant offset in the zero-doppler bins.

  • We assume that the range-Doppler plots are sparse, and take the median across a number of sample frames for the zero-doppler bin to estimate this offset.
  • If a hanning window is applied, we instead calculate the offset across doppler bins [-1, 1] to account for doppler bleed.
  • This calculated offset is subtracted from the calculated spectrum.

Parameters:

Name Type Description Default
rsp TRSP

RSP pipeline to use.

required
Source code in src/xwr/rsp/jax/spectrum.py
class CalibratedSpectrum(Generic[TRSP]):
    """Radar processing with zero-doppler calibration.

    !!! info "Zero Doppler Calibration"

        Due to the antenna geometry and radar returns from the data collection
        rig which is mounted rigidly to the radar, the radar spectrum has a
        substantial constant offset in the zero-doppler bins.

        - We assume that the range-Doppler plots are sparse, and take the
          median across a number of sample frames for the zero-doppler bin to
          estimate this offset.
        - If a hanning window is applied, we instead calculate the offset
          across doppler bins `[-1, 1]` to account for doppler bleed.
        - This calculated offset is subtracted from the calculated spectrum.

    Args:
        rsp: RSP pipeline to use.
    """

    def __init__(
        self,
        rsp: TRSP,
    ) -> None:
        self.rsp = rsp

    def calibration_patch(
        self,
        sample: Complex64[Array, "n slow tx rx fast"]
        | Int16[Array, "n slow tx rx fast2"],
        batch: int = 1,
    ) -> Float32[Array, "doppler el az range"]:
        """Create a calibration patch for zero-doppler correction.

        Args:
            sample: sample IQ data to use for calibration.
            batch: sample size for RSP processing. Uses batch size `1` by
                default; should evenly divide the number of samples.

        Returns:
            Patch of the doppler-range-azimuth image which should be subracted
                from the zero-doppler bins of the range-doppler-angle spectrum.
        """
        sample = iq_from_iiqq(sample)

        s0 = self.rsp(sample[:batch])
        shape = s0.shape[1:]

        zero = shape[0] // 2
        start, stop = zero, zero + 1
        if "doppler" in self.rsp.window:
            start -= 1
            stop += 1
        self.slice = (slice(None), slice(start, stop))

        @jax.jit
        def _calib(frames) -> Float32[Array, "batch slice az el range"]:
            return jnp.abs(self.rsp(frames))[self.slice]

        batched = sample.reshape(-1, batch, *sample.shape[1:])
        slices = [s0[self.slice]] + [_calib(batch) for batch in batched]
        return jnp.median(jnp.concatenate(slices, axis=0))

    def __call__(
        self,
        iq: Complex64[Array, "#batch doppler tx rx range"]
        | Int16[Array, "#batch doppler tx rx range2"],
        calib: Float32[Array, "doppler el az range"],
    ) -> Float32[Array, "batch doppler el az range"]:
        """Run radar spectrum processing pipeline.

        !!! note

            After subtracting the calibration patch, any negative values are
            clipped to zero.

        Args:
            iq: batch of IQ data to run.
            calib: calibration patch to apply.

        Returns:
            Doppler-elevation-azimuth-range real spectrum, with zero doppler
                correction applied.
        """
        raw = jnp.abs(self.rsp(iq))
        return raw.at[self.slice].set(jnp.maximum(raw[self.slice] - calib, 0.0))

__call__

__call__(
    iq: Complex64[Array, "#batch doppler tx rx range"]
    | Int16[Array, "#batch doppler tx rx range2"],
    calib: Float32[Array, "doppler el az range"],
) -> Float32[Array, "batch doppler el az range"]

Run radar spectrum processing pipeline.

Note

After subtracting the calibration patch, any negative values are clipped to zero.

Parameters:

Name Type Description Default
iq Complex64[Array, '#batch doppler tx rx range'] | Int16[Array, '#batch doppler tx rx range2']

batch of IQ data to run.

required
calib Float32[Array, 'doppler el az range']

calibration patch to apply.

required

Returns:

Type Description
Float32[Array, 'batch doppler el az range']

Doppler-elevation-azimuth-range real spectrum, with zero doppler correction applied.

Source code in src/xwr/rsp/jax/spectrum.py
def __call__(
    self,
    iq: Complex64[Array, "#batch doppler tx rx range"]
    | Int16[Array, "#batch doppler tx rx range2"],
    calib: Float32[Array, "doppler el az range"],
) -> Float32[Array, "batch doppler el az range"]:
    """Run radar spectrum processing pipeline.

    !!! note

        After subtracting the calibration patch, any negative values are
        clipped to zero.

    Args:
        iq: batch of IQ data to run.
        calib: calibration patch to apply.

    Returns:
        Doppler-elevation-azimuth-range real spectrum, with zero doppler
            correction applied.
    """
    raw = jnp.abs(self.rsp(iq))
    return raw.at[self.slice].set(jnp.maximum(raw[self.slice] - calib, 0.0))

calibration_patch

calibration_patch(
    sample: Complex64[Array, "n slow tx rx fast"]
    | Int16[Array, "n slow tx rx fast2"],
    batch: int = 1,
) -> Float32[Array, "doppler el az range"]

Create a calibration patch for zero-doppler correction.

Parameters:

Name Type Description Default
sample Complex64[Array, 'n slow tx rx fast'] | Int16[Array, 'n slow tx rx fast2']

sample IQ data to use for calibration.

required
batch int

sample size for RSP processing. Uses batch size 1 by default; should evenly divide the number of samples.

1

Returns:

Type Description
Float32[Array, 'doppler el az range']

Patch of the doppler-range-azimuth image which should be subracted from the zero-doppler bins of the range-doppler-angle spectrum.

Source code in src/xwr/rsp/jax/spectrum.py
def calibration_patch(
    self,
    sample: Complex64[Array, "n slow tx rx fast"]
    | Int16[Array, "n slow tx rx fast2"],
    batch: int = 1,
) -> Float32[Array, "doppler el az range"]:
    """Create a calibration patch for zero-doppler correction.

    Args:
        sample: sample IQ data to use for calibration.
        batch: sample size for RSP processing. Uses batch size `1` by
            default; should evenly divide the number of samples.

    Returns:
        Patch of the doppler-range-azimuth image which should be subracted
            from the zero-doppler bins of the range-doppler-angle spectrum.
    """
    sample = iq_from_iiqq(sample)

    s0 = self.rsp(sample[:batch])
    shape = s0.shape[1:]

    zero = shape[0] // 2
    start, stop = zero, zero + 1
    if "doppler" in self.rsp.window:
        start -= 1
        stop += 1
    self.slice = (slice(None), slice(start, stop))

    @jax.jit
    def _calib(frames) -> Float32[Array, "batch slice az el range"]:
        return jnp.abs(self.rsp(frames))[self.slice]

    batched = sample.reshape(-1, batch, *sample.shape[1:])
    slices = [s0[self.slice]] + [_calib(batch) for batch in batched]
    return jnp.median(jnp.concatenate(slices, axis=0))

xwr.rsp.jax.PointCloud

Get radar point cloud from post FFT cube.

To convert azimuth-elevation bin indices to azimuth-elevation angles, we use the property that the azimuth bin indices correspond to the sin of the angle

angles = jnp.arcsin(
    jnp.linspace(-jnp.pi, jnp.pi, bin_size)
    / (2 * jnp.pi * antenna_spacing)
)
where the corrected antenna spacing is calculated by
0.5 * chirp_center_frequency / antenna_design_frequency

Info

The antenna design frequency here refers to the grid alignment of the antenna array, which are typically 0.5 wavelengths apart at some nominal design frequency. Thus, you must correct by a corresponding scale factor when the chirp center frequency differs.

Parameters:

Name Type Description Default
range_resolution float

range fft resolution

required
doppler_resolution float

doppler fft resolution

required
angle_fov Sequence[float]

angle field of view in degrees for (elevation, azimuth).

(20.0, 80.0)
angle_size Sequence[int]

angle fft size for (elevation, azimuth).

(128, 128)
antenna_spacing float

antenna spacing in terms of wavelength (default 0.5).

0.5
Source code in src/xwr/rsp/jax/aoa.py
class PointCloud:
    """Get radar point cloud from post FFT cube.

    To convert azimuth-elevation bin indices to azimuth-elevation angles,
    we use the property that the azimuth bin indices correspond to the sin of
    the angle
    ```
    angles = jnp.arcsin(
        jnp.linspace(-jnp.pi, jnp.pi, bin_size)
        / (2 * jnp.pi * antenna_spacing)
    )
    ```
    where the *corrected* antenna spacing is calculated by
    ```
    0.5 * chirp_center_frequency / antenna_design_frequency
    ```

    !!! info

        The antenna design frequency here refers to the grid alignment of the
        antenna array, which are typically 0.5 wavelengths apart at some
        nominal design frequency. Thus, you must correct by a corresponding
        scale factor when the chirp center frequency differs.

    Args:
        range_resolution: range fft resolution
        doppler_resolution: doppler fft resolution
        angle_fov: angle field of view in degrees for (elevation, azimuth).
        angle_size: angle fft size for (elevation, azimuth).
        antenna_spacing: antenna spacing in terms of wavelength (default 0.5).
    """

    def __init__(
        self,
        range_resolution: float,
        doppler_resolution: float,
        angle_fov: Sequence[float] = (20.0, 80.0),
        angle_size: Sequence[int] = (128, 128),
        antenna_spacing: float = 0.5,
    ) -> None:
        self.range_res = range_resolution
        self.doppler_res = doppler_resolution

        assert len(angle_fov) == 2 and len(angle_size) == 2, (
            "angle_fov and angle_size must be a sequence of length 2."
        )
        self.ele_fov = jnp.deg2rad(angle_fov[0])
        self.azi_fov = jnp.deg2rad(angle_fov[1])
        self.ele_angles = jnp.arcsin(
            jnp.linspace(-jnp.pi, jnp.pi, angle_size[0])
            / (2 * jnp.pi * antenna_spacing)
        )
        self.azi_angles = jnp.arcsin(
            jnp.linspace(-jnp.pi, jnp.pi, angle_size[1])
            / (2 * jnp.pi * antenna_spacing)
        )

    @staticmethod
    def _argmax_aoa(ang_sptr: Float32[Array, "ele azi"]) -> tuple[Array, ...]:
        """Argmax for angle of arrival estimation.

        Args:
            ang_sptr: post fft angle spectrum amplitude in 2D.

        Returns:
            detected angle index (elevation, azimuth).
        """
        idx = jnp.argmax(ang_sptr)
        idx2d = jnp.unravel_index(idx, ang_sptr.shape)
        return idx2d

    def aoa(
        self, cube: Float32[Array, "range doppler ele azi"]
    ) -> Int[Array, "range doppler 2"]:
        """Angle of arrival estimation.

        Args:
            cube: post fft spectrum amplitude.

        Returns:
            ang: detect angle index for every range doppler bin.
        """
        idxs = jax.vmap(jax.vmap(self._argmax_aoa))(cube)
        ang = jnp.stack((idxs), axis=-1)
        return ang

    def __call__(
        self,
        cube: Float32[Array, "doppler ele azi range"],
        mask: Bool[Array, "range doppler"],
    ) -> tuple[Bool[Array, "range doppler"], Float32[Array, "range doppler 4"]]:
        """Get point cloud from radar cube and detection mask.

        Args:
            cube: post fft spectrum amplitude.
            mask: CFAR detection mask.

        Returns:
            mask of valid points (given the specified angular bounds)
            all possible radar points
        """
        r_size, d_size = mask.shape
        range_v = jnp.arange(r_size) * self.range_res
        doppler_v = (jnp.arange(d_size) - d_size // 2) * self.doppler_res
        r_grid, d_grid = jnp.meshgrid(range_v, doppler_v, indexing="ij")

        angle_idx = self.aoa(cube.transpose(3, 0, 1, 2))
        ang_e = self.ele_angles[angle_idx[:, :, 0]]
        ang_a = self.azi_angles[angle_idx[:, :, 1]]
        mask_e = jnp.logical_and(ang_e < self.ele_fov, ang_e > -self.ele_fov)
        mask_a = jnp.logical_and(ang_a < self.azi_fov, ang_a > -self.azi_fov)
        mask_ang = jnp.logical_and(mask_a, mask_e)

        x = r_grid * jnp.cos(-ang_a) * jnp.cos(ang_e)
        y = r_grid * jnp.sin(-ang_a) * jnp.cos(ang_e)
        z = r_grid * jnp.sin(-ang_e)
        v = d_grid

        pc_mask = jnp.logical_and(mask, mask_ang)
        pc = jnp.stack((x, y, z, v), axis=-1)

        return pc_mask, pc

__call__

__call__(
    cube: Float32[Array, "doppler ele azi range"],
    mask: Bool[Array, "range doppler"],
) -> tuple[Bool[Array, "range doppler"], Float32[Array, "range doppler 4"]]

Get point cloud from radar cube and detection mask.

Parameters:

Name Type Description Default
cube Float32[Array, 'doppler ele azi range']

post fft spectrum amplitude.

required
mask Bool[Array, 'range doppler']

CFAR detection mask.

required

Returns:

Type Description
Bool[Array, 'range doppler']

mask of valid points (given the specified angular bounds)

Float32[Array, 'range doppler 4']

all possible radar points

Source code in src/xwr/rsp/jax/aoa.py
def __call__(
    self,
    cube: Float32[Array, "doppler ele azi range"],
    mask: Bool[Array, "range doppler"],
) -> tuple[Bool[Array, "range doppler"], Float32[Array, "range doppler 4"]]:
    """Get point cloud from radar cube and detection mask.

    Args:
        cube: post fft spectrum amplitude.
        mask: CFAR detection mask.

    Returns:
        mask of valid points (given the specified angular bounds)
        all possible radar points
    """
    r_size, d_size = mask.shape
    range_v = jnp.arange(r_size) * self.range_res
    doppler_v = (jnp.arange(d_size) - d_size // 2) * self.doppler_res
    r_grid, d_grid = jnp.meshgrid(range_v, doppler_v, indexing="ij")

    angle_idx = self.aoa(cube.transpose(3, 0, 1, 2))
    ang_e = self.ele_angles[angle_idx[:, :, 0]]
    ang_a = self.azi_angles[angle_idx[:, :, 1]]
    mask_e = jnp.logical_and(ang_e < self.ele_fov, ang_e > -self.ele_fov)
    mask_a = jnp.logical_and(ang_a < self.azi_fov, ang_a > -self.azi_fov)
    mask_ang = jnp.logical_and(mask_a, mask_e)

    x = r_grid * jnp.cos(-ang_a) * jnp.cos(ang_e)
    y = r_grid * jnp.sin(-ang_a) * jnp.cos(ang_e)
    z = r_grid * jnp.sin(-ang_e)
    v = d_grid

    pc_mask = jnp.logical_and(mask, mask_ang)
    pc = jnp.stack((x, y, z, v), axis=-1)

    return pc_mask, pc

aoa

aoa(
    cube: Float32[Array, "range doppler ele azi"],
) -> Int[Array, "range doppler 2"]

Angle of arrival estimation.

Parameters:

Name Type Description Default
cube Float32[Array, 'range doppler ele azi']

post fft spectrum amplitude.

required

Returns:

Name Type Description
ang Int[Array, 'range doppler 2']

detect angle index for every range doppler bin.

Source code in src/xwr/rsp/jax/aoa.py
def aoa(
    self, cube: Float32[Array, "range doppler ele azi"]
) -> Int[Array, "range doppler 2"]:
    """Angle of arrival estimation.

    Args:
        cube: post fft spectrum amplitude.

    Returns:
        ang: detect angle index for every range doppler bin.
    """
    idxs = jax.vmap(jax.vmap(self._argmax_aoa))(cube)
    ang = jnp.stack((idxs), axis=-1)
    return ang

xwr.rsp.jax.RSPJax

Bases: RSP[Array], ABC

Base Radar Signal Processing with common functionality.

Parameters:

Name Type Description Default
window bool | dict[Literal['range', 'doppler', 'azimuth', 'elevation'], bool]

whether to apply a hanning window. If bool, the same option is applied to all axes. If dict, specify per axis with keys "range", "doppler", "azimuth", and "elevation".

False
size dict[Literal['range', 'doppler', 'azimuth', 'elevation'], int]

target size for each axis after zero-padding, specified by axis. If an axis is not spacified, it is not padded.

{}
Source code in src/xwr/rsp/jax/rsp.py
class RSPJax(RSP[Array], ABC):
    """Base Radar Signal Processing with common functionality.

    Args:
        window: whether to apply a hanning window. If `bool`, the same option
            is applied to all axes. If `dict`, specify per axis with keys
            "range", "doppler", "azimuth", and "elevation".
        size: target size for each axis after zero-padding, specified by axis.
            If an axis is not spacified, it is not padded.
    """

    def fft(
        self, array: Complex64[Array, "..."], axes: tuple[int, ...],
        size: tuple[int, ...] | None = None,
        shift: tuple[int, ...] | None = None
    ) -> Complex64[Array, "..."]:
        fftd = jnp.fft.fftn(array, s=size, axes=axes)
        if shift is None:
            return fftd
        else:
            return jnp.fft.fftshift(fftd, axes=shift)

    @staticmethod
    def pad(
        x: Shaped[Array, "..."], axis: int, size: int
    ) -> Shaped[Array, "..."]:
        if size <= x.shape[axis]:
            raise ValueError(
                f"Cannot zero-pad axis {axis} to target size {size}, which is "
                f"less than or equal the current size {x.shape[axis]}.")

        shape = list(x.shape)
        shape[axis] = size - x.shape[axis]
        zeros = jnp.zeros(shape, dtype=x.dtype)

        return jnp.concatenate([x, zeros], axis=axis)

    @staticmethod
    def hann(
        iq: Complex64[Array, "..."], axis: int
    ) -> Complex64[Array, "..."]:
        hann = jnp.hanning(iq.shape[axis] + 2)[1:-1]
        broadcast: list[None | slice] = [None] * iq.ndim
        broadcast[axis] = slice(None)
        return iq * (hann / jnp.mean(hann))[tuple(broadcast)]

    def azimuth_aoa(
        self, iq: Complex64[Array, "batch slow tx rx fast"]
        | Int16[Array, "batch slow tx rx fast*2"]
    ) -> Int[Array, "batch doppler range"]:
        """Estimate angle of arrival (AoA).

        !!! note

            The AOA bin resolution is determined by the number of bins this
            RSP instance is configured with.

        Args:
            iq: raw IQ data.

        Returns:
            Estimated angle of arrival (AoA) index for each range-Doppler bin.
        """
        spec: Complex64[Array, "batch doppler el az range"] = self(iq)
        az_spec: Float32[Array, "batch doppler az range"] = (
            jnp.mean(jnp.abs(spec), axis=2))
        return jnp.argmax(az_spec, axis=2)

azimuth_aoa

azimuth_aoa(
    iq: Complex64[Array, "batch slow tx rx fast"]
    | Int16[Array, "batch slow tx rx fast*2"],
) -> Int[Array, "batch doppler range"]

Estimate angle of arrival (AoA).

Note

The AOA bin resolution is determined by the number of bins this RSP instance is configured with.

Parameters:

Name Type Description Default
iq Complex64[Array, 'batch slow tx rx fast'] | Int16[Array, 'batch slow tx rx fast*2']

raw IQ data.

required

Returns:

Type Description
Int[Array, 'batch doppler range']

Estimated angle of arrival (AoA) index for each range-Doppler bin.

Source code in src/xwr/rsp/jax/rsp.py
def azimuth_aoa(
    self, iq: Complex64[Array, "batch slow tx rx fast"]
    | Int16[Array, "batch slow tx rx fast*2"]
) -> Int[Array, "batch doppler range"]:
    """Estimate angle of arrival (AoA).

    !!! note

        The AOA bin resolution is determined by the number of bins this
        RSP instance is configured with.

    Args:
        iq: raw IQ data.

    Returns:
        Estimated angle of arrival (AoA) index for each range-Doppler bin.
    """
    spec: Complex64[Array, "batch doppler el az range"] = self(iq)
    az_spec: Float32[Array, "batch doppler az range"] = (
        jnp.mean(jnp.abs(spec), axis=2))
    return jnp.argmax(az_spec, axis=2)