Matlab Code for Spike Time Distances Between Labeled (Multineuronal) Spike Trains, Parallel in q and k

Dmitriy Aronov: da2006@columbia.edu

Jonathan Victor: jdvicto@med.cornell.edu


main cost-based metrics page
algorithm page for cost-based metrics

Matlab code for Multineuronal Spike Time Metric, Parallel in q and k

There are four modules: spkdallqk_recur, spkdallqk_final, spkdallqk_dist, spkdallqk_trunc, with the following relationships:

spkdallqk_recur

function [clast,rmax,smax,opts,c]=spkdallqk_recur(sa,la,sb,lb,optsin) % [clast,rmax,smax,opts,c]=spkdallqk_recur(sa,la,sb,lb,optsin) does the recursion for the Aronov % multiunit algorithm, extended to simultaneous calculation for all values of q and k. % % sa, sb: spike times on the two spike trains % la, lb: spike labels (positive integers) % optsin: available for various debugging and optimizing options % optsin.recurdebug: show some debugging (1 or 2 levels) within the recursion % optsin.switchab: exchange a and b if necessary to reduce total number of inner loops % 0->disable, 1->enable (default), 2->enable and show, 3->force swap % optsin.disjoint: merge labels in one train if the other train does not contain them % 0->disable, 1->enable (default), 2->enable and show % optsin.sizeonly: 1->clast returns as a scalar, indicating the size of the working matrix; % c returns as a vector, indicating the dimensions of the working matrix % optsin.spaceopt: 0-> original version (default), 1->space-optimized version (after Ifije Ohiorhenuan) % % sa,la,sb,lb are potentially swapped and otherwise optimized before proceeding. % original values are stored in opts.orig.sa, opts.orig.la, opts.orig.sb, opts.orig.lb. % after optimization, they are stored in opts.used.sa, opts.used.la, opts.used.sb, opts.used.lb. % % % rmax=max number of possible matches, sum(min(length(sa_k),length(sb_k))) % smax=max number of possible mismatches % c(m+1,n1+1,..., nL+1,r+1,s+1) is the shortest linkage connecting r pairs % of spikes that match, and s pairs of spikes that do not match. % clast=c(end,...,end,:,:) is the portion of c necessary for calculating distances, % which is done in spkdallqk_final. % size(clast)=[rmax+1 smax+1] % % size(c)=[length(sa)+1 length(sb_1)+1 ... length(sb_L)+1 rmax+1 smax+1] % % Copyright (c) by Jonathan Victor. % % See also SPKDALLQK_FINAL, SPKDALLQK_DIST, SPKD, LABDIST_FAST. % if (nargin<=4) optsin=[]; end opts=optsin; if (~isfield(opts,'recurdebug')) opts.recurdebug=0; end if (~isfield(opts,'switchab')) opts.switchab=1; end if (~isfield(opts,'disjoint')) opts.disjoint=1; end if (~isfield(opts,'sizeonly')) opts.sizeonly=0; end if (~isfield(opts,'spaceopt')) opts.spaceopt=0; end %save input values opts.orig.sa=sa;opts.orig.sb=sb;opts.orig.la=la;opts.orig.lb=lb; % if opts.disjoint>=1 if (opts.disjoint>=2);disp('before disjoint optimization, la, lb=');disp(la);disp(lb);end lbnota=setdiff(unique(lb),unique(la)); lanotb=setdiff(unique(la),unique(lb)); if (length(lbnota)>1);lb(find(ismember(lb,lbnota)))=lbnota(1);end if (length(lanotb)>1);la(find(ismember(la,lanotb)))=lanotb(1);end if (opts.disjoint>=2);disp(' after disjoint optimization, la, lb=');disp(la);disp(lb);end end % [lurelabel,index,relabel] = unique([la lb]); labs = prod(size(lurelabel)); %replace each label by 1, ..., length(lunique) %disp('after replacement') la=relabel(1:length(la)); lb=relabel(length(la)+1:end); lunique=unique([la lb]); nua=[];nub=[]; for k = 1:labs sua{k}=sa(find(la==lunique(k))); nua(k)=length(find(la==lunique(k))); sub{k}=sb(find(lb==lunique(k))); nub(k)=length(find(lb==lunique(k))); end na=sum(nua); nb=sum(nub); rmax=sum(min([nua;nub],[],1)); smax=min(sum(min([nua;(nb-nub)],[],1)),sum(min([(na-nua);nub],[],1))); bsize=prod(nub+1); if (opts.switchab>=1) asize=prod(nua+1); absize=(na+1)*bsize; basize=(nb+1)*asize; if (opts.switchab>=2) disp(' na+1 asize, nb+1 bsize, absize, basize') disp([na+1 asize nb+1 bsize absize basize]); end if ((absize>basize) | (opts.switchab>=3)) if (opts.switchab>=2) disp('a and b swapped'); end lt=la;la=lb;lb=lt; st=sa;sa=sb;sb=st; nt=na;na=nb;nb=nt; nut=nua;nua=nub;nub=nut; tsize=asize;asize=bsize;bsize=tsize; for k=1:labs sut{k}=sua{k};sua{k}=sub{k};sub{k}=sut{k}; end else if (opts.switchab>=2) disp(' a and b not swapped.'); end end end cdims=[na+1 nub+1 rmax+1 smax+1]; %actual dimensions for c cwdims=[na+1 bsize rmax+1 smax+1]; %working dimensions for cw csdims=[2 bsize rmax+1 smax+1]; %working dimensions for space optimized version %save values used for later use opts.used.sa=sa;opts.used.sb=sb;opts.used.la=la;opts.used.lb=lb;opts.used.cdims=cdims;opts.used.cwdims=cwdims; % if (opts.recurdebug) disp('cwdims');disp(cwdims); end %if only sizes are requested, do this and return if (opts.sizeonly) if (opts.spaceopt==1) c=csdims; else c=cwdims; end clast=prod(c); return; end %initialize c for spaceopt if (opts.spaceopt==1) c.time=[]; c.na=[]; c.bmulti=[]; c.rmax=[]; c.smax=[]; c.clast=cell(0); end %check for trivial situations if (min(na,nb)==0) clast=0; rmax=0; smax=0; if (opts.spaceopt==0); c=zeros(cdims); else c.time=unique([sa sb]); if (na==0) c.na=zeros(1,nb); c.bmulti=[1:nb]; end if (nb==0) c.na=[1:na]; c.bmulti=zeros(1,na); end c.rmax=zeros(1,max(na,nb)); c.smax=zeros(1,max(na,nb)); for j=1:max(na,nb) c.clast{j}=0; end return end end %set up divisors for de-indexing and re-indexing bdiv=[]; for k=1:labs bdiv(k)=prod(nub(1:k-1)+1); end if (opts.recurdebug) disp('bdiv');disp(bdiv); end if (opts.spaceopt==0) if (opts.recurdebug>=1) disp('original (non-space-optimized)'); end %the original recursion (non space-optimized) cw=repmat(NaN,cwdims); % working array cw(:,:,1,1)=0; %initialize: no length needed for no links %loop over the subdivided train (b) for jb=0:(bsize-1) %find the b-multi-index bv=[]; for k=1:labs bv(k)=mod(floor(jb/bdiv(k)),nub(k)+1); %lengths of current subtrains for b end bvm=repmat(bv,[labs 1])-eye(labs); jbrec=bdiv*bvm'; %0-based the indices needed for the recursion if (opts.recurdebug>=2) disp('jb');disp(jb); disp('bv');disp(bv); if (opts.recurdebug>=3);disp('bvm');disp(bvm);end disp('jbrec');disp(jbrec); end %loop over the non-subdivided train (a) av=zeros(1,labs); %will track how many spikes of which kind are in the first ia spikes of a for ia=0:na if (ia>0) u=la(ia); av(u)=av(u)+1; end % u is label of the ith spike rmaxij=sum(min([av;bv],[],1)); %maximum number of matched links maxij=min(ia,sum(bv)); %maximum number of unmatched links and also maximum total number of links smaxij=min(sum(min([av;(sum(bv)-bv)],[],1)),sum(min([(sum(av)-av);bv],[],1))); %maximum number of unmatched links %loop over r and s if (opts.recurdebug>=3) disp('ia');disp(ia); disp('av,bv');disp([av;bv]); disp('maxij,rmaxij,smaxij');disp([maxij,rmaxij,smaxij]); end for rs=0:maxij %loop over r+s for r=max(0,rs-smaxij):min(rmaxij,rs) %loop over (r,s) s=rs-r; %guarantee that r is in [0:rmaxij] and s is in [0:smaxij] %do the calculation here cwtent=Inf; %last spike in a is unlinked if (ia>0) cwtent=min(cwtent,cw(ia,jb+1,r+1,s+1)); end %matched link case if ((ia>0) & (bv(u)>0) & (r>0)) cwtent=min(cwtent,cw(ia,jbrec(u)+1,r,s+1)+abs(sa(ia)-sub{u}(bv(u)))); end %unmatched link case if ((ia>0) & (s>0)) for k=setdiff(find(bv>0),u) cwtent=min(cwtent,cw(ia,jbrec(k)+1,r+1,s)+abs(sa(ia)-sub{k}(bv(k)))); end end %last spike in b is unlinked for k=find(bv>0) cwtent=min(cwtent,cw(ia+1,jbrec(k)+1,r+1,s+1)); end if (cwtent<Inf) cw(ia+1,jb+1,r+1,s+1)=cwtent; end end %loop over (r,s) end %loop over r+s end %loop over ia end %loop over jb clast=reshape(cw(na+1,bsize,:,:),[rmax+1 smax+1]); c=reshape(cw,cdims); end if (opts.spaceopt==1) if (opts.recurdebug>=1) disp('space-optimized'); end %the space-optimized recursion bv_list=zeros(bsize,labs); jbrec_list=zeros(bsize,labs); for jb=0:(bsize-1) %find the b-multi-index if (labs==0) bv=0; jbrec=0; else for k=1:labs bv(k)=mod(floor(jb/bdiv(k)),nub(k)+1); %lengths of current subtrains for b end bvm=repmat(bv,[labs 1])-eye(labs); jbrec=bdiv*bvm'; %0-based the indices needed for the recursion end if (opts.recurdebug>=2) disp('jb');disp(jb); disp('bv');disp(bv); disp('labs');disp(labs); if (opts.recurdebug>=3);disp('bvm');disp(bvm);end disp('jbrec');disp(jbrec); end bv_list(jb+1,:)=bv; jbrec_list(jb+1,:)=jbrec; end %set up table to pull out cseq if needed, using logic of spkdallqk_trunc c.time=unique([sa sb]); for j=1:length(c.time); tmax=c.time(j); la_trunc=la(find(sa<=tmax)); lb_trunc=lb(find(sb<=tmax)); nua_trunc=zeros(1,labs); nub_trunc=zeros(1,labs); for k=1:labs if length(la_trunc>0) nua_trunc(k)=length(find(la_trunc==lunique(k))); end if length(lb_trunc>0) nub_trunc(k)=length(find(lb_trunc==lunique(k))); end end na_trunc=sum(nua_trunc); nb_trunc=sum(nub_trunc); rmax_trunc=sum(min([nua_trunc;nub_trunc],[],1)); smax_trunc=min(sum(min([nua_trunc;(nb_trunc-nub_trunc)],[],1)),sum(min([(na_trunc-nua_trunc);nub_trunc],[],1))); bmulti=sum(nub_trunc.*bdiv); c.na(j)=na_trunc; c.bmulti(j)=bmulti; c.rmax(j)=rmax_trunc; c.smax(j)=smax_trunc; c.clast=[]; %make sure that there is something in this field end nclast=0; %number of clast structures calculated % cs=repmat(NaN,csdims); % working array cs(:,:,1,1)=0; %initialize: no length needed for no links % cs(1,:,:,:): values already computed % cs(2,:,:,:): values to be computed %loop over the non-subdivided train (a) av=zeros(1,labs); %will track how many spikes of which kind are in the first ia spikes of a for ia=0:na if (ia>0) u=la(ia); av(u)=av(u)+1; end % u is label of the ith spike %loop over the subdivided train (b) for jb=0:(bsize-1) %look up the b-multi-index bv=bv_list(jb+1,:); jbrec=jbrec_list(jb+1,:); % rmaxij=sum(min([av;bv],[],1)); %maximum number of matched links maxij=min(ia,sum(bv)); %maximum number of unmatched links and also maximum total number of links smaxij=min(sum(min([av;(sum(bv)-bv)],[],1)),sum(min([(sum(av)-av);bv],[],1))); %maximum number of unmatched links %loop over r and s for rs=0:maxij %loop over r+s for r=max(0,rs-smaxij):min(rmaxij,rs) %loop over (r,s) s=rs-r; %guarantee that r is in [0:rmaxij] and s is in [0:smaxij] %do the calculation here cwtent=Inf; %last spike in a is unlinked if (ia>0) cwtent=min(cwtent,cs(1,jb+1,r+1,s+1)); end %matched link case if ((ia>0) & (bv(u)>0) & (r>0)) cwtent=min(cwtent,cs(1,jbrec(u)+1,r,s+1)+abs(sa(ia)-sub{u}(bv(u)))); end %unmatched link case if ((ia>0) & (s>0)) for k=setdiff(find(bv>0),u) cwtent=min(cwtent,cs(1,jbrec(k)+1,r+1,s)+abs(sa(ia)-sub{k}(bv(k)))); end end %last spike in b is unlinked for k=find(bv>0) cwtent=min(cwtent,cs(2,jbrec(k)+1,r+1,s+1)); end if (cwtent<Inf) cs(2,jb+1,r+1,s+1)=cwtent; end end %loop over (r,s) end %loop over r+s end %loop over jb % %cs(2,jb+1,r+1,s+1) plays the role of cw(ia+1,jb+1,r+1,s+1) if (length(c.time)>0) for jval=find(c.na==ia) nclast=nclast+1; if (opts.recurdebug>=1);disp(sprintf(' making cseq with ia, jval,nclast= %d %d %d',ia,jval,nclast));end rmax_trunc=c.rmax(jval); smax_trunc=c.smax(jval); c.clast{nclast}=reshape(... cs(2,c.bmulti(jval)+1,1+[0:rmax_trunc],1+[0:smax_trunc]),[1+rmax_trunc 1+smax_trunc]); end end % if (ia<na) % do the rollover cs(1,:,:,:)=cs(2,:,:,:); cs(2,:,:,:)=NaN; cs(2,:,1,1)=0; %initialize: no length needed for no links end end %loop over ia clast=reshape(cs(2,bsize,:,:),[rmax+1 smax+1]); end % if (opts.recurdebug) disp(' k lunique(k) nua nub'); for k=1:labs disp(sprintf(' %2d %3d %4d %4d',k,lunique(k),nua(k),nub(k))); end disp('sua{k} and sub{k}'); for k=1:labs disp(sprintf(' k=%2d',k)); if (length(sua{k})>0) disp(sua{k}); else disp(' '); end; if (length(sub{k})>0) disp(sub{k}); else disp(' '); end; end disp(' rmax smax');disp(sprintf(' %4d %4d',rmax,smax)); disp('cdims');disp(cdims); end return

spkdallqk_final

function [dists,opts]=spkdallqk_final(qklist,sa,sb,clast,rmax,smax,optsin) % [dists,opts]=spkdallqk_final(qklist,sa,sb,clast,rmax,smax,optsin) does the % final portion of the Aronov multiunit algorithm, extended to simultaneous % calculation for all values of q and k. % % qklist: a list of pairs of values of q and k. qklist(:,1): q-values; qklist(:,2): k-values % sa, sb: spike times on the two spike trains % clast, rmax, smax: calculated by spkdallqk_recur % optsin: available for various debugging and optimizing options % % Copyright (c) by Jonathan Victor. % % See also SPKDALLQK_RECUR, SPKDALLQK_DIST, SPKD, LABDIST_FAST. % if (nargin<=6) optsin=[]; end opts=optsin; % na=length(sa); nb=length(sb); nqk=size(qklist,1); % %make matrix of na+nb-2r-2s, indicating costs of deletions and insertions % nanbrs=(na+nb)*ones(1+rmax,1+smax)-2*repmat([0:rmax]',1,1+smax)-2*repmat([0:smax],1+rmax,1); % % find the best strategy (all r (matched links) and all s (mismatched links) % for iqk=1:nqk posscosts=qklist(iqk,1)*clast+qklist(iqk,2)*repmat([0:smax],1+rmax,1)+nanbrs; dists(iqk,1)=min(min(posscosts)); end return

spkdallqk_dist

function [dists,clast,rmax,smax,opts]=spkdallqk_dist(qklist,sa,la,sb,lb,optsin) % function [dists,clast,rmax,smax,opts]=spkdallqk_dist(qklist,sa,la,sb,lb,optsin) does the % recursion and final portion of the Aronov multiunit algorithm, extended to simultaneous % calculation for all values of q and k. % % qklist: a list of pairs of values of q and k. qklist(:,1): q-values; qklist(:,2): k-values % sa, sb: spike times on the two spike trains % la, lb: labels on the two spike trains % clast, rmax, smax: calculated by spkdallqk_recur % optsin: available for various debugging and optimizing options % % Copyright (c) by Jonathan Victor. % % See also SPKDALLQK_RECUR, SPKDALLQK_DIST. % if (nargin<=5) optsin=[]; end %do the recursion [clast,rmax,smax,opts,c]=spkdallqk_recur(sa,la,sb,lb,optsin); %do the final stage [dists,opts]=spkdallqk_final(qklist,sa,sb,clast,rmax,smax,optsin); return

spkdallqk_trunc

function [clast_trunc,rmax_trunc,smax_trunc,na_trunc,bmulti_trunc]=spkdallqk_trunc(tmax,c,opts) % [clast_trunc,rmax_trunc,smax_trunc,na_trunc,bmulti_trunc]=spkdallqk_trunc(tmax,c,opts) % takes the matrix c as calculated by % spkdallqk_recur for two labeled spike trains, and determines the slice clast_trunc that % is required for calculating distances, if the spike trains are truncated at time tmax. % % CRUCIAL that opts is passed; this contains the original spike trains % stored in opts.orig.sa, opts.orig.la, opts.orig.sb, opts.orig.lb. % and the optimized ones that c was calculated for, % stored in opts.used.sa, opts.used.la, opts.used.sb, opts.used.lb. % % c calculated by [clast,rmax,smax,optsused,c]=spkdallqk_recur(sa,la,sb,lb,opts); % clast_trunc,rmax_trunc,smax_trunc should equal clast, rmax,smax IF tmax=Inf. % % c can also be the structure cseq, as returned by spkdallqk_truncseq % % rmax=max number of possible matches, sum(min(length(sa_k),length(sb_k))) % smax=max number of possible mismatches % c(m+1,n1+1,..., nL+1,r+1,s+1) is the shortest linkage connecting r pairs % of spikes that match, and s pairs of spikes that do not match. % clast=c(end,...,end,:,:) is the portion of c necessary for calculating distances, % which is done in spkdallqk_final. % size(clast)=[rmax+1 smax+1] % size(clast_trunc)=[rmax_trunc+1 smax_trunc+1] % % See spkdallqk.doc. % % Copyright (c) by Jonathan Victor. % % See also SPKDALLQK_RECUR, SPKDALLQK_FINAL, SPKDALLQK_TRUNCSEQ. % % get optimized values sa=opts.used.sa; sb=opts.used.sb; la=opts.used.la; lb=opts.used.lb; % lunique=unique([la lb]); labs=length(lunique); if (labs==0) clast_trunc=0; rmax_trunc=0; smax_trunc=0; return end for k = 1:labs nua(k)=length(find(la==lunique(k))); nub(k)=length(find(lb==lunique(k))); end na=sum(nua); nb=sum(nub); rmax=sum(min([nua;nub],[],1)); smax=min(sum(min([nua;(nb-nub)],[],1)),sum(min([(na-nua);nub],[],1))); bsize=prod(nub+1); if (~isstruct(c)) cdims=[na+1 nub+1 rmax+1 smax+1]; %actual dimensions for c cwdims=[na+1 bsize rmax+1 smax+1]; %working dimensions for cw cw=reshape(c,cwdims); end % %now find the truncated spike trains la_trunc=la(find(sa<=tmax)); lb_trunc=lb(find(sb<=tmax)); %and find parameters for them nua_trunc=zeros(1,labs); nub_trunc=zeros(1,labs); for k = 1:labs if length(la_trunc>0) nua_trunc(k)=length(find(la_trunc==lunique(k))); end if length(lb_trunc>0) nub_trunc(k)=length(find(lb_trunc==lunique(k))); end end na_trunc=sum(nua_trunc); nb_trunc=sum(nub_trunc); rmax_trunc=sum(min([nua_trunc;nub_trunc],[],1)); smax_trunc=min(sum(min([nua_trunc;(nb_trunc-nub_trunc)],[],1)),sum(min([(na_trunc-nua_trunc);nub_trunc],[],1))); %set up divisors for de-indexing and re-indexing with respect to full spike train for k=1:labs bdiv(k)=prod(nub(1:k-1)+1); end bmulti_trunc=sum(nub_trunc.*bdiv); % %separate cases, depending on whether c is the full matrix, or a structure % (full matrix returned by spkdallqk if spaceopt=0; structure if spaceopt=1) % if (~isstruct(c)) slice_trunc=reshape(cw(na_trunc+1,bmulti_trunc+1,:,:),[rmax+1 smax+1]); clast_trunc=slice_trunc([1:rmax_trunc+1],[1:smax_trunc+1]); else j=max([0,find(tmax>=c.time)]); if (j==0) clast_trunc=0; else clast_trunc=c.clast{j}; end end return