0001 function [y]=mvrk(A, x, eps, varargin)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026 nswp = 20;
0027 kickrank = 2;
0028 verb = 1;
0029 y = [];
0030
0031 for i=1:2:length(varargin)-1
0032 switch lower(varargin{i})
0033 case 'nswp'
0034 nswp=varargin{i+1};
0035
0036
0037 case 'y0'
0038 y=varargin{i+1};
0039 case 'verb'
0040 verb=varargin{i+1};
0041 case 'kickrank'
0042 kickrank=varargin{i+1};
0043 end;
0044 end;
0045
0046 d = x.dphys;
0047 xc = x.core;
0048 xf = x.tuck;
0049 Af = A.tuck;
0050 Ac = A.core;
0051
0052 rca = Ac.r;
0053 rfa = cell(d,1);
0054 rcx = xc.r;
0055 rfx = cell(d,1);
0056 n = cell(d,1);
0057 m = cell(d,1);
0058
0059 L = zeros(d,1);
0060 for i=1:d
0061 L(i) = xf{i}.d;
0062 n{i} = Af{i}.n;
0063 m{i} = Af{i}.m;
0064 rfa{i} = Af{i}.r;
0065 rfx{i} = xf{i}.r;
0066 end;
0067
0068 rfy = cell(d,1);
0069 if (isempty(y))
0070 yc = tt_rand(1, d, 1);
0071 yf = cell(d,1);
0072 for i=1:d
0073 yf{i} = tt_rand(n{i}, L(i), [1; 1*ones(L(i),1)]);
0074 end;
0075 else
0076 yc = y.core;
0077 yf = y.tuck;
0078 end;
0079 rcy = yc.r;
0080 for i=1:d
0081 rfy{i} = yf{i}.r;
0082 end;
0083
0084
0085
0086 phcl = cell(d+1,1); phcl{1} = 1;
0087 phcr = cell(d+1,1); phcr{d+1} = 1;
0088 phfb = cell(d,1);
0089 for i=1:d
0090 phfb{i} = cell(L(i),1);
0091 phfb{i}{1} = 1;
0092 end;
0093 phfc = cell(d,1);
0094 phft = cell(d,1);
0095 for i=1:d
0096 phft{i} = cell(L(i)+1,1);
0097 phft{i}{L(i)+1} = 1;
0098 end;
0099
0100 rty = zeros(d,1);
0101 rtx = zeros(d,1);
0102 rta = zeros(d,1);
0103
0104
0105 for swp=1:nswp
0106 dy_max = 0;
0107 r_max = 0;
0108
0109
0110 for i=d:-1:1
0111 cury = yf{i};
0112 for j=1:L(i)
0113 cr = cury{j};
0114 cr = reshape(cr, rfy{i}(j)*n{i}(j), rfy{i}(j+1));
0115 [cr, rv]=qr(cr, 0);
0116 if (j<L(i))
0117
0118 cr2 = cury{j+1};
0119 cr2 = reshape(cr2, rfy{i}(j+1), n{i}(j+1)*rfy{i}(j+2));
0120 cr2 = rv*cr2;
0121 rfy{i}(j+1) = size(cr,2);
0122 cury{j+1} = reshape(cr2, rfy{i}(j+1), n{i}(j+1), rfy{i}(j+2));
0123 else
0124
0125 cr2 = yc{i};
0126 cr2 = permute(cr2, [2, 1, 3]);
0127 cr2 = reshape(cr2, rfy{i}(j+1), rcy(i)*rcy(i+1));
0128 cr2 = rv*cr2;
0129 rfy{i}(j+1) = size(cr,2);
0130 cr2 = reshape(cr2, rfy{i}(j+1), rcy(i), rcy(i+1));
0131 yc{i} = permute(cr2, [2, 1, 3]);
0132 end;
0133 cr = reshape(cr, rfy{i}(j), n{i}(j), rfy{i}(j+1));
0134 cury{j} = cr;
0135
0136 if (j<L(i))
0137 phfb{i}{j+1} = compute_next_Phi(phfb{i}{j}, cr, Af{i}{j}, xf{i}{j}, 'lr');
0138 else
0139 phfc{i} = compute_next_Phi(phfb{i}{j}, cr, Af{i}{j}, xf{i}{j}, 'lr');
0140 end;
0141 end;
0142 yf{i} = cury;
0143 end;
0144
0145 Acr = build_real_core_matrix(Ac, phfc);
0146 for i=d:-1:2
0147 rty(i) = rfy{i}(L(i)+1);
0148 cr = yc{i};
0149 cr = reshape(cr, rcy(i), rty(i)*rcy(i+1));
0150 [cr, rv] = qr(cr.', 0);
0151 cr2 = yc{i-1};
0152 rty(i-1) = rfy{i-1}(L(i-1)+1);
0153 cr2 = reshape(cr2, rcy(i-1)*rty(i-1), rcy(i));
0154 cr2 = cr2*(rv.');
0155 rcy(i) = size(cr, 2);
0156 cr = reshape(cr.', rcy(i), rty(i), rcy(i+1));
0157 yc{i-1} = reshape(cr2, rcy(i-1), rty(i-1), rcy(i));
0158 yc{i} = cr;
0159
0160 phcr{i} = compute_next_Phi(phcr{i+1}, cr, Acr{i}, xc{i}, 'rl');
0161 end;
0162
0163
0164 for i=1:d
0165 rty(i) = rfy{i}(L(i)+1);
0166
0167
0168 curx = xf{i}; curxc = xc{i}; rtx(i) = rfx{i}(L(i)+1);
0169 cura = Af{i}; curac = Ac{i}; rta(i) = rfa{i}(L(i)+1);
0170 cury = yf{i}; curyc = yc{i};
0171 curxc = permute(curxc, [2, 1, 3]);
0172 curxc = reshape(curxc, rtx(i), rcx(i)*rcx(i+1));
0173
0174 curx = curx*curxc;
0175
0176
0177 curx = tt_reshape(curx, (curx.n).*[ones(L(i)-1,1); curx.r(L(i)+1)]);
0178 curyc = permute(curyc, [2, 1, 3]);
0179 curyc = reshape(curyc, rty(i), rcy(i)*rcy(i+1));
0180 cury = cury*curyc;
0181
0182 cury = tt_reshape(cury, (cury.n).*[ones(L(i)-1,1); cury.r(L(i)+1)]);
0183
0184 ph1 = phcl{i};
0185 ph1 = permute(ph1, [1,3,2]);
0186 ph1 = reshape(ph1, rcy(i)*rcx(i), rca(i));
0187 ph2 = phcr{i+1};
0188 ph2 = permute(ph2, [2, 1,3]);
0189 ph2 = reshape(ph2, rca(i+1), rcy(i+1)*rcx(i+1));
0190 curac = reshape(curac, rca(i), rta(i)*rca(i+1));
0191 curac = ph1*curac;
0192 curac = reshape(curac, rcy(i)*rcx(i)*rta(i), rca(i+1));
0193 curac = curac*ph2;
0194 curac = reshape(curac, rcy(i), rcx(i), rta(i), rcy(i+1), rcx(i+1));
0195 curac = permute(curac, [3, 1, 4, 2, 5]);
0196 curac = reshape(curac, rta(i), rcy(i)*rcy(i+1)*rcx(i)*rcx(i+1));
0197 cura = tt_tensor(cura);
0198 cura = cura*curac;
0199 lasta = cura{L(i)};
0200 lasta = reshape(lasta, rfa{i}(L(i)), n{i}(L(i)), m{i}(L(i)), rcy(i)*rcy(i+1), rcx(i)*rcx(i+1));
0201 lasta = permute(lasta, [1, 2, 4, 3, 5]);
0202 cura{L(i)} = reshape(lasta, rfa{i}(L(i)), n{i}(L(i))*rcy(i)*rcy(i+1)*m{i}(L(i))*rcx(i)*rcx(i+1));
0203
0204 curn = n{i}.*[ones(L(i)-1,1); rcy(i)*rcy(i+1)];
0205 curm = m{i}.*[ones(L(i)-1,1); rcx(i)*rcx(i+1)];
0206 cura = tt_matrix(cura, curn, curm);
0207
0208
0209 for j=L(i):-1:2
0210 rx1 = rfx{i}(j-1); rx2 = rfx{i}(j);
0211 ry1 = rfy{i}(j-1); ry2 = rfy{i}(j);
0212 ra1 = rfa{i}(j-1); ra2 = rfa{i}(j);
0213 if (j==L(i))
0214 rx3 = 1; ry3 = 1; ra3 = 1;
0215 else
0216 rx3 = rfx{i}(j+1); ra3 = rfa{i}(j+1); ry3 = rfy{i}(j+1);
0217 end;
0218
0219 rhs2 = reshape(phft{i}{j+1}, ry3*ra3, rx3);
0220 x2 = curx{j};
0221 x2 = reshape(x2, rx2*curm(j), rx3);
0222 rhs2 = rhs2*(x2.');
0223 rhs2 = reshape(rhs2, ry3, ra3, rx2, curm(j));
0224 rhs2 = permute(rhs2, [2, 4, 1, 3]);
0225 rhs2 = reshape(rhs2, ra3*curm(j), ry3*rx2);
0226 a2 = cura{j};
0227 a2 = permute(a2, [2, 1, 4, 3]);
0228 a2 = reshape(a2, curn(j)*ra2, ra3*curm(j));
0229 rhs2 = a2*rhs2;
0230
0231 rhs2 = reshape(rhs2, curn(j), ra2, ry3, rx2);
0232 rhs2 = permute(rhs2, [1, 3, 2, 4]);
0233
0234 rhs = reshape(rhs2, curn(j)*ry3*ra2, rx2);
0235 x1 = curx{j-1};
0236 x1 = reshape(x1, rx1*curm(j-1), rx2);
0237 rhs = rhs*(x1.');
0238 rhs = reshape(rhs, curn(j), ry3, ra2, rx1, curm(j-1));
0239 rhs = permute(rhs, [5, 3, 4, 1, 2]);
0240 rhs = reshape(rhs, curm(j-1)*ra2, rx1*curn(j)*ry3);
0241 a1 = cura{j-1};
0242 a1 = reshape(a1, ra1*curn(j-1), curm(j-1)*ra2);
0243 rhs = a1*rhs;
0244 rhs = reshape(rhs, ra1, curn(j-1), rx1, curn(j)*ry3);
0245 rhs = permute(rhs, [1, 3, 2, 4]);
0246 rhs = reshape(rhs, ra1*rx1, curn(j-1)*curn(j)*ry3);
0247 rhs = reshape(phfb{i}{j-1}, ry1, ra1*rx1)*rhs;
0248 rhs = reshape(rhs, ry1*curn(j-1), curn(j)*ry3);
0249
0250 y_prev = cury{j-1};
0251 y_prev = reshape(y_prev, ry1*curn(j-1), ry2);
0252 y_prev = y_prev*reshape(cury{j}, ry2, curn(j)*ry3);
0253
0254 dy = norm(rhs-y_prev, 'fro')/norm(rhs, 'fro');
0255 dy_max = max(dy_max, dy);
0256
0257 [u,s,v]=svd(rhs, 'econ');
0258 s = diag(s);
0259 nrm = norm(s);
0260 r = my_chop2(s, eps*nrm/sqrt(L(i))/sqrt(d));
0261 v = conj(v(:,1:r));
0262 u = u(:,1:r)*diag(s(1:r));
0263
0264 v = reort(v, randn(curn(j)*ry3, kickrank));
0265 radd = size(v,2)-r;
0266 u = [u, zeros(ry1*curn(j-1), radd)];
0267 r = r+radd;
0268
0269 cury{j} = reshape(v.', r, curn(j), ry3);
0270 cury{j-1} = reshape(u, ry1, curn(j-1), r);
0271
0272 rhs2 = reshape(rhs2, curn(j)*ry3, ra2*rx2);
0273 rhs2 = (v')*rhs2;
0274 phft{i}{j} = reshape(rhs2, r, ra2, rx2);
0275 rfy{i}(j) = r;
0276 r_max = max(r_max, r);
0277 if (verb>1)
0278 fprintf('=mvrk= swp %d, factor {%d}{%d}, dy: %3.3e, r: %d\n', swp, i, j, dy, r);
0279 end;
0280 end;
0281
0282 for j=1:L(i)
0283 cr = cury{j};
0284 if (j<L(i))
0285 cr = reshape(cr, rfy{i}(j)*n{i}(j), rfy{i}(j+1));
0286 [cr, rv]=qr(cr, 0);
0287 cr2 = cury{j+1};
0288 n2 = size(cr2, 2); ry3 = size(cr2, 3);
0289 cr2 = reshape(cr2, rfy{i}(j+1), n2*ry3);
0290 cr2 = rv*cr2;
0291 rfy{i}(j+1) = size(cr,2);
0292 cury{j+1} = reshape(cr2, rfy{i}(j+1), n2, ry3);
0293 cr = reshape(cr, rfy{i}(j), n{i}(j), rfy{i}(j+1));
0294 cury{j} = cr;
0295 phfb{i}{j+1} = compute_next_Phi(phfb{i}{j}, cr, Af{i}{j}, xf{i}{j}, 'lr');
0296 else
0297
0298 ry1 = rfy{i}(j); n1 = n{i}(j); n2 = rcy(i)*rcy(i+1);
0299 cr = reshape(cr, ry1*n1, n2);
0300 [u,s,v]=svd(cr, 'econ');
0301 s = diag(s);
0302 nrm = norm(s);
0303 r = my_chop2(s, eps*nrm/sqrt(L(i))/sqrt(d));
0304 u = u(:,1:r);
0305 v = diag(s(1:r))*(v(:,1:r)');
0306
0307 u = reort(u, randn(ry1*n1, kickrank));
0308 radd = size(u,2)-r;
0309 v = [v; zeros(radd, n2)];
0310 r = r+radd;
0311 u = reshape(u, ry1, n1, r);
0312 cury{j} = u;
0313 v = reshape(v, r, rcy(i), rcy(i+1));
0314 yc{i} = permute(v, [2, 1, 3]);
0315 rfy{i}(L(i)+1) = r;
0316 r_max = max(r_max, r);
0317 phfc{i} = compute_next_Phi(phfb{i}{j}, u, Af{i}{j}, xf{i}{j}, 'lr');
0318 if (verb>1)
0319 fprintf('=mvrk= swp %d, tucker_rank(%d), r: %d\n', swp, i, r);
0320 end;
0321 end;
0322 end;
0323 yf{i} = cury;
0324 if (i<d)
0325
0326
0327 rx1 = rcx(i); rx2 = rcx(i+1); rx3 = rcx(i+2);
0328 ry1 = rcy(i); ry2 = rcy(i+1); ry3 = rcy(i+2);
0329 ra1 = rca(i); ra2 = rca(i+1); ra3 = rca(i+2);
0330 curn1 = rfy{i}(L(i)+1); curn2 = rfy{i+1}(L(i+1)+1);
0331 curm1 = rfx{i}(L(i)+1); curm2 = rfx{i+1}(L(i+1)+1);
0332 ph = phfc{i};
0333 ph = permute(ph, [1, 3, 2]);
0334 ph = reshape(ph, curn1*curm1, rfa{i}(L(i)+1));
0335 cura = Ac{i};
0336 cura = permute(cura, [2, 1, 3]);
0337 cura = reshape(cura, rfa{i}(L(i)+1), rca(i)*rca(i+1));
0338 cura = ph*cura;
0339 cura = reshape(cura, curn1*curm1, rca(i), rca(i+1));
0340 cura = permute(cura, [2, 1, 3]);
0341 cura = reshape(cura, rca(i), curn1, curm1, rca(i+1));
0342 Acr{i} = cura;
0343 cura = permute(cura, [2, 4, 1, 3]);
0344 cura = reshape(cura, curn1*ra2, ra1*curm1);
0345
0346 rhs1 = reshape(phcl{i}, ry1*ra1, rx1);
0347 rhs1 = rhs1*reshape(xc{i}, rx1, curm1*rx2);
0348 rhs1 = reshape(rhs1, ry1, ra1, curm1, rx2);
0349 rhs1 = permute(rhs1, [2, 3, 1, 4]);
0350 rhs1 = reshape(rhs1, ra1*curm1, ry1*rx2);
0351 rhs1 = cura*rhs1;
0352 rhs1 = reshape(rhs1, curn1, ra2, ry1, rx2);
0353
0354 rhs1 = permute(rhs1, [3, 1, 2, 4]);
0355
0356 rhs = reshape(rhs1, ry1*curn1*ra2, rx2);
0357 rhs = rhs*reshape(xc{i+1}, rx2, curm2*rx3);
0358 rhs = reshape(rhs, ry1*curn1, ra2, curm2, rx3);
0359 rhs = permute(rhs, [3, 2, 4, 1]);
0360 rhs = reshape(rhs, curm2*ra2, rx3*ry1*curn1);
0361 cura = Acr{i+1};
0362 cura = permute(cura, [2, 4, 3, 1]);
0363 cura = reshape(cura, curn2*ra3, curm2*ra2);
0364 rhs = cura*rhs;
0365 rhs = reshape(rhs, curn2, ra3*rx3, ry1*curn1);
0366 rhs = permute(rhs, [2, 3, 1]);
0367 rhs = reshape(rhs, ra3*rx3, ry1*curn1*curn2);
0368 rhs = reshape(phcr{i+2}, ry3, ra3*rx3)*rhs;
0369 rhs = reshape(rhs.', ry1*curn1, curn2*ry3);
0370
0371 y_prev = reshape(yc{i}, ry1*curn1, ry2);
0372 y_prev = y_prev*reshape(yc{i+1}, ry2, curn2*ry3);
0373
0374 dy = norm(rhs-y_prev, 'fro')/norm(rhs, 'fro');
0375 dy_max = max(dy_max, dy);
0376
0377 [u,s,v]=svd(rhs, 'econ');
0378 s = diag(s);
0379 nrm = norm(s);
0380 r = my_chop2(s, eps*nrm/sqrt(d));
0381 u = u(:,1:r);
0382 v = conj(v(:,1:r))*diag(s(1:r));
0383
0384 u = reort(u, randn(ry1*curn1, kickrank));
0385 radd = size(u,2)-r;
0386 v = [v, zeros(curn2*ry3, radd)];
0387 r = r+radd;
0388 yc{i} = reshape(u, ry1, curn1, r);
0389 yc{i+1} = reshape(v.', r, curn2, ry3);
0390
0391 rhs1 = reshape(rhs1, ry1*curn1, ra2*rx2);
0392 phcl{i+1} = reshape((u')*rhs1, r, ra2, rx2);
0393 rcy(i+1)=r;
0394 r_max = max(r_max, r);
0395 if (verb>1)
0396 fprintf('=mvrk= swp %d, core {%d}, dy: %3.3e, r: %d\n', swp, i, dy, r);
0397 end;
0398 end;
0399 end;
0400
0401 if (verb>0)
0402 fprintf('=mvrk= swp %d, dy_max: %3.3e, r_max: %d\n', swp, dy_max, r_max);
0403 end;
0404 if (dy_max<eps)
0405 break;
0406 end;
0407 end;
0408
0409 y = qtt_tucker;
0410 y.dphys = d;
0411 y.core = yc;
0412 y.tuck = yf;
0413
0414 end
0415
0416
0417 function [Phi] = compute_next_Phi(Phi_prev, x, A, y, direction)
0418
0419
0420
0421
0422
0423
0424 if (strcmp(direction, 'rl'))
0425
0426 x = permute(x, [3, 2, 1]);
0427 y = permute(y, [3, 2, 1]);
0428 if (~isempty(A))
0429 A = permute(A, [4, 2, 3, 1]);
0430 end
0431 end
0432
0433 rx1 = size(x,1); n = size(x,2); rx2 = size(x,3);
0434 ry1 = size(y,1); m = size(y,2); ry2 = size(y,3);
0435 if (~isempty(A))
0436 ra1 = size(A,1); ra2 = size(A,4);
0437 else
0438 ra1 = 1; ra2 = 1;
0439 end
0440
0441 Phi = reshape(Phi_prev, [rx1*ra1, ry1]);
0442 y = reshape(y, [ry1, m*ry2]);
0443 Phi = Phi*y;
0444 Phi = reshape(Phi, [rx1, ra1, m, ry2]);
0445 Phi = permute(Phi, [2, 3, 1, 4]);
0446 if (~isempty(A))
0447 Phi = reshape(Phi, [ra1*m, rx1*ry2]);
0448 A = permute(A, [4, 2, 1, 3]);
0449 A = reshape(A, [ra2*n, ra1*m]);
0450 Phi = A*Phi;
0451 Phi = reshape(Phi, [ra2, n, rx1, ry2]);
0452 end
0453 Phi = permute(Phi, [3, 2, 1, 4]);
0454 Phi = reshape(Phi, [rx1*n, ra2*ry2]);
0455 x = reshape(x, [rx1*n, rx2]);
0456 Phi = (x')*Phi;
0457 if (~isempty(A))
0458 Phi = reshape(Phi, [rx2, ra2, ry2]);
0459 end
0460 end
0461
0462 function [Ac]=build_real_core_matrix(Ac, phfc)
0463
0464
0465
0466
0467
0468 d = Ac.d;
0469 rta = Ac.n;
0470 rca = Ac.r;
0471 rtx = zeros(d,1);
0472 rty = zeros(d,1);
0473
0474 for i=1:d
0475 ph = phfc{i};
0476 rtx(i) = size(ph, 1);
0477 rty(i) = size(ph, 3);
0478 ph = permute(ph, [1, 3, 2]);
0479 ph = reshape(ph, rtx(i)*rty(i), rta(i));
0480 cura = Ac{i};
0481 cura = permute(cura, [2, 1, 3]);
0482 cura = reshape(cura, rta(i), rca(i)*rca(i+1));
0483 cura = ph*cura;
0484 cura = reshape(cura, rtx(i)*rty(i), rca(i), rca(i+1));
0485 cura = permute(cura, [2, 1, 3]);
0486 cura = reshape(cura, rca(i), rtx(i)*rty(i), rca(i+1));
0487 Ac{i} = cura;
0488 end;
0489
0490 Ac = tt_matrix(Ac, rtx, rty);
0491 end