Matlab Code for Spike Time Distances Between Labeled (Multineuronal) Spike
Trains, Parallel in q and k
main cost-based metrics page
algorithm page for cost-based metrics
Matlab code for Multineuronal Spike Time Metric, Parallel in q and k, Optimized
There are five modules: spkdallqk_recur, spkdallqk_final, spkdallqk_dist, spkdallqk_trunc,
with the following relationships:
- spkdallqk_recur
carries out the core dynamic programming algorithm and produces
two output arrays, clast and c. clast applies to the entire spike trains; c can be
used (via spkdallqk_trunc) to calculate distances from truncated spike trains. Version 6.
- spkdallqk_recur_ctc
optimized version of the above; optimization by John Zollweg, at Cornell Theory Center.
Version 5. This is invoked by spkdallqk_recur if optsin.ctc=1
- spkdallqk_final
carries out the final step of the algorithm to calculate distances
from the array clast produced by spkdallqk_recur.
- spkdallqk_dist
calls spkdallqk_recur and then spkdallqk_final.
- spkdallqk_trunc
recalculates an array clast from truncated spike trains,
from the array c calculated by spkdallqk_recur.
spkdallqk_recur
function [clast,rmax,smax,opts,c]=spkdallqk_recur(sa,la,sb,lb,optsin)
% [clast,rmax,smax,opts,c]=spkdallqk_recur(sa,la,sb,lb,optsin) does the recursion for the Aronov
% multiunit algorithm, extended to simultaneous calculation for all values of q and k.
%
% this is v6: cachers option (does not save time)
%
% sa, sb: spike times on the two spike trains
% la, lb: spike labels (positive integers)
% optsin: available for various debugging and optimizing options
% optsin.recurdebug: show some debugging (1 or 2 levels) within the recursion
% optsin.showmethod: show method options
% optsin.switchab: exchange a and b if necessary to reduce total number of inner loops
% 0->disable, 1->enable (default), 2->enable and show, 3->force swap
% optsin.disjoint: merge labels in one train if the other train does not contain them
% 0->disable, 1->enable (default), 2->enable and show
% optsin.sizeonly: 1->clast returns as a scalar, indicating the size of the working matrix;
% c returns as a vector, indicating the dimensions of the working matrix
% optsin.spaceopt: 0-> original version (default), 1->space-optimized version (after Ifije Ohiorhenuan)
% optsin.ctc: 0-> original version (default), 1->use Cornell Theory Center version
% optsin.cachers: 0-> original version (default), 1-> cache the references to cw outside of the rs loop
% cachers=1 forces optsin.ctc and optsin.spaceopt to 0
%
% sa,la,sb,lb are potentially swapped and otherwise optimized before proceeding.
% original values are stored in opts.orig.sa, opts.orig.la, opts.orig.sb, opts.orig.lb.
% after optimization, they are stored in opts.used.sa, opts.used.la, opts.used.sb, opts.used.lb.
%
% rmax=max number of possible matches, sum(min(length(sa_k),length(sb_k)))
% smax=max number of possible mismatches
% c(m+1,n1+1,..., nL+1,r+1,s+1) is the shortest linkage connecting r pairs
% of spikes that match, and s pairs of spikes that do not match.
% clast=c(end,...,end,:,:) is the portion of c necessary for calculating distances,
% which is done in spkdallqk_final.
% size(clast)=[rmax+1 smax+1]
%
% See spkdallqk.doc.
%
% size(c)=[length(sa)+1 length(sb_1)+1 ... length(sb_L)+1 rmax+1 smax+1]
%
% Copyright (c) by Jonathan Victor, 2006.
%
% See also SPKDALLQK_FINAL, SPKDALLQK_DIST, SPKD, LABDIST_FAST, SPKDALLQ, RUNRECUR_MULTIPHA,
% SPKDALLQK_OPTS_DEMO.
%
if (nargin<=4) optsin=[]; end
opts=optsin;
if (~isfield(opts,'recurdebug')) opts.recurdebug=0; end
if (~isfield(opts,'switchab')) opts.switchab=1; end
if (~isfield(opts,'disjoint')) opts.disjoint=1; end
if (~isfield(opts,'sizeonly')) opts.sizeonly=0; end
if (~isfield(opts,'spaceopt')) opts.spaceopt=0; end
if (~isfield(opts,'cachers')) opts.cachers=0; end
if (~isfield(opts,'showmethod')) opts.showmethod=1; end
if (~isfield(opts,'ctc')) opts.ctc=0; end
if (opts.cachers)
opts.spaceopt=0;
opts.ctc=0;
end
if (opts.ctc==1)
if (opts.recurdebug>0) | (opts.showmethod>0)
disp('calling spkdallqk_recur_ctc from spkdallqk_recur');
end
[clast,rmax,smax,opts,c]=spkdallqk_recur_ctc(sa,la,sb,lb,optsin);
return
end
%save input values
opts.orig.sa=sa;opts.orig.sb=sb;opts.orig.la=la;opts.orig.lb=lb;
%
if opts.disjoint>=1
if (opts.disjoint>=2);disp('before disjoint optimization, la, lb=');disp(la);disp(lb);end
lbnota=setdiff(unique(lb),unique(la));
lanotb=setdiff(unique(la),unique(lb));
if (length(lbnota)>1);lb(find(ismember(lb,lbnota)))=lbnota(1);end
if (length(lanotb)>1);la(find(ismember(la,lanotb)))=lanotb(1);end
if (opts.disjoint>=2);disp(' after disjoint optimization, la, lb=');disp(la);disp(lb);end
end
%
[lurelabel,index,relabel] = unique([la lb]);
labs = prod(size(lurelabel));
%replace each label by 1, ..., length(lunique)
%disp('after replacement')
la=relabel(1:length(la));
lb=relabel(length(la)+1:end);
lunique=unique([la lb]);
nua=[];nub=[];
for k = 1:labs
sua{k}=sa(find(la==lunique(k)));
nua(k)=length(find(la==lunique(k)));
sub{k}=sb(find(lb==lunique(k)));
nub(k)=length(find(lb==lunique(k)));
end
na=sum(nua);
nb=sum(nub);
rmax=sum(min([nua;nub],[],1));
smax=min(sum(min([nua;(nb-nub)],[],1)),sum(min([(na-nua);nub],[],1)));
bsize=prod(nub+1);
if (opts.switchab>=1)
asize=prod(nua+1);
absize=(na+1)*bsize;
basize=(nb+1)*asize;
if (opts.switchab>=2)
disp(' na+1 asize, nb+1 bsize, absize, basize')
disp([na+1 asize nb+1 bsize absize basize]);
end
if ((absize>basize) | (opts.switchab>=3))
if (opts.switchab>=2) disp('a and b swapped'); end
lt=la;la=lb;lb=lt;
st=sa;sa=sb;sb=st;
nt=na;na=nb;nb=nt;
nut=nua;nua=nub;nub=nut;
tsize=asize;asize=bsize;bsize=tsize;
for k=1:labs
sut{k}=sua{k};sua{k}=sub{k};sub{k}=sut{k};
end
else
if (opts.switchab>=2) disp(' a and b not swapped.'); end
end
end
cdims=[na+1 nub+1 rmax+1 smax+1]; %actual dimensions for c
cwdims=[na+1 bsize rmax+1 smax+1]; %working dimensions for cw
csdims=[2 bsize rmax+1 smax+1]; %working dimensions for space optimized version
%save values used for later use
opts.used.sa=sa;opts.used.sb=sb;opts.used.la=la;opts.used.lb=lb;opts.used.cdims=cdims;opts.used.cwdims=cwdims;
%
if (opts.recurdebug)
disp('cwdims');disp(cwdims);
end
%if only sizes are requested, do this and return
if (opts.sizeonly)
if (opts.spaceopt==1)
c=csdims;
else
c=cwdims;
end
clast=prod(c);
return;
end
%initialize c for spaceopt
if (opts.spaceopt==1)
c.time=[];
c.na=[];
c.bmulti=[];
c.rmax=[];
c.smax=[];
c.clast=cell(0);
end
%check for trivial situations
if (min(na,nb)==0)
clast=0;
rmax=0;
smax=0;
if (opts.spaceopt==0);
c=zeros(cdims);
else
c.time=unique([sa sb]);
if (na==0)
c.na=zeros(1,nb);
c.bmulti=[1:nb];
end
if (nb==0)
c.na=[1:na];
c.bmulti=zeros(1,na);
end
c.rmax=zeros(1,max(na,nb));
c.smax=zeros(1,max(na,nb));
for j=1:max(na,nb)
c.clast{j}=0;
end
return
end
end
%set up divisors for de-indexing and re-indexing
bdiv=[];
for k=1:labs
bdiv(k)=prod(nub(1:k-1)+1);
end
if (opts.recurdebug)
disp('bdiv');disp(bdiv);
end
if (opts.spaceopt==0) & (opts.cachers==0)
if (opts.recurdebug>=1) | (opts.showmethod>0) disp('original (non-space-optimized)'); end
%the original recursion (non space-optimized)
cw=repmat(NaN,cwdims); % working array
cw(:,:,1,1)=0; %initialize: no length needed for no links
%loop over the subdivided train (b)
for jb=0:(bsize-1)
%find the b-multi-index
bv=[];
for k=1:labs
bv(k)=mod(floor(jb/bdiv(k)),nub(k)+1); %lengths of current subtrains for b
end
bvm=repmat(bv,[labs 1])-eye(labs);
jbrec=bdiv*bvm'; %0-based the indices needed for the recursion
if (opts.recurdebug>=2)
disp('jb');disp(jb);
disp('bv');disp(bv);
if (opts.recurdebug>=3);disp('bvm');disp(bvm);end
disp('jbrec');disp(jbrec);
end
%loop over the non-subdivided train (a)
av=zeros(1,labs); %will track how many spikes of which kind are in the first ia spikes of a
for ia=0:na
if (ia>0) u=la(ia); av(u)=av(u)+1; end % u is label of the ith spike
rmaxij=sum(min([av;bv],[],1)); %maximum number of matched links
maxij=min(ia,sum(bv)); %maximum number of unmatched links and also maximum total number of links
smaxij=min(sum(min([av;(sum(bv)-bv)],[],1)),sum(min([(sum(av)-av);bv],[],1))); %maximum number of unmatched links
%loop over r and s
if (opts.recurdebug>=3)
disp('ia');disp(ia);
disp('av,bv');disp([av;bv]);
disp('maxij,rmaxij,smaxij');disp([maxij,rmaxij,smaxij]);
end
for rs=0:maxij %loop over r+s
for r=max(0,rs-smaxij):min(rmaxij,rs) %loop over (r,s)
s=rs-r; %guarantee that r is in [0:rmaxij] and s is in [0:smaxij]
%do the calculation here
cwtent=Inf;
%last spike in a is unlinked
if (ia>0) cwtent=min(cwtent,cw(ia,jb+1,r+1,s+1)); end
%matched link case
if ((ia>0) & (bv(u)>0) & (r>0)) cwtent=min(cwtent,cw(ia,jbrec(u)+1,r,s+1)+abs(sa(ia)-sub{u}(bv(u)))); end
%unmatched link case
if ((ia>0) & (s>0))
for k=setdiff(find(bv>0),u)
cwtent=min(cwtent,cw(ia,jbrec(k)+1,r+1,s)+abs(sa(ia)-sub{k}(bv(k))));
end
end
%last spike in b is unlinked
for k=find(bv>0)
cwtent=min(cwtent,cw(ia+1,jbrec(k)+1,r+1,s+1));
end
if (cwtent=1) | (opts.showmethod>0) disp('original (non-space-optimized) with cachers=1'); end
cwt_init=repmat(NaN,rmax+1,smax+1);
cwt_init(1,1)=0;
%the original recursion (non space-optimized)
cw=repmat(NaN,cwdims); % working array
cw(:,:,1,1)=0; %initialize: no length needed for no links
%loop over the subdivided train (b)
for jb=0:(bsize-1)
%find the b-multi-index
bv=[];
for k=1:labs
bv(k)=mod(floor(jb/bdiv(k)),nub(k)+1); %lengths of current subtrains for b
end
bvm=repmat(bv,[labs 1])-eye(labs);
jbrec=bdiv*bvm'; %0-based the indices needed for the recursion
if (opts.recurdebug>=2)
disp('jb');disp(jb);
disp('bv');disp(bv);
if (opts.recurdebug>=3);disp('bvm');disp(bvm);end
disp('jbrec');disp(jbrec);
end
%loop over the non-subdivided train (a)
av=zeros(1,labs); %will track how many spikes of which kind are in the first ia spikes of a
for ia=0:na
if (ia>0) u=la(ia); av(u)=av(u)+1; end % u is label of the ith spike
rmaxij=sum(min([av;bv],[],1)); %maximum number of matched links
maxij=min(ia,sum(bv)); %maximum number of unmatched links and also maximum total number of links
smaxij=min(sum(min([av;(sum(bv)-bv)],[],1)),sum(min([(sum(av)-av);bv],[],1))); %maximum number of unmatched links
%loop over r and s
if (opts.recurdebug>=3)
disp('ia');disp(ia);
disp('av,bv');disp([av;bv]);
disp('maxij,rmaxij,smaxij');disp([maxij,rmaxij,smaxij]);
end
for k=1:labs
if (jbrec(k)>=0)
if (ia>0) cwb{k}=reshape(cw(ia,jbrec(k)+1,:,:),rmax+1,smax+1); end
cwp{k}=reshape(cw(ia+1,jbrec(k)+1,:,:),rmax+1,smax+1);
end
end
cwa=[];
if (ia>0) cwa=reshape(cw(ia,jb+1,:,:),rmax+1,smax+1); end
cwt=cwt_init;
for rs=0:maxij %loop over r+s
for r=max(0,rs-smaxij):min(rmaxij,rs) %loop over (r,s)
s=rs-r; %guarantee that r is in [0:rmaxij] and s is in [0:smaxij]
%do the calculation here
cwtent=Inf;
%last spike in a is unlinked
if (ia>0) cwtent=min(cwtent,cwa(r+1,s+1)); end
%matched link case
if ((ia>0) & (bv(u)>0) & (r>0)) cwtent=min(cwtent,cwb{u}(r,s+1)+abs(sa(ia)-sub{u}(bv(u)))); end
%unmatched link case
if ((ia>0) & (s>0))
for k=setdiff(find(bv>0),u)
cwtent=min(cwtent,cwb{k}(r+1,s)+abs(sa(ia)-sub{k}(bv(k))));
end
end
%last spike in b is unlinked
for k=find(bv>0)
cwtent=min(cwtent,cwp{k}(r+1,s+1));
end
if (cwtent< Inf) cwt(r+1,s+1)=cwtent; end
end %loop over (r,s)
end %loop over r+s
cw(ia+1,jb+1,:,:)=reshape(cwt,[1 1 rmax+1 smax+1]);
end %loop over ia
end %loop over jb
clast=reshape(cw(na+1,bsize,:,:),[rmax+1 smax+1]);
c=reshape(cw,cdims);
end
if (opts.spaceopt==1)
if (opts.recurdebug>=1) | (opts.showmethod>0) disp('space-optimized'); end
%the space-optimized recursion
bv_list=zeros(bsize,labs);
jbrec_list=zeros(bsize,labs);
for jb=0:(bsize-1)
%find the b-multi-index
if (labs==0)
bv=0;
jbrec=0;
else
for k=1:labs
bv(k)=mod(floor(jb/bdiv(k)),nub(k)+1); %lengths of current subtrains for b
end
bvm=repmat(bv,[labs 1])-eye(labs);
jbrec=bdiv*bvm'; %0-based the indices needed for the recursion
end
if (opts.recurdebug>=2)
disp('jb');disp(jb);
disp('bv');disp(bv);
disp('labs');disp(labs);
if (opts.recurdebug>=3);disp('bvm');disp(bvm);end
disp('jbrec');disp(jbrec);
end
bv_list(jb+1,:)=bv;
jbrec_list(jb+1,:)=jbrec;
end
%set up table to pull out cseq if needed, using logic of spkdallqk_trunc
c.time=unique([sa sb]);
for j=1:length(c.time);
tmax=c.time(j);
la_trunc=la(find(sa<=tmax));
lb_trunc=lb(find(sb<=tmax));
nua_trunc=zeros(1,labs);
nub_trunc=zeros(1,labs);
for k=1:labs
if length(la_trunc>0) nua_trunc(k)=length(find(la_trunc==lunique(k))); end
if length(lb_trunc>0) nub_trunc(k)=length(find(lb_trunc==lunique(k))); end
end
na_trunc=sum(nua_trunc);
nb_trunc=sum(nub_trunc);
rmax_trunc=sum(min([nua_trunc;nub_trunc],[],1));
smax_trunc=min(sum(min([nua_trunc;(nb_trunc-nub_trunc)],[],1)),sum(min([(na_trunc-nua_trunc);nub_trunc],[],1)));
bmulti=sum(nub_trunc.*bdiv);
c.na(j)=na_trunc;
c.bmulti(j)=bmulti;
c.rmax(j)=rmax_trunc;
c.smax(j)=smax_trunc;
c.clast=[]; %make sure that there is something in this field
end
nclast=0; %number of clast structures calculated
%
cs=repmat(NaN,csdims); % working array
cs(:,:,1,1)=0; %initialize: no length needed for no links
% cs(1,:,:,:): values already computed
% cs(2,:,:,:): values to be computed
%loop over the non-subdivided train (a)
av=zeros(1,labs); %will track how many spikes of which kind are in the first ia spikes of a
for ia=0:na
if (ia>0) u=la(ia); av(u)=av(u)+1; end % u is label of the ith spike
%loop over the subdivided train (b)
for jb=0:(bsize-1)
%look up the b-multi-index
bv=bv_list(jb+1,:);
jbrec=jbrec_list(jb+1,:);
%
rmaxij=sum(min([av;bv],[],1)); %maximum number of matched links
maxij=min(ia,sum(bv)); %maximum number of unmatched links and also maximum total number of links
smaxij=min(sum(min([av;(sum(bv)-bv)],[],1)),sum(min([(sum(av)-av);bv],[],1))); %maximum number of unmatched links
%loop over r and s
for rs=0:maxij %loop over r+s
for r=max(0,rs-smaxij):min(rmaxij,rs) %loop over (r,s)
s=rs-r; %guarantee that r is in [0:rmaxij] and s is in [0:smaxij]
%do the calculation here
cwtent=Inf;
%last spike in a is unlinked
if (ia>0) cwtent=min(cwtent,cs(1,jb+1,r+1,s+1)); end
%matched link case
if ((ia>0) & (bv(u)>0) & (r>0)) cwtent=min(cwtent,cs(1,jbrec(u)+1,r,s+1)+abs(sa(ia)-sub{u}(bv(u)))); end
%unmatched link case
if ((ia>0) & (s>0))
for k=setdiff(find(bv>0),u)
cwtent=min(cwtent,cs(1,jbrec(k)+1,r+1,s)+abs(sa(ia)-sub{k}(bv(k))));
end
end
%last spike in b is unlinked
for k=find(bv>0)
cwtent=min(cwtent,cs(2,jbrec(k)+1,r+1,s+1));
end
if (cwtent0)
for jval=find(c.na==ia)
nclast=nclast+1;
if (opts.recurdebug>=1);disp(sprintf(' making cseq with ia, jval,nclast= %d %d %d',ia,jval,nclast));end
rmax_trunc=c.rmax(jval);
smax_trunc=c.smax(jval);
c.clast{nclast}=reshape(...
cs(2,c.bmulti(jval)+1,1+[0:rmax_trunc],1+[0:smax_trunc]),[1+rmax_trunc 1+smax_trunc]);
end
end
%
if (ia0) disp(sua{k}); else disp(' '); end;
if (length(sub{k})>0) disp(sub{k}); else disp(' '); end;
end
disp(' rmax smax');disp(sprintf(' %4d %4d',rmax,smax));
disp('cdims');disp(cdims);
end
return
spkdallqk_recur_ctc
function [clast,rmax,smax,opts,c]=spkdallqk_recur_ctc(sa,la,sb,lb,optsin)
% [clast,rmax,smax,opts,c]=spkdallqk_recur_ctc(sa,la,sb,lb,optsin) does the recursion for the Aronov
% multiunit algorithm, extended to simultaneous calculation for all values of q and k.
%
% spkdallqk_recur_ctc_v5 derived from spkdallqk_recur_ctc_v4
% this is spkdallqk_recur_ctc, optimized by John Zollweg at Cornell Theory Center zollweg@tc.cornell.edu
% with trivial changes (numel->prod(size(, and &&->&, ||->|) for backward compatibility for matlab5
%
% sa, sb: spike times on the two spike trains
% la, lb: spike labels (positive integers)
% optsin: available for various debugging and optimizing options
% optsin.recurdebug: show some debugging (1 or 2 levels) within the recursion
% optsin.switchab: exchange a and b if necessary to reduce total number of inner loops
% 0->disable, 1->enable (default), 2->enable and show, 3->force swap
% optsin.disjoint: merge labels in one train if the other train does not contain them
% 0->disable, 1->enable (default), 2->enable and show
% optsin.sizeonly: 1->clast returns as a scalar, indicating the size of the working matrix;
% c returns as a vector, indicating the dimensions of the working matrix
% optsin.spaceopt: 0-> original version (default), 1->space-optimized version (after Ifije Ohiorhenuan)
%
% sa,la,sb,lb are potentially swapped and otherwise optimized before proceeding.
% original values are stored in opts.orig.sa, opts.orig.la, opts.orig.sb, opts.orig.lb.
% after optimization, they are stored in opts.used.sa, opts.used.la, opts.used.sb, opts.used.lb.
%
%
% rmax=max number of possible matches, sum(min(length(sa_k),length(sb_k)))
% smax=max number of possible mismatches
% c(m+1,n1+1,..., nL+1,r+1,s+1) is the shortest linkage connecting r pairs
% of spikes that match, and s pairs of spikes that do not match.
% clast=c(end,...,end,:,:) is the portion of c necessary for calculating distances,
% which is done in spkdallqk_final.
% size(clast)=[rmax+1 smax+1]
%
% See spkdallqk.doc.
%
% size(c)=[length(sa)+1 length(sb_1)+1 ... length(sb_L)+1 rmax+1 smax+1]
%
% Copyright (c) by Jonathan Victor.
%
% See also SPKDALLQK_FINAL, SPKDALLQK_DIST, SPKD, LABDIST_FAST, SPKDALLQ, RUNRECUR_MULTIPHA.
%
if (nargin<=4), optsin=[]; end;
opts=optsin;
if (~isfield(opts,'recurdebug')), opts.recurdebug=0; end;
if (~isfield(opts,'switchab')), opts.switchab=1; end;
if (~isfield(opts,'disjoint')), opts.disjoint=1; end;
if (~isfield(opts,'sizeonly')), opts.sizeonly=0; end;
if (~isfield(opts,'spaceopt')), opts.spaceopt=0; end;
%save input values
opts.orig.sa=sa;opts.orig.sb=sb;opts.orig.la=la;opts.orig.lb=lb;
%
if opts.disjoint>=1
if (opts.disjoint>=2);disp('before disjoint optimization, la, lb=');disp(la);disp(lb);end
lbnota=setdiff(unique(lb),unique(la));
lanotb=setdiff(unique(la),unique(lb));
if (length(lbnota)>1);lb(find(ismember(lb,lbnota)))=lbnota(1);end
if (length(lanotb)>1);la(find(ismember(la,lanotb)))=lanotb(1);end
if (opts.disjoint>=2);disp(' after disjoint optimization, la, lb=');disp(la);disp(lb);end
end
%
[lurelabel,index,relabel] = unique([la lb]);
labs = prod(size(lurelabel));
%replace each label by 1, ..., length(lunique)
%disp('after replacement')
la=relabel(1:length(la));
lb=relabel(length(la)+1:end);
lunique=unique([la lb]);
nua=[];nub=[];
for k = 1:labs
sua{k}=sa(find(la==lunique(k)));
nua(k)=length(find(la==lunique(k)));
sub{k}=sb(find(lb==lunique(k)));
nub(k)=length(find(lb==lunique(k)));
end
na=sum(nua);
nb=sum(nub);
rmax=sum(min([nua;nub],[],1));
smax=min(sum(min([nua;(nb-nub)],[],1)),sum(min([(na-nua);nub],[],1)));
bsize=prod(nub+1);
if (opts.switchab>=1)
asize=prod(nua+1);
absize=(na+1)*bsize;
basize=(nb+1)*asize;
if (opts.switchab>=2)
disp(' na+1 asize, nb+1 bsize, absize, basize')
disp([na+1 asize nb+1 bsize absize basize]);
end
if ((absize>basize) | (opts.switchab>=3))
if (opts.switchab>=2), disp('a and b swapped'); end;
lt=la;la=lb;lb=lt;
st=sa;sa=sb;sb=st;
nt=na;na=nb;nb=nt;
nut=nua;nua=nub;nub=nut;
tsize=asize;asize=bsize;bsize=tsize;
for k=1:labs
sut{k}=sua{k};sua{k}=sub{k};sub{k}=sut{k};
end
else
if (opts.switchab>=2), disp(' a and b not swapped.'); end;
end
end
cdims=[na+1 nub+1 rmax+1 smax+1]; %actual dimensions for c
cwdims=[na+1 bsize rmax+1 smax+1]; %working dimensions for cw
csdims=[bsize rmax+1 smax+1 2]; %working dimensions for space optimized version
%save values used for later use
opts.used.sa=sa;opts.used.sb=sb;opts.used.la=la;opts.used.lb=lb;opts.used.cdims=cdims;opts.used.cwdims=cwdims;
%
if (opts.recurdebug)
disp('cwdims');disp(cwdims);
end
%if only sizes are requested, do this and return
if (opts.sizeonly)
if (opts.spaceopt==1)
c=csdims;
else
c=cwdims;
end
clast=prod(c);
return;
end
%initialize c for spaceopt
if (opts.spaceopt==1)
c.time=[];
c.na=[];
c.bmulti=[];
c.rmax=[];
c.smax=[];
c.clast=cell(0);
end
%check for trivial situations
if (min(na,nb)==0)
clast=0;
rmax=0;
smax=0;
if (opts.spaceopt==0);
c=zeros(cdims);
else
c.time=unique([sa sb]);
if (na==0)
c.na=zeros(1,nb);
c.bmulti=[1:nb];
end
if (nb==0)
c.na=[1:na];
c.bmulti=zeros(1,na);
end
c.rmax=zeros(1,max(na,nb));
c.smax=zeros(1,max(na,nb));
for j=1:max(na,nb)
c.clast{j}=0;
end
return
end
end
%set up divisors for de-indexing and re-indexing
bdiv=[];
for k=1:labs
bdiv(k)=prod(nub(1:k-1)+1);
end
if (opts.recurdebug)
disp('bdiv');disp(bdiv);
end
if (opts.spaceopt==0)
if (opts.recurdebug>=1), disp('original (non-space-optimized)'); end;
%the original recursion (non space-optimized)
cw=repmat(NaN,cwdims); % working array
cw(:,:,1,1)=0; %initialize: no length needed for no links
%loop over the subdivided train (b)
for jb=0:(bsize-1)
%find the b-multi-index
bv=[];
for k=1:labs
bv(k)=mod(floor(jb/bdiv(k)),nub(k)+1); %lengths of current subtrains for b
end
bvm=repmat(bv,[labs 1])-eye(labs);
jbrec=bdiv*bvm'; %0-based the indices needed for the recursion
if (opts.recurdebug>=2)
disp('jb');disp(jb);
disp('bv');disp(bv);
if (opts.recurdebug>=3);disp('bvm');disp(bvm);end;
disp('jbrec');disp(jbrec);
end;
%loop over the non-subdivided train (a)
av=zeros(1,labs); %will track how many spikes of which kind are in the first ia spikes of a
for ia=0:na
if (ia>0), u=la(ia); av(u)=av(u)+1; end; % u is label of the ith spike
rmaxij=sum(min([av;bv],[],1)); %maximum number of matched links
maxij=min(ia,sum(bv)); %maximum number of unmatched links and also maximum total number of links
smaxij=min(sum(min([av;(sum(bv)-bv)],[],1)),sum(min([(sum(av)-av);bv],[],1))); %maximum number of unmatched links
%loop over r and s
if (opts.recurdebug>=3)
disp('ia');disp(ia);
disp('av,bv');disp([av;bv]);
disp('maxij,rmaxij,smaxij');disp([maxij,rmaxij,smaxij]);
end;
for rs=0:maxij %loop over r+s
for r=max(0,rs-smaxij):min(rmaxij,rs) %loop over (r,s)
s=rs-r; %guarantee that r is in [0:rmaxij] and s is in [0:smaxij]
%do the calculation here
cwtent=Inf;
%last spike in a is unlinked
if (ia>0), cwtent=min(cwtent,cw(ia,jb+1,r+1,s+1)); end;
%matched link case
if (ia>0)
if (bv(u)>0)
if (r>0), cwtent=min(cwtent,cw(ia,jbrec(u)+1,r,s+1)+abs(sa(ia)-sub{u}(bv(u)))); end;
end
end
%unmatched link case
if (ia>0)
if (s>0)
for k=setdiff(find(bv>0),u)
cwtent=min(cwtent,cw(ia,jbrec(k)+1,r+1,s)+abs(sa(ia)-sub{k}(bv(k))));
end;
end
end;
%last spike in b is unlinked
for k=find(bv>0)
cwtent=min(cwtent,cw(ia+1,jbrec(k)+1,r+1,s+1));
end;
if (cwtent=1), disp('space-optimized'); end;
%the space-optimized recursion
bv_list=zeros(labs,bsize);
jbrec_list=zeros(labs,bsize);
for jb=0:(bsize-1)
%find the b-multi-index
if (labs==0)
bv=0;
jbrec=0;
else
for k=1:labs
bv(k)=mod(floor(jb/bdiv(k)),nub(k)+1); %lengths of current subtrains for b
end;
bvm=repmat(bv,[labs 1])-eye(labs);
jbrec=bdiv*bvm'; %0-based the indices needed for the recursion
end;
if (opts.recurdebug>=2)
disp('jb');disp(jb);
disp('bv');disp(bv);
disp('labs');disp(labs);
if (opts.recurdebug>=3);disp('bvm');disp(bvm);end;
disp('jbrec');disp(jbrec);
end;
bv_list(:,jb+1)=bv';
jbrec_list(:,jb+1)=jbrec';
end;
%set up table to pull out cseq if needed, using logic of spkdallqk_trunc
c.time=unique([sa sb]);
for j=1:length(c.time);
tmax=c.time(j);
la_trunc=la(find(sa<=tmax));
lb_trunc=lb(find(sb<=tmax));
nua_trunc=zeros(1,labs);
nub_trunc=zeros(1,labs);
for k=1:labs
if length(la_trunc>0), nua_trunc(k)=length(find(la_trunc==lunique(k))); end;
if length(lb_trunc>0), nub_trunc(k)=length(find(lb_trunc==lunique(k))); end;
end;
na_trunc=sum(nua_trunc);
nb_trunc=sum(nub_trunc);
rmax_trunc=sum(min([nua_trunc;nub_trunc],[],1));
smax_trunc=min(sum(min([nua_trunc;(nb_trunc-nub_trunc)],[],1)),sum(min([(na_trunc-nua_trunc);nub_trunc],[],1)));
bmulti=sum(nub_trunc.*bdiv);
c.na(j)=na_trunc;
c.bmulti(j)=bmulti;
c.rmax(j)=rmax_trunc;
c.smax(j)=smax_trunc;
c.clast=[]; %make sure that there is something in this field
end;
nclast=0; %number of clast structures calculated
%
cs=repmat(NaN,csdims); % working array
cs(:,1,1,:)=0; %initialize: no length needed for no links
% cs(:,:,:,1): values already computed
% cs(:,:,:,2): values to be computed
%loop over the non-subdivided train (a)
av=zeros(1,labs); %will track how many spikes of which kind are in the first ia spikes of a
%prologue -- ia=0
%loop over the subdivided train (b)
for jb=0:(bsize-1)
%look up the b-multi-index
bv=bv_list(:,jb+1)';
jbrec=jbrec_list(:,jb+1)';
bvl=find(bv>0);
%do the calculation here
cwtent=Inf;
%last spike in b is unlinked
for k=bvl
cwtent=min(cwtent, cs(jbrec(k)+1,1,1,2));
end;
if (cwtent0)
for jval=find(c.na==0)
nclast=nclast+1;
if (opts.recurdebug>=1);disp(sprintf(' making cseq with ia, jval,nclast= %d %d %d',0,jval,nclast));end;
rmax_trunc=c.rmax(jval);
smax_trunc=c.smax(jval);
c.clast{nclast}=reshape(cs(c.bmulti(jval)+1,1:rmax_trunc+1,1:smax_trunc+1,2),rmax_trunc+1,smax_trunc+1);
end;
end;
%
if (00);
klist=bvl(~ismembc(bvl,u));
bvu=bv(u);
if (bvu>0)
incr=abs(sa(ia)-sub{u}(bv(u)));
jbrcp=jbrec(u)+1;
else
incr=Inf;
jbrcp=1;
end;
for k=klist
kincr(k)=abs(sa(ia)-sub{k}(bv(k)));
end;
rmaxij=sum(min([av;bv],[],1)); %maximum number of matched links
maxij=min(ia,sum(bv)); %maximum number of unmatched links and also maximum total number of links
smaxij=min(sum(min([av;(sum(bv)-bv)],[],1)),sum(min([(sum(av)-av);bv],[],1))); %maximum number of unmatched links
%loop over r and s
for rs=0:maxij %loop over r+s
for r=max(0,rs-smaxij):min(rmaxij,rs) %loop over (r,s)
s=rs-r; %guarantee that r is in [0:rmaxij] and s is in [0:smaxij]
%do the calculation here
cwtent=Inf;
%last spike in a is unlinked
cwtent=min(cwtent,cs(jb+1,r+1,s+1,1));
%matched link case
if (r>0), cwtent=min(cwtent,cs(jbrcp,r,s+1,1)+incr); end;
%unmatched link case
if (s>0)
for k=klist
cwtent=min(cwtent,cs(jbrec(k)+1,r+1,s,1)+kincr(k));
end;
end;
%last spike in b is unlinked
for k=bvl
cwtent=min(cwtent, cs(jbrec(k)+1,r+1,s+1,2));
end;
if (cwtent0)
for jval=find(c.na==ia)
nclast=nclast+1;
if (opts.recurdebug>=1);disp(sprintf(' making cseq with ia, jval,nclast= %d %d %d',ia,jval,nclast));end;
rmax_trunc=c.rmax(jval);
smax_trunc=c.smax(jval);
c.clast{nclast}=reshape(cs(c.bmulti(jval)+1,1:rmax_trunc+1,1:smax_trunc+1,2),rmax_trunc+1,smax_trunc+1);
end;
end;
%
if (ia0), disp(sua{k}); else disp(' '); end;
if (length(sub{k})>0), disp(sub{k}); else disp(' '); end;
end
disp(' rmax smax');disp(sprintf(' %4d %4d',rmax,smax));
disp('cdims');disp(cdims);
end
return
spkdallqk_final
function [dists,opts]=spkdallqk_final(qklist,sa,sb,clast,rmax,smax,optsin)
% [dists,opts]=spkdallqk_final(qklist,sa,sb,clast,rmax,smax,optsin) does the
% final portion of the Aronov multiunit algorithm, extended to simultaneous
% calculation for all values of q and k.
%
% qklist: a list of pairs of values of q and k. qklist(:,1): q-values; qklist(:,2): k-values
% sa, sb: spike times on the two spike trains
% clast, rmax, smax: calculated by spkdallqk_recur
% optsin: available for various debugging and optimizing options
%
% Copyright (c) by Jonathan Victor.
%
% See also SPKDALLQK_RECUR, SPKDALLQK_DIST, SPKD, LABDIST_FAST.
%
if (nargin<=6) optsin=[]; end
opts=optsin;
%
na=length(sa);
nb=length(sb);
nqk=size(qklist,1);
%
%make matrix of na+nb-2r-2s, indicating costs of deletions and insertions
%
nanbrs=(na+nb)*ones(1+rmax,1+smax)-2*repmat([0:rmax]',1,1+smax)-2*repmat([0:smax],1+rmax,1);
%
% find the best strategy (all r (matched links) and all s (mismatched links)
%
for iqk=1:nqk
posscosts=qklist(iqk,1)*clast+qklist(iqk,2)*repmat([0:smax],1+rmax,1)+nanbrs;
dists(iqk,1)=min(min(posscosts));
end
return
spkdallqk_dist
function [dists,clast,rmax,smax,opts]=spkdallqk_dist(qklist,sa,la,sb,lb,optsin)
% function [dists,clast,rmax,smax,opts]=spkdallqk_dist(qklist,sa,la,sb,lb,optsin) does the
% recursion and final portion of the Aronov multiunit algorithm, extended to simultaneous
% calculation for all values of q and k.
%
% qklist: a list of pairs of values of q and k. qklist(:,1): q-values; qklist(:,2): k-values
% sa, sb: spike times on the two spike trains
% la, lb: labels on the two spike trains
% clast, rmax, smax: calculated by spkdallqk_recur
% optsin: available for various debugging and optimizing options
%
% Copyright (c) by Jonathan Victor.
%
% See also SPKDALLQK_RECUR, SPKDALLQK_DIST.
%
if (nargin<=5) optsin=[]; end
%do the recursion
[clast,rmax,smax,opts,c]=spkdallqk_recur(sa,la,sb,lb,optsin);
%do the final stage
[dists,opts]=spkdallqk_final(qklist,sa,sb,clast,rmax,smax,optsin);
return
spkdallqk_trunc
function [clast_trunc,rmax_trunc,smax_trunc,na_trunc,bmulti_trunc]=spkdallqk_trunc(tmax,c,opts)
% [clast_trunc,rmax_trunc,smax_trunc,na_trunc,bmulti_trunc]=spkdallqk_trunc(tmax,c,opts)
% takes the matrix c as calculated by
% spkdallqk_recur for two labeled spike trains, and determines the slice clast_trunc that
% is required for calculating distances, if the spike trains are truncated at time tmax.
%
% CRUCIAL that opts is passed; this contains the original spike trains
% stored in opts.orig.sa, opts.orig.la, opts.orig.sb, opts.orig.lb.
% and the optimized ones that c was calculated for,
% stored in opts.used.sa, opts.used.la, opts.used.sb, opts.used.lb.
%
% c calculated by [clast,rmax,smax,optsused,c]=spkdallqk_recur(sa,la,sb,lb,opts);
% clast_trunc,rmax_trunc,smax_trunc should equal clast, rmax,smax IF tmax=Inf.
%
% c can also be the structure cseq, as returned by spkdallqk_truncseq
%
% rmax=max number of possible matches, sum(min(length(sa_k),length(sb_k)))
% smax=max number of possible mismatches
% c(m+1,n1+1,..., nL+1,r+1,s+1) is the shortest linkage connecting r pairs
% of spikes that match, and s pairs of spikes that do not match.
% clast=c(end,...,end,:,:) is the portion of c necessary for calculating distances,
% which is done in spkdallqk_final.
% size(clast)=[rmax+1 smax+1]
% size(clast_trunc)=[rmax_trunc+1 smax_trunc+1]
%
% See spkdallqk.doc.
%
% Copyright (c) by Jonathan Victor.
%
% See also SPKDALLQK_RECUR, SPKDALLQK_FINAL, SPKDALLQK_TRUNCSEQ.
%
% get optimized values
sa=opts.used.sa;
sb=opts.used.sb;
la=opts.used.la;
lb=opts.used.lb;
%
lunique=unique([la lb]);
labs=length(lunique);
if (labs==0)
clast_trunc=0;
rmax_trunc=0;
smax_trunc=0;
return
end
for k = 1:labs
nua(k)=length(find(la==lunique(k)));
nub(k)=length(find(lb==lunique(k)));
end
na=sum(nua);
nb=sum(nub);
rmax=sum(min([nua;nub],[],1));
smax=min(sum(min([nua;(nb-nub)],[],1)),sum(min([(na-nua);nub],[],1)));
bsize=prod(nub+1);
if (~isstruct(c))
cdims=[na+1 nub+1 rmax+1 smax+1]; %actual dimensions for c
cwdims=[na+1 bsize rmax+1 smax+1]; %working dimensions for cw
cw=reshape(c,cwdims);
end
%
%now find the truncated spike trains
la_trunc=la(find(sa<=tmax));
lb_trunc=lb(find(sb<=tmax));
%and find parameters for them
nua_trunc=zeros(1,labs);
nub_trunc=zeros(1,labs);
for k = 1:labs
if length(la_trunc>0) nua_trunc(k)=length(find(la_trunc==lunique(k))); end
if length(lb_trunc>0) nub_trunc(k)=length(find(lb_trunc==lunique(k))); end
end
na_trunc=sum(nua_trunc);
nb_trunc=sum(nub_trunc);
rmax_trunc=sum(min([nua_trunc;nub_trunc],[],1));
smax_trunc=min(sum(min([nua_trunc;(nb_trunc-nub_trunc)],[],1)),sum(min([(na_trunc-nua_trunc);nub_trunc],[],1)));
%set up divisors for de-indexing and re-indexing with respect to full spike train
for k=1:labs
bdiv(k)=prod(nub(1:k-1)+1);
end
bmulti_trunc=sum(nub_trunc.*bdiv);
%
%separate cases, depending on whether c is the full matrix, or a structure
% (full matrix returned by spkdallqk if spaceopt=0; structure if spaceopt=1)
%
if (~isstruct(c))
slice_trunc=reshape(cw(na_trunc+1,bmulti_trunc+1,:,:),[rmax+1 smax+1]);
clast_trunc=slice_trunc([1:rmax_trunc+1],[1:smax_trunc+1]);
else
j=max([0,find(tmax>=c.time)]);
if (j==0) clast_trunc=0;
else
clast_trunc=c.clast{j};
end
end
return