0001 function [y,swp]=tt_mvk3(W, x, eps, varargin)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030 kickrank = 5;
0031 dropsweeps = 1;
0032
0033
0034 ddpow = 0.1;
0035 ddrank = 1;
0036 d_pow_check = 0;
0037 bot_conv = 0.1;
0038 top_conv = 0.99;
0039 verb = 1;
0040
0041 d = size(x,1);
0042
0043
0044 y = core(tt_ones(tt_size(x)));
0045 y = tt_scal2(y, -tt_dot2(y,y)/2, 1);
0046
0047 rmax=1000;
0048 nswp=25;
0049
0050 for i=1:2:length(varargin)-1
0051 if (~isempty(varargin{i+1}))
0052 switch lower(varargin{i})
0053 case 'nswp'
0054 nswp=varargin{i+1};
0055 case 'rmax'
0056 rmax=lower(varargin{i+1});
0057 case 'y0'
0058 y=varargin{i+1};
0059 case 'verb'
0060 verb=varargin{i+1};
0061 case 'kickrank'
0062 kickrank=varargin{i+1};
0063 case 'ddpow'
0064 ddpow=varargin{i+1};
0065 case 'ddrank'
0066 ddrank=varargin{i+1};
0067 case 'd_pow_check'
0068 d_pow_check=varargin{i+1};
0069 case 'bot_conv'
0070 bot_conv=varargin{i+1};
0071 case 'top_conv'
0072 top_conv=varargin{i+1};
0073 otherwise
0074 error('Unrecognized option: %s\n',varargin{i});
0075 end
0076 end;
0077 end
0078
0079
0080
0081
0082 if ( isa(y,'tt_tensor') )
0083 y=core(y);
0084 end
0085
0086
0087 x{1}=reshape(x{1}, size(x{1},1),1,size(x{1},2));
0088 y{1}=reshape(y{1}, size(y{1},1),1,size(y{1},2));
0089 W{1}=reshape(W{1}, size(W{1},1), size(W{1},2), 1,size(W{1},3));
0090
0091 phiywx = cell(d+1,1); phiywx{1}=1; phiywx{d+1}=1;
0092
0093 last_sweep = false;
0094 dy_old = ones(d-1,1);
0095 dy = zeros(d-1,1);
0096
0097 drank = ones(d-1,1);
0098
0099 dpows = ones(d-1,1);
0100
0101 for swp=1:nswp
0102
0103 for i=d:-1:2
0104 y1 = y{i}; n1 = size(y1,1); r1 = size(y1,2); r2 = size(y1,3);
0105 y1 = permute(y1, [1 3 2]);
0106 y1 = reshape(y1, n1*r2, r1);
0107 [y1,rv]=qr(y1,0);
0108 y2 = y{i-1}; n2 = size(y2,1); r0 = size(y2,2);
0109 y2 = reshape(y2, n2*r0, r1);
0110 y2 = y2*(rv.');
0111 r1 = size(y1,2);
0112 y1 = reshape(y1, n1, r2, r1);
0113 y1 = permute(y1, [1 3 2]);
0114 y{i}=y1;
0115 y{i-1}=reshape(y2, n2,r0,r1);
0116
0117
0118 y1 = permute(y1, [3 1 2]);
0119 y1 = reshape(y1, r2*n1, r1);
0120 x1 = x{i}; rx1 = size(x1,2); rx2 = size(x1,3);
0121 x1 = reshape(x1, n1*rx1, rx2);
0122
0123 w1 = W{i}; rw1 = size(w1,3); rw2 = size(w1,4);
0124 phiywx{i}=reshape(phiywx{i+1}, r2*rw2, rx2)*(x1.');
0125 phiywx{i}=reshape(phiywx{i}, r2, rw2, n1, rx1);
0126 phiywx{i}=permute(phiywx{i}, [3 2 4 1]);
0127 phiywx{i}=reshape(phiywx{i}, n1*rw2, rx1*r2);
0128 w1 = permute(w1, [1 3 2 4]);
0129 w1 = reshape(w1, n1*rw1, n1*rw2);
0130 phiywx{i}=w1*phiywx{i};
0131 phiywx{i}=reshape(phiywx{i}, n1, rw1, rx1, r2);
0132 phiywx{i}=permute(phiywx{i}, [4 1 2 3]);
0133 phiywx{i}=reshape(phiywx{i}, r2*n1, rw1*rx1);
0134 phiywx{i}=(y1')*phiywx{i};
0135 phiywx{i}=reshape(phiywx{i}, r1, rw1, rx1);
0136 end;
0137
0138
0139 for i=1:d-1
0140 x1 = x{i}; n1 = size(x1,1); rx1 = size(x1,2); rx2 = size(x1,3);
0141 x2 = x{i+1}; n2 = size(x2,1); rx3 = size(x2,3);
0142 w1 = W{i}; rw1 = size(w1,3); rw2 = size(w1,4);
0143 w2 = W{i+1}; rw3 = size(w2,4);
0144 y1 = y{i}; ry1 = size(y1,2); ry2 = size(y1,3);
0145 y2 = y{i+1}; ry3 = size(y2,3);
0146
0147 x1 = permute(x1, [2 1 3]);
0148 x1 = reshape(x1, rx1, n1*rx2);
0149 x2 = permute(x2, [2 1 3]);
0150 x2 = reshape(x2, rx2*n2, rx3);
0151
0152 rhs = reshape(phiywx{i}, ry1*rw1, rx1)*x1;
0153 rhs = reshape(rhs, ry1, rw1, n1, rx2);
0154 rhs = permute(rhs, [3 2 1 4]);
0155 rhs = reshape(rhs, n1*rw1, ry1*rx2);
0156 w1 = permute(w1, [1 4 2 3]);
0157 w1 = reshape(w1, n1*rw2, n1*rw1);
0158 rhs = w1*rhs;
0159 rhs = reshape(rhs, n1,rw2,ry1, rx2);
0160 rhs = permute(rhs, [3 1 2 4]);
0161 rhs = reshape(rhs, ry1*n1, rw2*rx2);
0162
0163 rhs2 = reshape(phiywx{i+2}, ry3*rw3, rx3);
0164 rhs2 = rhs2*(x2.');
0165 rhs2 = reshape(rhs2, ry3, rw3, rx2, n2);
0166 rhs2 = permute(rhs2, [4 2 3 1]);
0167 rhs2 = reshape(rhs2, n2*rw3, rx2*ry3);
0168 w2 = permute(w2, [1 3 2 4]);
0169 w2 = reshape(w2, n2*rw2, n2*rw3);
0170 rhs2 = w2*rhs2;
0171 rhs2 = reshape(rhs2, n2, rw2*rx2, ry3);
0172 rhs2 = permute(rhs2, [1 3 2]);
0173 rhs2 = reshape(rhs2, n2*ry3, rw2*rx2);
0174 rhs = rhs*(rhs2.');
0175
0176 y1 = permute(y1, [2 1 3]);
0177 y1 = reshape(y1, ry1*n1, ry2);
0178 y2 = permute(y2, [2 1 3]);
0179 y2 = reshape(y2, ry2, n2*ry3);
0180 yprev = y1*y2;
0181 if (mod(swp,dropsweeps)~=0)&&(swp>1)&&(~last_sweep)
0182 norm_y1 = norm(y1, 'fro');
0183 y1 = y1/norm_y1;
0184 y2 = y2*norm_y1;
0185 end;
0186
0187 dy(i) = norm(rhs-yprev, 'fro')/norm(rhs, 'fro');
0188 if (norm(rhs, 'fro')==0)
0189 dy(i)=0;
0190 end;
0191 if (swp==1)
0192 dy_old(i)=dy(i);
0193 end;
0194
0195 if (last_sweep)
0196 rhs=yprev;
0197 end;
0198
0199
0200 if (dy(i)/dy_old(i)>top_conv)&&(dy(i)>eps/(d^d_pow_check))
0201 drank(i)=drank(i)+ddrank;
0202 dpows(i)=dpows(i)+ddpow;
0203 end;
0204
0205 if (dy(i)/dy_old(i)<bot_conv)||(dy(i)<eps/(d^d_pow_check))
0206 drank(i)=max(drank(i)-ddrank, 1);
0207 dpows(i)=max(dpows(i)-ddpow, 1);
0208 end;
0209
0210 if (last_sweep)
0211 dpows(i)=0.5;
0212 end;
0213
0214 if (mod(swp,dropsweeps)~=0)&&(swp>1)&&(~last_sweep)
0215
0216 [u,s,v]=svd(rhs-yprev,'econ');
0217 else
0218 [u,s,v]=svd(rhs, 'econ');
0219 end;
0220 r = my_chop2(diag(s), eps/(d^dpows(i))*norm(rhs,'fro'));
0221 if (~last_sweep)
0222 r = r+drank(i);
0223 end;
0224 r = min(r, max(size(s)));
0225 r = min(r, rmax);
0226
0227
0228
0229
0230 if (verb>1)
0231 fprintf('sweep %d, block %d, rank: %d, dy: %3.3e, dy_old: %3.3e, drank: %g, dpow: %g\n', swp, i, r, dy(i), dy_old(i), drank(i), dpows(i));
0232 end;
0233
0234
0235
0236
0237 u = u(:,1:r);
0238 v = conj(v(:,1:r))*s(1:r,1:r);
0239 if (mod(swp,dropsweeps)~=0)&&(swp>1)&&(~last_sweep)
0240
0241 u = [y1, u];
0242 v = [y2.', v];
0243 [u,rv]=qr(u,0);
0244 ry2 = size(u,2);
0245 v = v*(rv.');
0246 else
0247
0248 if (~last_sweep)
0249 u = reort(u, randn(size(u,1),kickrank));
0250 r = size(u,2);
0251 v = [v, zeros(size(v,1),r-size(v,2))];
0252 end;
0253 ry2 = size(u,2);
0254 end;
0255
0256
0257
0258
0259
0260
0261 y{i}=permute(reshape(u, ry1, n1, ry2), [2 1 3]);
0262 y{i+1}=permute(reshape(v, n2, ry3, ry2), [1 3 2]);
0263
0264
0265 x1 = x{i}; rx1 = size(x1,2); rx2 = size(x1,3);
0266 x1 = reshape(permute(x1, [2 1 3]), rx1, n1*rx2);
0267
0268 w1 = W{i}; rw1 = size(w1,3); rw2 = size(w1,4);
0269 phiywx{i+1}=reshape(phiywx{i}, ry1*rw1, rx1)*x1;
0270 phiywx{i+1}=reshape(phiywx{i+1}, ry1, rw1, n1, rx2);
0271 phiywx{i+1}=permute(phiywx{i+1}, [3 2 4 1]);
0272 phiywx{i+1}=reshape(phiywx{i+1}, n1*rw1, rx2*ry1);
0273 w1 = permute(w1, [2 3 1 4]);
0274 w1 = reshape(w1, n1*rw1, n1*rw2);
0275 phiywx{i+1}=(w1.')*phiywx{i+1};
0276 phiywx{i+1}=reshape(phiywx{i+1}, n1, rw2, rx2, ry1);
0277 phiywx{i+1}=permute(phiywx{i+1}, [4 1 2 3]);
0278 phiywx{i+1}=reshape(phiywx{i+1}, ry1*n1, rw2*rx2);
0279 phiywx{i+1}=(u')*phiywx{i+1};
0280 phiywx{i+1}=reshape(phiywx{i+1}, ry2, rw2, rx2);
0281 end;
0282
0283
0284
0285
0286
0287
0288 if (verb>0)
0289 erank=0; sumn=0;
0290 for i=1:d
0291 erank = erank+size(y{i},1)*size(y{i},2)*size(y{i},3);
0292 sumn = sumn+size(y{i},1);
0293 end;
0294 erank = sqrt(erank/sumn);
0295 fprintf('=mvk3= Sweep %d, dy_max: %3.3e, conv_max: %1.5f, erank: %g\n', swp, max(dy), max(dy)/max(dy_old), erank);
0296 end;
0297 if (last_sweep)
0298 break;
0299 end;
0300
0301 if (max(dy)<eps/(d^d_pow_check))
0302 last_sweep = true;
0303
0304 end;
0305 dy_old = dy;
0306 end;
0307
0308 y{1}=reshape(y{1}, size(y{1},1), size(y{1},3));
0309
0310 if (swp==nswp)&&(max(dy)>eps/(d^d_pow_check))
0311 fprintf('tt_mvk3 warning: error is not fixed for maximal number of sweeps %d, err_max: %3.3e\n', swp, max(dy));
0312 end;
0313
0314 end
0315