0001 function [y]=dmrg_cross(d,n,fun,eps,varargin)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037 rmin=1;
0038 verb=true;
0039 radd=0;
0040 kickrank=2;
0041 nswp=10;
0042 y=[];
0043 vectorized=false;
0044 for i=1:2:length(varargin)-1
0045 switch lower(varargin{i})
0046 case 'nswp'
0047 nswp=varargin{i+1};
0048 case 'y0'
0049 y=varargin{i+1};
0050 case 'verb'
0051 verb=varargin{i+1};
0052 case 'rmin'
0053 rmin=varargin{i+1};
0054 case 'radd'
0055 radd=varargin{i+1};
0056 case 'vec'
0057 vectorized=varargin{i+1};
0058 case 'kickrank'
0059 kickrank=varargin{i+1};
0060
0061 otherwise
0062 error('Unrecognized option: %s\n',varargin{i});
0063 end
0064 end
0065
0066 if ( numel(n) == 1 )
0067 n=n*ones(d,1);
0068 end
0069
0070 sz=n;
0071 if (isempty(y) )
0072 y=tt_rand(sz,d,2);
0073 end
0074 if ( ~vectorized )
0075 elem=@(ind) my_vec_fun(ind,fun);
0076 end
0077 y=round(y,0);
0078 ry=y.r;
0079 [y,rm]=qr(y,'rl');
0080 y=rm*y;
0081
0082
0083
0084 swp=1;
0085 rmat=cell(d+1,1);
0086 rmat{d+1}=1;
0087 rmat{1}=1;
0088 index_array{d+1}=zeros(0,ry(d+1));
0089 index_array{1}=zeros(ry(1),0);
0090 r1=1;
0091 for i=d:-1:2
0092 cr=y{i}; cr=reshape(cr,[ry(i)*n(i),ry(i+1)]);
0093 cr = cr*r1; cr=reshape(cr,[ry(i),n(i)*ry(i+1)]); cr=cr.';
0094 [cr,rm]=qr(cr,0);
0095 [ind]=maxvol2(cr);
0096 ind_old=index_array{i+1};
0097 rnew=min(n(i)*ry(i+1),ry(i));
0098 ind_new=zeros(d-i+1,rnew);
0099 for s=1:rnew
0100 f_in=ind(s);
0101 w1=tt_ind2sub([ry(i+1),n(i)],f_in);
0102 rs=w1(1); js=w1(2);
0103 ind_new(:,s)=[js,ind_old(:,rs)'];
0104 end
0105 index_array{i}=ind_new;
0106 r1=cr(ind,:);
0107 cr=cr/r1;
0108 r1=r1*rm;
0109 r1=r1.';
0110
0111 cr=cr.';
0112 y{i}=reshape(cr,[ry(i),n(i),ry(i+1)]);
0113 cr=reshape(cr,[ry(i)*n(i),ry(i+1)]);
0114 cr=cr*rmat{i+1}; cr=reshape(cr,[ry(i),n(i)*ry(i+1)]);
0115 cr=cr.';
0116 [~,rm]=qr(cr,0);
0117 rmat{i}=rm;
0118 end
0119
0120 cr=y{1}; cr=reshape(cr,[ry(1)*n(1),ry(2)]);
0121 y{1}=reshape(cr*r1,[ry(1),n(1),ry(2)]);
0122 not_converged = true;
0123 dir = 1;
0124 i=1;
0125 er_max=0;
0126 while ( swp < nswp && not_converged )
0127
0128
0129
0130
0131
0132
0133
0134
0135
0136
0137
0138 rm1=rmat{i}; rm2=rmat{i+2};
0139 cr1=y{i}; cr2=y{i+1};
0140 ind1=index_array{i};
0141 ind2=index_array{i+2};
0142 big_index=zeros(ry(i),n(i),n(i+1),ry(i+2),d);
0143 for i1=1:n(i)
0144 for i2=1:n(i+1)
0145 for s1=1:ry(i)
0146 for s2=1:ry(i+2)
0147 ind=[ind1(s1,:),i1,i2,ind2(:,s2)'];
0148 big_index(s1,i1,i2,s2,:)=ind;
0149 end
0150 end
0151 end
0152 end
0153 big_index=reshape(big_index,[numel(big_index)/d,d]);
0154 score=elem(big_index);
0155
0156 score=reshape(score,[ry(i),n(i)*n(i+1)*ry(i+2)]);
0157 score=rmat{i}*score;
0158 ry(i)=size(score,1);
0159 score=reshape(score,[ry(i)*n(i)*n(i+1),ry(i+2)]);
0160 score=score*rmat{i+2};
0161 ry(i+2)=size(score,2);
0162
0163
0164
0165 score=reshape(score,[ry(i)*n(i),n(i+1)*ry(i+2)]);
0166 [u,s,v]=svd(score,'econ');
0167 s=diag(s);
0168 r=my_chop2(s,norm(s)*eps/sqrt(d-1));
0169 u=u(:,1:r); v=v(:,1:r); s=s(1:r);
0170
0171
0172 if ( dir == 1 )
0173 v=v*diag(s);
0174
0175 ur=randn(size(u,1),kickrank);
0176 u=reort(u,ur);
0177 radd=size(u,2)-r;
0178 if ( radd > 0 )
0179 vr=zeros(size(v,1),radd);
0180 v=[v,vr];
0181 end
0182 r=r+radd;
0183 else
0184 u=u*diag(s);
0185 vr=randn(size(v,1),kickrank);
0186 v=reort(v,vr);
0187 radd=size(v,2)-r;
0188 if ( radd > 0 )
0189 ur=zeros(size(u,1),radd);
0190 u=[u,ur];
0191 end
0192 r=r+radd;
0193 end
0194
0195 v=v';
0196
0197
0198 appr=reshape(cr1,[numel(cr1)/ry(i+1),ry(i+1)])*reshape(cr2,[ry(i+1),numel(cr2)/ry(i+1)]);
0199 appr=reshape(appr,[ry(i),n(i)*n(i+1)*ry(i+2)]);
0200 appr=rmat{i}*appr;
0201 appr=reshape(appr,[ry(i)*n(i)*n(i+1),ry(i+2)]);
0202 appr=appr*rmat{i+2};
0203 er_loc=norm(score(:)-appr(:))/norm(score(:));
0204 er_max=max(er_max,er_loc);
0205 if ( verb )
0206 fprintf('swp=%d block=%d new_rank=%d local_er=%3.1e\n',swp,i,r,er_loc);
0207 end
0208 ry(i+1)=r;
0209
0210 u = reshape(u,[ry(i),n(i)*r]);
0211 u = rmat{i}\u;
0212 v=reshape(v,[r*n(i+1),ry(i+2)]);
0213 u=reshape(u,[ry(i)*n(i),ry(i+1)]);
0214 v=v/rmat{i+2}; v=reshape(v,[r,n(i+1)*ry(i+2)]);
0215 if ( dir == 1 )
0216 [u,rm]=qr(u,0);
0217 ind=maxvol2(u);
0218 r1=u(ind,:);
0219 u=u/r1; y{i}=reshape(u,[ry(i),n(i),ry(i+1)]);
0220 r1=r1*rm;
0221 v=r1*v; y{i+1}=reshape(v,[ry(i+1),n(i+1),ry(i+2)]);
0222
0223 u1=reshape(u,[ry(i),n(i)*ry(i+1)]);
0224 u1=rmat{i}*u1;
0225 u1=reshape(u1,[ry(i)*n(i),ry(i+1)]);
0226 [~,rm]=qr(u1,0);
0227 rmat{i+1}=rm;
0228
0229 ind_old=index_array{i};
0230 ind_new=zeros(ry(i+1),i);
0231 for s=1:ry(i+1)
0232 f_in=ind(s);
0233 w1=tt_ind2sub([ry(i),n(i)],f_in);
0234 rs=w1(1); js=w1(2);
0235 ind_new(s,:)=[ind_old(rs,:),js];
0236 end
0237 index_array{i+1}=ind_new;
0238 if ( i == d - 1 )
0239 dir = -dir;
0240 else
0241 i=i+1;
0242 end
0243 else
0244 v=v.';
0245 [v,rm]=qr(v,0);
0246 ind=maxvol2(v);
0247 r1=v(ind,:);
0248 v=v/r1; v2=reshape(v,[n(i+1),ry(i+2),ry(i+1)]); y{i+1}=permute(v2,[3,1,2]);
0249 r1=r1*rm; r1=r1.';
0250 u=u*r1; y{i}=reshape(u,[ry(i),n(i),ry(i+1)]);
0251
0252 v=v.';
0253 v=reshape(v,[ry(i+1)*n(i+1),ry(i+2)]);
0254 v=v*rmat{i+2};
0255 v=reshape(v,[ry(i+1),n(i+1)*ry(i+2)]); v=v.';
0256 [~,rm]=qr(v,0);
0257 rmat{i+1}=rm;
0258
0259 ind_old=index_array{i+2};
0260 ind_new=zeros(d-i,ry(i+1));
0261 for s=1:ry(i+1);
0262 f_in=ind(s);
0263 w1=tt_ind2sub([n(i+1),ry(i+2)],f_in);
0264 rs=w1(2); js=w1(1);
0265 ind_new(:,s)=[js,ind_old(:,rs)'];
0266 end
0267 index_array{i+1}=ind_new;
0268 if ( i == 1 )
0269 dir=-dir;
0270 swp = swp + 1;
0271 if ( er_max < eps )
0272 not_converged=false;
0273 else
0274 er_max=0;
0275 end
0276 else
0277 i=i-1;
0278 end
0279 end
0280 end
0281 return
0282 end
0283 function val=my_vec_fun(ind,fun)
0284
0285
0286
0287
0288 M=size(ind,1);
0289 val=zeros(M,1);
0290 for i=1:M
0291 ind_loc=ind(i,:); ind_loc=ind_loc(:);
0292 val(i)=fun(ind_loc);
0293 end
0294 return
0295 end