0001 function [y,swp]=tt_wround(W, x, eps, varargin)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022 kickrank = 5;
0023 dropsweeps = 1;
0024
0025
0026 ddpow = 0.1;
0027 ddrank = 1;
0028 d_pow_check = 0;
0029 bot_conv = 0.1;
0030 top_conv = 0.99;
0031 verb = 1;
0032
0033 d = size(x,1);
0034
0035
0036 y = tt_ones(d, tt_size(x));
0037 y = tt_scal2(y, -tt_dot2(y,y)/2, 1);
0038
0039 rmax=1000;
0040 nswp=25;
0041
0042 for i=1:2:length(varargin)-1
0043 if (~isempty(varargin{i+1}))
0044 switch lower(varargin{i})
0045 case 'nswp'
0046 nswp=varargin{i+1};
0047 case 'rmax'
0048 rmax=lower(varargin{i+1});
0049 case 'y0'
0050 y=varargin{i+1};
0051 case 'verb'
0052 verb=varargin{i+1};
0053 case 'kickrank'
0054 kickrank=varargin{i+1};
0055 case 'ddpow'
0056 ddpow=varargin{i+1};
0057 case 'ddrank'
0058 ddrank=varargin{i+1};
0059 case 'd_pow_check'
0060 d_pow_check=varargin{i+1};
0061 case 'bot_conv'
0062 bot_conv=varargin{i+1};
0063 case 'top_conv'
0064 top_conv=varargin{i+1};
0065 otherwise
0066 error('Unrecognized option: %s\n',varargin{i});
0067 end
0068 end;
0069 end
0070
0071
0072
0073
0074
0075 x{1}=reshape(x{1}, size(x{1},1),1,size(x{1},2));
0076 y{1}=reshape(y{1}, size(y{1},1),1,size(y{1},2));
0077 if (~isempty(W))
0078 W{1}=reshape(W{1}, size(W{1},1), size(W{1},2), 1,size(W{1},3));
0079 end;
0080
0081 if (~isempty(W))
0082 phiywy = cell(d+1,1); phiywy{1}=1; phiywy{d+1}=1;
0083 phiywx = cell(d+1,1); phiywx{1}=1; phiywx{d+1}=1;
0084 end;
0085 phiyx = cell(d+1,1); phiyx{1}=1; phiyx{d+1}=1;
0086
0087 last_sweep = false;
0088 dy_old = ones(d-1,1);
0089 dy = zeros(d-1,1);
0090
0091 drank = ones(d-1,1);
0092
0093 dpows = ones(d-1,1);
0094
0095 for swp=1:nswp
0096
0097 for i=d:-1:2
0098 y1 = y{i}; n1 = size(y1,1); r1 = size(y1,2); r2 = size(y1,3);
0099 y1 = permute(y1, [1 3 2]);
0100 y1 = reshape(y1, n1*r2, r1);
0101 [y1,rv]=qr(y1,0);
0102 y2 = y{i-1}; n2 = size(y2,1); r0 = size(y2,2);
0103 y2 = reshape(y2, n2*r0, r1);
0104 y2 = y2*(rv.');
0105 r1 = size(y1,2);
0106 y1 = reshape(y1, n1, r2, r1);
0107 y1 = permute(y1, [1 3 2]);
0108 y{i}=y1;
0109 y{i-1}=reshape(y2, n2,r0,r1);
0110
0111
0112 y1 = permute(y1, [3 1 2]);
0113 y1 = reshape(y1, r2*n1, r1);
0114 x1 = x{i}; rx1 = size(x1,2); rx2 = size(x1,3);
0115 x1 = reshape(x1, n1*rx1, rx2);
0116 phiyx{i} = phiyx{i+1}*(x1.');
0117 phiyx{i} = reshape(phiyx{i}, r2*n1, rx1);
0118 phiyx{i}=(y1')*phiyx{i};
0119
0120 if (~isempty(W))
0121 w1 = W{i}; rw1 = size(w1,3); rw2 = size(w1,4);
0122 phiywx{i}=reshape(phiywx{i+1}, r2*rw2, rx2)*(x1.');
0123 phiywx{i}=reshape(phiywx{i}, r2, rw2, n1, rx1);
0124 phiywx{i}=permute(phiywx{i}, [3 2 4 1]);
0125 phiywx{i}=reshape(phiywx{i}, n1*rw2, rx1*r2);
0126 w1 = permute(w1, [1 3 2 4]);
0127 w1 = reshape(w1, n1*rw1, n1*rw2);
0128 phiywx{i}=w1*phiywx{i};
0129 phiywx{i}=reshape(phiywx{i}, n1, rw1, rx1, r2);
0130 phiywx{i}=permute(phiywx{i}, [4 1 2 3]);
0131 phiywx{i}=reshape(phiywx{i}, r2*n1, rw1*rx1);
0132 phiywx{i}=(y1')*phiywx{i};
0133 phiywx{i}=reshape(phiywx{i}, r1, rw1, rx1);
0134
0135 y1 = reshape(y1, r2, n1*r1);
0136 phiywy{i}=reshape(phiywy{i+1}, r2*rw2, r2)*y1;
0137 phiywy{i}=reshape(phiywy{i}, r2, rw2, n1, r1);
0138 phiywy{i}=permute(phiywy{i}, [3 2 4 1]);
0139 phiywy{i}=reshape(phiywy{i}, n1*rw2, r1*r2);
0140 phiywy{i}=w1*phiywy{i};
0141 phiywy{i}=reshape(phiywy{i}, n1, rw1, r1, r2);
0142 phiywy{i}=permute(phiywy{i}, [4 1 2 3]);
0143 phiywy{i}=reshape(phiywy{i}, r2*n1, rw1*r1);
0144 y1 = reshape(y1, r2*n1, r1);
0145 phiywy{i}=(y1')*phiywy{i};
0146 phiywy{i}=reshape(phiywy{i}, r1, rw1, r1);
0147 end;
0148 end;
0149
0150
0151 for i=1:d-1
0152 x1 = x{i}; n1 = size(x1,1); rx1 = size(x1,2); rx2 = size(x1,3);
0153 x2 = x{i+1}; n2 = size(x2,1); rx3 = size(x2,3);
0154 if (~isempty(W))
0155 w1 = W{i}; rw1 = size(w1,3); rw2 = size(w1,4);
0156 w2 = W{i+1}; rw3 = size(w2,4);
0157 end;
0158 y1 = y{i}; ry1 = size(y1,2); ry2 = size(y1,3);
0159 y2 = y{i+1}; ry3 = size(y2,3);
0160 ryold = ry2;
0161
0162 x1 = permute(x1, [2 1 3]);
0163 x1 = reshape(x1, rx1, n1*rx2);
0164 ynew = phiyx{i}*x1;
0165 ynew = reshape(ynew, ry1*n1, rx2);
0166 x2 = permute(x2, [2 1 3]);
0167 x2 = reshape(x2, rx2, n2*rx3);
0168 ynew = ynew*x2;
0169 ynew = reshape(ynew, ry1*n1*n2, rx3);
0170 ynew = ynew*(phiyx{i+2}.');
0171 ynew = reshape(ynew, ry1*n1, n2*ry3);
0172
0173 if (~isempty(W))
0174 rhs = reshape(phiywx{i}, ry1*rw1, rx1)*x1;
0175 rhs = reshape(rhs, ry1, rw1, n1, rx2);
0176 rhs = permute(rhs, [3 2 1 4]);
0177 rhs = reshape(rhs, n1*rw1, ry1*rx2);
0178 w1 = permute(w1, [1 4 2 3]);
0179 w1 = reshape(w1, n1*rw2, n1*rw1);
0180 rhs = w1*rhs;
0181 rhs = reshape(rhs, n1,rw2,ry1, rx2);
0182 rhs = permute(rhs, [3 1 2 4]);
0183 rhs = reshape(rhs, ry1*n1, rw2*rx2);
0184
0185 rhs2 = reshape(phiywx{i+2}, ry3*rw3, rx3);
0186 x2 = reshape(x2, rx2*n2, rx3);
0187 rhs2 = rhs2*(x2.');
0188 rhs2 = reshape(rhs2, ry3, rw3, rx2, n2);
0189 rhs2 = permute(rhs2, [4 2 3 1]);
0190 rhs2 = reshape(rhs2, n2*rw3, rx2*ry3);
0191 w2 = permute(w2, [1 3 2 4]);
0192 w2 = reshape(w2, n2*rw2, n2*rw3);
0193 rhs2 = w2*rhs2;
0194 rhs2 = reshape(rhs2, n2, rw2*rx2, ry3);
0195 rhs2 = permute(rhs2, [1 3 2]);
0196 rhs2 = reshape(rhs2, n2*ry3, rw2*rx2);
0197 rhs = rhs*(rhs2.');
0198 rhs = reshape(rhs, ry1*n1*n2*ry3, 1);
0199
0200 mtx = cell(2,1);
0201 mtx{1} = permute(phiywy{i}, [1 3 2]);
0202 mtx{1} = reshape(mtx{1}, ry1*ry1, rw1);
0203 w1 = reshape(w1, n1*rw2*n1, rw1);
0204 mtx{1} = mtx{1}*(w1.');
0205 mtx{1} = reshape(mtx{1}, ry1, ry1, n1, rw2, n1);
0206 mtx{1} = permute(mtx{1}, [1 3 2 5 4]);
0207 mtx{1} = reshape(mtx{1}, ry1*n1, ry1*n1, rw2);
0208
0209 mtx{2} = permute(phiywy{i+2}, [1 3 2]);
0210 mtx{2} = reshape(mtx{2}, ry3*ry3, rw3);
0211 w2 = reshape(w2, n2*rw2*n2, rw3);
0212 mtx{2}= mtx{2}*(w2.');
0213 mtx{2} = reshape(mtx{2}, ry3, ry3, n2, rw2, n2);
0214 mtx{2} = permute(mtx{2}, [3 1 5 2 4]);
0215 mtx{2} = reshape(mtx{2}, n2*ry3, n2*ry3, rw2);
0216 end;
0217
0218 y1 = permute(y1, [2 1 3]);
0219 y1 = reshape(y1, ry1*n1, ry2);
0220 y2 = permute(y2, [2 1 3]);
0221 y2 = reshape(y2, ry2, n2*ry3);
0222 yprev = y1*y2;
0223
0224 vdy = ynew-yprev;
0225 if (~isempty(W))
0226 vdy = bfun2(W, vdy, ry1, n1, n2, ry3, ry1, n1, n2, ry3);
0227 else
0228 rhs = ynew;
0229 end;
0230 dy(i) = norm(vdy, 'fro')/norm(rhs, 'fro');
0231 if (norm(rhs, 'fro')==0)
0232 dy(i)=0;
0233 end;
0234 if (swp==1)
0235 dy_old(i)=dy(i);
0236 end;
0237
0238
0239 if (dy(i)/dy_old(i)>top_conv)&&(dy(i)>eps/(d^d_pow_check))
0240 drank(i)=drank(i)+ddrank;
0241 dpows(i)=dpows(i)+ddpow;
0242 end;
0243
0244 if (dy(i)/dy_old(i)<bot_conv)||(dy(i)<eps/(d^d_pow_check))
0245 drank(i)=max(drank(i)-ddrank, 1);
0246 dpows(i)=max(dpows(i)-ddpow, 1);
0247 end;
0248
0249 if (last_sweep)
0250 dpows(i)=0.5;
0251 end;
0252
0253 if (mod(swp,dropsweeps)~=0)&&(swp>1)&&(~last_sweep)
0254
0255 [u,s,v]=svd(ynew-yprev,'econ');
0256 else
0257 [u,s,v]=svd(ynew, 'econ');
0258 end;
0259 if (~isempty(W))
0260
0261 r0 = 1; rM = min(size(s,1),rmax); r = round((r0+rM)/2);
0262 while (rM-r0>1)
0263 cur_err = norm(s(r+1:end,r+1:end), 'fro')/norm(s,'fro');
0264 cur_sol = reshape(u(:,1:r)*s(1:r,1:r)*v(:,1:r)', ry1*n1*n2*ry3, 1);
0265 cur_res = norm(bfun2(mtx,cur_sol,ry1,n1,n2,ry3,ry1,n1,n2,ry3) - rhs)/norm(rhs);
0266 if (verb>1)
0267 fprintf('sweep %d, block %d, rank: %d, resid: %3.3e, L2-err: %3.3e\n', swp, i, r, cur_res, cur_err);
0268 end;
0269 if (cur_res<eps/(d^dpows(i)))
0270 rM = r-1;
0271 r = round((r0+rM)/2);
0272 else
0273 r0 = r;
0274 r = round((r0+rM)/2);
0275 end;
0276 end;
0277
0278 while (r<min(size(s,1), rmax))
0279 r=r+1;
0280 cur_sol = reshape(u(:,1:r)*s(1:r,1:r)*v(:,1:r)', ry1*n1*n2*ry3, 1);
0281 cur_err = norm(s(r+1:end,r+1:end), 'fro')/norm(s,'fro');
0282 cur_res = norm(bfun2(mtx,cur_sol,ry1,n1,n2,ry3,ry1,n1,n2,ry3) - rhs)/norm(rhs);
0283 if (verb>1)
0284 fprintf('sweep %d, block %d, rank: %d, resid: %3.3e, L2-err: %3.3e\n', swp, i, r, cur_res, cur_err);
0285 end;
0286 if (cur_res<eps/(d^dpows(i)))
0287 break;
0288 end;
0289 end;
0290 else
0291 r = my_chop2(diag(s), eps/(d^dpows(i))*norm(ynew,'fro'));
0292 end;
0293 if (~last_sweep)
0294 r = r+drank(i);
0295 end;
0296 r = min(r, max(size(s)));
0297 r = min(r,rmax);
0298
0299
0300
0301
0302 if (verb>1)
0303 fprintf('sweep %d, block %d, rank: %d, dy: %3.3e, dy_old: %3.3e, drank: %g, dpow: %g\n', swp, i, r, dy(i), dy_old(i), drank(i), dpows(i));
0304 end;
0305
0306
0307
0308
0309 u = u(:,1:r);
0310 v = conj(v(:,1:r))*s(1:r,1:r);
0311 if (mod(swp,dropsweeps)~=0)&&(swp>1)&&(~last_sweep)
0312
0313 u = [y1, u];
0314 v = [y2.', v];
0315 [u,rv]=qr(u,0);
0316 ry2 = size(u,2);
0317 v = v*(rv.');
0318 else
0319
0320 if (~last_sweep)
0321 u = reort(u, randn(size(u,1),kickrank));
0322 r = size(u,2);
0323 v = [v, zeros(size(v,1),r-size(v,2))];
0324 end;
0325 ry2 = size(u,2);
0326 end;
0327
0328
0329
0330
0331
0332
0333 y{i}=permute(reshape(u, ry1, n1, ry2), [2 1 3]);
0334 y{i+1}=permute(reshape(v, n2, ry3, ry2), [1 3 2]);
0335
0336
0337 x1 = x{i}; rx1 = size(x1,2); rx2 = size(x1,3);
0338 x1 = reshape(permute(x1, [2 1 3]), rx1, n1*rx2);
0339 phiyx{i+1} = phiyx{i}*x1;
0340 phiyx{i+1} = reshape(phiyx{i+1}, ry1*n1, rx2);
0341 phiyx{i+1}=(u')*phiyx{i+1};
0342
0343 if (~isempty(W))
0344 w1 = W{i}; rw1 = size(w1,3); rw2 = size(w1,4);
0345 phiywx{i+1}=reshape(phiywx{i}, ry1*rw1, rx1)*x1;
0346 phiywx{i+1}=reshape(phiywx{i+1}, ry1, rw1, n1, rx2);
0347 phiywx{i+1}=permute(phiywx{i+1}, [3 2 4 1]);
0348 phiywx{i+1}=reshape(phiywx{i+1}, n1*rw1, rx2*ry1);
0349 w1 = permute(w1, [2 3 1 4]);
0350 w1 = reshape(w1, n1*rw1, n1*rw2);
0351 phiywx{i+1}=(w1.')*phiywx{i+1};
0352 phiywx{i+1}=reshape(phiywx{i+1}, n1, rw2, rx2, ry1);
0353 phiywx{i+1}=permute(phiywx{i+1}, [4 1 2 3]);
0354 phiywx{i+1}=reshape(phiywx{i+1}, ry1*n1, rw2*rx2);
0355 phiywx{i+1}=(u')*phiywx{i+1};
0356 phiywx{i+1}=reshape(phiywx{i+1}, ry2, rw2, rx2);
0357
0358 u = reshape(u, ry1, n1*ry2);
0359 phiywy{i+1}=reshape(phiywy{i}, ry1*rw1, ry1)*u;
0360 phiywy{i+1}=reshape(phiywy{i+1}, ry1, rw1, n1, ry2);
0361 phiywy{i+1}=permute(phiywy{i+1}, [3 2 4 1]);
0362 phiywy{i+1}=reshape(phiywy{i+1}, n1*rw1, ry2*ry1);
0363 phiywy{i+1}=(w1.')*phiywy{i+1};
0364 phiywy{i+1}=reshape(phiywy{i+1}, n1, rw2, ry2, ry1);
0365 phiywy{i+1}=permute(phiywy{i+1}, [4 1 2 3]);
0366 phiywy{i+1}=reshape(phiywy{i+1}, ry1*n1, rw2*ry2);
0367 u = reshape(u, ry1*n1, ry2);
0368 phiywy{i+1}=(u')*phiywy{i+1};
0369 phiywy{i+1}=reshape(phiywy{i+1}, ry2, rw2, ry2);
0370 end;
0371 end;
0372
0373
0374
0375
0376
0377
0378 if (verb>0)
0379 fprintf('=wround= Sweep %d, dy_max: %3.3e, conv_max: %1.5f\n', swp, max(dy), max(dy)/max(dy_old));
0380 end;
0381 if (last_sweep)
0382 break;
0383 end;
0384
0385 if (max(dy)<eps/(d^d_pow_check))
0386 last_sweep = true;
0387 end;
0388 dy_old = dy;
0389 end;
0390
0391 y{1}=reshape(y{1}, size(y{1},1), size(y{1},3));
0392
0393 if (swp==nswp)&&(max(dy)>eps/(d^d_pow_check))
0394 fprintf('tt_wround warning: error is not fixed for maximal number of sweeps %d, err_max: %3.3e\n', swp, max(dy));
0395 end;
0396
0397 end
0398
0399
0400 function [y]=bfun2(B, x, rxm1, m1, m2, rxm3, rxn1, k1, k2, rxn3)
0401
0402
0403
0404 rB=size(B{1},3);
0405 x = reshape(x, rxm1*m1, m2*rxm3);
0406 B1 = permute(B{1}, [3 1 2]);
0407 B1 = reshape(B1, rB*rxn1*k1, rxm1*m1);
0408 y = B1*x;
0409 y = reshape(y, rB, rxn1*k1, m2*rxm3);
0410 y = permute(y, [3 1 2]);
0411 y = reshape(y, m2*rxm3*rB, rxn1*k1);
0412 B2 = reshape(B{2}, k2*rxn3, m2*rxm3*rB);
0413 y = B2*y;
0414 y = reshape(y.', rxn1*k1*k2*rxn3, 1);
0415 end