model
Train and test neuromorphic models.
Models are networks endowed with the methods fit()
, predict()
, similarity()
, load()
,
and save()
. This class provides an ML-focused API for the training and usage of sapicore
Network
output for practical purposes.
- class model.Model(network: Network = None, **kwargs)
Model base class.
Loosely follows the design of scikit-learn’s
sklearn.base.BaseEstimator
interface.Note
In a machine learning context, spiking networks can be used in diverse ways, e.g. as classifiers or as generative models. While
engine
is mostly about form (network architecture and information flow),model
is all about function (how to fit the model to data and how to utilize it once trained).- fit(data: Tensor | Sequence[Tensor], duration: int | Sequence[int], rinse: int | Sequence[int] = 0, **kwargs)
Serves a batch of data, then turns off learning for all synapses.
Warning
fit()
does not return intermediate output. Users should register forward hooks to efficiently stream data to memory or disk throughout the simulation (seeadd_data_hook()
).
- load(path: str) Network
Loads a
engine.network.Network
from path and assigns it to this object’s network attribute.- Parameters:
path (str) – Path to the file containing the model.
- Returns:
A reference to the loaded
engine.network.Network
object in case it is required by the calling function.- Return type:
Note
The default implementation uses
torch.load()
, for the common case where files are used. Users may override this method when other formats are called for.
- predict(data: Tensor, labels: Sequence, **kwargs) Sequence
Predicts the labels of data.
- Parameters:
data (Tensor) – Standalone 2D tensor of data buffer, sample X feature.
labels (Sequence) – Label values corresponding to classification layer cell indices.
- Returns:
Vector of predicted labels.
- Return type:
Sequence
- save(path: str)
Saves the
engine.network.Network
object owned by this model to path.- Parameters:
path (str) – Destination path, inclusive of the file to which the network should be saved.
Note
The default implementation uses
torch.save()
, for the common case where files are used. Users may override this method when other formats are called for.
- serve(data: Tensor | Sequence[Tensor], duration: int | Sequence[int], rinse: int | Sequence[int] = 0, **kwargs)
Applies
engine.network.Network.forward()
sequentially on a batch of buffer data.- Parameters:
data (Tensor or Sequence of Tensor) – 2D tensor(s) of data buffer to be fed to the root ensemble(s) of this object’s network, formatted sample X feature.
duration (int or Sequence of int) – Duration of sample presentation. Simulates duration of exposure to a particular input. If a list or a tensor is provided, the i-th sample in the batch is maintained for duration[i] steps.
rinse (int or Sequence of int) – Null stimulation steps (0s in-between samples). If a list or a tensor is provided, the i-th sample is followed by rinse[i] rinse steps.
duration[i] (Each sample i is presented for) –
rinse[i]. (followed by all 0s stimulation for) –
- similarity(data: Tensor, metric: str | Callable, **kwargs) Tensor
Performs a similarity analysis on network responses to data, yielding a pairwise distance matrix.
- Parameters:
data (Tensor) – 2D tensor of data buffer, formatted sample X feature.
metric (str or Callable) – Distance metric to be used. If a string value is provided, it should correspond to one of the available
scipy.spatial.distance
metrics. If a custom function is provided, it should accept the data tensor and return a scalar corresponding to their distance.