0001 function [y]=tt_rc(d,n,elem_fun,eps,varargin)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014 if ( numel(n) == 1 )
0015 n=n*ones(d,1);
0016 end
0017 nswp=50;
0018 rmax=1000;
0019 x0=[];
0020 verb=2;
0021 vec=false;
0022
0023 change_dir_on=false;
0024 for i=1:2:length(varargin)-1
0025 switch lower(varargin{i})
0026 case 'nswp'
0027 nswp=varargin{i+1};
0028 case 'rmax'
0029 rmax=lower(varargin{i+1});
0030 case 'x0'
0031 x0=varargin{i+1};
0032 case 'verb'
0033 verb=varargin{i+1};
0034 case 'vec'
0035 vec=varargin{i+1};
0036 case 'change_dir_on'
0037 change_dir_on=varargin{i+1};
0038 otherwise
0039 error('Unrecognized option: %s\n',varargin{i});
0040 end
0041 end
0042
0043
0044
0045
0046
0047
0048
0049
0050 i_left=cell(d,1);
0051 i_right=cell(d,1);
0052 r_left=cell(d,1);
0053 r_right=cell(d,1);
0054
0055 if ( isempty(x0) )
0056 ry=ones(d+1,1);
0057
0058 ind=2*ones(d,1);
0059
0060
0061 ind=find_fiber_maximum(elem_fun,n,ind);
0062 for i=1:d
0063 i_left{i}=ind(i);
0064 i_right{i}=ind(i);
0065 r_left{i}=1;
0066 r_right{i}=1;
0067 end
0068 else
0069 ry=x0.r;
0070 ps=x0.ps;
0071 cr=x0.core;
0072 rm=1;
0073 x0_old=x0;
0074 for i=1:d-1
0075 cur_core=cr(ps(i):ps(i+1)-1);
0076 cur_core=reshape(cur_core,[ry(i),n(i)*ry(i+1)]);
0077 cur_core=rm*cur_core;
0078 cur_core=reshape(cur_core,[ry(i)*n(i),ry(i+1)]);
0079 [cur_core,rm]=qr(cur_core,0);
0080 ind=maxvol2(cur_core); ind=ind';
0081 sbm=cur_core(ind,:);
0082 cur_core=cur_core / sbm;
0083
0084 rm=sbm*rm;
0085 cr(ps(i):ps(i+1)-1)=cur_core(:);
0086
0087 [radd,iadd]=ind2sub([ry(i);n(i)],ind);
0088 i_left{i}=iadd;
0089 r_left{i}=radd;
0090 end
0091 cur_core=cr(ps(d):ps(d+1)-1);
0092 cur_core=reshape(cur_core,[ry(d),n(d)*ry(d+1)]);
0093 cur_core=rm*cur_core;
0094 cr(ps(d):ps(d+1)-1)=cur_core(:);
0095
0096 rm=1;
0097 for i=d:-1:2
0098 cur_core=cr(ps(i):ps(i+1)-1);
0099 cur_core=reshape(cur_core,[ry(i)*n(i),ry(i+1)]);
0100 cur_core=cur_core*rm;
0101 cur_core=reshape(cur_core,[ry(i),n(i)*ry(i+1)]);
0102 cur_core=cur_core.';
0103 [cur_core,rm]=qr(cur_core,0);
0104 ind=maxvol2(cur_core); ind=ind';
0105 sbm=cur_core(ind,:);
0106 cur_core=cur_core/sbm;
0107 rm=sbm*rm;
0108 cur_core=cur_core.';
0109 cr(ps(i):ps(i+1)-1)=cur_core;
0110 rm=rm.';
0111 [iadd,radd]=ind2sub([n(i);ry(i+1)],ind);
0112 i_right{i}=iadd;
0113 r_right{i}=radd;
0114 end
0115 cur_core=cr(ps(1):ps(2)-1);
0116 cur_core=reshape(cur_core,[ry(1)*n(1),ry(2)]);
0117 cur_core=cur_core*rm;
0118 cr(ps(1):ps(2)-1)=cur_core(:);
0119 x0.core=cr(:);
0120
0121 end
0122
0123
0124
0125 ind_left=get_multi_left(i_left,r_left,ry);
0126 ind_right=get_multi_right(i_right,r_right,ry);
0127
0128
0129
0130 super_core=cell(d-1,1);
0131 for i=1:d-1
0132
0133 index_set=zeros(d,ry(i),n(i),n(i+1),ry(i+2));
0134 if ( i == 1 )
0135 ileft=zeros(1,0);
0136 else
0137 ileft=ind_left{i-1};
0138 end
0139 if ( i == d - 1 )
0140 iright = zeros(1,0);
0141 else
0142 iright=ind_right{i+2};
0143 end
0144 for s1=1:ry(i)
0145 for s2=1:ry(i+2)
0146 for i1=1:n(i)
0147 for i2=1:n(i+1)
0148 index_set(:,s1,i1,i2,s2)=[ileft(s1,:),i1,i2,iright(s2,:)];
0149 end
0150 end
0151 end
0152 end
0153 M=ry(i)*n(i)*n(i+1)*ry(i+2);
0154 index_set=reshape(index_set,[d,M]);
0155 if ( vec )
0156 super_core{i}=reshape(elem_fun(index_set),[ry(i),n(i),n(i+1),ry(i+2)]);
0157 else
0158 cur_core=zeros(M,1);
0159 for k=1:M
0160 cur_core(k)=elem_fun(index_set(:,k));
0161 end
0162 super_core{i}=reshape(cur_core,[ry(i),n(i),n(i+1),ry(i+2)]);
0163 end
0164
0165 end
0166
0167
0168 restart=true;
0169 while ( restart )
0170
0171
0172 swp=1;
0173 dir='lr';
0174 i=1;
0175 converged=false;
0176 rank_increase=false;
0177 while ( swp <= nswp && ~converged)
0178
0179
0180
0181
0182
0183
0184
0185
0186
0187
0188
0189
0190
0191 cur_core=super_core{i};
0192
0193
0194
0195 cur_core=reshape(cur_core,[ry(i)*n(i),n(i+1)*ry(i+2)]);
0196
0197
0198
0199
0200 i1=r_left{i}+(i_left{i}-1)*ry(i);
0201 i2=i_right{i+1}+(r_right{i+1}-1)*n(i+1);
0202
0203
0204
0205
0206
0207
0208
0209 [iadd1,iadd2]=enlarge_cross(cur_core,i1,i2,eps);
0210 if ( ~isempty(iadd1) )
0211 rank_increase=true;
0212
0213 [radd_left,iadd_left]=ind2sub([ry(i);n(i)],iadd1);
0214 [iadd_right,radd_right]=ind2sub([n(i+1);ry(i+2)],iadd2);
0215 i_left{i}=[i_left{i};iadd_left];
0216 r_left{i}=[r_left{i};radd_left];
0217
0218
0219
0220 i_right{i+1}=[i_right{i+1};iadd_right];
0221 r_right{i+1}=[r_right{i+1};radd_right];
0222 rold=ry(i+1);
0223 ry(i+1)=ry(i+1)+numel(iadd1);
0224
0225 ind_left=get_multi_left(i_left,r_left,ry);
0226
0227 ind_right=get_multi_right(i_right,r_right,ry);
0228
0229 if ( i > 1 )
0230
0231
0232 radd=numel(iadd1);
0233 rtmp=ry; rtmp(i+1)=radd;
0234
0235
0236 tmp2=ind_right;
0237 tmp2{i+1}(1:rold,:)=[];
0238 core_add=compute_supercore(i-1,elem_fun,d,n,rtmp,ind_left,tmp2,vec);
0239 new_core=zeros(ry(i-1),n(i-1),n(i),ry(i+1));
0240 new_core(:,:,:,1:rold)=super_core{i-1};
0241 new_core(:,:,:,rold+1:end)=core_add;
0242 super_core{i-1}=new_core;
0243
0244
0245 end
0246 if ( i < d - 1)
0247 radd=numel(iadd1);
0248 rtmp=ry; rtmp(i+1)=radd;
0249 tmp1=ind_left;
0250 tmp1{i}(1:rold,:)=[];
0251 core_add=compute_supercore(i+1,elem_fun,d,n,rtmp,tmp1,ind_right,vec);
0252 new_core=zeros(ry(i+1),n(i),n(i+1),ry(i+3));
0253 if ( rold > 0 )
0254 new_core(1:rold,:,:,:)=super_core{i+1};
0255 end
0256 new_core(rold+1:end,:,:,:)=core_add;
0257 super_core{i+1}=new_core;
0258
0259 end
0260
0261
0262
0263
0264 end
0265 if ( i ~= 1 && i ~= d - 1 && change_dir_on)
0266 sm=super_core{i-1}; sp=super_core{i+1};
0267 i1m=r_left{i-1}+(i_left{i-1}-1)*ry(i-1);
0268 i2m=i_right{i}+(r_right{i}-1)*n(i);
0269 i1p=r_left{i+1}+(i_left{i+1}-1)*ry(i+1);
0270 i2p=i_right{i+2}+(r_right{i+2}-1)*n(i+2);
0271
0272 sm=reshape(sm,[ry(i-1)*n(i-1),n(i)*ry(i+1)]);
0273 sp=reshape(sp,[ry(i+1)*n(i+1),n(i+2)*ry(i+3)]);
0274 smb=sm(:,i2m); [smb,dmp]=qr(smb,0);
0275 sub_m=smb(i1m,:);
0276 sm_appr=(smb / sub_m) * sm(i1m,:);
0277
0278 spb=sp(:,i2p); [spb,dmp]=qr(spb,0);
0279 sub_p=spb(i1p,:);
0280 sp_appr=(spb / sub_p) * sp(i1p,:);
0281
0282
0283 end
0284
0285
0286
0287 if ( verb >= 1 )
0288 fprintf('Step %d sweep %d ry(%d): %d -> %d rank=%3.1f \n',i,swp,i+1,ry(i+1)-numel(iadd1),ry(i+1),mean(ry));
0289 end
0290 if ( i ~= 1 && i ~= d-1 && change_dir_on )
0291 er1=norm(sm-sm_appr,'fro')/norm(sm,'fro');
0292 er2=norm(sp-sp_appr,'fro')/norm(sp,'fro');
0293 else
0294 er1=0; er2=0;
0295 end
0296
0297
0298 if ( (strcmp(dir,'lr') && er1 > er2) && er1 > eps && change_dir_on )
0299 dir='rl';
0300 if ( verb >= 1 )
0301 fprintf('i=%d CHANGED DIRECTION LR->RL er1=%3.2e, er2=%3.2e\n',i,er1,er2);
0302 end
0303
0304 i=i-1;
0305 rank_increase=true;
0306 elseif ( strcmp(dir,'lr') )
0307 i=i+1;
0308 elseif ( (strcmp(dir,'rl') && er2 > er1 ) && er2 > eps && change_dir_on)
0309 dir='lr';
0310 i=i+1;
0311 if ( verb >= 1 )
0312 fprintf('i=%d CHANGED DIRECTION RL->LR er1=%3.2e, er2=%3.2e\n',i,er1,er2);
0313 end
0314 rank_increase=true;
0315 elseif ( strcmp(dir,'rl') )
0316 i=i-1;
0317 end
0318
0319
0320
0321 if ( strcmp(dir,'lr') )
0322 if ( i == d-1 )
0323 dir='rl';
0324 end
0325 else
0326 if ( i == 1 )
0327 dir ='lr';
0328 swp=swp+1;
0329 if ( rank_increase )
0330 rank_increase=false;
0331 else
0332 converged=true;
0333 end
0334 end
0335 end
0336 end
0337 if ( verb >= 1 )
0338 fprintf('Converged in %d sweeps \n',swp);
0339 end
0340
0341
0342
0343
0344
0345
0346 ps=cumsum([1;n.*ry(1:d).*ry(2:d+1)]);
0347 cr=zeros(ps(d+1)-1,1);
0348
0349 for i=1:d-1
0350 score=super_core{i};
0351 score=reshape(score,[ry(i)*n(i),n(i+1)*ry(i+2)]);
0352 i2=i_right{i+1}+(r_right{i+1}-1)*n(i+1);
0353 i1=r_left{i}+(i_left{i}-1)*ry(i);
0354
0355
0356 r=0;
0357 er=2*eps;
0358
0359
0360
0361
0362
0363
0364
0365 sbm=score(i1,i2);
0366 [u,s,v]=svd(sbm,'econ'); s=diag(s);
0367
0368 maxr=max(size(sbm));
0369 while ( r < maxr && er > eps )
0370 r=r+1;
0371 u0=u(:,1:r); s0=s(1:r); v0=v(:,1:r);
0372 iu=maxvol2(u0);
0373 iv=maxvol2(v0);
0374 iu1=i1(iu); iv1=i2(iv);
0375 w1=score(:,iv1); [w1,dmp]=qr(w1,0);
0376 w1=w1 / w1(iu1,:);
0377 appr=w1*score(iu1,:);
0378 er=norm(score-appr,'fro')/norm(score,'fro');
0379 end
0380
0381
0382
0383 w2=zeros(ry(i)*n(i),ry(i+1));
0384 w2(:,iu)=w1;
0385 cr(ps(i):ps(i+1)-1)=w2(:);
0386
0387
0388
0389
0390
0391
0392
0393
0394 end
0395
0396 score=super_core{d-1};
0397 score=reshape(score,[ry(i)*n(i),n(i+1)*ry(i+2)]);
0398 i1=r_left{d-1}+(i_left{d-1}-1)*ry(d-1);
0399 score=score(i1,:);
0400 score=reshape(score,[ry(d),n(d),ry(d+1)]);
0401 cr(ps(d):ps(d+1)-1)=score(:);
0402
0403 y=tt_tensor;
0404 y.core=cr;
0405 y.ps=ps;
0406 y.r=ry;
0407 y.n=n;
0408 y.d=d;
0409
0410
0411 tmp_fun = @(ind) elem_fun(ind) - y(ind);
0412
0413
0414
0415
0416
0417 ind=ones(1,d);
0418 for s=1:5
0419 [ind]=find_fiber_maximum(tmp_fun,n,ind);
0420 end
0421 if ( tmp_fun(ind) > 10*eps*elem_fun(ind) )
0422
0423 restart=true;
0424
0425
0426 rprev=1;
0427 i=1;
0428
0429 for i=1:d-1
0430 i_left{i}=[i_left{i};ind(i)];
0431 if ( i > 1 )
0432 r_left{i}=[r_left{i};ry(i)+1];
0433 else
0434 r_left{i}=[r_left{i};1];
0435 end
0436 end
0437 for i=2:d
0438 i_right{i}=[i_right{i};ind(i)];
0439 if ( i < d )
0440 r_right{i}=[r_right{i}; ry(i+1)+1];
0441 else
0442 r_right{i}=[r_right{i}; 1];
0443 end
0444 end
0445 ry(2:d)=ry(2:d)+1;
0446
0447 ind_left=get_multi_left(i_left,r_left,ry);
0448
0449 ind_right=get_multi_right(i_right,r_right,ry);
0450 for i=1:d-1
0451 super_core{i}=compute_supercore(i,elem_fun,d,n,ry,ind_left,ind_right,vec);
0452 end
0453
0454 else
0455 restart=false;
0456 end
0457 end
0458 return
0459 end
0460
0461 function [super_core]=compute_supercore(i,elem_fun,d,n,ry,ind_left,ind_right,vec)
0462
0463 index_set=zeros(d,ry(i),n(i),n(i+1),ry(i+2));
0464 if ( i == 1 )
0465 ileft=zeros(1,0);
0466 else
0467 ileft=ind_left{i-1};
0468 end
0469 if ( i == d - 1 )
0470 iright = zeros(1,0);
0471 else
0472 iright=ind_right{i+2};
0473 end
0474 for s1=1:ry(i)
0475 for s2=1:ry(i+2)
0476 for i1=1:n(i)
0477 for i2=1:n(i+1)
0478 index_set(:,s1,i1,i2,s2)=[ileft(s1,:),i1,i2,iright(s2,:)];
0479 end
0480 end
0481 end
0482 end
0483 M=ry(i)*n(i)*n(i+1)*ry(i+2);
0484 index_set=reshape(index_set,[d,M]);
0485 if ( vec )
0486 super_core=reshape(elem_fun(index_set),[ry(i),n(i),n(i+1),ry(i+2)]);
0487 else
0488 cur_core=zeros(M,1);
0489 for k=1:M
0490 cur_core(k)=elem_fun(index_set(:,k));
0491 end
0492 super_core=reshape(cur_core,[ry(i),n(i),n(i+1),ry(i+2)]);
0493 end
0494
0495
0496 return
0497 end
0498
0499 function [iadd1,iadd2]=enlarge_cross(mat,i1,i2,eps)
0500
0501
0502
0503
0504
0505
0506 magic_eps=1e-12;
0507
0508
0509
0510
0511
0512
0513
0514 i1save=i1;
0515 i2save=i2;
0516 sbm=mat(i1,i2);
0517
0518
0519
0520 [u,s,v]=svd(sbm,'econ');
0521 s=diag(s); rm=my_chop2(s,min(magic_eps,eps)*norm(s));
0522 u=u(:,1:rm); v=v(:,1:rm);
0523 indu=maxvol2(u); indv=maxvol2(v);
0524 i1=i1(indu);i2=i2(indv);
0525
0526 cbas=mat(:,i2);
0527 rbas=mat(i1,:);
0528
0529
0530
0531
0532
0533 [cbas,~]=qr(cbas,0);
0534 qsbm=cbas(i1,:);
0535
0536
0537
0538
0539
0540
0541
0542
0543
0544 appr=cbas / qsbm * rbas;
0545
0546
0547
0548
0549 iadd1=[];
0550 iadd2=[];
0551
0552 if ( norm(mat(:)-appr(:),'fro') < eps*norm(mat(:),'fro') )
0553 rnew=numel(i1);
0554 else
0555 desirata=eps*norm(mat(:),'fro');
0556 er=norm(mat(:)-appr(:),'fro');
0557 mat=mat-appr;
0558 while ( er > desirata )
0559
0560 [~,ind]=max(abs(mat(:)));
0561 ind=tt_ind2sub(size(mat),ind);
0562 i=ind(1); j=ind(2);
0563
0564 iadd1=[iadd1;i];
0565
0566
0567 iadd2=[iadd2;j];
0568
0569 u1=mat(:,j); u2=mat(i,:);
0570 u1=u1/mat(i,j);
0571 mat=mat-u1*u2;
0572 er=norm(mat(:),'fro');
0573 end
0574 end
0575
0576
0577
0578
0579
0580
0581
0582
0583
0584 return
0585 end
0586
0587 function [ind_left]=get_multi_left(i_left,r_left,ry)
0588
0589
0590 d=numel(i_left);
0591 ind_left=cell(d,1);
0592 ind_left{1}=i_left{1};
0593 for i=2:d-1
0594
0595
0596 ind_cur=zeros(ry(i+1),i);
0597 r_prev=r_left{i};
0598 ind_prev=ind_left{i-1};
0599 i_prev=i_left{i};
0600 for s=1:ry(i+1)
0601
0602 ind_cur(s,:)=[ind_prev(r_prev(s),:),i_prev(s)];
0603 end
0604 ind_left{i}=ind_cur;
0605 end
0606 return
0607 end
0608
0609 function [ind_right]=get_multi_right(i_right,r_right,ry)
0610
0611
0612 d=numel(i_right);
0613 ind_right=cell(d,1);
0614 ind_right{d}=i_right{d};
0615 for i=d-1:-1:2
0616
0617
0618 ind_cur=zeros(ry(i),d-i+1);
0619 r_prev=r_right{i};
0620 ind_prev=ind_right{i+1};
0621 i_prev=i_right{i};
0622 for s=1:ry(i)
0623
0624 ind_cur(s,:)=[i_prev(s), ind_prev(r_prev(s),:)];
0625 end
0626 ind_right{i}=ind_cur;
0627 end
0628
0629
0630 return
0631 end
0632
0633 function [ind]=find_fiber_maximum(elem_fun,n,ind)
0634
0635
0636
0637
0638
0639 d=numel(n);
0640 mx=elem_fun(ind); mx=abs(mx);
0641 git=2;
0642 for s=1:git
0643 for k=1:d
0644 ind_tmp=ind;
0645 for i=1:n(k)
0646 ind_tmp(k)=i;
0647 val=abs(elem_fun(ind_tmp));
0648 if ( val >= mx )
0649 ind(k)=i;
0650 mx=val;
0651
0652 end
0653 end
0654 end
0655 end
0656 return
0657 end