Skip to content

roverp.graphics

GPU-accelerated 2D graphics using JAX.

Warning

This module is not automatically imported; you will need to explicitly import it:

from roverp import graphics
You will also need to have the graphics extra installed.

roverp.graphics.Dilate

Dilate an N-dimensional image.

Takes the maximum of a ND image with offset version of itself.

Parameters:

Name Type Description Default
radius float

dilation radius; a neighborhood with size floor(radius) * 2 + 1 is masked to points which are within radius of the center, and applied as the dilation mask.

3.1
dims int

number of dimensions.

2
Source code in processing/src/roverp/graphics/pointcloud.py
class Dilate:
    """Dilate an N-dimensional image.

    Takes the maximum of a ND image with offset version of itself.

    Args:
        radius: dilation radius; a neighborhood with size
            `floor(radius) * 2 + 1` is masked to points which are within
            `radius` of the center, and applied as the dilation mask.
        dims: number of dimensions.
    """

    def __init__(self, radius: float = 3.1, dims: int = 2) -> None:
        iradius = int(radius)
        window = [np.arange(-iradius, iradius + 1)] * dims
        coords = np.meshgrid(*window)
        self.mask: list[Integer[np.ndarray, "N"]] = [
            x - iradius
            for x in np.where(sum(x**2 for x in coords) < radius**2)]

    def __call__(self, image: Num[Array, "*dims"]) -> Num[Array, "*dims"]:
        """Apply dilation to an image."""
        dilated = image
        for coords in zip(*self.mask):
            dilated = _offset_maximum(dilated, image, coords)
        return dilated

__call__

__call__(image: Num[Array, '*dims']) -> Num[Array, '*dims']

Apply dilation to an image.

Source code in processing/src/roverp/graphics/pointcloud.py
def __call__(self, image: Num[Array, "*dims"]) -> Num[Array, "*dims"]:
    """Apply dilation to an image."""
    dilated = image
    for coords in zip(*self.mask):
        dilated = _offset_maximum(dilated, image, coords)
    return dilated

roverp.graphics.JaxFont

GPU-accelerated vectorizable monospace text rendering.

Warning

@cached_property doesn't seem to play well with jax, so JaxFont pre-computes the font LUT. Don't intialize this class until needed!

Usage
  1. Initialize: font = JaxFont(font_name, size).
  2. Convert text to array(s): arr = font.encode("Hello World!").
  3. Render onto canvas: canvas = font.render(arr, canvas, color, x, y).
  4. Wrap any render calls into a JIT-compiled function to guarantee in-place editing.

Parameters:

Name Type Description Default
font str | None

Font file; must be monospace (or will be treated like one!). If None, load the included roboto.ttf file.

None
size int

Font size; is static to allow pre-computing the font.

18
Source code in processing/src/roverp/graphics/font.py
class JaxFont:
    """GPU-accelerated vectorizable monospace text rendering.

    !!! warning

        `@cached_property` doesn't seem to play well with jax, so `JaxFont`
        pre-computes the font LUT. Don't intialize this class until needed!

    Usage:
        1. Initialize: `font = JaxFont(font_name, size)`.
        2. Convert text to array(s): `arr = font.encode("Hello World!")`.
        3. Render onto canvas: `canvas = font.render(arr, canvas, color, x, y)`.
        4. Wrap any `render` calls into a JIT-compiled function to guarantee
           in-place editing.

    Args:
        font: Font file; must be monospace (or will be treated like one!). If
            `None`, load the included `roboto.ttf` file.
        size: Font size; is static to allow pre-computing the font.
    """

    def __init__(self, font: str | None = None, size: int = 18) -> None:
        if font is None:
            font = os.path.join(os.path.dirname(__file__), "roboto.ttf")

        ttf = ImageFont.truetype(font, size, encoding="ascii")
        chars = bytes(list(range(32, 127))).decode('ascii')

        width, height = 0, 0
        for char in chars:
            _, _, w, h = ttf.getbbox(char)
            width = max(width, int(w))
            height = max(height, int(h))

        stack = []
        for char in chars:
            canvas = Image.new('L', (width, height), "white")
            ImageDraw.Draw(canvas).text((0, 0), char, 'black', ttf)
            stack.append(np.array(canvas))
        self.raster = jnp.stack(stack)

    def __call__(
        self, text: UInt8[Array, "len"],
        canvas: UInt8[Array, "width height channels"],
        color: UInt8[Array, "channels"], x: int = 0, y: int = 0
    ) -> UInt8[Array, "width height channels"]:
        """Render text on canvas.

        !!! warning

            `x` and `y` must be constant!

        Args:
            text: character bytes (ASCII-encoded)
            canvas: array to write to. Must be a jax array.
            color: color to apply, with the same number of channels as `canvas`.
            x: vertical position to write text at; `+x` is down.
            y: horizontal position to write text at, `+y` is right.

        Returns:
            Rendered canvas. If the original is no longer used (e.g. all
                subsequent computation uses only the return here), `render`
                will not cause a copy as long as it is jit-compiled.
        """
        indices = jnp.clip(text - 32, 0, self.raster.shape[0] - 1)
        mask = jnp.concatenate(self.raster[indices], axis=1)
        b = x + mask.shape[0]
        r = y + mask.shape[1]
        mask = mask[
            :min(canvas.shape[0] - x, mask.shape[0]),
            :min(canvas.shape[1] - y, mask.shape[1])
        ] / 255

        return canvas.at[x:b, y:r].set(
            ((1 - mask)[:, :, None] * color[None, None, :]).astype(jnp.uint8)
            + (mask[:, :, None] * canvas[x: b, y: r]).astype(jnp.uint8))

    def encode(self, text: str | Iterable[str]) -> UInt8[Array, "..."]:
        """Convert a string or list of strings to an array of ASCII indices.

        !!! warning

            The inputs must all have the same length. This function is not
            jit-compilable.
        """
        if isinstance(text, str):
            return jnp.frombuffer(
                bytes(text, encoding='ascii'), dtype=np.uint8)
        else:
            return jnp.stack([
                np.frombuffer(bytes(s, encoding='ascii'), dtype=np.uint8)
                for s in text])

__call__

__call__(
    text: UInt8[Array, len],
    canvas: UInt8[Array, "width height channels"],
    color: UInt8[Array, channels],
    x: int = 0,
    y: int = 0,
) -> UInt8[Array, "width height channels"]

Render text on canvas.

Warning

x and y must be constant!

Parameters:

Name Type Description Default
text UInt8[Array, len]

character bytes (ASCII-encoded)

required
canvas UInt8[Array, 'width height channels']

array to write to. Must be a jax array.

required
color UInt8[Array, channels]

color to apply, with the same number of channels as canvas.

required
x int

vertical position to write text at; +x is down.

0
y int

horizontal position to write text at, +y is right.

0

Returns:

Type Description
UInt8[Array, 'width height channels']

Rendered canvas. If the original is no longer used (e.g. all subsequent computation uses only the return here), render will not cause a copy as long as it is jit-compiled.

Source code in processing/src/roverp/graphics/font.py
def __call__(
    self, text: UInt8[Array, "len"],
    canvas: UInt8[Array, "width height channels"],
    color: UInt8[Array, "channels"], x: int = 0, y: int = 0
) -> UInt8[Array, "width height channels"]:
    """Render text on canvas.

    !!! warning

        `x` and `y` must be constant!

    Args:
        text: character bytes (ASCII-encoded)
        canvas: array to write to. Must be a jax array.
        color: color to apply, with the same number of channels as `canvas`.
        x: vertical position to write text at; `+x` is down.
        y: horizontal position to write text at, `+y` is right.

    Returns:
        Rendered canvas. If the original is no longer used (e.g. all
            subsequent computation uses only the return here), `render`
            will not cause a copy as long as it is jit-compiled.
    """
    indices = jnp.clip(text - 32, 0, self.raster.shape[0] - 1)
    mask = jnp.concatenate(self.raster[indices], axis=1)
    b = x + mask.shape[0]
    r = y + mask.shape[1]
    mask = mask[
        :min(canvas.shape[0] - x, mask.shape[0]),
        :min(canvas.shape[1] - y, mask.shape[1])
    ] / 255

    return canvas.at[x:b, y:r].set(
        ((1 - mask)[:, :, None] * color[None, None, :]).astype(jnp.uint8)
        + (mask[:, :, None] * canvas[x: b, y: r]).astype(jnp.uint8))

encode

encode(text: str | Iterable[str]) -> UInt8[Array, ...]

Convert a string or list of strings to an array of ASCII indices.

Warning

The inputs must all have the same length. This function is not jit-compilable.

Source code in processing/src/roverp/graphics/font.py
def encode(self, text: str | Iterable[str]) -> UInt8[Array, "..."]:
    """Convert a string or list of strings to an array of ASCII indices.

    !!! warning

        The inputs must all have the same length. This function is not
        jit-compilable.
    """
    if isinstance(text, str):
        return jnp.frombuffer(
            bytes(text, encoding='ascii'), dtype=np.uint8)
    else:
        return jnp.stack([
            np.frombuffer(bytes(s, encoding='ascii'), dtype=np.uint8)
            for s in text])

roverp.graphics.Render

2D renderer to combine data channels in a fixed format.

Warning

If the strings being rendered every change length, this will trigger recompilation!

Parameters:

Name Type Description Default
size tuple[int, int]

total frame size as (height, width).

required
channels dict[tuple[int, int, int, int], str]

which data channels to render. Specify as (xmin, xmax, ymin, ymax), where x is the vertical axis (lower is higher) and y is the horizontal axis (lower is left).

required
transforms dict[str, Callable[[Shaped[Array, ...]], UInt8[Array, '?h ?w 3']]]

dict of jax jit-compatible transforms to apply to each data channel (organized by channel name); must output RGB images.

required
text dict[tuple[int, int], str]

text to render; each key is a (x, y) coordinate, and each value is a format string.

required
font JaxFont

rendering font.

required
textcolor tuple[int, int, int]

text rendering configuration.

(255, 255, 255)
Source code in processing/src/roverp/graphics/render.py
class Render:
    """2D renderer to combine data channels in a fixed format.

    !!! warning

        If the strings being rendered every change length, this will trigger
        recompilation!

    Args:
        size: total frame size as `(height, width)`.
        channels: which data channels to render. Specify as
            `(xmin, xmax, ymin, ymax)`, where `x` is the vertical axis
            (lower is higher) and `y` is the horizontal axis (lower is left).
        transforms: dict of jax jit-compatible transforms to apply to each data
            channel (organized by channel name); must output RGB images.
        text: text to render; each key is a `(x, y)` coordinate, and each
            value is a format string.
        font: rendering font.
        textcolor: text rendering configuration.
    """

    def __init__(
        self, size: tuple[int, int],
        channels: dict[tuple[int, int, int, int], str],
        transforms: dict[str, Callable[
            [Shaped[Array, "..."]], UInt8[Array, "?h ?w 3"]]],
        text: dict[tuple[int, int], str],
        font: JaxFont,
        textcolor: tuple[int, int, int] = (255, 255, 255),
    ) -> None:
        self.size = size
        self.font = font
        self.text = text

        def _get_transform(
            name: str
        ) -> Callable[[Shaped[Array, "..."]], UInt8[Array, "?h ?w 3"]]:
            tf = transforms.get(name)
            if tf is None:
                tf = transforms.get("*")
            if tf is None:
                tf = lambda x: x
            return tf

        def _render_func(
            data: dict[str, Shaped[Array, "..."]],
            encoded_text: dict[tuple[int, int], UInt8[Array, "..."]]
        ) -> UInt8[Array, "h w 3"]:
            frame = jnp.zeros((*size, 3), dtype=jnp.uint8)
            data = {k: _get_transform(k)(v) for k, v in data.items()}
            for k, v in channels.items():
                try:
                    frame = frame.at[k[0]:k[1], k[2]:k[3]].set(data[v])
                except ValueError as e:
                    print(
                        f"Incompatible shapes for channel {v}: "
                        f"{data[v].shape} into {k}")
                    raise e

            _textcolor = jnp.array(textcolor, dtype=jnp.uint8)
            for k2, v2 in encoded_text.items():
                frame = font(v2, frame, _textcolor, x=k2[0], y=k2[1])

            return frame

        self._render_func = jax.jit(_render_func)
        self._vrender_func = jax.jit(jax.vmap(_render_func))

    def __call__(
        self, data: dict[str, Shaped[Array, "..."]],
        meta: dict[str, Any] | list[dict[str, Any]]
    ) -> UInt8[Array, "*batch h w 3"]:
        """Render (possibly batched) frame.

        Args:
            data: input data, organized into channels by name. Must have
                fixed dimensions.
            meta: metadata for text captions/labels.

        Returns:
            Rendered RGB frame (or batch of frames).
        """
        # Batched
        if isinstance(meta, list):
            encoded_text = {
                k: jnp.stack([self.font.encode(v.format(**m)) for m in meta])
                for k, v in self.text.items()}
            return self._vrender_func(data, encoded_text)
        # Non-batched
        else:
            encoded_text = {
                k: self.font.encode(v.format(**meta))
                for k, v in self.text.items()}
            return self._render_func(data, encoded_text)

__call__

__call__(
    data: dict[str, Shaped[Array, ...]],
    meta: dict[str, Any] | list[dict[str, Any]],
) -> UInt8[Array, "*batch h w 3"]

Render (possibly batched) frame.

Parameters:

Name Type Description Default
data dict[str, Shaped[Array, ...]]

input data, organized into channels by name. Must have fixed dimensions.

required
meta dict[str, Any] | list[dict[str, Any]]

metadata for text captions/labels.

required

Returns:

Type Description
UInt8[Array, '*batch h w 3']

Rendered RGB frame (or batch of frames).

Source code in processing/src/roverp/graphics/render.py
def __call__(
    self, data: dict[str, Shaped[Array, "..."]],
    meta: dict[str, Any] | list[dict[str, Any]]
) -> UInt8[Array, "*batch h w 3"]:
    """Render (possibly batched) frame.

    Args:
        data: input data, organized into channels by name. Must have
            fixed dimensions.
        meta: metadata for text captions/labels.

    Returns:
        Rendered RGB frame (or batch of frames).
    """
    # Batched
    if isinstance(meta, list):
        encoded_text = {
            k: jnp.stack([self.font.encode(v.format(**m)) for m in meta])
            for k, v in self.text.items()}
        return self._vrender_func(data, encoded_text)
    # Non-batched
    else:
        encoded_text = {
            k: self.font.encode(v.format(**meta))
            for k, v in self.text.items()}
        return self._render_func(data, encoded_text)

roverp.graphics.Scatter

Render a scatter plot.

Parameters:

Name Type Description Default
radius float

point radius.

3.1
resolution tuple[int, int]

image resolution.

(320, 640)
Source code in processing/src/roverp/graphics/pointcloud.py
class Scatter:
    """Render a scatter plot.

    Args:
        radius: point radius.
        resolution: image resolution.
    """

    def __init__(
        self, radius: float = 3.1,
        resolution: tuple[int, int] = (320, 640)
    ) -> None:
        self.dilate = Dilate(radius=radius, dims=2)
        self.resolution = resolution

    def __call__(
        self, x: Num[Array, "N"], y: Num[Array, "N"], c: Num[Array, "N"]
    ) -> Num[Array, "height width"]:
        """Render scatter plot image, with each point as a circle.

        Args:
            x: x-coordinate, with +x facing right. `x` should be
                normalized to [0, 1] as the width of the image.
            y: y-coordinate, with +y facing up. `y` should be normalized to
                [0, 1] as the height of the image.
            c: point intensity with arbitrary type. The intensity is sorted
                in increasing order to maximize the chances that the higher
                value is taken in case multiple points map to the same bin.

        Returns:
            Rendered scatter plot. Note that there is some indeterminism in
                case multiple points with the same intensity fall in the same
                initial pixel.
        """
        img = jnp.zeros(self.resolution, dtype=c.dtype)

        ord = jnp.argsort(c)
        iy = (self.resolution[0] * (1 - y)).astype(jnp.int32)
        ix = (x * self.resolution[1]).astype(jnp.int32)
        img = img.at[iy[ord], ix[ord]].set(c[ord])
        return self.dilate(img)

__call__

__call__(
    x: Num[Array, N], y: Num[Array, N], c: Num[Array, N]
) -> Num[Array, "height width"]

Render scatter plot image, with each point as a circle.

Parameters:

Name Type Description Default
x Num[Array, N]

x-coordinate, with +x facing right. x should be normalized to [0, 1] as the width of the image.

required
y Num[Array, N]

y-coordinate, with +y facing up. y should be normalized to [0, 1] as the height of the image.

required
c Num[Array, N]

point intensity with arbitrary type. The intensity is sorted in increasing order to maximize the chances that the higher value is taken in case multiple points map to the same bin.

required

Returns:

Type Description
Num[Array, 'height width']

Rendered scatter plot. Note that there is some indeterminism in case multiple points with the same intensity fall in the same initial pixel.

Source code in processing/src/roverp/graphics/pointcloud.py
def __call__(
    self, x: Num[Array, "N"], y: Num[Array, "N"], c: Num[Array, "N"]
) -> Num[Array, "height width"]:
    """Render scatter plot image, with each point as a circle.

    Args:
        x: x-coordinate, with +x facing right. `x` should be
            normalized to [0, 1] as the width of the image.
        y: y-coordinate, with +y facing up. `y` should be normalized to
            [0, 1] as the height of the image.
        c: point intensity with arbitrary type. The intensity is sorted
            in increasing order to maximize the chances that the higher
            value is taken in case multiple points map to the same bin.

    Returns:
        Rendered scatter plot. Note that there is some indeterminism in
            case multiple points with the same intensity fall in the same
            initial pixel.
    """
    img = jnp.zeros(self.resolution, dtype=c.dtype)

    ord = jnp.argsort(c)
    iy = (self.resolution[0] * (1 - y)).astype(jnp.int32)
    ix = (x * self.resolution[1]).astype(jnp.int32)
    img = img.at[iy[ord], ix[ord]].set(c[ord])
    return self.dilate(img)

roverp.graphics.hsv_to_rgb

hsv_to_rgb(hsv: Float[Array, '... 3']) -> Float[Array, '... 3']

Convert hsv values to rgb.

Copied from matplotlib, modified for vectorization, and converted to jax.

Parameters:

Name Type Description Default
hsv Float[Array, '... 3']

HSV colors.

required

Returns:

Type Description
Float[Array, '... 3']

RGB colors float (0, 1), using the array format corresponding to the provided backend.

Source code in processing/src/roverp/graphics/colors.py
def hsv_to_rgb(
    hsv: Float[Array, "... 3"]
) -> Float[Array, "... 3"]:
    """Convert hsv values to rgb.

    Copied [from matplotlib](
    https://matplotlib.org/3.1.1/_modules/matplotlib/colors.html#hsv_to_rgb),
    modified for vectorization, and converted to jax.

    Args:
        hsv: HSV colors.

    Returns:
        RGB colors `float (0, 1)`, using the array format corresponding to the
            provided backend.
    """
    in_shape = hsv.shape
    h = hsv[..., 0]
    s = hsv[..., 1]
    v = hsv[..., 2]

    i = (h * 6.0).astype(int)
    f = (h * 6.0) - i
    p = v * (1.0 - s)
    q = v * (1.0 - s * f)
    t = v * (1.0 - s * (1.0 - f))

    r = sum((i % 6 == j) * x for j, x in enumerate([v, q, p, p, t, v, v]))
    g = sum((i % 6 == j) * x for j, x in enumerate([t, v, v, q, p, p, v]))
    b = sum((i % 6 == j) * x for j, x in enumerate([p, p, t, v, v, q, v]))

    rgb = jnp.stack([r, g, b], axis=-1)
    return rgb.reshape(in_shape)

roverp.graphics.lut

lut(colors: Num[Array, 'n d'], data: Float[Array, ...]) -> Num[Array, '... d']

Apply a discrete lookup table (e.g. colormap).

Parameters:

Name Type Description Default
colors Num[Array, 'n d']

list of discrete colors to apply (e.g. 0-255 RGB values). Can be an arbitrary number of channels, not just RGB.

required
data Float[Array, ...]

input data to index (0 <= data <= 1).

required

Returns:

Type Description
Num[Array, '... d']

An array with the same shape as data, with an extra dimension appended.

Source code in processing/src/roverp/graphics/colors.py
def lut(
    colors: Num[Array, "n d"], data: Float[Array, "..."]
) -> Num[Array, "... d"]:
    """Apply a discrete lookup table (e.g. colormap).

    Args:
        colors: list of discrete colors to apply (e.g. 0-255 RGB values). Can
            be an arbitrary number of channels, not just RGB.
        data: input data to index (`0 <= data <= 1`).

    Returns:
        An array with the same shape as `data`, with an extra dimension
            appended.
    """
    fidx = jnp.clip(data, 0.0, 1.0) * (colors.shape[0] - 1)
    return jnp.take(colors, fidx.astype(int), axis=0)

roverp.graphics.mpl_colormap

mpl_colormap(cmap: str = 'viridis') -> UInt8[Array, 'n 3']

Get color LUT from matplotlib colormap.

Use with lut.

Source code in processing/src/roverp/graphics/colors.py
def mpl_colormap(cmap: str = "viridis") -> UInt8[Array, "n 3"]:
    """Get color LUT from matplotlib colormap.

    Use with [`lut`][^.].
    """
    # For some reason, mypy does not recognize `colors` as an attribute of mpl.
    colors = cast(
        matplotlib.colors.ListedColormap,  # type: ignore
        matplotlib.colormaps[cmap]).colors
    return (jnp.array(colors) * 255).astype(jnp.uint8)

roverp.graphics.render_image

render_image(
    data: Num[Array, "h w"],
    colors: Num[Array, "n d"] | None = None,
    resize: tuple[int, int] | None = None,
    scale: float | int | None = None,
    pmin: float | None = None,
    pmax: float | None = None,
) -> UInt8[Array, "h2 w2 d"]

Apply colormap with specified scaling, clipping, and sizing.

Parameters:

Name Type Description Default
colors Num[Array, 'n d'] | None

colormap (e.g. output of 🇵🇾func:mpl_colormap).

None
data Num[Array, 'h w']

input data to map.

required
resize tuple[int, int] | None

resize inputs to specified size.

None
scale float | int | None

if specified, use this exact scale to normalize the data to [0, 1], with clipping applied.

None
pmin float | None

if specified, use this percentile as the minimum for normalization instead of the actual min.

None
pmax float | None

if specified, use this percentile as the maximum for normalization instead of the actual max.

None

Returns:

Type Description
UInt8[Array, 'h2 w2 d']

Rendered RGB image.

Source code in processing/src/roverp/graphics/colors.py
def render_image(
    data: Num[Array, "h w"],
    colors: Num[Array, "n d"] | None = None,
    resize: tuple[int, int] | None = None,
    scale: float | int | None = None,
    pmin: float | None = None, pmax: float | None = None
) -> UInt8[Array, "h2 w2 d"]:
    """Apply colormap with specified scaling, clipping, and sizing.

    Args:
        colors: colormap (e.g. output of :py:func:`mpl_colormap`).
        data: input data to map.
        resize: resize inputs to specified size.
        scale: if specified, use this exact scale to normalize the data to
            `[0, 1]`, with clipping applied.
        pmin: if specified, use this percentile as the minimum for
            normalization instead of the actual min.
        pmax: if specified, use this percentile as the maximum for
            normalization instead of the actual max.

    Returns:
        Rendered RGB image.
    """
    if colors is None:
        raise ValueError("Must specify input colormap.")

    if scale is not None:
        data = jnp.clip(data / scale, 0.0, 1.0)
    else:
        left = (
            jnp.percentile(data, pmin) if pmin is not None else jnp.min(data))
        right = (
            jnp.percentile(data, pmax) if pmax is not None else jnp.max(data))
        data = jnp.clip((data - left) / (right - left), 0.0, 1.0)

    if resize is not None:
        data = resize_func(data, height=resize[0], width=resize[1])

    return lut(colors, data)

roverp.graphics.synchronize

synchronize(
    streams: dict[str, Iterator[Any]],
    timestamps: dict[str, Float[ndarray, "?T"]],
    period: float = 1 / 30.0,
    round: float | None = None,
    duplicate: dict[str, str] = {},
    batch: int = 0,
    stop_at: float = 0.0,
) -> Iterator[
    tuple[float, dict[str, int], Any] | list[tuple[float, dict[str, int], Any]]
]

Sychronize asynchronous video/data streams.

Parameters:

Name Type Description Default
streams dict[str, Iterator[Any]]

input iterator streams to synchronize.

required
timestamps dict[str, Float[ndarray, '?T']]

timestamp arrays for each stream.

required
period float

query period, in seconds.

1 / 30.0
round float | None

if specified, round the start time up to the nearest round seconds.

None
duplicate dict[str, str]

duplicate selected streams (values) into the specified keys.

{}
batch int

batch size for parallelized pipelines. If batch=0 (default), no batching is applied.

0
stop_at float

terminate early after this many seconds. If 0.0 (default), plays back the full provided streams.

0.0

Yields:

Type Description
tuple[float, dict[str, int], Any] | list[tuple[float, dict[str, int], Any]]

A tuple with the current timestamp relative to the start time, the index of each synchronized frame, and references to the active values at that timestamp. The active values are given by reference only, and should not be modified.

Source code in processing/src/roverp/graphics/sync.py
def synchronize(
    streams: dict[str, Iterator[Any]],
    timestamps: dict[str, Float[np.ndarray, "?T"]],
    period: float = 1 / 30.0,
    round: float | None = None,
    duplicate: dict[str, str] = {},
    batch: int = 0,
    stop_at: float = 0.0
) -> Iterator[
    tuple[float, dict[str, int], Any]
    | list[tuple[float, dict[str, int], Any]]
]:
    """Sychronize asynchronous video/data streams.

    Args:
        streams: input iterator streams to synchronize.
        timestamps: timestamp arrays for each stream.
        period: query period, in seconds.
        round: if specified, round the start time up to the nearest `round`
            seconds.
        duplicate: duplicate selected streams (values) into the specified keys.
        batch: batch size for parallelized pipelines. If `batch=0` (default),
            no batching is applied.
        stop_at: terminate early after this many seconds. If `0.0` (default),
            plays back the full provided streams.

    Yields:
        A tuple with the current timestamp relative to the start time, the
            index of each synchronized frame, and references to the active
            values at that timestamp. The active values are given by reference
            only, and should not be modified.
    """
    def handle_duplicates(t, ii, active):
        for k, v in duplicate.items():
            ii[k] = ii[v]
            active[k] = active[v]
        return t, ii, active

    if set(streams.keys()) != set(timestamps.keys()):
        raise ValueError("Streams and timestamps do not have matching keys.")

    ii = {k: 0 for k in timestamps}
    active = {k: next(v) for k, v in streams.items()}

    start_time = max(v[0] for v in timestamps.values())
    if round is not None:
        start_time = start_time // round + round
    ts = start_time

    _batch: list = []
    try:
        while True:
            for k in timestamps:
                while timestamps[k][ii[k]] < ts:
                    ii[k] += 1
                    if ii[k] >= timestamps[k].shape[0]:
                        raise StopIteration

                    try:
                        active[k] = next(streams[k])
                    except StopIteration:
                        print(f"Exhausted: {k}")
                        raise StopIteration

            if batch == 0:
                yield handle_duplicates(ts - start_time, ii, active)
            else:
                # Need to make sure we get a copy of ii and active!
                _batch.append(handle_duplicates(
                    ts - start_time, dict(**ii), dict(**active)))
                if len(_batch) == batch:
                    yield _batch
                    _batch = []
            ts += period

            if stop_at > 0 and ts > start_time + stop_at:
                print(f"Stopping early: t=+{ts - start_time:.3f}s")
                raise StopIteration

    except StopIteration:
        pass

roverp.graphics.write_buffered

write_buffered(
    queue: Queue[UInt8[ndarray | Array, "*batch H W 3"] | None],
    out: str,
    fps: float = 30.0,
    codec: str = "h264",
) -> None

Write video from a queue of optionally batched images.

Parameters:

Name Type Description Default
queue Queue[UInt8[ndarray | Array, '*batch H W 3'] | None]

input queue; None indicates the end of the stream.

required
out str

output file path.

required
fps float

frames per second.

30.0
codec str

video codec to use; see supported codecs.

'h264'
Source code in processing/src/roverp/graphics/writer.py
def write_buffered(
    queue: Queue[UInt8[np.ndarray | Array, "*batch H W 3"] | None],
    out: str, fps: float = 30.0, codec: str = "h264",
) -> None:
    """Write video from a queue of optionally batched images.

    Args:
        queue: input queue; `None` indicates the end of the stream.
        out: output file path.
        fps: frames per second.
        codec: video codec to use; see [supported codecs](
            https://imageio.readthedocs.io/en/stable/format_gif.html#supported-codecs).
    """

    def worker():
        writer = imageio.get_writer(out, fps=fps, codec=codec)
        while True:
            frame = queue.get()
            if frame is None:
                break
            if len(frame.shape) == 4:
                for x in frame:
                    writer.append_data(np.array(x))
            else:
                writer.append_data(np.array(frame))
        writer.close()

    Thread(target=worker, daemon=True).start()

roverp.graphics.write_consume

write_consume(
    iter: Iterator[UInt8[ndarray | Array, "*batch H W 3"]],
    out: str,
    fps: float = 30.0,
    codec: str = "h264",
    queue_size: int = 32,
) -> None

Write video from an iterator of optionally batched images.

Parameters:

Name Type Description Default
iter Iterator[UInt8[ndarray | Array, '*batch H W 3']]

input iterator.

required
out str

output file path.

required
fps float

frames per second.

30.0
codec str

video codec to use; see supported codecs.

'h264'
queue_size int

maximum size of the internal queue to use for buffering.

32
Source code in processing/src/roverp/graphics/writer.py
def write_consume(
    iter: Iterator[UInt8[np.ndarray | Array, "*batch H W 3"]], out: str,
    fps: float = 30.0, codec: str = "h264", queue_size: int = 32
) -> None:
    """Write video from an iterator of optionally batched images.

    Args:
        iter: input iterator.
        out: output file path.
        fps: frames per second.
        codec: video codec to use; see [supported codecs](
            https://imageio.readthedocs.io/en/stable/format_gif.html#supported-codecs).
        queue_size: maximum size of the internal queue to use for buffering.
    """
    queue: Queue[UInt8[np.ndarray | Array, "*batch H W 3"] | None]
    queue = Queue(maxsize=queue_size)
    write_buffered(queue, out=out, fps=fps, codec=codec)

    for item in iter:
        queue.put(item)
    queue.put(None)