Home > tt2 > cross > dmrg_cross.m

dmrg_cross

PURPOSE ^

DMRG-cross method for the approximation of TT-tensors

SYNOPSIS ^

function [y]=dmrg_cross(d,n,fun,eps,varargin)

DESCRIPTION ^

DMRG-cross method for the approximation of TT-tensors
   [A]=DMRG_CROSS(D,N,FUN,EPS,OPTIONS) Computes the approximation of a
   given tensor via the adaptive DMRG-cross procedure. The input is a pair
   (D,N) which determines the size of the tensor (N can be either a
   number, or array of mode sizes). FUN is the function to compute
   a prescribed element of a tensor (FUN(IND)), or it can be vectorized to
   compute series of elements of a tensor (see OPTIONS) To pass parameters 
   to FUN please use anonymous function handles. EPS is the accuracy 
   of the approximation.Options are provided in form
   'PropertyName1',PropertyValue1,'PropertyName2',PropertyValue2 and so
   on. The parameters are set to default (in brackets in the following) 
   The list of option names and default values are:
       o nswp - number of DMRG sweeps [10]
       o vec  - Fun is vectorized [ true | {false} ]
       o verb - output debug information [ {true} | false ]
       o y0   - initial approximation [random rank-2]
       o radd - minimal rank change [0]
       o rmin - minimal rank that is allows [1]
       o kickrank - stabilization parameter [2]

   Example:
       d=10; n=2; fun = @(ind) sum(ind);
       tt=dmrg_cross(d,n,fun,1e-7);


 TT-Toolbox 2.2, 2009-2012

This is TT Toolbox, written by Ivan Oseledets et al.
Institute of Numerical Mathematics, Moscow, Russia
webpage: http://spring.inm.ras.ru/osel

For all questions, bugs and suggestions please mail
ivan.oseledets@gmail.com
---------------------------
Default parameters

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [y]=dmrg_cross(d,n,fun,eps,varargin)
0002 %DMRG-cross method for the approximation of TT-tensors
0003 %   [A]=DMRG_CROSS(D,N,FUN,EPS,OPTIONS) Computes the approximation of a
0004 %   given tensor via the adaptive DMRG-cross procedure. The input is a pair
0005 %   (D,N) which determines the size of the tensor (N can be either a
0006 %   number, or array of mode sizes). FUN is the function to compute
0007 %   a prescribed element of a tensor (FUN(IND)), or it can be vectorized to
0008 %   compute series of elements of a tensor (see OPTIONS) To pass parameters
0009 %   to FUN please use anonymous function handles. EPS is the accuracy
0010 %   of the approximation.Options are provided in form
0011 %   'PropertyName1',PropertyValue1,'PropertyName2',PropertyValue2 and so
0012 %   on. The parameters are set to default (in brackets in the following)
0013 %   The list of option names and default values are:
0014 %       o nswp - number of DMRG sweeps [10]
0015 %       o vec  - Fun is vectorized [ true | {false} ]
0016 %       o verb - output debug information [ {true} | false ]
0017 %       o y0   - initial approximation [random rank-2]
0018 %       o radd - minimal rank change [0]
0019 %       o rmin - minimal rank that is allows [1]
0020 %       o kickrank - stabilization parameter [2]
0021 %
0022 %   Example:
0023 %       d=10; n=2; fun = @(ind) sum(ind);
0024 %       tt=dmrg_cross(d,n,fun,1e-7);
0025 %
0026 %
0027 % TT-Toolbox 2.2, 2009-2012
0028 %
0029 %This is TT Toolbox, written by Ivan Oseledets et al.
0030 %Institute of Numerical Mathematics, Moscow, Russia
0031 %webpage: http://spring.inm.ras.ru/osel
0032 %
0033 %For all questions, bugs and suggestions please mail
0034 %ivan.oseledets@gmail.com
0035 %---------------------------
0036 %Default parameters
0037 rmin=1;
0038 verb=true;
0039 radd=0;
0040 kickrank=2;
0041 nswp=10;
0042 y=[];
0043 vectorized=false;
0044 for i=1:2:length(varargin)-1
0045     switch lower(varargin{i})
0046         case 'nswp'
0047             nswp=varargin{i+1};
0048         case 'y0'
0049             y=varargin{i+1};
0050         case 'verb'
0051             verb=varargin{i+1};
0052         case 'rmin'
0053             rmin=varargin{i+1};
0054         case 'radd'
0055             radd=varargin{i+1};
0056         case 'vec'
0057             vectorized=varargin{i+1};
0058         case 'kickrank'
0059             kickrank=varargin{i+1};
0060 
0061         otherwise
0062             error('Unrecognized option: %s\n',varargin{i});
0063     end
0064 end
0065 
0066 if ( numel(n) == 1 )
0067    n=n*ones(d,1);
0068 end
0069 
0070 sz=n;
0071 if (isempty(y) )
0072     y=tt_rand(sz,d,2); 
0073 end
0074 if ( ~vectorized ) 
0075     elem=@(ind) my_vec_fun(ind,fun);
0076 end
0077 y=round(y,0); %To avoid overranks
0078 ry=y.r;
0079 [y,rm]=qr(y,'rl');
0080 y=rm*y;
0081 %Warmup procedure: orthogonalization from right to left of the initial
0082 %approximation & computation of the index sets & computation of the
0083 %right-to-left R matrix
0084 swp=1;
0085 rmat=cell(d+1,1); 
0086 rmat{d+1}=1;
0087 rmat{1}=1; %These are R-matrices from the QR-decomposition.
0088 index_array{d+1}=zeros(0,ry(d+1)); 
0089 index_array{1}=zeros(ry(1),0);
0090 r1=1;
0091 for i=d:-1:2
0092     cr=y{i}; cr=reshape(cr,[ry(i)*n(i),ry(i+1)]);
0093     cr = cr*r1; cr=reshape(cr,[ry(i),n(i)*ry(i+1)]); cr=cr.';
0094     [cr,rm]=qr(cr,0);
0095     [ind]=maxvol2(cr); 
0096     ind_old=index_array{i+1};
0097     rnew=min(n(i)*ry(i+1),ry(i));
0098     ind_new=zeros(d-i+1,rnew);
0099     for s=1:rnew
0100        f_in=ind(s);
0101        w1=tt_ind2sub([ry(i+1),n(i)],f_in);
0102        rs=w1(1); js=w1(2);
0103        ind_new(:,s)=[js,ind_old(:,rs)'];
0104     end
0105     index_array{i}=ind_new;
0106     r1=cr(ind,:);
0107     cr=cr/r1; 
0108     r1=r1*rm;
0109     r1=r1.';
0110 
0111     cr=cr.'; 
0112     y{i}=reshape(cr,[ry(i),n(i),ry(i+1)]);
0113     cr=reshape(cr,[ry(i)*n(i),ry(i+1)]);
0114     cr=cr*rmat{i+1}; cr=reshape(cr,[ry(i),n(i)*ry(i+1)]);
0115     cr=cr.'; 
0116     [~,rm]=qr(cr,0);
0117     rmat{i}=rm; %The R-matrix
0118 end
0119 %Forgot to put r1 onto the last core
0120 cr=y{1}; cr=reshape(cr,[ry(1)*n(1),ry(2)]);
0121 y{1}=reshape(cr*r1,[ry(1),n(1),ry(2)]); 
0122 not_converged = true;
0123 dir = 1; %The direction of the sweep
0124 i=1; %Current position
0125 er_max=0;
0126 while ( swp < nswp && not_converged )
0127     % A sweep through the cores
0128     %Compute the current index set, compute the current supercore
0129     %(right now without any 2D cross inside, but it is trivial to
0130     %implement). The supercore is (i,i+1) now.
0131     %Left index set is index_array{i}, right index set is index_array{i+2}
0132     %We will modify ry(i+1) at this step and use rmat{i} and rmat{i+2}
0133     %as "weighting" matrices for the low-rank approximation. The initial
0134     %approximation is simply rmax{i}*u{i}*u{i+1}*rmat{i+2} (hey!)
0135     %We also have to store the submatrix in the current factors
0136     %Then the algorithm would be as follows: Computex sets, compute
0137     %supercore. Compute rmax{i}*Phi*rmax{i+2} = U*V by SVD, then split
0138     rm1=rmat{i}; rm2=rmat{i+2};
0139     cr1=y{i}; cr2=y{i+1};
0140     ind1=index_array{i};
0141     ind2=index_array{i+2};
0142     big_index=zeros(ry(i),n(i),n(i+1),ry(i+2),d);
0143     for i1=1:n(i)
0144         for i2=1:n(i+1)
0145             for s1=1:ry(i)
0146                 for s2=1:ry(i+2)
0147                     ind=[ind1(s1,:),i1,i2,ind2(:,s2)'];
0148                     big_index(s1,i1,i2,s2,:)=ind;
0149                 end
0150             end
0151         end
0152     end
0153     big_index=reshape(big_index,[numel(big_index)/d,d]);
0154     score=elem(big_index); 
0155     %Now plug in the rmax matrices
0156     score=reshape(score,[ry(i),n(i)*n(i+1)*ry(i+2)]);
0157     score=rmat{i}*score;
0158     ry(i)=size(score,1);
0159     score=reshape(score,[ry(i)*n(i)*n(i+1),ry(i+2)]);
0160     score=score*rmat{i+2}; 
0161     ry(i+2)=size(score,2);
0162     
0163     %Do the SVD splitting (later on we can replace it by cross for large
0164     %mode sizes)
0165     score=reshape(score,[ry(i)*n(i),n(i+1)*ry(i+2)]);
0166     [u,s,v]=svd(score,'econ');
0167     s=diag(s);
0168     r=my_chop2(s,norm(s)*eps/sqrt(d-1)); %Truncation
0169     u=u(:,1:r); v=v(:,1:r); s=s(1:r); 
0170     %Kick rank
0171     
0172     if ( dir == 1 ) 
0173         v=v*diag(s);
0174         
0175         ur=randn(size(u,1),kickrank);
0176         u=reort(u,ur);
0177         radd=size(u,2)-r;
0178         if ( radd > 0 )
0179             vr=zeros(size(v,1),radd);
0180             v=[v,vr];
0181         end
0182         r=r+radd;
0183     else
0184          u=u*diag(s);
0185          vr=randn(size(v,1),kickrank);
0186          v=reort(v,vr);
0187          radd=size(v,2)-r;
0188          if ( radd > 0 )
0189              ur=zeros(size(u,1),radd);
0190              u=[u,ur];
0191          end
0192          r=r+radd;
0193     end
0194     
0195     v=v';
0196 
0197     %Compute the previous approximation
0198     appr=reshape(cr1,[numel(cr1)/ry(i+1),ry(i+1)])*reshape(cr2,[ry(i+1),numel(cr2)/ry(i+1)]);
0199     appr=reshape(appr,[ry(i),n(i)*n(i+1)*ry(i+2)]);
0200     appr=rmat{i}*appr;
0201     appr=reshape(appr,[ry(i)*n(i)*n(i+1),ry(i+2)]);
0202     appr=appr*rmat{i+2}; 
0203     er_loc=norm(score(:)-appr(:))/norm(score(:));
0204     er_max=max(er_max,er_loc);
0205     if ( verb ) 
0206         fprintf('swp=%d block=%d new_rank=%d local_er=%3.1e\n',swp,i,r,er_loc);
0207     end
0208     ry(i+1)=r;
0209 
0210     u = reshape(u,[ry(i),n(i)*r]);
0211     u = rmat{i}\u; %Hope it is stable blin
0212     v=reshape(v,[r*n(i+1),ry(i+2)]); 
0213     u=reshape(u,[ry(i)*n(i),ry(i+1)]);
0214     v=v/rmat{i+2}; v=reshape(v,[r,n(i+1)*ry(i+2)]);
0215     if ( dir == 1 ) 
0216         [u,rm]=qr(u,0); 
0217         ind=maxvol2(u); 
0218         r1=u(ind,:); 
0219         u=u/r1; y{i}=reshape(u,[ry(i),n(i),ry(i+1)]);
0220         r1=r1*rm; 
0221         v=r1*v; y{i+1}=reshape(v,[ry(i+1),n(i+1),ry(i+2)]);
0222         %Recalculate rmat
0223         u1=reshape(u,[ry(i),n(i)*ry(i+1)]);
0224         u1=rmat{i}*u1;
0225         u1=reshape(u1,[ry(i)*n(i),ry(i+1)]);
0226         [~,rm]=qr(u1,0);
0227         rmat{i+1}=rm;
0228         %Recalculate index array
0229         ind_old=index_array{i};
0230         ind_new=zeros(ry(i+1),i);
0231         for s=1:ry(i+1)
0232             f_in=ind(s);
0233             w1=tt_ind2sub([ry(i),n(i)],f_in);
0234             rs=w1(1); js=w1(2);
0235             ind_new(s,:)=[ind_old(rs,:),js];
0236         end
0237         index_array{i+1}=ind_new; 
0238         if ( i == d - 1 ) 
0239             dir = -dir;
0240         else
0241             i=i+1;
0242         end
0243     else %Reverse direction
0244          v=v.'; %v is standing
0245         [v,rm]=qr(v,0);
0246         ind=maxvol2(v);
0247         r1=v(ind,:);
0248         v=v/r1; v2=reshape(v,[n(i+1),ry(i+2),ry(i+1)]); y{i+1}=permute(v2,[3,1,2]);
0249         r1=r1*rm; r1=r1.';
0250         u=u*r1; y{i}=reshape(u,[ry(i),n(i),ry(i+1)]);
0251         %Recalculate rmat
0252         v=v.'; 
0253         v=reshape(v,[ry(i+1)*n(i+1),ry(i+2)]);
0254         v=v*rmat{i+2};
0255         v=reshape(v,[ry(i+1),n(i+1)*ry(i+2)]); v=v.';
0256         [~,rm]=qr(v,0);
0257         rmat{i+1}=rm;
0258         %Recalculate index array
0259         ind_old=index_array{i+2};
0260         ind_new=zeros(d-i,ry(i+1));
0261         for s=1:ry(i+1);
0262             f_in=ind(s);
0263             w1=tt_ind2sub([n(i+1),ry(i+2)],f_in);
0264             rs=w1(2); js=w1(1);
0265             ind_new(:,s)=[js,ind_old(:,rs)'];
0266         end
0267         index_array{i+1}=ind_new;
0268         if ( i == 1 ) 
0269             dir=-dir;
0270             swp = swp + 1;
0271             if ( er_max < eps ) 
0272                 not_converged=false;
0273             else
0274                 er_max=0;
0275             end
0276         else
0277             i=i-1;
0278         end
0279     end
0280 end
0281 return
0282 end
0283 function val=my_vec_fun(ind,fun)
0284 %Trivial vectorized computation of the elements of a tensor
0285 %   [VAL]=MY_VEC_FUN(IND,FUN) Given a function handle FUN, compute all
0286 %   elements of a tensor given in the index array IND. IND is a M x d
0287 %   array, where M is the number of indices to be computed.
0288 M=size(ind,1);
0289 val=zeros(M,1);
0290 for i=1:M
0291    ind_loc=ind(i,:); ind_loc=ind_loc(:);
0292    val(i)=fun(ind_loc);
0293 end
0294 return
0295 end

Generated on Wed 08-Feb-2012 18:20:24 by m2html © 2005