data

Data operations.

class data.AxisDescriptor(name: str = '', labels: list | ndarray[Any, dtype[_ScalarType_co]] | Tensor = None, axis: int = 0)

Metadata vector of labels describing a particular axis of an N-dimensional array.

Parameters:
  • name (str, optional) – Name of the descriptor (i.e., the variable whose value is given by the labels).

  • labels (list or NDarray or Tensor, optional) – Label values given as a list, numpy array, or tensor. Regardless of passed type, these are internally converted to a numpy array (not tensors, as those do not support string labels).

  • axis (int, optional) – The data axis described by this variable. Defaults to 1 (columns).

Example

Generate a dummy dataset and describe its axes with multiple lists of labels:

>>> import torch
>>> data = Data(buffer=torch.rand(8, 4))
>>> study = AxisDescriptor(name="study", labels=[1, 1, 2, 2, 1, 1, 2, 2], axis=0)
>>> animal = AxisDescriptor(name="animal", labels=["A", "A", "A", "A", "B", "B", "B", "B"], axis=0)
>>> sensor = AxisDescriptor(name="sensor", labels=["Occipital", "Parietal", "Temporal", "Frontal"], axis=1)

Here, we have eight buffer (rows) of four measured dimensions (columns). The third AxisDescriptor indicates that the columns correspond to sensor locations. The first two describe the rows, and contain information about the study and animal each sample was obtained from.

class data.Data(identifier: str = '', buffer: Tensor = None, metadata: Metadata = None, root: str | None = None, remote_urls: str | list[str] = '', download: bool = False, overwrite: bool = False)

Dataset base class.

Provides an interface and basic default implementations for fetching, organizing, and representing external datasets. Designed to be incrementally extended.

Parameters:
  • identifier (str, optional) – Name for this dataset.

  • metadata (Metadata, optional) – Metadata object describing this dataset. Consists of multiple AxisDescriptor references.

  • root (str, optional) – Local root directory for dataset file(s), if applicable.

  • remote_urls (str or list of str, optional) – Remote URL(s) from which to fetch this dataset, if applicable.

  • buffer (Tensor, optional) – A buffer holding data buffer in tensor form. Useful for initializing objects on the fly, e.g. during data synthesis. Users may leverage access() to manage disk I/O, e.g. using torch.storage.Storage, memory mapped arrays/tensors, or HDF lazy loading.

  • download (bool, optional) – Whether to download the set. Defaults to False. If True, the download only commences if the root doesn’t exist or is empty.

  • overwrite (bool, optional) – Whether to download the set regardless of whether root already contains cached/pre-downloaded data.

access(index: slice, axis: int = None) Tensor

Specifies how to access data by mapping indices to actual samples (e.g., from file(s) in root).

The default implementation slices into self.buffer to accommodate the trivial cases where the user has directly initialized this Data object with a buffer tensor or loaded its values by reading a file that fits in memory (the latter case would be handled by load()).

More sophisticated use cases may require lazy loading or navigating HDF files. That kind of logic should be implemented here by derivative classes.

Parameters:
  • index (slice) – Index(es) to slice into.

  • axis (int, optional) – Optionally, a specific axis along which to apply index selection.

Note

Where file data are concerned (e.g., audio/image, each being a labeled “sample”), use this method to read and potentially transform them, returning the finished product.

get_metadata(keys: bool = True) list[Any]

Returns metadata keys by default, or their values (AxisDescriptor references) if keys is False.

load(indices: slice = None)

Populates the buffer tensor buffer and/or descriptors attribute table by loading one or more files into memory, potentially selecting only indices.

Since different datasets and pipelines call for different formats, implementation is left to the user.

Parameters:

indices (slice) – Specific indices to include, one for each file.

Returns:

Self reference. For use in single-line initialization and loading.

Return type:

Data

Warning

Populating the buffer with the entire dataset should only be done when it can fit in memory. For large sets, the buffer should not be used naively; access() should be overriden to implement some form of lazy loading.

modify(index: slice, values: Tensor)

Set or modify data values at the given indices to values.

The default implementation edits the buffer field of this Data object. Users may wish to override it in cases where the buffer is not used directly.

Parameters:
  • index (slice) – Indices to modify.

  • values (Tensor) – Values to set data at indices to.

sample(method: Callable, axis: int = 0, **kwargs)

Applies method to sample from this dataset once without mutating it, returning a copy of the object containing only the data and labels at the sampled indices.

The method can be a reference to a BaseCrossValidator. In that case, the keyword arguments should include any applicable keyword arguments, e.g. shuffle, label_key, group_key if applicable (see also CV).

If method is not a base cross validator, keyword arguments will be passed to it directly.

Parameters:
  • method (Callable or BaseCrossValidator) – Method used to sample from this dataset.

  • axis (int, optional) – Axis along which selection is performed. Defaults to zero (that is, rows/buffer).

  • kwargs

    retain: int or float

    The number or proportion of buffer to retain.

    shuffle: bool

    If using a sklearn.model_selection.BaseCrossValidator to sample, whether to toggle the shuffle parameter on.

    label_keys: str or list of str

    Label key(s) by which to stratify sampling, if applicable.

    group_key: str

    Label key by which to group sampling, if applicable.

Returns:

A subset of this dataset.

Return type:

Data

save()

Dump the buffer contents and metadata to disk at root.

Since different datasets and pipelines call for different formats, implementation is left to the user.

scan(pattern: str = '*', recursive: bool = False) list[str]

Scans the root directory and returns a list of files found that match the glob pattern.

Parameters:
  • pattern (str, optional) – Pattern to look for in file names that should be included, glob-formatted. Defaults to “*” (any).

  • recursive (bool, optional) – Whether glob should search recursively. Defaults to False.

trim(index: slice, axis: int = None)

Trims this instance by selecting indices, potentially along axis, returning a subset of the original dataset in terms of both buffer entries and labels/descriptors. Does not mutate the underlying object.

Parameters:
  • index (slice) – Index(es) to retain.

  • axis (int, optional) – Single axis to apply index selection to. Defaults to None, in which case index is used directly within access().

class data.Metadata(*args: AxisDescriptor)

Collection of AxisDescriptor objects indexed in a table.

add_descriptors(*args: AxisDescriptor)

Adds one or more AxisDescriptor objects to this metadata instance.

to_dataframe(axis: int = 0) DataFrame

Condenses the AxisDescriptor objects registered in this dataset’s data.Data.Metadata table attribute to a pandas dataframe.

Performed for row (sample) descriptors by default, but can also be used to aggregate AxisDescriptor objects describing any other axis.

Note

If descriptors for a particular axis vary in length, NaN values will be filled in as necessary.

Modules

data.external

External dataset processing.

data.sampling

Data sampling utility classes.