0001 function [y,ev] = dmrg_eigb(a,k,eps,varargin)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036 y0=[];
0037 rmax=2500;
0038 nswp=4;
0039 msize=1000;
0040 max_l_steps=200;
0041 kick_rank=5;
0042 verb=true;
0043 for i=1:2:length(varargin)-1
0044 switch lower(varargin{i})
0045 case 'nswp'
0046 nswp=varargin{i+1};
0047 case 'rmax'
0048 rmax=lower(varargin{i+1});
0049 case 'x0'
0050 y0=varargin{i+1};
0051 case 'msize'
0052 msize=varargin{i+1};
0053 case 'verb'
0054 verb=varargin{i+1};
0055 otherwise
0056 error('Unrecognized option: %s\n',varargin{i});
0057
0058 end
0059 end
0060
0061 n=a.n;
0062 d=a.d;
0063 if ( isempty(y0) )
0064 kk=max(5,k);
0065 r=kk*ones(1,d+1);
0066 r(d+1)=k;
0067 r(1)=1;
0068 y0=tt_random(n,d,r);
0069 end
0070 y=round(y0,0);
0071
0072
0073
0074
0075 psy=y.ps;
0076 ry=y.r;
0077 tta=a.tt;
0078 psa=tta.ps;
0079 ra=tta.r;
0080 cra=tta.core;
0081 cry=y.core;
0082 rm=1;
0083
0084 phi=cell(d+1,1);
0085 phi{1}=1;
0086 phi{d+1}=1;
0087 pos1=1;
0088 for i=1:d
0089 cr=cry(psy(i):psy(i+1)-1);
0090 cr=reshape(cr,[ry(i),n(i)*ry(i+1)]);
0091 cr=rm*cr;
0092 ry(i)=size(cr,1);
0093 cr=reshape(cr,[ry(i)*n(i),ry(i+1)]);
0094 [u,rm]=qr(cr,0);
0095 rnew=size(u,2);
0096
0097 cry(pos1:pos1+ry(i)*n(i)*rnew-1)=u(:);
0098 pos1=pos1+ry(i)*n(i)*rnew;
0099
0100
0101 crm=cra(psa(i):psa(i+1)-1);
0102 crm=reshape(crm,[ra(i),n(i),n(i),ra(i+1)]);
0103 phx=phi{i};
0104 phx=reshape(phx,[ra(i),ry(i),ry(i)]); phx=permute(phx,[1,3,2]);
0105 x0=u;
0106 x0=reshape(x0,[ry(i),n(i)*rnew]);
0107 phx=reshape(phx,[ra(i)*ry(i),ry(i)]);
0108 phx=phx*x0;
0109 phx=reshape(phx,[ra(i),ry(i),n(i),rnew]);
0110 crm=permute(crm,[1,3,2,4]); crm=reshape(crm,[ra(i)*n(i),n(i)*ra(i+1)]);
0111
0112 phx=permute(phx,[2,4,1,3]);
0113 phx=reshape(phx,[ry(i)*rnew,ra(i)*n(i)]);
0114 phx=phx*crm;
0115 phx=reshape(phx,[ry(i),rnew,n(i),ra(i+1)]);
0116 phx=permute(phx,[2,4,1,3]);
0117 phx=reshape(phx,[rnew*ra(i+1),ry(i)*n(i)]);
0118 x0=u;
0119 x0=reshape(x0,[ry(i)*n(i),rnew]);
0120 phx=phx*x0;
0121 phx=reshape(phx,[rnew,ra(i+1),rnew]);
0122 phx=permute(phx,[2,1,3]);
0123 phi{i+1}=phx;
0124
0125 end
0126
0127 ry(d+1)=rnew;
0128
0129 cry=cry(1:pos1-1);
0130 y.r=ry;
0131 y.core=cry;
0132 y.ps=cumsum([1;n.*ry(1:d).*ry(2:d+1)]);
0133
0134
0135
0136
0137
0138
0139
0140
0141
0142
0143
0144
0145 phi{d+1}=1;
0146 y.r=ry;
0147 y.core=cry;
0148 y.ps=cumsum([1;n.*ry(1:d).*ry(2:d+1)]);
0149 psy=y.ps;
0150
0151 swp=1;
0152 not_converged=true;
0153
0154 ry(d+1)=1;
0155 dir='rl';
0156 i=d-1;
0157 cry_left=cry(1:psy(d)-1);
0158 cry_right=cry(psy(d):psy(d+1)-1);
0159 while ( swp <= nswp && not_converged )
0160
0161 cra1=cra(psa(i):psa(i+1)-1);
0162 cra2=cra(psa(i+1):psa(i+2)-1);
0163 px1=phi{i}; px1=reshape(px1,[ra(i),ry(i),ry(i)]);
0164
0165 px1=permute(px1,[1,3,2]);
0166 px1=reshape(px1,[ra(i),ry(i)*ry(i)]);
0167 px2=phi{i+2};
0168 px2=reshape(px2,[ra(i+2),ry(i+2),ry(i+2)]);
0169
0170
0171
0172 cra2=reshape(cra2,[ra(i+1)*n(i+1)*n(i+1),ra(i+2)]);
0173 px2=reshape(px2,[ra(i+2),ry(i+2)*ry(i+2)]);
0174 b2=cra2*px2;
0175 b1=px1.'*reshape(cra1,[ra(i),numel(cra1)/ra(i)]);
0176 b1=reshape(b1,[ry(i),ry(i),n(i),n(i),ra(i+1)]); b1_save=b1;
0177 b1=permute(b1,[1,3,2,4,5]);
0178 b1=reshape(b1,[ry(i)*n(i),ry(i)*n(i),ra(i+1)]);
0179 b2=reshape(b2,[ra(i+1),n(i+1),n(i+1),ry(i+2),ry(i+2)]);
0180 b2_save=b2;
0181 b2=permute(b2,[2,4,3,5,1]);
0182 b2=reshape(b2,[n(i+1)*ry(i+2),n(i+1)*ry(i+2),ra(i+1)]);
0183 mm=cell(2,1); mm{1}=b1; mm{2}=b2;
0184 cry=[cry_left(:);cry_right(:)];
0185 y.r=ry;
0186 y.core=cry;
0187 y.ps=cumsum([1;n.*ry(1:d).*ry(2:d+1)]);
0188
0189
0190 if ( strcmp(dir,'rl') )
0191 pos=numel(cry_left);
0192 cry1=cry_left(pos-ry(i)*n(i)*ry(i+1)+1:pos);
0193 cry2=cry_right(1:ry(i+1)*n(i+1)*k*ry(i+2));
0194 cry1=reshape(cry1,[ry(i)*n(i),ry(i+1)]);
0195 cry2=reshape(cry2,[ry(i+1),n(i+1)*k*ry(i+2)]);
0196 w=cry1*cry2; w=reshape(w,[ry(i),n(i),n(i+1),k,ry(i+2)]);
0197 w=permute(w,[1,2,3,5,4]); w=reshape(w,[numel(w)/k,k]);
0198 else
0199 pos=numel(cry_left);
0200 cry1=cry_left(pos-ry(i)*n(i)*k*ry(i+1)+1:pos);
0201 cry2=cry_right(1:ry(i+1)*n(i+1)*ry(i+2));
0202 cry1=reshape(cry1,[ry(i)*n(i)*k,ry(i+1)]);
0203 cry2=reshape(cry2,[ry(i+1),n(i+1)*ry(i+2)]);
0204 w=cry1*cry2; w=reshape(w,[ry(i),n(i),k,n(i+1),ry(i+2)]);
0205 w=permute(w,[1,2,4,5,3]); w=reshape(w,[numel(w)/k,k]);
0206 end
0207
0208 bw=bfun(mm,w); ev=bw'*w;
0209
0210 er_old=bw-w*ev;
0211 erc=zeros(1,size(w,2));
0212 for j=1:size(w,2)
0213 erc(j)=norm(er_old(:,j));
0214 end
0215 er0=norm(er_old,'fro')/norm(w,'fro');
0216 if ( size(w,1) >= max(5*k,msize) )
0217 matvec='bfun';
0218 [wnew,ev,fail_flag]=lobpcg(w,@(x) bfun(mm,x),eps,max_l_steps);
0219 else
0220 matvec='full';
0221 fm=full(tt_matrix(mm));
0222 [v,dg]=eig(fm);
0223
0224
0225 ev=diag(dg);
0226 [ev,ind]=sort(ev,'ascend');
0227 v=v(:,ind);
0228 wnew=v(:,1:k);
0229 ev=ev(1:k);
0230 end
0231 er1=norm(bfun(mm,wnew)-wnew*diag(ev),'fro')/norm(wnew,'fro');
0232
0233 fv=sum(ev);
0234 fprintf('sweep=%d block=%d fv=%10.15f loc solve=%3.2e old_solve=%3.2e \n',swp,i,fv,er1,er0);
0235 disp(erc);
0236 if ( strcmp(dir,'rl') )
0237
0238
0239
0240
0241
0242 rhs=wnew*diag(ev);
0243 if (strcmp(matvec,'full'))
0244 res_true = norm(fm*wnew-rhs)/norm(rhs);
0245 else
0246 res_true = norm(bfun(mm,wnew)-rhs)/norm(rhs);
0247 end;
0248
0249
0250 wnew=reshape(wnew,[ry(i),n(i),n(i+1),ry(i+2),k]);
0251 wnew=permute(wnew,[1,2,5,3,4]); wnew=reshape(wnew,[ry(i)*n(i)*k,n(i+1)*ry(i+2)]);
0252 [u,s,v]=svd(wnew,'econ'); s=diag(s);
0253
0254
0255
0256
0257 r0=1; r1=min(numel(s),rmax);
0258 r=1;
0259
0260 while ( (r ~= r0 || r ~= r1) && r0 <= r1)
0261 r=min(floor((r0+r1)/2),rmax);
0262
0263 u1=u(:,1:r)*diag(s(1:r));
0264
0265 u1=reshape(u1,[ry(i)*n(i),k,r]);
0266 u1=permute(u1,[1,3,2]);
0267 u1=reshape(u1,[numel(u1)/k,k]);
0268 [u2,~,v2]=svd(u1,'econ');
0269 u1=u2*v2';
0270 u1=reshape(u1,[ry(i)*n(i),r,k]);
0271 u1=permute(u1,[1,3,2]);
0272 u1=reshape(u1,[numel(u1)/r,r]);
0273 sol = u1*(v(:,1:r))';
0274
0275 sol = reshape(sol,[ry(i),n(i),k,n(i+1),ry(i+2)]);
0276 sol=permute(sol,[1,2,4,5,3]); sol=reshape(sol,[numel(sol)/k,k]);
0277
0278
0279
0280 if (strcmp(matvec,'full'))
0281 resid = norm(fm*sol-rhs)/norm(rhs);
0282 else
0283 resid = norm(bfun(mm,sol)-rhs)/norm(rhs);
0284 end;
0285 if ( verb )
0286 fprintf('sweep %d, block %d, r0=%d, r1=%d, r=%d, resid=%g, MatVec=%s\n', swp, i, r0, r1, r, resid,matvec);
0287 end
0288 if ((resid<max(res_true*1.2, eps)) )
0289 r1=r;
0290 else
0291 r0=min(r+1,rmax);
0292 end;
0293
0294 end
0295 rnew=r;
0296 if ( norm(sol'*sol-eye(k)) > 1e-7 )
0297 keyboard;
0298 end
0299
0300
0301 u=u1; v=v(:,1:rnew);
0302
0303
0304
0305 radd=min(kick_rank,size(v,1)-rnew);
0306 rnew=rnew+radd;
0307 if ( radd > 0 )
0308 vr=randn(size(v,1),radd);
0309 ur=zeros(size(u,1),radd);
0310
0311 mvr=v'*vr; vnew=vr-v*mvr;
0312 reort_flag=false;
0313 for j=1:radd
0314 if ( norm(vnew(:,j)) <= 0.5*norm(vr(:,j)))
0315 reort_flag=true;
0316 end
0317 end
0318 if (reort_flag)
0319 sv=v'*vnew;
0320
0321 vnew=vnew-v*sv;
0322 end
0323 [vnew,~]=qr(vnew,0);
0324
0325 v=[v,vnew];
0326 u=[u,ur];
0327
0328
0329
0330 end
0331 v=v';
0332
0333
0334
0335
0336
0337 pos=numel(cry_left);
0338 cry_left(pos-ry(i)*n(i)*ry(i+1)+1:pos)=[];
0339 cry_right(1:ry(i+1)*n(i+1)*k*ry(i+2))=[];
0340 cry_right=[u(:)', v(:)',cry_right];
0341 ry(i+1)=rnew;
0342
0343
0344
0345
0346
0347
0348
0349 phx=reshape(b2_save,[ra(i+1),n(i+1),n(i+1),ry(i+2),ry(i+2)]);
0350 phx=permute(phx,[2,4,1,3,5]);
0351
0352
0353
0354 phx=reshape(phx,[n(i+1)*ry(i+2),ra(i+1)*n(i+1)*ry(i+2)]);
0355 v0=v;
0356 v0=reshape(v0,[ry(i+1),n(i+1)*ry(i+2)]);
0357 phx=v0*phx;
0358 phx=reshape(phx,[ry(i+1),ra(i+1),n(i+1),ry(i+2)]);
0359 phx=permute(phx,[3,4,1,2]);
0360 phx=reshape(phx,[n(i+1)*ry(i+2),ry(i+1)*ra(i+1)]);
0361 phx=v0*phx;
0362 phx=reshape(phx,[ry(i+1),ry(i+1),ra(i+1)]);
0363
0364 phx=permute(phx,[3,2,1]);
0365 phi{i+1}=phx;
0366 else
0367
0368
0369 rhs=wnew*diag(ev);
0370 if (strcmp(matvec,'full'))
0371 res_true = norm(fm*wnew-rhs)/norm(rhs);
0372 else
0373 res_true = norm(bfun(mm,wnew)-rhs)/norm(rhs);
0374 end;
0375 wnew=reshape(wnew,[ry(i),n(i),n(i+1),ry(i+2),k]);
0376 wnew=permute(wnew,[1,2,3,5,4]); wnew=reshape(wnew,[ry(i)*n(i),n(i+1)*k*ry(i+2)]);
0377 [u,s,v]=svd(wnew,'econ'); s=diag(s);
0378 r0=1; r1=min(size(s,1),rmax);
0379 r=1;
0380
0381 while ( (r ~= r0 || r ~= r1) && r0 <= r1)
0382 r=min(floor((r0+r1)/2),rmax);
0383
0384 v1=v(:,1:r)*diag(s(1:r));
0385
0386 v1=reshape(v1,[n(i+1),k,ry(i+2),r]);
0387 v1=permute(v1,[1,3,4,2]);
0388 v1=reshape(v1,[numel(v1)/k,k]);
0389 [u2,~,v2]=svd(v1,'econ');
0390 v1=u2*v2';
0391 v1=reshape(v1,[n(i+1),ry(i+2),r,k]);
0392 v1=permute(v1,[1,4,2,3]);
0393 v1=reshape(v1,[numel(v1)/r,r]);
0394 sol=u(:,1:r)*v1';
0395 sol = reshape(sol,[ry(i),n(i),n(i+1),k,ry(i+2)]);
0396 sol=permute(sol,[1,2,3,5,4]); sol=reshape(sol,[numel(sol)/k,k]);
0397 if (strcmp(matvec,'full'))
0398 resid = norm(fm*sol-rhs)/norm(rhs);
0399 else
0400 resid = norm(bfun(mm,sol)-rhs)/norm(rhs);
0401 end;
0402 if ( verb )
0403 fprintf('sweep %d, block %d, r0=%d r1=%d r=%d, resid=%g, MatVec=%s\n', swp, i, r0, r1, r, resid,matvec);
0404 end
0405 if ((resid<max(res_true*1.2, eps)) )
0406 r1=r;
0407 else
0408 r0=min(r+1,rmax);
0409 end;
0410 end
0411 rnew=r;
0412
0413
0414
0415
0416
0417 u=u(:,1:rnew); v=v1;
0418
0419
0420 radd=min(kick_rank,size(u,1)-rnew);
0421 rnew=rnew+radd;
0422 if ( radd > 0 )
0423 vr=zeros(size(v,1),radd);
0424 ur=randn(size(u,1),radd);
0425
0426 mur=u'*ur; unew=ur-u*mur;
0427 reort_flag=false;
0428 for j=1:radd
0429 if ( norm(unew(:,j)) <= 0.5*norm(ur(:,j)))
0430 reort_flag=true;
0431 end
0432 end
0433 if (reort_flag)
0434 sv=u'*unew;
0435
0436 unew=unew-u*sv;
0437 end
0438 [unew,~]=qr(unew,0);
0439
0440 u=[u,unew];
0441 v=[v,vr];
0442 end
0443 v=v';
0444
0445
0446
0447 cry_right(1:ry(i+1)*n(i+1)*ry(i+2))=[];
0448 pos=numel(cry_left);
0449 cry_left(pos-ry(i)*n(i)*k*ry(i+1)+1:pos)=[];
0450 cry_left=[cry_left,u(:)',v(:)'];
0451
0452
0453
0454
0455
0456
0457
0458 ry(i+1)=rnew;
0459
0460
0461
0462
0463
0464
0465 phx=reshape(b1_save,[ry(i),ry(i),n(i),n(i),ra(i+1)]);
0466 u0=u; u0=reshape(u0,[ry(i)*n(i),ry(i+1)]);
0467 phx=permute(phx,[1,3,5,2,4]);
0468
0469
0470
0471 phx=reshape(phx,[ry(i)*n(i)*ra(i+1),ry(i)*n(i)]);
0472 phx=phx*u0;
0473 phx=reshape(phx,[ry(i),n(i),ra(i+1),ry(i+1)]);
0474 phx=permute(phx,[3,4,1,2]);
0475 phx=reshape(phx,[ra(i+1)*ry(i+1),ry(i)*n(i)]);
0476 phx=phx*u0; phx=reshape(phx,[ra(i+1),ry(i+1),ry(i+1)]);
0477 phx=permute(phx,[1,2,3]);
0478 phi{i+1}=phx;
0479 end
0480
0481 if ( strcmp(dir,'rl') )
0482 if ( i > 1 )
0483 i=i-1;
0484 else
0485 dir='lr';
0486
0487 cry_left=cry_right(1:ry(1)*n(1)*ry(2)*k);
0488 cry_right(1:ry(1)*n(1)*ry(2)*k)=[];
0489 swp=swp+1;
0490 end
0491 else
0492 if ( i < d-1 )
0493 i=i+1;
0494 else
0495 dir='rl';
0496 pos=numel(cry_left);
0497 cry_right=cry_left(pos-ry(d)*n(d)*ry(d+1)*k+1:pos);
0498 cry_left(pos-ry(d)*n(d)*ry(d+1)*k+1:pos)=[];
0499 swp=swp+1;
0500
0501 end
0502 end
0503
0504
0505 end
0506
0507
0508 ry(d+1)=k;
0509 y.r=ry;
0510 cry=[cry_left,cry_right];
0511 y.core=cry;
0512 y.r=ry;
0513 y.d=d;
0514 y.ps=cumsum([1;n.*ry(1:d).*ry(2:d+1)]);
0515 end
0516
0517 function [x]=bfun(a,x)
0518
0519
0520
0521 re=size(x,2);
0522 ul=a{1}; vl=a{2};
0523 n1=size(ul,1);m1=size(ul,2); ral=size(ul,3);
0524 n2=size(vl,1); m2=size(vl,2);
0525 ul=permute(ul,[1,3,2]); ul=reshape(ul,[numel(ul)/m1,m1]);
0526 x=reshape(x,[m1,numel(x)/m1]);
0527 x=ul*x;
0528 x=reshape(x,[n1,ral,m2,re]);
0529 x=permute(x,[3,2,1,4]);
0530 x=reshape(x,[m2*ral,n1*re]);
0531 vl=reshape(vl,[n2,m2*ral]);
0532 x=vl*x;
0533 x=reshape(x,[n2,n1,re]);
0534 x=permute(x,[2,1,3]);
0535 x=reshape(x,[numel(x)/re,re]);
0536 return
0537 end
0538
0539