model
Train and test neuromorphic models.
Models are networks endowed with the methods fit()
, predict()
, similarity()
, load()
,
and save()
. This class provides an ML-focused API for the training and usage of sapicore
Network
output for practical purposes.
- class model.Model(network: Network = None, **kwargs)
Model base class.
Implements and extends the scikit-learn
sklearn.base.BaseEstimator
interface.Note
In a machine learning context, spiking networks can be used in diverse ways, e.g. as classifiers or as generative models. While
engine
is mostly about form (network architecture and information flow),model
is all about function (how to fit the model to data and how to utilize it once trained).- draw(path: str, node_size: int = 750)
Saves an SVG networkx graph plot showing ensembles and their general connectivity patterns.
- Parameters:
path (str) – Destination path for network figure.
node_size (int, optional) – Node size in network graph plot.
Note
May be extended and/or moved to a dedicated visualization package in future versions.
- fit(data: Tensor | list[torch.Tensor], repetitions: int | list[int] | Tensor = 1)
Applies
engine.network.Network.forward()
sequentially on a block of buffer data, then turns off learning for the network.The training buffer may be obtained, e.g., from a
CV
cross validator object.- Parameters:
data (Tensor or list of Tensor) – 2D tensor(s) of data buffer to be fed to the root ensembles of this object’s network, formatted sample X feature.
repetitions (int or list of int or Tensor) – How many times to repeat each sample before moving on to the next one. Simulates duration of exposure to a particular input. If a list or a tensor is provided, the i-th row in the batch is repeated repetitions[i] times.
Warning
fit()
does not return intermediate output. Users should register forward hooks to efficiently stream data to disk throughout the simulation (seeadd_data_hook()
).
- load(path: str) Network
Loads a
engine.network.Network
from path and assigns it to this object’s network attribute.- Parameters:
path (str) – Path to the file containing the model.
- Returns:
A reference to the loaded
engine.network.Network
object in case it is required by the calling function.- Return type:
Note
The default implementation uses
torch.load()
, for the common case where files are used. Users may override this method when other formats are called for.
- predict(data: Data | Tensor) Tensor
Predicts the labels of data by feeding the buffer to a trained network and applying some procedure to the resulting population/readout layer response.
- Parameters:
data (Data or Tensor) – Sapicore dataset or a standalone 2D tensor of data buffer, formatted sample X feature.
- Returns:
Vector (1D tensor) of predicted labels.
- Return type:
Tensor
- save(path: str)
Saves the
engine.network.Network
object owned by this model to path.- Parameters:
path (str) – Destination path, inclusive of the file to which the network should be saved.
Note
The default implementation uses
torch.save()
, for the common case where files are used. Users may override this method when other formats are called for.
- set_fit_request(*, data: bool | None | str = '$UNCHANGED$', repetitions: bool | None | str = '$UNCHANGED$') Model
Request metadata passed to the
fit
method.Note that this method is only relevant if
enable_metadata_routing=True
(seesklearn.set_config()
). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed tofit
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it tofit
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.New in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline
. Otherwise it has no effect.- Parameters:
data (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
data
parameter infit
.repetitions (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
repetitions
parameter infit
.
- Returns:
self – The updated object.
- Return type:
object
- set_predict_request(*, data: bool | None | str = '$UNCHANGED$') Model
Request metadata passed to the
predict
method.Note that this method is only relevant if
enable_metadata_routing=True
(seesklearn.set_config()
). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed topredict
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it topredict
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.New in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline
. Otherwise it has no effect.- Parameters:
data (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
data
parameter inpredict
.- Returns:
self – The updated object.
- Return type:
object
- similarity(data: Tensor, metric: str | Callable) Tensor
Performs rudimentary similarity analysis on the network population responses to data, obtaining a pairwise distance matrix reflecting sample separation.
- Parameters:
data (Tensor) – 2D tensor of data buffer, formatted sample X feature.
metric (str or Callable) – Distance metric to be used. If a string value is provided, it should correspond to one of the available
scipy.spatial.distance
metrics. If a custom function is provided, it should accept the data tensor and return a scalar corresponding to their distance.