model

Train and test neuromorphic models.

Models are networks endowed with the methods fit(), predict(), similarity(), load(), and save(). This class provides an ML-focused API for the training and usage of sapicore Network output for practical purposes.

class model.Model(network: Network = None, **kwargs)

Model base class.

Implements and extends the scikit-learn sklearn.base.BaseEstimator interface.

Note

In a machine learning context, spiking networks can be used in diverse ways, e.g. as classifiers or as generative models. While engine is mostly about form (network architecture and information flow), model is all about function (how to fit the model to data and how to utilize it once trained).

draw(path: str, node_size: int = 750)

Saves an SVG networkx graph plot showing ensembles and their general connectivity patterns.

Parameters:
  • path (str) – Destination path for network figure.

  • node_size (int, optional) – Node size in network graph plot.

Note

May be extended and/or moved to a dedicated visualization package in future versions.

fit(data: Tensor | list[torch.Tensor], repetitions: int | list[int] | Tensor = 1)

Applies engine.network.Network.forward() sequentially on a block of buffer data, then turns off learning for the network.

The training buffer may be obtained, e.g., from a CV cross validator object.

Parameters:
  • data (Tensor or list of Tensor) – 2D tensor(s) of data buffer to be fed to the root ensembles of this object’s network, formatted sample X feature.

  • repetitions (int or list of int or Tensor) – How many times to repeat each sample before moving on to the next one. Simulates duration of exposure to a particular input. If a list or a tensor is provided, the i-th row in the batch is repeated repetitions[i] times.

Warning

fit() does not return intermediate output. Users should register forward hooks to efficiently stream data to disk throughout the simulation (see add_data_hook()).

load(path: str) Network

Loads a engine.network.Network from path and assigns it to this object’s network attribute.

Parameters:

path (str) – Path to the file containing the model.

Returns:

A reference to the loaded engine.network.Network object in case it is required by the calling function.

Return type:

Network

Note

The default implementation uses torch.load(), for the common case where files are used. Users may override this method when other formats are called for.

predict(data: Data | Tensor) Tensor

Predicts the labels of data by feeding the buffer to a trained network and applying some procedure to the resulting population/readout layer response.

Parameters:

data (Data or Tensor) – Sapicore dataset or a standalone 2D tensor of data buffer, formatted sample X feature.

Returns:

Vector (1D tensor) of predicted labels.

Return type:

Tensor

save(path: str)

Saves the engine.network.Network object owned by this model to path.

Parameters:

path (str) – Destination path, inclusive of the file to which the network should be saved.

Note

The default implementation uses torch.save(), for the common case where files are used. Users may override this method when other formats are called for.

set_fit_request(*, data: bool | None | str = '$UNCHANGED$', repetitions: bool | None | str = '$UNCHANGED$') Model

Request metadata passed to the fit method.

Note that this method is only relevant if enable_metadata_routing=True (see sklearn.set_config()). Please see User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

New in version 1.3.

Note

This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a Pipeline. Otherwise it has no effect.

Parameters:
  • data (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for data parameter in fit.

  • repetitions (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for repetitions parameter in fit.

Returns:

self – The updated object.

Return type:

object

set_predict_request(*, data: bool | None | str = '$UNCHANGED$') Model

Request metadata passed to the predict method.

Note that this method is only relevant if enable_metadata_routing=True (see sklearn.set_config()). Please see User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to predict if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to predict.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

New in version 1.3.

Note

This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a Pipeline. Otherwise it has no effect.

Parameters:

data (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for data parameter in predict.

Returns:

self – The updated object.

Return type:

object

similarity(data: Tensor, metric: str | Callable) Tensor

Performs rudimentary similarity analysis on the network population responses to data, obtaining a pairwise distance matrix reflecting sample separation.

Parameters:
  • data (Tensor) – 2D tensor of data buffer, formatted sample X feature.

  • metric (str or Callable) – Distance metric to be used. If a string value is provided, it should correspond to one of the available scipy.spatial.distance metrics. If a custom function is provided, it should accept the data tensor and return a scalar corresponding to their distance.