Matlab Code for Spike Time Distances Between Labeled (Multineuronal) Spike
Trains, Parallel in q and k
main cost-based metrics page
algorithm page for cost-based metrics
Matlab code for Multineuronal Spike Time Metric, Parallel in q and k
There are four modules: spkdallqk_recur, spkdallqk_final, spkdallqk_dist, spkdallqk_trunc,
with the following relationships:
- spkdallqk_recur
carries out the core dynamic programming algorithm and produces
two output arrays, clast and c. clast applies to the entire spike trains; c can be
used (via spkdallqk_trunc) to calculate distances from truncated spike trains.
- spkdallqk_final
carries out the final step of the algorithm to calculate distances
from the array clast produced by spkdallqk_recur.
- spkdallqk_dist
calls spkdallqk_recur and then spkdallqk_final.
- spkdallqk_trunc
recalculates an array clast from truncated spike trains,
from the array c calculated by spkdallqk_recur.
spkdallqk_recur
function [clast,rmax,smax,opts,c]=spkdallqk_recur(sa,la,sb,lb,optsin)
% [clast,rmax,smax,opts,c]=spkdallqk_recur(sa,la,sb,lb,optsin) does the recursion for the Aronov
% multiunit algorithm, extended to simultaneous calculation for all values of q and k.
%
% sa, sb: spike times on the two spike trains
% la, lb: spike labels (positive integers)
% optsin: available for various debugging and optimizing options
% optsin.recurdebug: show some debugging (1 or 2 levels) within the recursion
% optsin.switchab: exchange a and b if necessary to reduce total number of inner loops
% 0->disable, 1->enable (default), 2->enable and show, 3->force swap
% optsin.disjoint: merge labels in one train if the other train does not contain them
% 0->disable, 1->enable (default), 2->enable and show
% optsin.sizeonly: 1->clast returns as a scalar, indicating the size of the working matrix;
% c returns as a vector, indicating the dimensions of the working matrix
% optsin.spaceopt: 0-> original version (default), 1->space-optimized version (after Ifije Ohiorhenuan)
%
% sa,la,sb,lb are potentially swapped and otherwise optimized before proceeding.
% original values are stored in opts.orig.sa, opts.orig.la, opts.orig.sb, opts.orig.lb.
% after optimization, they are stored in opts.used.sa, opts.used.la, opts.used.sb, opts.used.lb.
%
%
% rmax=max number of possible matches, sum(min(length(sa_k),length(sb_k)))
% smax=max number of possible mismatches
% c(m+1,n1+1,..., nL+1,r+1,s+1) is the shortest linkage connecting r pairs
% of spikes that match, and s pairs of spikes that do not match.
% clast=c(end,...,end,:,:) is the portion of c necessary for calculating distances,
% which is done in spkdallqk_final.
% size(clast)=[rmax+1 smax+1]
%
% size(c)=[length(sa)+1 length(sb_1)+1 ... length(sb_L)+1 rmax+1 smax+1]
%
% Copyright (c) by Jonathan Victor.
%
% See also SPKDALLQK_FINAL, SPKDALLQK_DIST, SPKD, LABDIST_FAST.
%
if (nargin<=4) optsin=[]; end
opts=optsin;
if (~isfield(opts,'recurdebug')) opts.recurdebug=0; end
if (~isfield(opts,'switchab')) opts.switchab=1; end
if (~isfield(opts,'disjoint')) opts.disjoint=1; end
if (~isfield(opts,'sizeonly')) opts.sizeonly=0; end
if (~isfield(opts,'spaceopt')) opts.spaceopt=0; end
%save input values
opts.orig.sa=sa;opts.orig.sb=sb;opts.orig.la=la;opts.orig.lb=lb;
%
if opts.disjoint>=1
if (opts.disjoint>=2);disp('before disjoint optimization, la, lb=');disp(la);disp(lb);end
lbnota=setdiff(unique(lb),unique(la));
lanotb=setdiff(unique(la),unique(lb));
if (length(lbnota)>1);lb(find(ismember(lb,lbnota)))=lbnota(1);end
if (length(lanotb)>1);la(find(ismember(la,lanotb)))=lanotb(1);end
if (opts.disjoint>=2);disp(' after disjoint optimization, la, lb=');disp(la);disp(lb);end
end
%
[lurelabel,index,relabel] = unique([la lb]);
labs = prod(size(lurelabel));
%replace each label by 1, ..., length(lunique)
%disp('after replacement')
la=relabel(1:length(la));
lb=relabel(length(la)+1:end);
lunique=unique([la lb]);
nua=[];nub=[];
for k = 1:labs
sua{k}=sa(find(la==lunique(k)));
nua(k)=length(find(la==lunique(k)));
sub{k}=sb(find(lb==lunique(k)));
nub(k)=length(find(lb==lunique(k)));
end
na=sum(nua);
nb=sum(nub);
rmax=sum(min([nua;nub],[],1));
smax=min(sum(min([nua;(nb-nub)],[],1)),sum(min([(na-nua);nub],[],1)));
bsize=prod(nub+1);
if (opts.switchab>=1)
asize=prod(nua+1);
absize=(na+1)*bsize;
basize=(nb+1)*asize;
if (opts.switchab>=2)
disp(' na+1 asize, nb+1 bsize, absize, basize')
disp([na+1 asize nb+1 bsize absize basize]);
end
if ((absize>basize) | (opts.switchab>=3))
if (opts.switchab>=2) disp('a and b swapped'); end
lt=la;la=lb;lb=lt;
st=sa;sa=sb;sb=st;
nt=na;na=nb;nb=nt;
nut=nua;nua=nub;nub=nut;
tsize=asize;asize=bsize;bsize=tsize;
for k=1:labs
sut{k}=sua{k};sua{k}=sub{k};sub{k}=sut{k};
end
else
if (opts.switchab>=2) disp(' a and b not swapped.'); end
end
end
cdims=[na+1 nub+1 rmax+1 smax+1]; %actual dimensions for c
cwdims=[na+1 bsize rmax+1 smax+1]; %working dimensions for cw
csdims=[2 bsize rmax+1 smax+1]; %working dimensions for space optimized version
%save values used for later use
opts.used.sa=sa;opts.used.sb=sb;opts.used.la=la;opts.used.lb=lb;opts.used.cdims=cdims;opts.used.cwdims=cwdims;
%
if (opts.recurdebug)
disp('cwdims');disp(cwdims);
end
%if only sizes are requested, do this and return
if (opts.sizeonly)
if (opts.spaceopt==1)
c=csdims;
else
c=cwdims;
end
clast=prod(c);
return;
end
%initialize c for spaceopt
if (opts.spaceopt==1)
c.time=[];
c.na=[];
c.bmulti=[];
c.rmax=[];
c.smax=[];
c.clast=cell(0);
end
%check for trivial situations
if (min(na,nb)==0)
clast=0;
rmax=0;
smax=0;
if (opts.spaceopt==0);
c=zeros(cdims);
else
c.time=unique([sa sb]);
if (na==0)
c.na=zeros(1,nb);
c.bmulti=[1:nb];
end
if (nb==0)
c.na=[1:na];
c.bmulti=zeros(1,na);
end
c.rmax=zeros(1,max(na,nb));
c.smax=zeros(1,max(na,nb));
for j=1:max(na,nb)
c.clast{j}=0;
end
return
end
end
%set up divisors for de-indexing and re-indexing
bdiv=[];
for k=1:labs
bdiv(k)=prod(nub(1:k-1)+1);
end
if (opts.recurdebug)
disp('bdiv');disp(bdiv);
end
if (opts.spaceopt==0)
if (opts.recurdebug>=1) disp('original (non-space-optimized)'); end
%the original recursion (non space-optimized)
cw=repmat(NaN,cwdims); % working array
cw(:,:,1,1)=0; %initialize: no length needed for no links
%loop over the subdivided train (b)
for jb=0:(bsize-1)
%find the b-multi-index
bv=[];
for k=1:labs
bv(k)=mod(floor(jb/bdiv(k)),nub(k)+1); %lengths of current subtrains for b
end
bvm=repmat(bv,[labs 1])-eye(labs);
jbrec=bdiv*bvm'; %0-based the indices needed for the recursion
if (opts.recurdebug>=2)
disp('jb');disp(jb);
disp('bv');disp(bv);
if (opts.recurdebug>=3);disp('bvm');disp(bvm);end
disp('jbrec');disp(jbrec);
end
%loop over the non-subdivided train (a)
av=zeros(1,labs); %will track how many spikes of which kind are in the first ia spikes of a
for ia=0:na
if (ia>0) u=la(ia); av(u)=av(u)+1; end % u is label of the ith spike
rmaxij=sum(min([av;bv],[],1)); %maximum number of matched links
maxij=min(ia,sum(bv)); %maximum number of unmatched links and also maximum total number of links
smaxij=min(sum(min([av;(sum(bv)-bv)],[],1)),sum(min([(sum(av)-av);bv],[],1))); %maximum number of unmatched links
%loop over r and s
if (opts.recurdebug>=3)
disp('ia');disp(ia);
disp('av,bv');disp([av;bv]);
disp('maxij,rmaxij,smaxij');disp([maxij,rmaxij,smaxij]);
end
for rs=0:maxij %loop over r+s
for r=max(0,rs-smaxij):min(rmaxij,rs) %loop over (r,s)
s=rs-r; %guarantee that r is in [0:rmaxij] and s is in [0:smaxij]
%do the calculation here
cwtent=Inf;
%last spike in a is unlinked
if (ia>0) cwtent=min(cwtent,cw(ia,jb+1,r+1,s+1)); end
%matched link case
if ((ia>0) & (bv(u)>0) & (r>0)) cwtent=min(cwtent,cw(ia,jbrec(u)+1,r,s+1)+abs(sa(ia)-sub{u}(bv(u)))); end
%unmatched link case
if ((ia>0) & (s>0))
for k=setdiff(find(bv>0),u)
cwtent=min(cwtent,cw(ia,jbrec(k)+1,r+1,s)+abs(sa(ia)-sub{k}(bv(k))));
end
end
%last spike in b is unlinked
for k=find(bv>0)
cwtent=min(cwtent,cw(ia+1,jbrec(k)+1,r+1,s+1));
end
if (cwtent=1) disp('space-optimized'); end
%the space-optimized recursion
bv_list=zeros(bsize,labs);
jbrec_list=zeros(bsize,labs);
for jb=0:(bsize-1)
%find the b-multi-index
if (labs==0)
bv=0;
jbrec=0;
else
for k=1:labs
bv(k)=mod(floor(jb/bdiv(k)),nub(k)+1); %lengths of current subtrains for b
end
bvm=repmat(bv,[labs 1])-eye(labs);
jbrec=bdiv*bvm'; %0-based the indices needed for the recursion
end
if (opts.recurdebug>=2)
disp('jb');disp(jb);
disp('bv');disp(bv);
disp('labs');disp(labs);
if (opts.recurdebug>=3);disp('bvm');disp(bvm);end
disp('jbrec');disp(jbrec);
end
bv_list(jb+1,:)=bv;
jbrec_list(jb+1,:)=jbrec;
end
%set up table to pull out cseq if needed, using logic of spkdallqk_trunc
c.time=unique([sa sb]);
for j=1:length(c.time);
tmax=c.time(j);
la_trunc=la(find(sa<=tmax));
lb_trunc=lb(find(sb<=tmax));
nua_trunc=zeros(1,labs);
nub_trunc=zeros(1,labs);
for k=1:labs
if length(la_trunc>0) nua_trunc(k)=length(find(la_trunc==lunique(k))); end
if length(lb_trunc>0) nub_trunc(k)=length(find(lb_trunc==lunique(k))); end
end
na_trunc=sum(nua_trunc);
nb_trunc=sum(nub_trunc);
rmax_trunc=sum(min([nua_trunc;nub_trunc],[],1));
smax_trunc=min(sum(min([nua_trunc;(nb_trunc-nub_trunc)],[],1)),sum(min([(na_trunc-nua_trunc);nub_trunc],[],1)));
bmulti=sum(nub_trunc.*bdiv);
c.na(j)=na_trunc;
c.bmulti(j)=bmulti;
c.rmax(j)=rmax_trunc;
c.smax(j)=smax_trunc;
c.clast=[]; %make sure that there is something in this field
end
nclast=0; %number of clast structures calculated
%
cs=repmat(NaN,csdims); % working array
cs(:,:,1,1)=0; %initialize: no length needed for no links
% cs(1,:,:,:): values already computed
% cs(2,:,:,:): values to be computed
%loop over the non-subdivided train (a)
av=zeros(1,labs); %will track how many spikes of which kind are in the first ia spikes of a
for ia=0:na
if (ia>0) u=la(ia); av(u)=av(u)+1; end % u is label of the ith spike
%loop over the subdivided train (b)
for jb=0:(bsize-1)
%look up the b-multi-index
bv=bv_list(jb+1,:);
jbrec=jbrec_list(jb+1,:);
%
rmaxij=sum(min([av;bv],[],1)); %maximum number of matched links
maxij=min(ia,sum(bv)); %maximum number of unmatched links and also maximum total number of links
smaxij=min(sum(min([av;(sum(bv)-bv)],[],1)),sum(min([(sum(av)-av);bv],[],1))); %maximum number of unmatched links
%loop over r and s
for rs=0:maxij %loop over r+s
for r=max(0,rs-smaxij):min(rmaxij,rs) %loop over (r,s)
s=rs-r; %guarantee that r is in [0:rmaxij] and s is in [0:smaxij]
%do the calculation here
cwtent=Inf;
%last spike in a is unlinked
if (ia>0) cwtent=min(cwtent,cs(1,jb+1,r+1,s+1)); end
%matched link case
if ((ia>0) & (bv(u)>0) & (r>0)) cwtent=min(cwtent,cs(1,jbrec(u)+1,r,s+1)+abs(sa(ia)-sub{u}(bv(u)))); end
%unmatched link case
if ((ia>0) & (s>0))
for k=setdiff(find(bv>0),u)
cwtent=min(cwtent,cs(1,jbrec(k)+1,r+1,s)+abs(sa(ia)-sub{k}(bv(k))));
end
end
%last spike in b is unlinked
for k=find(bv>0)
cwtent=min(cwtent,cs(2,jbrec(k)+1,r+1,s+1));
end
if (cwtent0)
for jval=find(c.na==ia)
nclast=nclast+1;
if (opts.recurdebug>=1);disp(sprintf(' making cseq with ia, jval,nclast= %d %d %d',ia,jval,nclast));end
rmax_trunc=c.rmax(jval);
smax_trunc=c.smax(jval);
c.clast{nclast}=reshape(...
cs(2,c.bmulti(jval)+1,1+[0:rmax_trunc],1+[0:smax_trunc]),[1+rmax_trunc 1+smax_trunc]);
end
end
%
if (ia0) disp(sua{k}); else disp(' '); end;
if (length(sub{k})>0) disp(sub{k}); else disp(' '); end;
end
disp(' rmax smax');disp(sprintf(' %4d %4d',rmax,smax));
disp('cdims');disp(cdims);
end
return
spkdallqk_final
function [dists,opts]=spkdallqk_final(qklist,sa,sb,clast,rmax,smax,optsin)
% [dists,opts]=spkdallqk_final(qklist,sa,sb,clast,rmax,smax,optsin) does the
% final portion of the Aronov multiunit algorithm, extended to simultaneous
% calculation for all values of q and k.
%
% qklist: a list of pairs of values of q and k. qklist(:,1): q-values; qklist(:,2): k-values
% sa, sb: spike times on the two spike trains
% clast, rmax, smax: calculated by spkdallqk_recur
% optsin: available for various debugging and optimizing options
%
% Copyright (c) by Jonathan Victor.
%
% See also SPKDALLQK_RECUR, SPKDALLQK_DIST, SPKD, LABDIST_FAST.
%
if (nargin<=6) optsin=[]; end
opts=optsin;
%
na=length(sa);
nb=length(sb);
nqk=size(qklist,1);
%
%make matrix of na+nb-2r-2s, indicating costs of deletions and insertions
%
nanbrs=(na+nb)*ones(1+rmax,1+smax)-2*repmat([0:rmax]',1,1+smax)-2*repmat([0:smax],1+rmax,1);
%
% find the best strategy (all r (matched links) and all s (mismatched links)
%
for iqk=1:nqk
posscosts=qklist(iqk,1)*clast+qklist(iqk,2)*repmat([0:smax],1+rmax,1)+nanbrs;
dists(iqk,1)=min(min(posscosts));
end
return
spkdallqk_dist
function [dists,clast,rmax,smax,opts]=spkdallqk_dist(qklist,sa,la,sb,lb,optsin)
% function [dists,clast,rmax,smax,opts]=spkdallqk_dist(qklist,sa,la,sb,lb,optsin) does the
% recursion and final portion of the Aronov multiunit algorithm, extended to simultaneous
% calculation for all values of q and k.
%
% qklist: a list of pairs of values of q and k. qklist(:,1): q-values; qklist(:,2): k-values
% sa, sb: spike times on the two spike trains
% la, lb: labels on the two spike trains
% clast, rmax, smax: calculated by spkdallqk_recur
% optsin: available for various debugging and optimizing options
%
% Copyright (c) by Jonathan Victor.
%
% See also SPKDALLQK_RECUR, SPKDALLQK_DIST.
%
if (nargin<=5) optsin=[]; end
%do the recursion
[clast,rmax,smax,opts,c]=spkdallqk_recur(sa,la,sb,lb,optsin);
%do the final stage
[dists,opts]=spkdallqk_final(qklist,sa,sb,clast,rmax,smax,optsin);
return
spkdallqk_trunc
function [clast_trunc,rmax_trunc,smax_trunc,na_trunc,bmulti_trunc]=spkdallqk_trunc(tmax,c,opts)
% [clast_trunc,rmax_trunc,smax_trunc,na_trunc,bmulti_trunc]=spkdallqk_trunc(tmax,c,opts)
% takes the matrix c as calculated by
% spkdallqk_recur for two labeled spike trains, and determines the slice clast_trunc that
% is required for calculating distances, if the spike trains are truncated at time tmax.
%
% CRUCIAL that opts is passed; this contains the original spike trains
% stored in opts.orig.sa, opts.orig.la, opts.orig.sb, opts.orig.lb.
% and the optimized ones that c was calculated for,
% stored in opts.used.sa, opts.used.la, opts.used.sb, opts.used.lb.
%
% c calculated by [clast,rmax,smax,optsused,c]=spkdallqk_recur(sa,la,sb,lb,opts);
% clast_trunc,rmax_trunc,smax_trunc should equal clast, rmax,smax IF tmax=Inf.
%
% c can also be the structure cseq, as returned by spkdallqk_truncseq
%
% rmax=max number of possible matches, sum(min(length(sa_k),length(sb_k)))
% smax=max number of possible mismatches
% c(m+1,n1+1,..., nL+1,r+1,s+1) is the shortest linkage connecting r pairs
% of spikes that match, and s pairs of spikes that do not match.
% clast=c(end,...,end,:,:) is the portion of c necessary for calculating distances,
% which is done in spkdallqk_final.
% size(clast)=[rmax+1 smax+1]
% size(clast_trunc)=[rmax_trunc+1 smax_trunc+1]
%
% See spkdallqk.doc.
%
% Copyright (c) by Jonathan Victor.
%
% See also SPKDALLQK_RECUR, SPKDALLQK_FINAL, SPKDALLQK_TRUNCSEQ.
%
% get optimized values
sa=opts.used.sa;
sb=opts.used.sb;
la=opts.used.la;
lb=opts.used.lb;
%
lunique=unique([la lb]);
labs=length(lunique);
if (labs==0)
clast_trunc=0;
rmax_trunc=0;
smax_trunc=0;
return
end
for k = 1:labs
nua(k)=length(find(la==lunique(k)));
nub(k)=length(find(lb==lunique(k)));
end
na=sum(nua);
nb=sum(nub);
rmax=sum(min([nua;nub],[],1));
smax=min(sum(min([nua;(nb-nub)],[],1)),sum(min([(na-nua);nub],[],1)));
bsize=prod(nub+1);
if (~isstruct(c))
cdims=[na+1 nub+1 rmax+1 smax+1]; %actual dimensions for c
cwdims=[na+1 bsize rmax+1 smax+1]; %working dimensions for cw
cw=reshape(c,cwdims);
end
%
%now find the truncated spike trains
la_trunc=la(find(sa<=tmax));
lb_trunc=lb(find(sb<=tmax));
%and find parameters for them
nua_trunc=zeros(1,labs);
nub_trunc=zeros(1,labs);
for k = 1:labs
if length(la_trunc>0) nua_trunc(k)=length(find(la_trunc==lunique(k))); end
if length(lb_trunc>0) nub_trunc(k)=length(find(lb_trunc==lunique(k))); end
end
na_trunc=sum(nua_trunc);
nb_trunc=sum(nub_trunc);
rmax_trunc=sum(min([nua_trunc;nub_trunc],[],1));
smax_trunc=min(sum(min([nua_trunc;(nb_trunc-nub_trunc)],[],1)),sum(min([(na_trunc-nua_trunc);nub_trunc],[],1)));
%set up divisors for de-indexing and re-indexing with respect to full spike train
for k=1:labs
bdiv(k)=prod(nub(1:k-1)+1);
end
bmulti_trunc=sum(nub_trunc.*bdiv);
%
%separate cases, depending on whether c is the full matrix, or a structure
% (full matrix returned by spkdallqk if spaceopt=0; structure if spaceopt=1)
%
if (~isstruct(c))
slice_trunc=reshape(cw(na_trunc+1,bmulti_trunc+1,:,:),[rmax+1 smax+1]);
clast_trunc=slice_trunc([1:rmax_trunc+1],[1:smax_trunc+1]);
else
j=max([0,find(tmax>=c.time)]);
if (j==0) clast_trunc=0;
else
clast_trunc=c.clast{j};
end
end
return