classdef Tempotron < handle
% A class representing a tempotron, as described in
% Gutig & Sompolinsky (2006).
% The (subthreshold) membrane voltage of the tempotron
% is a weighted sum from all incoming spikes and the
% resting potential of the neuron. The contribution of
% each spike decays exponentiall with time, the speed of
% this decay is determined by two parameters tau and tau_s,
% denoting the decay time constants of membrane integration
% and synaptic currents, respectively.
threshold = 1.0;
t_spi = 10;
function obj = Tempotron(V_rest, tau, tau_s, synaptic_efficacies, threshold)
% set parameters as attributes
obj.V_rest = V_rest;
obj.tau = tau;
obj.tau_s = tau_s; % usually smaller than tau, the recommended value: tau/4
obj.log_tts = log(obj.tau/obj.tau_s);
if nargin>4, obj.threshold = threshold; end
obj.efficacies = synaptic_efficacies(:);
obj.t_spi = 10; % spike integration time, compute this with formula
% compute normalisation factor V_0
obj.V_norm = obj.compute_norm_factor(tau, tau_s);
function V_0 = compute_norm_factor(obj, tau, tau_s)
% Compute and return the normalisation factor:
% V_0 = 1/K((tau * tau_s * log(tau/tau_s)) / (tau - tau_s)-0)
% That normalises the function:
% K(t-t_i) = V_0 (exp(-(t-t_i)/tau) - exp(-(t-t_i)/tau_s)
% Such that it amplitude is 1 and the unitary PSP
% amplitudes are given by the synaptic efficacies.
tmax = (tau * tau_s * log(tau/tau_s)) / (tau - tau_s);
v_max = obj.K(1, tmax, 0);
V_0 = 1/v_max;
function value = K(obj, V_0, t, t_i)
% Compute the function
% K(t-t_i) = V_0 (exp(-(t-t_i)/tau) - exp(-(t-t_i)/tau_s)
% It gives a horizon line for tau = tau_s and a peak (valley)
% for tau > tau_s (tau < tau_s).
low = t<t_i;
value = (~low).*(V_0 * (exp(-(t-t_i)/obj.tau) - exp(-(t-t_i)/obj.tau_s)));
function V = compute_membrane_potential(obj, t, spike_times)
% Compute the membrane potential of the neuron given
% by the function:
% V(t) = sum_i w_i sum_{t_i} K(t-t_i) + V_rest
% Where w_i denote the synaptic efficacies and t_i denote
% ith afferent.
% spike_times: an cell array with at position i the spike times of
% the i-th afferent
% create an array with the contributions of the
% spikes for each synaps
spike_contribs = obj.compute_spike_contributions(t, spike_times);
% multiply with the synaptic efficacies
total_incoming = spike_contribs.*obj.efficacies;
% add sum and add V_rest to get membrane potential
V = sum(total_incoming) + obj.V_rest;
function deriv = compute_derivative(obj, t, spike_times)
% Compute the derivative of the membrane potential
% of the neuron at time t.
% This derivative is given by:
% V'(t) = V_0 sum_i w_i sum_{t_n} (exp(-(t-t_n)/tau_s)/tau_s - exp(-(t-t_n)/tau)/tau)
% for t_n < t
% sort spikes in chronological order
spike_times = spike_times(:).';
ha = 0;
ns = cellfun('length',spike_times);
spikes_chron = zeros(sum(ns),2);
for synapse = 1:length(spike_times)
temp = spike_times{synapse};
if ~isempty(temp)
time = temp(:);
spikes_chron(ha+1:ha+ns(synapse),:) = [time, synapse*ones(ns(synapse),1)];
ha = ha + ns(synapse);
spikes_chron = sortrows(spikes_chron,1);
% Make a list of spike times and their corresponding weights
spikes = [spikes_chron(:,1), obj.efficacies(spikes_chron(:,2))];% efficacies长度必须和spike_times一样
% At time t we want to incorporate all the spikes for which
% t_spike < t
ind = spikes(:,1) <= t;
sum_tau = sum(spikes(ind,2).*exp(spikes(ind,1)/obj.tau));
sum_tau_s = sum(spikes(ind,2).*exp(spikes(ind,1)/obj.tau_s));
factor_tau = exp(-t/obj.tau)/obj.tau;
factor_tau_s = exp(-t/obj.tau_s)/obj.tau_s;
deriv = obj.V_norm * (factor_tau_s*sum_tau_s - factor_tau*sum_tau);
function spike_contribs = compute_spike_contributions(obj, t, spike_times)
% Compute the decayed contribution of the incoming spikes.
% nr of synapses
N_synapse = length(spike_times);
% loop over spike times to compute the contributions
% of individual spikes
spike_contribs = zeros(N_synapse,1);
for neuron_pos = 1:N_synapse
spike_time = spike_times{neuron_pos};
% disp(num2str(obj.K(obj.V_rest, t, spike_time))
spike_contribs(neuron_pos) = sum(obj.K(obj.V_norm, t,spike_time));
function train(obj, io_pairs, steps, learning_rate)
% Train the tempotron on the given input-output pairs,
% applying gradient decscend to adapt the weights.
% steps: the maximum number of training steps
% io_pairs: a list with tuples of spike times and the
% desired response on them
% learning_rate: the learning rate of the gradient descend
% Run until maximum number of steps is reached or
% no weight updates occur anymore
for i = 1:steps
% go through io-pairs in random order
np = length(io_pairs);
io_pairs = io_pairs(randperm(np));
for j = 1:np
[spike_times, target] = deal(io_pairs{j}{:});
obj.adapt_weights(spike_times, target, learning_rate);
function [t, membrane_potentials] = get_membrane_potentials(obj, t_start, t_end, spike_times, interval)
% Get a list of membrane potentials from t_start to t_end
% as a result of the inputted spike times.
if nargin<5, interval = 0.1; end
% create vectorised version of membrane potential function
potential_vect = @(t)obj.compute_membrane_potential(t,spike_times);
% compute membrane potentials
t = t_start:interval:(t_end-interval);
membrane_potentials = arrayfun(potential_vect,t);
function [t,derivatives] = get_derivatives(obj, t_start, t_end, spike_times, interval)
% Get a list of the derivative of the membrane potentials from
% t_start to t_end as a result of the inputted spike times.
if nargin<5, interval = 0.1; end
% create a vectorised version of derivative function
deriv_vect = @(t)obj.compute_derivative(t,spike_times);
% exclude spike times from being vectorised
% compute derivatives
t = t_start:interval:(t_end-interval
