0001 function [y,swp]=tt_mvk2(a,x,eps, max_r, max_swp)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 exists_max_r=1;
0018 if ((nargin<4)||(isempty(max_r)))
0019 exists_max_r=0;
0020 end;
0021
0022 nswp=15;
0023 if ((nargin>=5)&&(~isempty(max_swp)))
0024 nswp = max_swp;
0025 end;
0026
0027 kickrank = 5;
0028 dropsweeps = 1;
0029 ddpow = 0.1;
0030 ddrank = 1;
0031 d_pow_check = 0;
0032 bot_conv = 0.1;
0033 top_conv = 0.99;
0034 verb = 2;
0035
0036 d=size(a,1);
0037
0038
0039
0040
0041
0042
0043
0044
0045 y=tt_random(tt_size(a),d,2);
0046 ph=cell(d,1);
0047
0048 swp=1;
0049
0050 log_norms_ph = zeros(d,1);
0051 log_norms_y = zeros(d,1);
0052 last_sweep = false;
0053
0054 dy_old = ones(d-1,1);
0055 dy = zeros(d-1,1);
0056
0057 drank = ones(d-1,1);
0058
0059 dpows = ones(d-1,1);
0060
0061 while ( swp < nswp)
0062
0063
0064
0065
0066 for q=1:d
0067 log_norms_y(q)=log(norm(reshape(y{q}, size(y{q},1)*size(y{q},2), size(y{q},3)), 'fro')+1e-308);
0068 y{q}=y{q}./exp(log_norms_y(q));
0069 end;
0070
0071
0072 err_max = 0;
0073
0074
0075
0076 mat=y{d};
0077 [q,rv]=qr(mat,0);
0078 log_norms_y(d)=log(norm(q,'fro')+1e-308);
0079 y{d}=q;
0080 y{d}=y{d}./exp(log_norms_y(d));
0081
0082 core=a{d}; n=size(core,1); m=size(core,2); ra=size(core,3);
0083 mat_x=x{d}; rx=size(mat_x,2);
0084 mat_y=y{d}; ry=size(mat_y,2);
0085 core=permute(core,[1,3,2]);
0086 core=reshape(core,[n*ra,m]);
0087 core=core*mat_x;
0088 core=reshape(core,[n,ra,rx]);
0089 core=permute(core,[2,3,1]);
0090 core=reshape(core,[ra*rx,n]);
0091 ph{d}=core*mat_y;
0092 log_norms_ph(d) = log(norm(ph{d}, 'fro')+1e-308);
0093 ph{d}=ph{d}./exp(log_norms_ph(d));
0094 log_norms_ph(d) = log_norms_ph(d) + log_norms_y(d);
0095 ph{d}=reshape(ph{d},[ra,rx,ry]);
0096 ph{d}=permute(ph{d},[2,1,3]);
0097
0098
0099 for i=(d-1):-1:2
0100 core=y{i};
0101 core=ten_conv(core,3,rv');
0102 ncur=size(core,1);
0103 r2=size(core,2);
0104 r3=size(core,3);
0105 core=permute(core,[1,3,2]);
0106 core=reshape(core,[ncur*r3,r2]);
0107 [y{i},rv]=qr(core,0);
0108 log_norms_y(i)=log(norm(y{i}, 'fro')+1e-308);
0109 y{i}=y{i}./exp(log_norms_y(i));
0110 rnew=min(r2,ncur*r3);
0111 y{i}=reshape(y{i},[ncur,r3,rnew]);
0112 y{i}=permute(y{i},[1,3,2]);
0113
0114
0115
0116 a1=a{i}; n=size(a1,1); m=size(a1,2); ra1=size(a1,3); ra2=size(a1,4);
0117 x1=x{i};
0118 rx1=size(x1,2); rx2=size(x1,3);
0119 y1=y{i};
0120 ry1=size(y1,2); ry2=size(y1,3);
0121 phi=ph{i+1};
0122 x1=reshape(x1,[m*rx1,rx2]);
0123 phi=reshape(phi,[rx2,ra2*ry2]);
0124 phi=x1*phi;
0125 phi=reshape(phi,[m,rx1,ra2,ry2]);
0126
0127 phi=permute(phi,[1,3,2,4]);
0128 phi=reshape(phi,[m*ra2,rx1*ry2]);
0129 a1=permute(a1,[1,3,2,4]);
0130 a1=reshape(a1,[n*ra1,m*ra2]);
0131 phi=a1*phi;
0132
0133 phi=reshape(phi,[n,ra1,rx1,ry2]);
0134 phi=permute(phi,[1,4,2,3]);
0135 phi=reshape(phi,[n*ry2,ra1*rx1]);
0136 y1=permute(y1,[1,3,2]);
0137 y1=reshape(y1,[n*ry2,ry1]);
0138 phi=y1'*phi;
0139 log_norms_ph(i)=log(norm(phi, 'fro')+1e-308);
0140 phi = phi./exp(log_norms_ph(i));
0141 log_norms_ph(i)=log_norms_ph(i)+log_norms_ph(i+1)+log_norms_y(i);
0142
0143 phi=reshape(phi,[ry1,ra1,rx1]);
0144 phi=permute(phi,[3,2,1]);
0145 ph{i}=phi;
0146
0147 end
0148
0149
0150
0151
0152
0153
0154
0155
0156 core1=a{1}; n1=size(core1,1); m1=size(core1,2); ra1=size(core1,3);
0157 core2=a{2}; n2=size(core2,1); m2=size(core2,2); ra2=size(core2,4);
0158 x1=x{1}; rx1=size(x1,2);
0159 x2=x{2}; rx2=size(x2,3);
0160 phi2=ph{3};
0161 ry2=size(phi2,3);
0162
0163 phi2=reshape(x2,[m2*rx1,rx2])*reshape(phi2,[rx2,ra2*ry2]);
0164 phi2=reshape(phi2,[m2,rx1,ra2,ry2]);
0165 phi2=permute(phi2,[1,3,2,4]);
0166 core2=permute(core2,[1,3,2,4]);
0167 phi2=reshape(core2,[n2*ra1,m2*ra2])*reshape(phi2,[m2*ra2,rx1*ry2]);
0168 phi2=reshape(phi2,[n2,ra1,rx1,ry2]);
0169 phi2=permute(phi2,[2,3,1,4]);
0170 phi2=reshape(phi2,[ra1*rx1,n2*ry2]);
0171
0172
0173
0174
0175
0176 core1=permute(core1,[1,3,2]);
0177 x1=reshape(core1,[n1*ra1,m1])*x1;
0178
0179
0180
0181
0182
0183
0184 mat=reshape(x1,[n1,ra1*rx1])*phi2;
0185 mat=reshape(mat,[n1,n2*ry2]);
0186
0187
0188
0189
0190
0191 [u,s,v]=svd(mat,0);
0192 s=diag(s);
0193 nrm=norm(s);
0194 rold=size(y{1},2);
0195 r=my_chop2(s,eps/sqrt(d)*nrm);
0196
0197
0198 if (exists_max_r) r = min(r, max_r); end;
0199 if ( verb>1 )
0200 fprintf('We can push rank %d to %d \n',1,r);
0201 end
0202 u=u(:,1:r);
0203 v=v(:,1:r)*diag(s(1:r));
0204
0205 u = [u, randn(size(u,1),kickrank)];
0206 v = [v, zeros(size(v,1),kickrank)];
0207 [u,rv]=qr(u,0);
0208 r = size(u,2);
0209 v = v*(rv.');
0210 y{1}=u;
0211 log_norms_y(1)=log(norm(u, 'fro')+1e-308);
0212 y{1}=y{1}./exp(log_norms_y(1));
0213 log_norms_y(2)=log(norm(v, 'fro')+1e-308);
0214 y{2}=reshape(v,[n2,ry2,r]);
0215 y{2}=permute(y{2},[1,3,2]);
0216 y{2}=y{2}./exp(log_norms_y(2));
0217 log_norms_y(2)=log_norms_y(2)+log_norms_ph(3);
0218
0219
0220
0221 core=a{1}; n=size(core,1); m=size(core,2); ra=size(core,3);
0222 x1=x{1}; rx=size(x1,2);
0223 y1=y{1}; ry=size(y1,2);
0224 core=permute(core,[1,3,2]); core=reshape(core,[n*ra,m]);
0225 phi1=core*x1;
0226 phi1=y1'*reshape(phi1,[n,ra*rx]);
0227 log_norms_ph(1)=log(norm(phi1, 'fro')+1e-308);
0228 phi1 = phi1./exp(log_norms_ph(1));
0229 log_norms_ph(1) = log_norms_ph(1) + log_norms_y(1);
0230 ph{1}=reshape(phi1,[ry,ra,rx]);
0231 for i=2:d-2
0232
0233
0234
0235
0236
0237
0238 core1=a{i}; n1=size(core1,1); m1=size(core1,2); ra1=size(core1,3); ra2=size(core1,4);
0239 core2=a{i+1}; n2=size(core2,1); m2=size(core2,2); ra3=size(core2,4);
0240 x1=x{i}; rx1=size(x1,2); rx2=size(x1,3);
0241 x2=x{i+1}; rx3=size(x2,3);
0242 ph1=ph{i-1}; ry1=size(ph1,1);
0243 ph2=ph{i+2}; ry3=size(ph2,3);
0244
0245 x1=permute(x1,[2,1,3]); x1=reshape(x1,[rx1,m1*rx2]);
0246 ph1=reshape(ph1,[ry1*ra1,rx1]);
0247 ph1=ph1*x1;
0248
0249 ph1=reshape(ph1,[ry1,ra1,m1,rx2]);
0250 ph1=permute(ph1,[4,1,3,2]);
0251 ph1=reshape(ph1,[rx2*ry1,m1*ra1]);
0252 core1=permute(core1,[2,3,4,1]);
0253 core1=reshape(core1,[m1*ra1,ra2*n1]);
0254 ph1=ph1*core1;
0255 ph1=reshape(ph1,[rx2,ry1,ra2,n1]);
0256
0257 ph_save=ph1;
0258
0259 ph2=reshape(x2,[m2*rx2,rx3])*reshape(ph2,[rx3,ra3*ry3]);
0260
0261 ph2=reshape(ph2,[m2,rx2,ra3,ry3]);
0262
0263 core2=permute(core2,[1,3,2,4]); core2=reshape(core2,[n2*ra2,m2*ra3]);
0264 ph2=permute(ph2,[1,3,2,4]); ph2=reshape(ph2,[m2*ra3,rx2*ry3]);
0265 ph2=core2*ph2;
0266
0267 ph2=reshape(ph2,[n2,ra2,rx2,ry3]);
0268
0269
0270 ph1=permute(ph1,[2,4,3,1]);
0271 ph1=reshape(ph1,[ry1*n1,ra2*rx2]);
0272 ph2=permute(ph2,[1,4,2,3]);
0273 ph2=reshape(ph2,[n2*ry3,ra2*rx2]);
0274 mat=ph1*ph2';
0275
0276 y1=y{i};
0277 y2=y{i+1};
0278 ry2=size(y1,3);
0279 y1=permute(y1,[2,1,3]);
0280 y2=permute(y2,[2,1,3]);
0281 y1=reshape(y1,[ry1*n1,ry2]);
0282 y2=reshape(y2,[ry2,n2*ry3]);
0283 y2 = y2*exp(log_norms_y(i)+log_norms_y(i+1)-log_norms_ph(i-1)-log_norms_ph(i+2));
0284 app=y1*y2;
0285 dy(i)=norm(mat-app,'fro')/norm(mat,'fro');
0286 if (swp==1)
0287 dy_old(i)=dy(i);
0288 end;
0289
0290
0291 if (dy(i)/dy_old(i)>top_conv)&&(dy(i)>eps/(d^d_pow_check))
0292 drank(i)=drank(i)+ddrank;
0293 dpows(i)=dpows(i)+ddpow;
0294 end;
0295
0296 if (dy(i)/dy_old(i)<bot_conv)||(dy(i)<eps/(d^d_pow_check))
0297 drank(i)=max(drank(i)-ddrank, 1);
0298 dpows(i)=max(dpows(i)-ddpow, 1);
0299 end;
0300
0301 if (last_sweep)
0302 dpows(i)=0.5;
0303 end;
0304
0305
0306
0307
0308
0309
0310 if (mod(swp,dropsweeps)~=0)&&(swp>1)&&(~last_sweep)
0311 [u,s,v]=svd(mat-app,'econ');
0312 else
0313 [u,s,v]=svd(mat,'econ');
0314 end;
0315 s=diag(s);
0316
0317
0318
0319
0320 r=my_chop2(s,eps/(d^dpows(i))*norm(mat, 'fro'));
0321 if (~last_sweep)
0322 r = r+drank(i);
0323 end;
0324 r = min(r, max(size(s)));
0325 if (exists_max_r) r = min(r, max_r); end;
0326
0327 if ( verb>1 )
0328 fprintf('We can push rank %d to %d \n',i,r);
0329 end
0330 u=u(:,1:r);
0331 v=conj(v(:,1:r))*diag(s(1:r));
0332 if (mod(swp,dropsweeps)~=0)&&(swp>1)&&(~last_sweep)
0333 u = [y1, u];
0334 v = [y2.', v];
0335 [u,rv]=qr(u,0);
0336 v = v*(rv.');
0337 r = size(u,2);
0338 else
0339 if (~last_sweep)
0340
0341 u = [u, randn(size(u,1),kickrank)];
0342 v = [v, zeros(size(v,1),kickrank)];
0343 [u,rv]=qr(u,0);
0344 r = size(u,2);
0345 v = v*(rv.');
0346 end;
0347 end;
0348 log_norms_y(i)=log(norm(u, 'fro')+1e-308);
0349 log_norms_y(i+1)=log(norm(v, 'fro')+1e-308);
0350 u=reshape(u,[ry1,n1,r]);
0351 y{i}=permute(u,[2,1,3]);
0352 y{i}=y{i}./exp(log_norms_y(i));
0353 v=reshape(v,[n2,ry3,r]);
0354 y{i+1}=permute(v,[1,3,2]);
0355 y{i+1}=y{i+1}./exp(log_norms_y(i+1));
0356 log_norms_y(i+1)=log_norms_y(i+1)+log_norms_ph(i-1)+log_norms_ph(i+2);
0357
0358
0359
0360
0361 ph_save=reshape(ph_save,[rx2,ry1,ra2,n1]);
0362 ph_save=permute(ph_save,[4,2,3,1]);
0363 ph_save=reshape(ph_save,[n1*ry1,ra2*rx2]);
0364 ph_save=reshape(y{i},[n1*ry1,r])'*ph_save;
0365 log_norms_ph(i)=log(norm(ph_save, 'fro')+1e-308);
0366 ph_save = ph_save./exp(log_norms_ph(i));
0367 log_norms_ph(i)=log_norms_ph(i)+log_norms_y(i)+log_norms_ph(i-1);
0368 ph{i}=reshape(ph_save,[r,ra2,rx2]);
0369 end
0370
0371
0372
0373
0374 core1=a{d-1}; n1=size(core1,1); m1=size(core1,2); ra1=size(core1,3); ra2=size(core1,4);
0375 core2=a{d}; n2=size(core2,1); m2=size(core2,2);
0376 x1=x{d-1}; rx1=size(x1,2); rx2=size(x1,3);
0377 x2=x{d};
0378 ph1=ph{d-2};
0379 ry1=size(ph1,1);
0380
0381
0382 x1=permute(x1,[2,1,3]); x1=reshape(x1,[rx1,m1*rx2]);
0383 ph1=reshape(ph1,[ry1*ra1,rx1]);
0384 ph1=ph1*x1;
0385
0386 ph1=reshape(ph1,[ry1,ra1,m1,rx2]);
0387 ph1=permute(ph1,[4,1,3,2]);
0388 ph1=reshape(ph1,[rx2*ry1,m1*ra1]);
0389 core1=permute(core1,[2,3,4,1]);
0390 core1=reshape(core1,[m1*ra1,ra2*n1]);
0391 ph1=ph1*core1;
0392 ph1=reshape(ph1,[rx2,ry1,ra2,n1]);
0393
0394
0395
0396
0397
0398 core2=permute(core2,[1,3,2]); core2=reshape(core2,[n2*ra2,m2]);
0399 ph2=core2*x2;
0400 ph_save=ph2;
0401 ph2=reshape(ph2,[n2,ra2*rx2]);
0402 ph1=reshape(ph1,[rx2,ry1,ra2,n1]);
0403 ph1=permute(ph1,[2,4,3,1]); ph1=reshape(ph1,[ry1*n1,ra2*rx2]);
0404 mat=ph1*ph2';
0405 [u,s,v]=svd(mat,'econ');
0406 s=diag(s);
0407 r=my_chop2(s,eps/sqrt(d)*norm(s));
0408 rold=size(y{d},2);
0409
0410
0411 if (exists_max_r) r = min(r, max_r); end;
0412
0413 u=u(:,1:r);
0414 v=v(:,1:r)*diag(s(1:r));
0415
0416 u = [u, randn(size(u,1),kickrank)];
0417 v = [v, zeros(size(v,1),kickrank)];
0418 [u,rv]=qr(u,0);
0419 r = size(u,2);
0420 v = v*(rv.');
0421
0422 log_norms_y(d-1)=log(norm(u, 'fro')+1e-308);
0423 log_norms_y(d)=log(norm(v, 'fro')+1e-308);
0424 u=reshape(u,[ry1,n1,r]);
0425 u=permute(u,[2,1,3]);
0426 y{d-1}=u;
0427 y{d-1}=y{d-1}./exp(log_norms_y(d-1));
0428 y{d}=v;
0429 y{d}=y{d}./exp(log_norms_y(d));
0430 log_norms_y(d-1)=log_norms_y(d-1)+log_norms_ph(d-2);
0431
0432
0433
0434 ph_save=reshape(ph_save,[n2,ra2,rx2]);
0435 ph_save=permute(ph_save,[1,3,2]); ph_save=reshape(ph_save,[n2,rx2*ra2]);
0436 ph_save=ph_save'*v;
0437 log_norms_ph(d)=log(norm(ph_save, 'fro')+1e-308);
0438 ph_save = ph_save./exp(log_norms_ph(d));
0439 ph_save=reshape(ph_save,[rx2,ra2,r]);
0440 ph{d}=ph_save;
0441
0442
0443
0444
0445 norm1_y = exp(sum(log_norms_y)/d);
0446 for i=1:d
0447 y{i}=y{i}.*norm1_y;
0448 end;
0449
0450 if (verb>0)
0451 fprintf('tt_mvk2: sweep %d, err_max = %3.3e\n', swp, max(dy));
0452 end;
0453 if (last_sweep)
0454 swp=swp+1;
0455 break;
0456 end;
0457 if (max(dy)<=eps/(d^d_pow_check))
0458 last_sweep=true;
0459 end;
0460
0461 dy_old = dy;
0462 swp=swp+1;
0463 if (swp==nswp-1)
0464 last_sweep=true;
0465 end;
0466 end
0467 if ( swp == nswp )&&(max(dy)>eps/(d^d_pow_check))
0468 fprintf('tt_mvk2 warning: error is not fixed for maximal number of sweeps %d, err_max: %3.3e\n', swp, err_max);
0469 end
0470
0471 return
0472 end