Home > tt2 > @tt_matrix > mvk2.m

mvk2

PURPOSE ^

Two-sided DMRG fast matrix-by-vector product

SYNOPSIS ^

function [y,swp]=mvk2(a,x,eps,nswp,z,rmax,varargin)

DESCRIPTION ^

Two-sided DMRG fast matrix-by-vector product
   [Y,SWP]=MVK2(A,X,EPS,[NSWP],[Y],[RMAX],[OPTIONS]) Two-sided DMRG (mvk
   is one-sided). Matrix-by-vector product of a TT-matrix A
   by a TT-tensor X with accuracy EPS. Also, one can specify the number of
   sweeps NSWP, initial approximation Z and the maximal TT-rank RMAX (if
   they become too large)


 TT-Toolbox 2.2, 2009-2012

This is TT Toolbox, written by Ivan Oseledets et al.
Institute of Numerical Mathematics, Moscow, Russia
webpage: http://spring.inm.ras.ru/osel

For all questions, bugs and suggestions please mail
ivan.oseledets@gmail.com
---------------------------

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [y,swp]=mvk2(a,x,eps,nswp,z,rmax,varargin)
0002 %Two-sided DMRG fast matrix-by-vector product
0003 %   [Y,SWP]=MVK2(A,X,EPS,[NSWP],[Y],[RMAX],[OPTIONS]) Two-sided DMRG (mvk
0004 %   is one-sided). Matrix-by-vector product of a TT-matrix A
0005 %   by a TT-tensor X with accuracy EPS. Also, one can specify the number of
0006 %   sweeps NSWP, initial approximation Z and the maximal TT-rank RMAX (if
0007 %   they become too large)
0008 %
0009 %
0010 % TT-Toolbox 2.2, 2009-2012
0011 %
0012 %This is TT Toolbox, written by Ivan Oseledets et al.
0013 %Institute of Numerical Mathematics, Moscow, Russia
0014 %webpage: http://spring.inm.ras.ru/osel
0015 %
0016 %For all questions, bugs and suggestions please mail
0017 %ivan.oseledets@gmail.com
0018 %---------------------------
0019 n=a.n;
0020 m=a.m;
0021 att=a.tt;
0022 corea=att.core;
0023 psa=att.ps;
0024 ra=att.r;
0025 d=att.d;
0026 %start_val='fast_mv';
0027 start_val='rough_mv';
0028 %start_val='random';
0029 rmax_loc=8; %For the "rough" matrix-by-vector product
0030 if ( nargin <= 5 || isempty(rmax) )
0031    rmax=1000;
0032 end
0033 if ( nargin <= 4 || isempty(z) )
0034     
0035     if ( strcmp(start_val,'random') )
0036     rz1=rank(a,1)*rank(x,1);
0037     rzd=rank(a,d+1)*rank(x,d+1);
0038     kf=5;
0039     rz=[rz1;kf*ones(d-1,1);rzd];
0040     z=tt_rand(n,ndims(x),rz);
0041     elseif (strcmp(start_val,'rough_mv'))
0042         rmax_loc=8; %
0043         xloc=round(x,0,rmax_loc);
0044         aloc=round(a,0,rmax_loc);
0045         z=round(aloc*xloc,eps); 
0046         %z=aloc*xloc;
0047     elseif (strcmp(start_val,'fast_mv') )
0048          rz1=rank(a,1)*rank(x,1);
0049       rzd=rank(a,d+1)*rank(x,d+1);
0050       kf=5;
0051       rz=[rz1;kf*ones(d-1,1);rzd];
0052       z=tt_rand(n,ndims(x),rz);
0053        z=mvk2(a,x,max(eps,1e-2),10,z,rmax); %First, do it with bad accuracy
0054     end
0055 end
0056 if ( nargin <= 3 || isempty(nswp) )
0057   nswp = 40;
0058 end
0059 y=z;
0060 
0061 %Parameters section
0062 kick_rank=6;
0063 verb=false;
0064 
0065 %Warmup is to orthogonalize Y from right-to-left and compute psi-matrices
0066 %for Ax
0067 psi=cell(d+1,1); %Psi-matrices
0068 %psi{d+1}=1; psi{1}=1;
0069 psi{d+1}=eye(rank(y,d+1)); 
0070 psi{1}=eye(rank(y,1));
0071 %Here we will add convergence test
0072 %Warmup: right-to-left QR + computation of psi matrices
0073 
0074 corex=x.core;
0075 psx=x.ps;
0076 rx=x.r;
0077 
0078 swp=1;
0079 converged=false;
0080 while (swp <= nswp && ~converged)  
0081 psy=y.ps;
0082   ry=y.r;
0083   corey=y.core;
0084      pos1=psy(d+1);
0085 cr1=corey(psy(d):psy(d+1)-1);
0086 
0087 for i=d:-1:2  
0088    cr2=corey(psy(i-1):psy(i)-1);
0089    cr1=reshape(cr1,[ry(i),n(i)*ry(i+1)]);
0090    cr2=reshape(cr2,[ry(i-1)*n(i-1),ry(i)]);
0091    cr1=cr1.';
0092    [q,rm]=qr(cr1,0); rn=size(q,2); rm=rm.';
0093    q=q.'; 
0094    ry(i)=rn;
0095    corey(pos1-ry(i+1)*n(i)*ry(i):pos1-1)=q(:);
0096    %Convolution is now performed for psi(i) using psi(i+1) and corea, and
0097    %(new) core q
0098    cra=corea(psa(i):psa(i+1)-1); cra=reshape(cra,[ra(i),n(i),m(i),ra(i+1)]);
0099    cry=reshape(conj(q),[ry(i),n(i),ry(i+1)]);
0100    crx=corex(psx(i):psx(i+1)-1); crx=reshape(crx,[rx(i),m(i),rx(i+1)]);
0101    pscur=psi{i+1}; pscur=reshape(pscur,[ra(i+1),rx(i+1),ry(i+1)]); %ra,rx,ry
0102    %First, convolve over rx(i+1)
0103    crx=reshape(crx,[rx(i)*m(i),rx(i+1)]);
0104    pscur=permute(pscur,[2,1,3]); pscur=reshape(pscur,[rx(i+1),ra(i+1)*ry(i+1)]);
0105    pscur=crx*pscur; %pscur is now rx(i)*m(i)*ra(i+1)*ry(i+1)
0106    %Convolve over m(i),ra(i+1),n(i),ry(i+1)
0107     pscur=reshape(pscur,[rx(i),m(i)*ra(i+1),ry(i+1)]);
0108     pscur=permute(pscur,[1,3,2]); 
0109     pscur=reshape(pscur,[rx(i)*ry(i+1),m(i)*ra(i+1)]);
0110     cra=reshape(cra,[ra(i)*n(i),m(i)*ra(i+1)]); cra=cra.';  
0111     pscur=pscur*cra; 
0112     %pscur is now rx(i)*ry(i+1)*ra(i)*n(i), it is left to convolve over
0113     %n(i)*ry(i+1)
0114     pscur=reshape(pscur,[rx(i),ry(i+1),ra(i),n(i)]);
0115     pscur=permute(pscur,[3,1,4,2]);
0116     pscur=reshape(pscur,[rx(i)*ra(i),n(i)*ry(i+1)]);
0117     cry=reshape(cry,[ry(i),n(i)*ry(i+1)]); cry=cry.';
0118     pscur=pscur*cry;
0119     psi{i}=pscur;
0120    %End of psi-block
0121    pos1=pos1-ry(i+1)*n(i)*ry(i);
0122    cr1=cr2*rm;
0123    
0124 end
0125 corey(pos1-ry(2)*n(1)*ry(1):pos1-1)=cr1(:);
0126 pos1=pos1-ry(2)*n(1)*ry(1);
0127 corey=corey(pos1:numel(corey)); %Truncate unused elements
0128 
0129   
0130 
0131   %left-to-right dmrg sweep
0132   pos1=1;
0133   cry_old=corey;
0134   psy=cumsum([1;n.*ry(1:d).*ry(2:d+1)]);
0135   converged=true;
0136   ermax=0;
0137   for i=1:d-1
0138      %We care for two cores, with number i & number i+1, and use
0139      %psi(i) and psi(i+2) as a basis; also we will need to recompute
0140      %psi(i+1)
0141      ps1=psi{i}; ps2=psi{i+2};
0142      cra1=corea(psa(i):psa(i+1)-1); cra2=corea(psa(i+1):psa(i+2)-1);
0143      crx1=corex(psx(i):psx(i+1)-1); crx2=corex(psx(i+1):psx(i+2)-1);
0144      %our convolution is
0145      %ps1(ra(i),rx(i),ry(i))*cra1(ra(i),n(i),m(i),ra(i+1))*
0146      %cra2(ra(i+1),n(i+1),m(i+1),ra(i+2))*
0147      %*ps2(ra(i+2),rx(i+2),ry(i+2))
0148      %*crx1(rx(i),m(i),rx(i+1))*cr2x(rx(i+1)*m(i+1)*rx(i+2))
0149      %Scheme ps1*crx1 over rx(i)
0150      
0151      ps1=reshape(ps1,[ra(i),rx(i),ry(i)]); 
0152      ps1=permute(ps1,[1,3,2]); ps1=reshape(ps1,[ra(i)*ry(i),rx(i)]);
0153      crx1=reshape(crx1,[rx(i),m(i)*rx(i+1)]);
0154      ps1=ps1*crx1; %ps1 is now ra(i)*ry(i)*m(i)*rx(i+1)
0155      %Now convolve with matrix A over ra(i)*m(i)
0156      ps1=reshape(ps1,[ra(i),ry(i),m(i),rx(i+1)]);
0157      ps1=permute(ps1,[2,4,1,3]);
0158      ps1=reshape(ps1,[ry(i)*rx(i+1),ra(i)*m(i)]);
0159      cra1=reshape(cra1,[ra(i),n(i),m(i),ra(i+1)]);
0160      cra1=permute(cra1,[1,3,2,4]); 
0161      cra1=reshape(cra1,[ra(i)*m(i),n(i)*ra(i+1)]);
0162      ps1=ps1*cra1; %ps1 is now ry(i)*rx(i+1)*n(i)*ra(i+1)
0163      %Then the ``same'' convolution is carried over for second pair of
0164      %cores
0165      ps2=reshape(ps2,[ra(i+2),rx(i+2),ry(i+2)]);
0166      ps2=permute(ps2,[2,1,3]);
0167      ps2=reshape(ps2,[rx(i+2),ra(i+2)*ry(i+2)]);
0168      crx2=reshape(crx2,[rx(i+1)*m(i+1),rx(i+2)]);
0169      ps2=crx2*ps2; %ps2 is now rx*(i+1)*m(i+1)*ra(i+2)*ry(i+2)
0170      %Convolve over m(i+1)*ra(i+2)
0171      ps2=reshape(ps2,[rx(i+1),m(i+1)*ra(i+2),ry(i+2)]);
0172      ps2=permute(ps2,[2,1,3]);
0173      ps2=reshape(ps2,[m(i+1)*ra(i+2),rx(i+1)*ry(i+2)]);
0174      cra2=reshape(cra2,[ra(i+1)*n(i+1),m(i+1)*ra(i+2)]);
0175      ps2=cra2*ps2; 
0176      %ps2 is now ra(i+1)*n(i+1)*rx(i+1)*ry(i+2)
0177      %Now form superblock by contraction ps1 & ps2
0178      %over ra(i+1)*rx(i+1)
0179      ps0=reshape(ps1,[ry(i),rx(i+1),n(i),ra(i+1)]);
0180      ps0=permute(ps0,[1,3,2,4]);
0181      ps0=reshape(ps0,[ry(i)*n(i),rx(i+1)*ra(i+1)]);
0182      ps2=reshape(ps2,[ra(i+1),n(i+1),rx(i+1),ry(i+2)]);
0183      ps2=permute(ps2,[3,1,2,4]); 
0184      ps2=reshape(ps2,[rx(i+1)*ra(i+1),n(i+1)*ry(i+2)]);
0185      super_core=ps0*ps2; %super_core is ry(i)*n(i)*n(i+1)*ry(i+2)
0186     if ( i > 1 )
0187      %Compute previous supercore
0188         cr1=corey(pos1:pos1+ry(i)*n(i)*ry(i+1)-1);
0189         cr2=cry_old(psy(i+1):psy(i+2)-1); 
0190         cr1=reshape(cr1,[ry(i)*n(i),ry(i+1)]);
0191         cr2=reshape(cr2,[ry(i+1),n(i+1)*ry(i+2)]);
0192          super_core_old=cr1*cr2; 
0193          er=norm(super_core_old(:)-super_core(:))/norm(super_core(:));
0194      
0195         if ( er > eps ) 
0196             converged=false;
0197         end
0198         ermax=max(er,ermax);
0199      %if ( verb )
0200      %  fprintf('i=%d er=%3.2e \n',i,er);
0201      %end
0202      end
0203      [u,s,v]=svd(super_core,'econ');
0204      s=diag(s); 
0205      r=my_chop2(s,eps/sqrt(d-1)*norm(s)); r=min(r,rmax);
0206      u=u(:,1:r); s=s(1:r); v=v(:,1:r); v=v*diag(s);
0207      
0208      %Kick rank
0209      
0210      ur=randn(size(u,1),kick_rank);
0211      %Orthogonalize ur to u by Golub-Kahan reorth
0212      u=reort(u,ur);
0213      radd=size(u,2)-r; 
0214      if ( radd > 0 )
0215         vr=zeros(size(v,1),radd);
0216         v=[v,vr];
0217      end
0218      r=size(u,2);
0219      
0220      
0221      ry(i+1)=r;
0222      %u is ry(i)*n(i)*ry(i+1)
0223      %core_new(pos1:pos1+ry(i)*n(i)*ry(i+1)-1)=u(:);
0224      corey(pos1:pos1+ry(i)*n(i)*ry(i+1)-1)=u(:); 
0225      
0226      u=reshape(u,[ry(i),n(i),ry(i+1)]);
0227 
0228      %Compute new psi
0229      %ps1 is ry(i)*rx(i+1)*n(i)*ra(i+1) with u over ry(i)*n(i)
0230      u=conj(u);
0231      ps1=reshape(ps1,[ry(i),rx(i+1),n(i),ra(i+1)]);
0232      ps1=permute(ps1,[4,2,1,3]);
0233      ps1=reshape(ps1,[ra(i+1)*rx(i+1),ry(i)*n(i)]);
0234      u=reshape(u,[ry(i)*n(i),ry(i+1)]);
0235      ps1=ps1*u;
0236      psi{i+1}=ps1;
0237      %Compute (?) new v
0238      pos1=pos1+ry(i)*n(i)*ry(i+1);
0239      v=v';
0240      %v=reshape(v,[ry(i+1),n(i+1),ry(i+2)]);
0241      corey(pos1:pos1+ry(i+1)*n(i+1)*ry(i+2)-1)=v(:);
0242   end
0243   psy=cumsum([1;n.*ry(1:d).*ry(2:d+1)]);
0244 y.core=corey;
0245 y.r=ry;
0246 y.ps=psy;
0247 swp=swp+1;
0248 if ( verb )
0249 fprintf('swp=%d er=%3.2e trunk=%3.2e \n',swp,ermax,eps/sqrt(d-1));
0250 end
0251 end
0252 if ( swp == nswp ) 
0253   fprintf('mvk2 warning: error is not fixed for maximal number of sweeps %d\n', swp); 
0254 end%end
0255 %p1=tt_mvdot(core(a),core(x),y0);
0256 %p2=dot(y,tt_tensor(y0));
0257 %abs(p1-p2)/abs(p1)
0258 %keyboard;
0259 
0260 
0261

Generated on Wed 08-Feb-2012 18:20:24 by m2html © 2005