data.sampling

Data sampling utility classes.

Leverages scikit-learn and pandas to implement versatile cross validation and sampling schemes.

class data.sampling.BalancedSampler(stratified: bool = False, replace: bool = False)

Sampler base class.

Leverages pandas’ group-by operation to parsimoniously perform (dis)proportionate random sampling, with or without replacement.

Parameters:
  • stratified (bool) – By default (False), attempts to draw the same number of buffer from each group regardless of imbalances. If True, uses stratified sampling (proportionate, preserving class imbalances).

  • replace (bool) – Whether to sample with replacement. Defaults to False.

Note

sample() accepts initialized objects of this class. Additional samplers may be implemented in future versions.

class data.sampling.CV(data, cross_validator: BaseCrossValidator, label_keys: str | list[str] = None, group_key: str = None, **kwargs)

Cross validation base class.

Parameter

data: Data

The data.Data object for which to generate the cross validator.

cross_validator: BaseCrossValidator

A scikit-learn BaseCrossValidator object.

label_keys: str or list of str, optional

One or more descriptor label names. Required for some CV schemes (e.g., for StratifiedKFold, it would be the descriptor whose class label counts should be consistent with their proportions in the full sample).

group_key: str, optional

A grouping descriptor key. Required for some CV schemes (e.g., GroupKFold, for generating batch-like folds with non-overlapping groups w.r.t. some descriptor).

Note

Some external libraries extend the cross validator offerings of scikit-learn. This implementation is compatible with any class derived from BaseCrossValidator.