0001 function [x]=dmrg_rake_solve2(A, y, tol, varargin)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033 nswp = 20;
0034 local_format = 'full';
0035
0036 max_full_size = 1500;
0037 max_full_size2 = Inf;
0038 nrestart = 25;
0039 gmres_iters = 2;
0040 verb = 1;
0041 kickrank = 2;
0042 checkrank = 1;
0043 resid_damp_glob = 1.1;
0044 resid_damp_loc = 1.1;
0045 rmax = Inf;
0046
0047 x = [];
0048
0049 for i=1:2:length(varargin)-1
0050 switch lower(varargin{i})
0051 case 'nswp'
0052 nswp=varargin{i+1};
0053 case 'rmax'
0054 rmax=lower(varargin{i+1});
0055 case 'x0'
0056 x=varargin{i+1};
0057 case 'verb'
0058 verb=varargin{i+1};
0059
0060
0061 case 'nrestart'
0062 nrestart=varargin{i+1};
0063 case 'gmres_iters'
0064 gmres_iters=varargin{i+1};
0065 case 'kickrank'
0066 kickrank=varargin{i+1};
0067 case 'max_full_size'
0068 max_full_size=varargin{i+1};
0069 case 'resid_damp'
0070 resid_damp_loc=varargin{i+1};
0071
0072
0073
0074
0075
0076
0077
0078
0079
0080
0081
0082
0083
0084
0085
0086
0087
0088
0089
0090
0091
0092
0093 otherwise
0094 error('Unrecognized option: %s\n',varargin{i});
0095 end
0096 end
0097
0098 d = y.dphys;
0099 yc = y.core;
0100 yf = y.tuck;
0101 Af = A.tuck;
0102 Ac = A.core;
0103
0104 L = zeros(1,d);
0105 n = zeros(max(L), d);
0106 for i=1:d
0107 L(i) = yf{i}.d;
0108 n(1:L(i), i) = yf{i}.n;
0109 end;
0110
0111 if (isempty(x))
0112 xc = tt_rand(2,d,2);
0113 xf = cell(d,1);
0114 for i=1:d
0115 xf{i} = tt_rand(n(1:L(i),i), L(i), [1;2*ones(L(i),1)]);
0116 end;
0117 else
0118 xc = x.core;
0119 xf = x.tuck;
0120 end;
0121
0122
0123
0124 rcy = yc.r;
0125 rfy = zeros(max(L)+1, d);
0126 for i=1:d
0127 rfy(1:L(i)+1, i) = yf{i}.r;
0128 end;
0129 rcA = Ac.r;
0130 rfA = zeros(max(L)+1, d);
0131 for i=1:d
0132 rfA(1:L(i)+1, i) = Af{i}.r;
0133 end;
0134 rcx = xc.r;
0135 rfx = zeros(max(L)+1, d);
0136 for i=1:d
0137 rfx(1:L(i)+1, i) = xf{i}.r;
0138 end;
0139
0140
0141 phcA = cell(d+1,1); phcA{1} = 1; phcA{d+1}=1;
0142 phcy = cell(d+1,1); phcy{1} = 1; phcy{d+1}=1;
0143 phfA = cell(d,1);
0144 phfy = cell(d,1);
0145 phAfc = cell(d,1);
0146 phyfc = cell(d,1);
0147 for i=1:d
0148 phfA{i} = cell(L(i)+1,1);
0149 phfA{i}{1} = 1; phfA{i}{L(i)+1} = 1;
0150 phfy{i} = cell(L(i)+1,1);
0151 phfy{i}{1} = 1; phfy{i}{L(i)+1} = 1;
0152 end;
0153
0154
0155 cphcA = cell(d+1,1); cphcA{1} = 1; cphcA{d+1}=1;
0156 cphcy = cell(d+1,1); cphcy{1} = 1; cphcy{d+1}=1;
0157 cphfA = cell(d,1);
0158 cphfy = cell(d,1);
0159 cphAfc = cell(d,1);
0160 cphyfc = cell(d,1);
0161 for i=1:d
0162 cphfA{i} = cell(L(i)+1,1);
0163 cphfA{i}{1} = 1; cphfA{i}{L(i)+1} = 1;
0164 cphfy{i} = cell(L(i)+1,1);
0165 cphfy{i}{1} = 1; cphfy{i}{L(i)+1} = 1;
0166 end;
0167
0168
0169 last_sweep = false;
0170
0171 for swp=1:nswp
0172
0173
0174 rcchk = [1; checkrank*ones(d-1,1); 1];
0175
0176 rfchk = zeros(max(L)+1, d);
0177 for i=1:d
0178
0179 rfchk(1:L(i)+1, i) = [1; checkrank*ones(L(i),1)];
0180 end;
0181
0182 dx_max = 0;
0183 res_max = 0;
0184 r_max = 0;
0185 chk_res_max = 0;
0186
0187 for i=d:-1:1
0188 for j=1:L(i)
0189 cr = xf{i}{j};
0190 cr = reshape(cr, rfx(j,i)*n(j,i), rfx(j+1,i));
0191 [cr, rv] = qr(cr, 0);
0192
0193 if (j<L(i))
0194
0195 cr2 = xf{i}{j+1};
0196 cr2 = reshape(cr2, rfx(j+1,i), n(j+1,i)*rfx(j+2,i));
0197 cr2 = rv*cr2;
0198 rfx(j+1,i) = size(cr, 2);
0199 xf{i}{j} = reshape(cr, rfx(j,i), n(j,i), rfx(j+1,i));
0200 xf{i}{j+1} = reshape(cr2, rfx(j+1,i), n(j+1,i), rfx(j+2,i));
0201 else
0202
0203 cr2 = xc{i};
0204 cr2 = permute(cr2, [2, 1, 3]);
0205 cr2 = reshape(cr2, rfx(j+1,i), rcx(i)*rcx(i+1));
0206 cr2 = rv*cr2;
0207 rfx(j+1,i) = size(cr, 2);
0208 cr2 = reshape(cr2, rfx(j+1,i), rcx(i), rcx(i+1));
0209 cr2 = permute(cr2, [2, 1, 3]);
0210 xf{i}{j} = reshape(cr, rfx(j,i), n(j,i), rfx(j+1,i));
0211 xc{i} = cr2;
0212 end;
0213
0214 cr = reshape(cr, rfx(j,i), n(j,i), rfx(j+1,i));
0215 if (j<L(i))
0216 phfA{i}{j+1} = compute_next_Phi(phfA{i}{j}, cr, Af{i}{j}, cr, 'lr');
0217 phfy{i}{j+1} = compute_next_Phi(phfy{i}{j}, cr, [], yf{i}{j}, 'lr');
0218 else
0219 phAfc{i} = compute_next_Phi(phfA{i}{j}, cr, Af{i}{j}, cr, 'lr');
0220 phyfc{i} = compute_next_Phi(phfy{i}{j}, cr, [], yf{i}{j}, 'lr');
0221 end;
0222
0223
0224
0225
0226
0227
0228
0229
0230 ccr = ones(1, n(j,i), 1);
0231 if (j<L(i))
0232 cphfA{i}{j+1} = compute_next_Phi(cphfA{i}{j}, ccr, Af{i}{j}, cr, 'lr');
0233 cphfy{i}{j+1} = compute_next_Phi(cphfy{i}{j}, ccr, [], yf{i}{j}, 'lr');
0234 else
0235 cphAfc{i} = compute_next_Phi(cphfA{i}{j}, ccr, Af{i}{j}, cr, 'lr');
0236 cphyfc{i} = compute_next_Phi(cphfy{i}{j}, ccr, [], yf{i}{j}, 'lr');
0237 end;
0238 end;
0239 end;
0240
0241
0242
0243 Acr = tt_matrix(Ac, Ac.n, ones(d,1));
0244 cAcr = tt_matrix(Ac, Ac.n, ones(d,1));
0245 ycr = yc;
0246 cycr = yc;
0247 for i=1:d
0248 Acr{i} = core_matrix(Ac{i}, phAfc{i});
0249 ycr{i} = core_vector(yc{i}, phyfc{i});
0250 cAcr{i} = core_matrix(Ac{i}, cphAfc{i});
0251 cycr{i} = core_vector(yc{i}, cphyfc{i});
0252 end;
0253 for i=d:-1:2
0254 rtx = rfx(L(i)+1, i);
0255 cr = xc{i};
0256 cr = reshape(cr, rcx(i), rtx*rcx(i+1));
0257 [cr, rv] = qr(cr.', 0);
0258 cr2 = xc{i-1};
0259 rtx2 = rfx(L(i-1)+1, i-1);
0260 cr2 = reshape(cr2, rcx(i-1)*rtx2, rcx(i));
0261 cr2 = cr2*(rv.');
0262 rcx(i) = size(cr, 2);
0263 cr = reshape(cr.', rcx(i), rtx, rcx(i+1));
0264 xc{i-1} = reshape(cr2, rcx(i-1), rtx2, rcx(i));
0265 xc{i} = cr;
0266
0267 phcA{i} = compute_next_Phi(phcA{i+1}, cr, Acr{i}, cr, 'rl');
0268 phcy{i} = compute_next_Phi(phcy{i+1}, cr, [], ycr{i}, 'rl');
0269
0270
0271
0272
0273
0274
0275
0276 ccr = 1;
0277
0278
0279 cphcA{i} = compute_next_Phi(cphcA{i+1}, ccr, cAcr{i}, cr, 'rl');
0280 cphcy{i} = compute_next_Phi(cphcy{i+1}, ccr, [], cycr{i}, 'rl');
0281 end;
0282
0283
0284
0285
0286
0287
0288
0289
0290
0291
0292
0293
0294
0295
0296
0297
0298
0299
0300
0301
0302
0303
0304
0305
0306
0307
0308
0309
0310
0311
0312
0313
0314
0315
0316
0317
0318
0319
0320
0321
0322
0323
0324
0325
0326
0327
0328
0329
0330
0331
0332
0333
0334
0335
0336
0337
0338
0339
0340
0341
0342
0343
0344
0345
0346 for i=1:d
0347
0348
0349
0350
0351
0352
0353
0354
0355
0356
0357 rta = rfA(L(i)+1,i); rtx = rfx(L(i)+1,i); rty = rfy(L(i)+1,i); rtchk = rfchk(L(i)+1,i);
0358 a1left = phcA{i};
0359 a1left = permute(a1left, [1,3,2]);
0360 a1left = reshape(a1left, rcx(i)*rcx(i), rcA(i));
0361 a1left = a1left*reshape(Ac{i}, rcA(i), rta*rcA(i+1));
0362
0363 ca1left = cphcA{i};
0364 ca1left = permute(ca1left, [1,3,2]);
0365 ca1left = reshape(ca1left, rcchk(i)*rcx(i), rcA(i));
0366 ca1left = ca1left*reshape(Ac{i}, rcA(i), rta*rcA(i+1));
0367
0368
0369
0370
0371
0372
0373
0374
0375
0376
0377
0378
0379 a1left = reshape(a1left, rcx(i)*rcx(i), rta, rcA(i+1));
0380 a1left = permute(a1left, [2, 1, 3]);
0381 a1left = reshape(a1left, rta, rcx(i)*rcx(i)*rcA(i+1));
0382
0383 ca1left = reshape(ca1left, rcchk(i)*rcx(i), rta, rcA(i+1));
0384 ca1left = permute(ca1left, [2, 1, 3]);
0385 ca1left = reshape(ca1left, rta, rcchk(i)*rcx(i)*rcA(i+1));
0386
0387
0388 a2 = Af{i}{L(i)};
0389 a2 = reshape(a2, rfA(L(i),i)*n(L(i),i)*n(L(i),i), rta);
0390
0391
0392
0393
0394
0395
0396 a2 = a2*a1left;
0397 a2 = reshape(a2, rfA(L(i),i), n(L(i),i), n(L(i),i), rcx(i), rcx(i), rcA(i+1));
0398 a2 = permute(a2, [1, 2, 4, 3, 5, 6]);
0399 a2 = reshape(a2, rfA(L(i),i), n(L(i),i)*rcx(i), n(L(i),i)*rcx(i), rcA(i+1));
0400
0401 ca2 = Af{i}{L(i)};
0402 ca2 = reshape(ca2, rfA(L(i),i)*n(L(i),i)*n(L(i),i), rta);
0403 ca2 = ca2*ca1left;
0404 ca2 = reshape(ca2, rfA(L(i),i), n(L(i),i), n(L(i),i), rcchk(i), rcx(i), rcA(i+1));
0405 ca2 = permute(ca2, [1, 2, 4, 3, 5, 6]);
0406 ca2 = reshape(ca2, rfA(L(i),i), n(L(i),i)*rcchk(i), n(L(i),i)*rcx(i), rcA(i+1));
0407
0408 Afr = Af{i};
0409 Afr{L(i)} = a2;
0410
0411 cAfr = Af{i};
0412 cAfr{L(i)} = ca2;
0413
0414 y1left = phcy{i};
0415 y1left = y1left*reshape(yc{i}, rcy(i), rty*rcy(i+1));
0416
0417 ph2 = phcy{i+1};
0418 ph2 = ph2.';
0419 y1top = reshape(y1left, rcx(i)*rty, rcy(i+1))*ph2;
0420 y1top = reshape(y1top, rcx(i), rty, rcx(i+1));
0421 y1top = permute(y1top, [1, 3, 2]);
0422 y1top = reshape(y1top, rcx(i)*rcx(i+1), rty);
0423 y2 = yf{i}{L(i)};
0424 y2 = reshape(y2, rfy(L(i),i)*n(L(i),i), rty);
0425 y2 = y2*y1top.';
0426 y2 = reshape(y2, rfy(L(i),i), n(L(i),i)*rcx(i)*rcx(i+1), 1);
0427 yfr = yf{i};
0428 yfr{L(i)} = y2;
0429
0430 cy1left = cphcy{i};
0431 cy1left = cy1left*reshape(yc{i}, rcy(i), rty*rcy(i+1));
0432
0433 ph2 = cphcy{i+1};
0434 ph2 = ph2.';
0435 cy1top = reshape(cy1left, rcchk(i)*rty, rcy(i+1))*ph2;
0436 cy1top = reshape(cy1top, rcchk(i), rty, rcchk(i+1));
0437 cy1top = permute(cy1top, [1, 3, 2]);
0438 cy1top = reshape(cy1top, rcchk(i)*rcchk(i+1), rty);
0439 cy2 = yf{i}{L(i)};
0440 cy2 = reshape(cy2, rfy(L(i),i)*n(L(i),i), rty);
0441 cy2 = cy2*cy1top.';
0442 cy2 = reshape(cy2, rfy(L(i),i), n(L(i),i)*rcchk(i)*rcchk(i+1), 1);
0443 cyfr = yf{i};
0444 cyfr{L(i)} = cy2;
0445
0446 x1 = xc{i};
0447 x1 = permute(x1, [2, 1, 3]);
0448 x1 = reshape(x1, rtx, rcx(i)*rcx(i+1));
0449 x2 = xf{i}{L(i)};
0450 x2 = reshape(x2, rfx(L(i),i)*n(L(i),i), rtx);
0451 x2 = x2*x1;
0452 x2 = reshape(x2, rfx(L(i),i), n(L(i),i)*rcx(i)*rcx(i+1), 1);
0453 xfr = xf{i};
0454 xfr{L(i)} = x2;
0455
0456
0457
0458
0459
0460
0461
0462
0463
0464
0465
0466 curn = xfr.n;
0467 curm = xfr.n;
0468 curm(L(i)) = n(L(i),i);
0469 curra = Afr.r;
0470 curry = yfr.r;
0471 currx = xfr.r;
0472 currchk = [rfchk(1:L(i),i); 1];
0473
0474
0475
0476 for j=L(i):-1:2
0477
0478 if (j<L(i))
0479 Phi2 = phfA{i}{j+1};
0480 cPhi2 = cphfA{i}{j+1};
0481 else
0482 Phi2 = phcA{i+1};
0483 cPhi2 = cphcA{i+1};
0484 end;
0485 a2 = Afr{j};
0486 a1 = Afr{j-1};
0487 Phi1 = phfA{i}{j-1};
0488 ca2 = cAfr{j};
0489 ca1 = cAfr{j-1};
0490 cPhi1 = cphfA{i}{j-1};
0491
0492
0493
0494
0495
0496
0497
0498
0499
0500
0501
0502
0503
0504
0505
0506
0507 y2 = phfy{i}{j+1}.';
0508 y2 = reshape(yfr{j}, curry(j)*curn(j), curry(j+1))*y2;
0509 y2 = reshape(y2, curry(j), curn(j)*currx(j+1));
0510 y2 = y2.';
0511
0512 y1 = phfy{i}{j-1};
0513 y1 = y1*reshape(yfr{j-1}, curry(j-1), curn(j-1)*curry(j));
0514 y1 = reshape(y1, currx(j-1)*curn(j-1), curry(j));
0515
0516 x2 = reshape(xfr{j}, currx(j), curn(j)*currx(j+1));
0517 x2 = x2.';
0518 x1 = reshape(xfr{j-1}, currx(j-1)*curn(j-1), currx(j));
0519
0520
0521 cy2 = cphfy{i}{j+1}.';
0522 cy2 = reshape(cyfr{j}, curry(j)*curm(j), curry(j+1))*cy2;
0523 cy2 = reshape(cy2, curry(j), curm(j)*currchk(j+1));
0524 cy1 = cphfy{i}{j-1};
0525 cy1 = cy1*reshape(cyfr{j-1}, curry(j-1), curm(j-1)*curry(j));
0526 cy1 = reshape(cy1, currchk(j-1)*curm(j-1), curry(j));
0527 cy = cy1*cy2;
0528
0529
0530 if (j==L(i))
0531 currx(j+1)=rcx(i+1);
0532 currchk(j+1)=rcchk(i+1);
0533 curn(j) = n(L(i),i)*rcx(i);
0534 curm(j) = n(L(i),i)*rcchk(i);
0535 end;
0536
0537 if (verb>1)
0538 fprintf('=rake_solve2= swp %d, factor {%d}{%d}, ', swp, i, j);
0539 end;
0540 local_format = 'full';
0541 if (currx(j-1)*curn(j-1)*curn(j)*currx(j+1)>max_full_size2)
0542 local_format = 'tt';
0543 end;
0544
0545
0546
0547
0548
0549
0550 [u,s,v,r,dx_max,res_max]=local_solve(Phi1,a1, a2, Phi2, y1, y2, x1, x2, ...
0551 currx(j-1), curn(j-1), curn(j), currx(j+1), curra(j), ...
0552 tol/sqrt(L(i))/sqrt(d), res_max, dx_max, resid_damp_loc, ...
0553 local_format, max_full_size, nrestart, gmres_iters, verb);
0554 r = min(rmax, r);
0555 u = u(:,1:r); s = s(1:r,1:r); v = v(:,1:r);
0556 u = u*s;
0557
0558
0559 Asol = bfun3(cPhi1,ca1,ca2,cPhi2, u*(v.'));
0560 chk_res_max = max(chk_res_max, norm(Asol-cy(:))/norm(cy(:)));
0561
0562
0563 if (~last_sweep)
0564
0565
0566
0567
0568 v = reort(v, randn(curn(j)*currx(j+1), kickrank));
0569 end;
0570 radd = size(v,2)-r;
0571 u = [u, zeros(currx(j-1)*curn(j-1), radd)];
0572 r = r+radd;
0573 xfr{j} = reshape(v.', r, curn(j), currx(j+1));
0574 xfr{j-1} = reshape(u, currx(j-1), curn(j-1), r);
0575 currx(j) = r;
0576 rfx(j,i)=r;
0577 r_max = max(r_max, r);
0578
0579
0580
0581
0582
0583
0584
0585
0586 phfy{i}{j} = (v')*y2;
0587
0588 phfA{i}{j} = compute_next_Phi(Phi2, xfr{j}, a2, xfr{j}, 'rl');
0589
0590
0591 ccr = ones(n(j,i), 1);
0592
0593
0594 cphfy{i}{j} = (ccr')*(cy2.');
0595
0596 cphfA{i}{j} = compute_next_Phi(cPhi2, ccr.', ca2, xfr{j}, 'rl');
0597 end;
0598
0599
0600
0601
0602
0603
0604
0605
0606
0607
0608
0609
0610
0611 for j=1:L(i)
0612 cr = xfr{j};
0613
0614 if (j<L(i))
0615
0616 cr = reshape(cr, rfx(j,i)*n(j,i), rfx(j+1,i));
0617 [cr, rv] = qr(cr, 0);
0618 cr2 = xfr{j+1};
0619 ncur = size(cr2, 2);
0620 r3cur = size(cr2, 3);
0621 cr2 = reshape(cr2, rfx(j+1,i), ncur*r3cur);
0622 cr2 = rv*cr2;
0623 rfx(j+1,i) = size(cr, 2);
0624 xfr{j} = reshape(cr, rfx(j,i), n(j,i), rfx(j+1,i));
0625 xfr{j+1} = reshape(cr2, rfx(j+1,i), ncur, r3cur);
0626 else
0627
0628
0629 cr = reshape(cr, rfx(j,i)*n(j,i), rcx(i)*rcx(i+1));
0630 [u,s,v]=svd(cr, 'econ');
0631
0632
0633 Phi2 = phcA{i+1};
0634 curA2 = reshape(a1left, rta, rcx(i), rcx(i), rcA(i+1));
0635 curA1 = Af{i}{L(i)};
0636 Phi1 = phfA{i}{j};
0637
0638
0639
0640
0641
0642
0643
0644
0645
0646
0647
0648
0649
0650 rhs = reshape(yf{i}{j}, rfy(j,i), n(j,i)*rty);
0651 rhs = phfy{i}{j}*rhs;
0652 rhs = reshape(rhs, rfx(j,i)*n(j,i), rty);
0653 rhs = rhs*(y1top.');
0654 rhs = reshape(rhs, rfx(j,i)*n(j,i)*rcx(i)*rcx(i+1),1);
0655 r = 1;
0656 normy = norm(rhs);
0657
0658
0659
0660 res_true = norm(bfun3(Phi1, curA1, curA2, Phi2, cr)-rhs)/normy;
0661 while (r<=size(s,1))
0662 cursol = u(:,1:r)*s(1:r,1:r)*(v(:,1:r)');
0663
0664
0665
0666 res = norm(bfun3(Phi1, curA1, curA2, Phi2, cursol)-rhs)/normy;
0667 if (res<max(tol/sqrt(L(i)), res_true*2))
0668 break;
0669 end;
0670 r = r+1;
0671 end;
0672 if (verb>1)
0673 fprintf('=rake_solve2= swp %d, tuckerrank {%d}, res: %3.3e, r: %d\n', swp, i, res, r);
0674 end;
0675 r = min(rmax, r);
0676 u = u(:,1:r);
0677 v = conj(v(:,1:r));
0678 s = s(1:r,1:r);
0679 v = v*s;
0680 if (~last_sweep)
0681
0682
0683
0684
0685 u = reort(u, randn(rfx(j,i)*n(j,i), kickrank));
0686 end;
0687 radd = size(u,2)-r;
0688 v = [v, zeros(rcx(i)*rcx(i+1), radd)];
0689 r = r+radd;
0690 rfx(j+1,i) = r;
0691 cr = u;
0692 xfr{j} = reshape(cr, rfx(j,i), n(j,i), r);
0693 v = reshape(v.', r, rcx(i), rcx(i+1));
0694 v = permute(v, [2, 1, 3]);
0695 xc{i} = v;
0696 r_max = max(r_max, r);
0697 end;
0698
0699 cr = reshape(cr, rfx(j,i), n(j,i), rfx(j+1,i));
0700 if (j<L(i))
0701 phfA{i}{j+1} = compute_next_Phi(phfA{i}{j}, cr, Af{i}{j}, cr, 'lr');
0702 phfy{i}{j+1} = compute_next_Phi(phfy{i}{j}, cr, [], yf{i}{j}, 'lr');
0703 else
0704 phAfc{i} = compute_next_Phi(phfA{i}{j}, cr, Af{i}{j}, cr, 'lr');
0705 phyfc{i} = compute_next_Phi(phfy{i}{j}, cr, [], yf{i}{j}, 'lr');
0706 end;
0707
0708
0709
0710
0711
0712
0713 ccr = ones(1, n(j,i), 1);
0714 if (j<L(i))
0715 cphfA{i}{j+1} = compute_next_Phi(cphfA{i}{j}, ccr, Af{i}{j}, cr, 'lr');
0716 cphfy{i}{j+1} = compute_next_Phi(cphfy{i}{j}, ccr, [], yf{i}{j}, 'lr');
0717 else
0718 cphAfc{i} = compute_next_Phi(cphfA{i}{j}, ccr, Af{i}{j}, cr, 'lr');
0719 cphyfc{i} = compute_next_Phi(cphfy{i}{j}, ccr, [], yf{i}{j}, 'lr');
0720 end;
0721 end;
0722 xf{i} = xfr;
0723
0724
0725 if (i<d)
0726 rtx = rfx(L(i)+1,i); rtx2 = rfx(L(i+1)+1,i+1);
0727 rx1 = rcx(i); rx2 = rcx(i+1); rx3 = rcx(i+2);
0728 ra2 = rcA(i+1); ra3 = rcA(i+2);
0729 ry2 = rcy(i+1); ry3 = rcy(i+2);
0730
0731 rtchk = rfchk(L(i)+1,i); rtchk2 = rfchk(L(i+1)+1,i+1);
0732 rchk1 = rcchk(i); rchk3 = rcchk(i+2);
0733
0734
0735 Phi1 = phcA{i};
0736 a1 = core_matrix(Ac{i}, phAfc{i});
0737 a2 = Acr{i+1};
0738 Phi2 = phcA{i+2};
0739
0740 cPhi1 = cphcA{i};
0741 ca1 = core_matrix(Ac{i}, cphAfc{i});
0742 ca2 = cAcr{i+1};
0743 cPhi2 = cphcA{i+2};
0744
0745
0746
0747
0748
0749
0750
0751
0752
0753
0754
0755
0756
0757
0758 y1 = reshape(y1left, rx1, rty, ry2);
0759 y1 = core_vector(y1, phyfc{i});
0760 y1 = reshape(y1, rx1*rtx, ry2);
0761
0762 y2 = phcy{i+2}.';
0763 y2 = reshape(ycr{i+1}, ry2*rtx2, ry3)*y2;
0764 y2 = reshape(y2, ry2, rtx2*rx3);
0765 y2 = y2.';
0766
0767 x1 = reshape(xc{i}, rx1*rtx, rx2);
0768 x2 = reshape(xc{i+1}, rx2, rtx2*rx3);
0769 x2 = x2.';
0770
0771 if (verb>1)
0772 fprintf('=rake_solve2= swp %d, core {%d}, ', swp, i);
0773 end;
0774 local_format = 'full';
0775 if (rx1*rtx*rtx2*rx3>max_full_size2)
0776 local_format = 'tt';
0777 end;
0778
0779 [u,s,v,r,dx_max,res_max]=local_solve(Phi1, a1, a2, Phi2, y1, y2, x1, x2, ...
0780 rx1, rtx, rtx2, rx3, ra2, ...
0781 tol/sqrt(d), res_max, dx_max, resid_damp_loc,...
0782 local_format, max_full_size, nrestart, gmres_iters, verb);
0783
0784
0785
0786
0787
0788 r = min(rmax, r);
0789 u = u(:,1:r); s = s(1:r,1:r); v = v(:,1:r);
0790 v = v*s;
0791
0792
0793 Asol = bfun3(cPhi1, ca1, ca2, cPhi2, u*(v.'));
0794 cy1 = reshape(cy1left, rchk1, rty, ry2);
0795 cy1 = core_vector(cy1, cphyfc{i});
0796 cy1 = reshape(cy1, rchk1*rtchk, ry2);
0797 cy2 = cphcy{i+2}.';
0798 cy2 = reshape(cycr{i+1}, ry2*rtchk2, ry3)*cy2;
0799 cy2 = reshape(cy2, ry2, rtchk2*rchk3);
0800 cy = cy1*cy2;
0801 chk_res_max = max(chk_res_max, norm(Asol-cy(:))/norm(cy(:)));
0802
0803
0804 if (~last_sweep)
0805
0806
0807
0808
0809 u = reort(u, randn(rx1*rtx, kickrank));
0810 end;
0811 radd = size(u,2)-r;
0812 v = [v, zeros(rtx2*rx3, radd)];
0813 r = r+radd;
0814 xc{i} = reshape(u, rx1, rtx, r);
0815 xc{i+1} = reshape(v.', r, rtx2, rx3);
0816 rcx(i+1)=r;
0817 r_max = max(r_max, r);
0818
0819
0820
0821
0822
0823
0824
0825
0826 phcy{i+1} = (u')*y1;
0827
0828 phcA{i+1} = compute_next_Phi(phcA{i}, xc{i}, a1, xc{i}, 'lr');
0829
0830
0831
0832
0833 ccr = 1;
0834 cphcy{i+1} = (ccr')*cy1;
0835 ccr = reshape(ccr, rchk1, rtchk, rcchk(i+1));
0836 cphcA{i+1} = compute_next_Phi(cphcA{i}, ccr, ca1, xc{i}, 'lr');
0837
0838
0839
0840
0841
0842
0843
0844
0845
0846
0847
0848
0849
0850
0851
0852
0853
0854
0855
0856
0857
0858
0859
0860
0861
0862
0863
0864
0865
0866
0867
0868
0869
0870
0871
0872
0873
0874
0875 end;
0876 end;
0877
0878 if (verb>0)
0879 real_res = NaN;
0880
0881
0882
0883
0884
0885
0886
0887
0888 fprintf('\n=rake_solve2= swp %d, dx_max: %3.3e, res_max: %3.3e, r_max: %d, real_res: %3.3e, chk_res: %3.3e\n\n', swp, dx_max, res_max, r_max, real_res, chk_res_max);
0889 end;
0890 if (last_sweep)
0891 break;
0892 end;
0893 if (chk_res_max<tol)||(swp==nswp-1)
0894
0895
0896 break;
0897 end;
0898 end;
0899
0900 x = qtt_tucker;
0901 x.dphys = d;
0902 x.tuck = xf;
0903 x.core = xc;
0904 end
0905
0906
0907 function [Phi] = compute_next_Phi(Phi_prev, x, A, y, direction)
0908
0909
0910
0911
0912
0913
0914 if (strcmp(direction, 'rl'))
0915
0916 x = permute(x, [3, 2, 1]);
0917 y = permute(y, [3, 2, 1]);
0918 if (~isempty(A))
0919 A = permute(A, [4, 2, 3, 1]);
0920 end
0921 end
0922
0923 rx1 = size(x,1); n = size(x,2); rx2 = size(x,3);
0924 ry1 = size(y,1); m = size(y,2); ry2 = size(y,3);
0925 if (~isempty(A))
0926 ra1 = size(A,1); ra2 = size(A,4);
0927 else
0928 ra1 = 1; ra2 = 1;
0929 end
0930
0931 Phi = reshape(Phi_prev, [rx1*ra1, ry1]);
0932 y = reshape(y, [ry1, m*ry2]);
0933 Phi = Phi*y;
0934 Phi = reshape(Phi, [rx1, ra1, m, ry2]);
0935 Phi = permute(Phi, [2, 3, 1, 4]);
0936 if (~isempty(A))
0937 Phi = reshape(Phi, [ra1*m, rx1*ry2]);
0938 A = permute(A, [4, 2, 1, 3]);
0939 A = reshape(A, [ra2*n, ra1*m]);
0940 Phi = A*Phi;
0941 Phi = reshape(Phi, [ra2, n, rx1, ry2]);
0942 end
0943 Phi = permute(Phi, [3, 2, 1, 4]);
0944 Phi = reshape(Phi, [rx1*n, ra2*ry2]);
0945 x = reshape(x, [rx1*n, rx2]);
0946 Phi = (x')*Phi;
0947 if (~isempty(A))
0948 Phi = reshape(Phi, [rx2, ra2, ry2]);
0949 end
0950 end
0951
0952 function [A] = core_matrix(core_block, Phi_factor)
0953
0954
0955
0956
0957
0958
0959 r1 = size(core_block, 1); rtuck = size(core_block, 2); r2 = size(core_block, 3);
0960 n1 = size(Phi_factor, 1); n2 = size(Phi_factor, 3);
0961
0962 Phi_factor = permute(Phi_factor, [1, 3, 2]);
0963 Phi_factor = reshape(Phi_factor, n1*n2, rtuck);
0964 A = permute(core_block, [2, 1, 3]);
0965 A = reshape(A, rtuck, r1*r2);
0966 A = Phi_factor*A;
0967 A = reshape(A, n1, n2, r1, r2);
0968 A = permute(A, [3, 1, 2, 4]);
0969
0970 end
0971
0972
0973 function [A] = core_vector(core_block, Phi_factor)
0974
0975
0976
0977
0978
0979
0980 r1 = size(core_block, 1); rtuck = size(core_block, 2); r2 = size(core_block, 3);
0981 n1 = size(Phi_factor, 1);
0982
0983 A = permute(core_block, [2, 1, 3]);
0984 A = reshape(A, rtuck, r1*r2);
0985 A = Phi_factor*A;
0986 A = reshape(A, n1, r1, r2);
0987 A = permute(A, [2, 1, 3]);
0988
0989 end
0990
0991
0992 function [y]=bfun2(B, x, rxm1, m1, m2, rxm3, rxn1, k1, k2, rxn3)
0993
0994
0995
0996 rB=size(B{1},3);
0997 x = reshape(x, rxm1*m1, m2*rxm3);
0998 B1 = permute(B{1}, [3 1 2]);
0999 B1 = reshape(B1, rB*rxn1*k1, rxm1*m1);
1000 y = B1*x;
1001 y = reshape(y, rB, rxn1*k1, m2*rxm3);
1002 y = permute(y, [3 1 2]);
1003 y = reshape(y, m2*rxm3*rB, rxn1*k1);
1004 B2 = reshape(B{2}, k2*rxn3, m2*rxm3*rB);
1005 y = B2*y;
1006 y = reshape(y.', rxn1*k1*k2*rxn3, 1);
1007 end
1008
1009
1010 function [y]=bfun3(Phi1,B1,B2,Phi2, x)
1011
1012
1013
1014
1015
1016 ry1 = size(Phi1,1); ry3 = size(Phi2,1);
1017 rx1 = size(Phi1,3); rx3 = size(Phi2,3);
1018 rb1=size(B1,1); rb2=size(B1,4); rb3 = size(B2, 4);
1019 m1 = size(B1,3); m2 = size(B2,3);
1020 k1 = size(B1,2); k2 = size(B2,2);
1021
1022 y = reshape(x, rx1, m1*m2*rx3);
1023 Phi1 = reshape(Phi1, ry1*rb1, rx1);
1024 y = Phi1*y;
1025 y = reshape(y, ry1, rb1*m1, m2, rx3);
1026 y = permute(y, [2, 1, 3, 4]);
1027 y = reshape(y, rb1*m1, ry1*m2*rx3);
1028 B1 = permute(B1, [2, 4, 1, 3]);
1029 B1 = reshape(B1, k1*rb2, rb1*m1);
1030 y = B1*y;
1031 y = reshape(y, k1, rb2, ry1, m2, rx3);
1032 y = permute(y, [2, 4, 3, 1, 5]);
1033 y = reshape(y, rb2*m2, ry1*k1*rx3);
1034 B2 = permute(B2, [2, 4, 1, 3]);
1035 B2 = reshape(B2, k2*rb3, rb2*m2);
1036 y = B2*y;
1037 y = reshape(y, k2, rb3, ry1*k1, rx3);
1038 y = permute(y, [2, 4, 3, 1]);
1039 y = reshape(y, rb3*rx3, ry1*k1*k2);
1040 Phi2 = reshape(Phi2, ry3, rb3*rx3);
1041 y = Phi2*y;
1042 y = y.';
1043 y = reshape(y, ry1*k1*k2*ry3, 1);
1044 end
1045
1046
1047
1048
1049
1050 function [u,s,v,r,dx_max,res_max]=local_solve(Phi1,a1, a2, Phi2, y1, y2, x1, x2, ...
1051 rx1, n1, n2, rx3, ra2, ...
1052 real_tol, res_max, dx_max, resid_damp, ...
1053 local_format, max_full_size, nrestart, gmres_iters, verb)
1054
1055 if (strcmp(local_format, 'full'))
1056 sol_prev = x1*(x2.');
1057 sol_prev = sol_prev(:);
1058 rhs = y1*(y2.');
1059 rhs = rhs(:);
1060 normy = norm(rhs);
1061 if (rx1*n1*n2*rx3<max_full_size)
1062
1063 B = permute(Phi1, [1,3,2]);
1064 B = reshape(B, rx1*rx1, size(a1,1));
1065 a1 = reshape(a1, size(a1,1), n1*n1*ra2);
1066 B = B*a1;
1067 B = reshape(B, rx1, rx1, n1, n1, ra2);
1068 B = permute(B, [1, 3, 2, 4, 5]);
1069 B = reshape(B, rx1*n1*rx1*n1, ra2);
1070 ra3 = size(a2,4);
1071 a2 = reshape(a2, ra2, n2*n2*ra3);
1072 B = B*a2;
1073 B = reshape(B, rx1*n1*rx1*n1*n2*n2, ra3);
1074 Phi2 = permute(Phi2, [2, 1, 3]);
1075 Phi2 = reshape(Phi2, ra3, rx3*rx3);
1076 B = B*Phi2;
1077 B = reshape(B, rx1*n1, rx1*n1, n2, n2, rx3, rx3);
1078 B = permute(B, [1, 3, 5, 2, 4, 6]);
1079
1080
1081
1082
1083
1084
1085 B = reshape(B, rx1*n1*n2*rx3, rx1*n1*n2*rx3);
1086 sol = (B'*B) \ (B'*rhs);
1087
1088
1089
1090 [sol,flg]=gmres(B, rhs, ...
1091 nrestart, real_tol/resid_damp, gmres_iters, [], [], sol);
1092 if (flg>0)
1093 fprintf('--warn-- gmres did not converge\n');
1094 end;
1095
1096
1097
1098 res_true = norm(B*sol-rhs)/normy;
1099 res_prev = norm(B*sol_prev-rhs)/normy;
1100 else
1101 B = cell(2,1);
1102 B{1} = a1;
1103 B{2} = a2;
1104
1105
1106
1107 res_prev = norm(bfun3(Phi1, a1, a2, Phi2, sol_prev)-rhs)/normy;
1108
1109
1110
1111
1112
1113
1114
1115
1116 [sol,flg]=gmres(@(v)bfun3(Phi1, a1, a2, Phi2, v), rhs, ...
1117 nrestart, max(real_tol/resid_damp,res_prev*0), gmres_iters, [], [], sol_prev);
1118 if (flg>0)
1119 fprintf('--warn-- gmres did not converge\n');
1120 end;
1121
1122
1123
1124 res_true = norm(bfun3(Phi1, a1, a2, Phi2, sol)-rhs)/normy;
1125 end;
1126
1127 if ((res_prev/res_true)<resid_damp)&&(res_true>real_tol/resid_damp)
1128 fprintf('--warn-- the residual damp by gmres was smaller than in the truncation\n');
1129 end;
1130
1131 dx = norm(sol-sol_prev)/norm(sol);
1132 dx_max = max(dx_max, dx);
1133 if (rx1*n1*n2*rx3<max_full_size)
1134 Bx = B*sol; Bx_prev = B*sol_prev;
1135 else
1136 Bx = bfun3(Phi1, a1, a2, Phi2, sol);
1137 Bx_prev = bfun3(Phi1, a1, a2, Phi2, sol_prev);
1138 end;
1139 res_max = max(res_max, norm(Bx-Bx_prev)/norm(Bx));
1140
1141
1142 sol = reshape(sol, rx1*n1, n2*rx3);
1143 [u,s,v]=svd(sol, 'econ');
1144 s = diag(s);
1145 r1 = 1; r2 = numel(s); r = round((r1+r2)/2);
1146 while (r2-r1>1)
1147 cursol = u(:,1:r)*diag(s(1:r))*(v(:,1:r)');
1148 if (rx1*n1*n2*rx3<max_full_size)
1149 res = norm(B*cursol(:)-rhs)/normy;
1150 else
1151
1152
1153
1154 res = norm(bfun3(Phi1, a1, a2, Phi2, cursol)-rhs)/normy;
1155 end;
1156 if (res<max(real_tol, res_true*resid_damp))
1157 r2 = r;
1158 else
1159 r1 = r;
1160 end;
1161 r = round((r1+r2)/2);
1162 end;
1163
1164 while (r<=numel(s))
1165 cursol = u(:,1:r)*diag(s(1:r))*(v(:,1:r)');
1166 if (rx1*n1*n2*rx3<max_full_size)
1167 res = norm(B*cursol(:)-rhs)/normy;
1168 else
1169
1170
1171
1172 res = norm(bfun3(Phi1, a1, a2, Phi2, cursol)-rhs)/normy;
1173 end;
1174 if (res<max(real_tol, res_true*resid_damp))
1175 break;
1176 end;
1177 r = r+1;
1178 end;
1179 r = min(r, numel(s));
1180 if (verb>1)
1181 fprintf('dx: %3.3e, res: %3.3e, res_prev: %3.3e, r: %d\n', dx, res, res_prev, r);
1182 end;
1183
1184 s = diag(s(1:r));
1185 u = u(:,1:r);
1186 v = conj(v(:,1:r));
1187 else
1188
1189 B = cell(2,1);
1190 B{1} = a1;
1191 B{2} = a2;
1192
1193 iB = [];
1194 sol_prev = cell(2,1);
1195 sol_prev{1} = x1;
1196 sol_prev{2} = x2;
1197 rhs = cell(2,1);
1198 rhs{1} = y1;
1199 rhs{2} = y2;
1200 normy = tt_dist3(rhs, tt_scal(rhs,0));
1201 drhs = tt_mv(B, sol_prev);
1202 res_prev = tt_dist3(drhs, rhs)/normy;
1203 drhs = tt_add(rhs, tt_scal(drhs, -1));
1204 drhs = tt_compr2(drhs, real_tol);
1205 dsol = tt_gmres(B, drhs, real_tol*resid_damp_loc/res_prev, gmres_iters, nrestart, real_tol, real_tol, iB, [], [], [], 1);
1206
1207 sol = tt_add(sol_prev, dsol);
1208 sol = tt_compr2(sol, real_tol);
1209 normsol = tt_dist3(sol, tt_scal(sol,0));
1210 dx = tt_dist3(sol, sol_prev)/normsol;
1211 res = tt_dist3(tt_mv(B, sol), rhs)/normy;
1212
1213 dx_max = max(dx_max, dx);
1214 res_max = max(res_max, res_prev);
1215 [v, s]=qr(sol{2}, 0);
1216 sol{1} = sol{1}*(s.');
1217 [u, s]=qr(sol{1}, 0);
1218 r = size(sol{1},2);
1219 if (verb>1)
1220 fprintf('dx: %3.3e, res: %3.3e, res_prev: %3.3e, r: %d\n', dx, res, res_prev, r);
1221 end;
1222 end;
1223 end