simrun
❭ modular_reduced_model_inference
❭ strategy
❭ Strategy_spatiotemporalRaisedCosine
Strategy_spatiotemporalRaisedCosine¶
- class simrun.modular_reduced_model_inference.strategy.Strategy_spatiotemporalRaisedCosine(name, RaisedCosineBasis_spatial, RaisedCosineBasis_temporal)¶
Spatiotemporal raised cosine strategy.
Uses the :py:class:
RaisedCosineBasis
to create a set of basis functions.Attention
The input data must contain the following keys: -
spatiotemporalSa
: The spatiotemporal synaptic activation patterns of shape (n_spatial_bins, n_temporal_bins, n_trials). -st
: The spike times. -y
: The labels. -ISI
: The inter-spike intervals.- Parameters:¶
name (str) – The name of the strategy.
RaisedCosineBasis_spatial (RaisedCosineBasis) – The spatial basis functions \(\mathbf{g}(z)\).
RaisedCosineBasis_temporal (RaisedCosineBasis) – The temporal basis \(\mathbf{f}(t)\).
- Attributes:¶
- base_vectors_arrays_dict¶
The basis vectors for each group. basis vectors are of shape (n_trials, N_{tau}, N_{z}) These basis vectors are used for the optimizer, and are already multiplied with the data. Do not confuse them with the basis vectors of
RaisedCosineBasis_spatial
andRaisedCosineBasis_temporal
, as the latter are not multiplied with the synapse activaiton data.- Type:¶
dict
- convert_x¶
The conversion function to convert the 1D learnable weight vector \(\mathbf{x}\) into a structured dictionary.
- Type:¶
callable
- Methods:¶
_setup
()Compute the strategy’s basis vectors and set up the objective function.
Compute the basis vectors for the dataset.
_get_x0
()Get an initial guess for the learnable weights \(\mathbf{x}\) and \(\mathbf{y}\) of the basis functions \(\mathbf{f}(\tau)\) and \(\mathbf{g}(z)\).
_convert_x_static
(groups, len_z, x)static Convert the input array \(\mathbf{x}\) into a dictionary of basis vectors.
_get_score_static
(convert_x, base_vectors_arrays_dict, x)static Calculate the weighted net input \(WNI(t)\) for the given weights \(\mathbf{x}\).
normalize
(x, flipkey)Normalize the kernel basis functions such that sum of all absolute values of all kernels is 1.
get_color_by_group
(group)Map groups to a color.
visualize
(optimizer_output, only_successful, normalize)Plot the basis functions.