abstract_dataloader.ext.torch
¶
Pytorch interfaces and compatibility wrappers.
abstract_dataloader.ext.torch.Collate
¶
Bases: Collate[TTransformed, TCollated]
Generic numpy to pytorch collation.
Info
This collator uses optree.tree_map to
recursively traverse the input data structure. Python primitive
containers will work out-of-the-box, while dataclasses must be
registered with optree.
| Input | Behavior |
|---|---|
torch.Tensor |
Either stacked or concatenated, depending on mode. |
numpy.ndarray |
Converted to Tensor, then stacked/concatenated. |
int | float | bool, convert_scalars=True |
Converted to Tensor. |
| All other types | Passed through as a list. |
Type Parameters
TTransformed: input sample type.TCollated: output collated type.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mode
|
Literal['stack', 'concat']
|
whether to |
'concat'
|
convert_scalars
|
bool
|
whether to convert python scalars to pytorch tensors. |
True
|
Source code in src/abstract_dataloader/ext/torch.py
__call__
¶
__call__(data: Sequence[TTransformed]) -> TCollated
Apply collation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Sequence[TTransformed]
|
sequence of samples to collate (i.e., list of objects). Must have an identical structure. |
required |
Returns:
| Type | Description |
|---|---|
TCollated
|
Collated batch (i.e., object of lists). |
Source code in src/abstract_dataloader/ext/torch.py
abstract_dataloader.ext.torch.Pipeline
¶
Bases: Module, Pipeline[TRaw, TTransformed, TCollated, TProcessed]
Dataloader transform pipeline.
This pytorch-compatible pipeline extends
torch.nn.Module. It recursively searches its inputs
for a .children() -> Iterator | Iterable method, and checks the children
for any nn.Module objects, which are registered as submodules.
Type Parameters
TRaw: Input data format.TTransformed: Data after the firsttransformstep.TCollated: Data after the secondcollatestep.TProcessed: Output data format.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sample
|
Transform[TRaw, TTransformed] | None
|
sample transform; if |
None
|
collate
|
Collate[TTransformed, TCollated] | None
|
sample collation; if |
None
|
batch
|
Transform[TCollated, TProcessed] | None
|
batch collation; if |
None
|
Source code in src/abstract_dataloader/ext/torch.py
abstract_dataloader.ext.torch.TransformedDataset
¶
Bases: Dataset[TTransformed], Generic[TRaw, TTransformed]
Pytorch-compatible dataset with transformation applied.
Extends torch.utils.data.Dataset,
implementing a torch "map-style" dataset.
Type Parameters
TRaw: raw data type from the dataloader.TTransformed: output data type from the provided transform function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dataset
|
Dataset[TRaw]
|
source dataset. |
required |
transform
|
Transform[TRaw, TTransformed]
|
transformation to apply to each sample when loading (note
that |
required |