Home > tt2 > core > tt_reshape.m

tt_reshape

PURPOSE ^

Reshape of the TT-tensor

SYNOPSIS ^

function [tt2]=tt_reshape(tt1,sz,eps, rl, rr)

DESCRIPTION ^

Reshape of the TT-tensor
   [TT1]=TT_RESHAPE(TT,SZ) reshapes TT-tensor or TT-matrix into another 
   with mode sizes SZ, accuracy 1e-14

   [TT1]=TT_RESHAPE(TT,SZ,EPS) reshapes TT-tensor/matrix into another with
   mode sizes SZ and accuracy EPS
   
   [TT1]=TT_RESHAPE(TT,SZ,EPS, RL) reshapes TT-tensor/matrix into another 
   with mode size SZ and left tail rank RL

   [TT1]=TT_RESHAPE(TT,SZ,EPS, RL, RR) reshapes TT-tensor/matrix into 
   another with mode size SZ and tail ranks RL*RR
   Reshapes TT-tensor/matrix into a new one, with dimensions specified by SZ.

   If the input is TT-matrix, SZ must have the sizes for both modes, 
   so it is a matrix if sizes d2-by-2.
   If the input is TT-tensor, SZ may be either a column or a row vector.
   


 TT-Toolbox 2.2, 2009-2012

This is TT Toolbox, written by Ivan Oseledets et al.
Institute of Numerical Mathematics, Moscow, Russia
webpage: http://spring.inm.ras.ru/osel

For all questions, bugs and suggestions please mail
ivan.oseledets@gmail.com
---------------------------

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [tt2]=tt_reshape(tt1,sz,eps, rl, rr)
0002 %Reshape of the TT-tensor
0003 %   [TT1]=TT_RESHAPE(TT,SZ) reshapes TT-tensor or TT-matrix into another
0004 %   with mode sizes SZ, accuracy 1e-14
0005 %
0006 %   [TT1]=TT_RESHAPE(TT,SZ,EPS) reshapes TT-tensor/matrix into another with
0007 %   mode sizes SZ and accuracy EPS
0008 %
0009 %   [TT1]=TT_RESHAPE(TT,SZ,EPS, RL) reshapes TT-tensor/matrix into another
0010 %   with mode size SZ and left tail rank RL
0011 %
0012 %   [TT1]=TT_RESHAPE(TT,SZ,EPS, RL, RR) reshapes TT-tensor/matrix into
0013 %   another with mode size SZ and tail ranks RL*RR
0014 %   Reshapes TT-tensor/matrix into a new one, with dimensions specified by SZ.
0015 %
0016 %   If the input is TT-matrix, SZ must have the sizes for both modes,
0017 %   so it is a matrix if sizes d2-by-2.
0018 %   If the input is TT-tensor, SZ may be either a column or a row vector.
0019 %
0020 %
0021 %
0022 % TT-Toolbox 2.2, 2009-2012
0023 %
0024 %This is TT Toolbox, written by Ivan Oseledets et al.
0025 %Institute of Numerical Mathematics, Moscow, Russia
0026 %webpage: http://spring.inm.ras.ru/osel
0027 %
0028 %For all questions, bugs and suggestions please mail
0029 %ivan.oseledets@gmail.com
0030 %---------------------------
0031 
0032 
0033 d1=tt1.d;
0034 if (nargin<3)||(isempty(eps))
0035     eps = 1e-14;
0036 end;
0037 if (nargin<4)||(isempty(rl))
0038     rl = 1;
0039 end;
0040 if (nargin<5)||(isempty(rr))
0041     rr = 1;
0042 end;
0043 
0044 ismatrix = false;
0045 if (isa(tt1, 'tt_matrix'))
0046     d2 = size(sz, 1);
0047     ismatrix = true;
0048     % The size should be [n,m] in R^{d x 2}
0049     restn2_n = sz(:,1);
0050     restn2_m = sz(:,2);
0051     sz_n = sz(:,1);
0052     sz_m = sz(:,2);
0053     n1_n = tt1.n;
0054     n1_m = tt1.m;    
0055     sz = prod(sz, 2); % We will split/convolve using the vector form anyway
0056     tt1 = tt1.tt;
0057 else
0058     d2 = numel(sz);
0059 end;
0060 
0061 % Recompute sz to include r0,rd,
0062 % and the items of tt1
0063 sz(1)=sz(1)*rl;
0064 sz(d2)=sz(d2)*rr;
0065 tt1.n(1) = tt1.n(1)*tt1.r(1);
0066 tt1.n(d1) = tt1.n(d1)*tt1.r(d1+1);
0067 if (ismatrix) % in matrix: 1st tail rank goes to the n-mode, last to the m-mode
0068     restn2_n(1)=restn2_n(1)*rl;
0069     restn2_m(d2)=restn2_m(d2)*rr;
0070     n1_n(1) = n1_n(1)*tt1.r(1);
0071     n1_m(d1) = n1_m(d1)*tt1.r(d1+1);
0072 end;
0073 tt1.r(1)=1;
0074 tt1.r(d1+1)=1;
0075 
0076 n1=tt1.n;
0077 
0078 if ( prod(n1) ~= prod(sz) )
0079  error('Reshape: incorrect sizes');
0080 end
0081 
0082 
0083 needQRs = false;
0084 if (d2>d1)
0085     needQRs = true;
0086 end;
0087 if (d2<=d1)
0088     i2=1;
0089     n2 = sz;
0090     for i1=1:d1
0091         if (n2(i2)==1)
0092             i2 = i2+1;
0093             if (i2>d2)
0094                 break;
0095             end;
0096         end;
0097         if (mod(n2(i2), n1(i1))==0)
0098             n2(i2)=n2(i2)/n1(i1);
0099         else
0100             needQRs = true;
0101             break;
0102         end;
0103     end;
0104 end;
0105 
0106 if (needQRs) % We have to split some cores -> perform QRs
0107     for i=d1:-1:2
0108         cr = tt1{i};
0109         r1 = size(cr,1); r2 = size(cr,3);
0110         cr = reshape(cr, r1, n1(i)*r2);
0111         [cr,rv]=qr(cr.', 0); % Size n*r2, r1new - r1nwe,r1
0112         cr0 = tt1{i-1};
0113         r0 = size(cr0, 1);
0114         cr0 = reshape(cr0, r0*n1(i-1), r1);
0115         cr0 = cr0*(rv.'); % r0*n0, r1new
0116         r1 = size(cr,2);        
0117         cr0 = reshape(cr0, r0, n1(i-1), r1);
0118         cr = reshape(cr.', r1, n1(i), r2);
0119         tt1{i} = cr;
0120         tt1{i-1} = cr0;  
0121     end;
0122 end;
0123 
0124 r1 = tt1.r;
0125 r2 = ones(d2+1,1);
0126     
0127 i1 = 1; % Working index in tt1
0128 i2 = 1; % Working index in tt2
0129 core2 = zeros(0,1);
0130 last_ps2 = 1;
0131 curcr2 = 1;
0132 restn2 = sz;
0133 n2 = ones(d2,1);
0134 if (ismatrix)
0135     n2_n = ones(d2, 1);
0136     n2_m = ones(d2, 1);
0137 end;
0138 
0139 while (i1<=d1)
0140     curcr1 = tt1{i1};    
0141     if (gcd(restn2(i2), n1(i1))==n1(i1))
0142         % The whole core1 fits to core2. Convolve it
0143         if (i1<d1)&&(needQRs) % QR to the next core - for safety
0144             curcr1 = reshape(curcr1, r1(i1)*n1(i1), r1(i1+1));
0145             [curcr1, rv]=qr(curcr1, 0);
0146             curcr12 = tt1{i1+1};
0147             curcr12 = reshape(curcr12, r1(i1+1), n1(i1+1)*r1(i1+2));
0148             curcr12 = rv*curcr12;
0149             r1(i1+1)=size(curcr12, 1);
0150             tt1{i1+1} = reshape(curcr12, r1(i1+1), n1(i1+1), r1(i1+2));
0151         end;
0152         
0153         curcr1 = reshape(curcr1, r1(i1), n1(i1)*r1(i1+1));
0154         curcr2 = curcr2*curcr1; % size r21*nold, dn*r22
0155         if (ismatrix) % Permute if we are working with tt_matrix
0156             curcr2 = reshape(curcr2, r2(i2), n2_n(i2), n2_m(i2), n1_n(i1), n1_m(i1), r1(i1+1));
0157             curcr2 = permute(curcr2, [1, 2, 4, 3, 5, 6]);
0158             % Update the "matrix" sizes
0159             n2_n(i2) = n2_n(i2)*n1_n(i1);
0160             n2_m(i2) = n2_m(i2)*n1_m(i1);
0161             restn2_n(i2)=restn2_n(i2)/n1_n(i1);
0162             restn2_m(i2)=restn2_m(i2)/n1_m(i1);
0163         end;
0164         r2(i2+1)=r1(i1+1);
0165         % Update the sizes of tt2
0166         n2(i2)=n2(i2)*n1(i1);
0167         restn2(i2)=restn2(i2)/n1(i1);
0168         curcr2 = reshape(curcr2, r2(i2)*n2(i2), r2(i2+1));
0169 %         if (i1<d1)
0170             i1 = i1+1; % current core1 is over
0171 %         end;
0172     else
0173         if (gcd(restn2(i2), n1(i1))~=1)||(restn2(i2)==1)
0174             % There exists a nontrivial divisor. Split it and convolve
0175             n12 = gcd(restn2(i2), n1(i1));
0176             if (ismatrix) % Permute before the truncation
0177                 % Matrix sizes we are able to split
0178                 n12_n = gcd(restn2_n(i2), n1_n(i1));
0179                 n12_m = gcd(restn2_m(i2), n1_m(i1));
0180                 curcr1 = reshape(curcr1, r1(i1), n12_n, (n1_n(i1)/n12_n), n12_m, (n1_m(i1)/n12_m), r1(i1+1));
0181                 curcr1 = permute(curcr1, [1, 2, 4, 3, 5, 6]);
0182                 % Update the matrix sizes of tt2 and tt1
0183                 n2_n(i2)=n2_n(i2)*n12_n;
0184                 n2_m(i2)=n2_m(i2)*n12_m;
0185                 restn2_n(i2)=restn2_n(i2)/n12_n;
0186                 restn2_m(i2)=restn2_m(i2)/n12_m;
0187                 n1_n(i1) = n1_n(i1)/n12_n;
0188                 n1_m(i1) = n1_m(i1)/n12_m;
0189             end;
0190             
0191             curcr1 = reshape(curcr1, r1(i1)*n12, (n1(i1)/n12)*r1(i1+1));
0192             [u,s,v]=svd(curcr1, 'econ');
0193             s = diag(s);
0194             r = my_chop2(s, eps*norm(s)/sqrt(d2-1));
0195             u = u(:,1:r);
0196             v = v(:,1:r)*diag(s(1:r));
0197             u = reshape(u, r1(i1), n12*r);
0198             curcr2 = curcr2*u; % size r21*nold, dn*r22
0199             r2(i2+1)=r;
0200             % Update the sizes of tt2
0201             n2(i2)=n2(i2)*n12;
0202             restn2(i2)=restn2(i2)/n12;
0203             curcr2 = reshape(curcr2, r2(i2)*n2(i2), r2(i2+1));
0204             r1(i1) = r;
0205             % and tt1
0206             n1(i1) = n1(i1)/n12;
0207             curcr1 = reshape(v', r1(i1), n1(i1), r1(i1+1));
0208             tt1{i1} = curcr1;
0209         else
0210             % Bad case. We have to merge cores of tt1 until a common divisor
0211             % appears
0212             i1new = i1+1;
0213             curcr1 = reshape(curcr1, r1(i1)*n1(i1), r1(i1+1));
0214             while (gcd(restn2(i2), n1(i1))==1)&&(i1new<=d1)
0215                 cr1new = tt1{i1new};
0216                 cr1new = reshape(cr1new, r1(i1new), n1(i1new)*r1(i1new+1));
0217                 curcr1 = curcr1*cr1new; % size r1(i1)*n1(i1), n1new*r1new
0218                 if (ismatrix) % Permutes and matrix size updates
0219                     curcr1 = reshape(curcr1, r1(i1), n1_n(i1), n1_m(i1), n1_n(i1new), n1_m(i1new), r1(i1new+1));
0220                     curcr1 = permute(curcr1, [1, 2, 4, 3, 5, 6]);
0221                     n1_n(i1) = n1_n(i1)*n1_n(i1new);
0222                     n1_m(i1) = n1_m(i1)*n1_m(i1new);
0223                 end;
0224                 n1(i1) = n1(i1)*n1(i1new);
0225                 curcr1 = reshape(curcr1, r1(i1)*n1(i1), r1(i1new+1));
0226                 i1new = i1new+1;
0227             end;
0228             % Reduce dimension of tt1
0229             tt1.n = [n1(1:i1); n1(i1new:d1)];
0230             tt1.r = [r1(1:i1); r1(i1new:d1+1)];
0231             if (i1new<=d1)
0232                 crlast = tt1.core(tt1.ps(i1new):end);
0233             else
0234                 crlast = [];
0235             end;
0236             tt1.core = [tt1.core(1:tt1.ps(i1)-1); curcr1(:); crlast];
0237             n1 = tt1.n;
0238             r1 = tt1.r;
0239             d1 = numel(n1);
0240             tt1.d = d1;
0241             tt1.ps = cumsum([1; r1(1:d1).*n1.*r1(2:d1+1)]);
0242         end;
0243     end;
0244     
0245     if (restn2(i2)==1)&&((i1>d1)||((i1<=d1)&&(n1(i1)~=1))) % The core of tt2 is finished
0246         % The second condition works, if we are squeezing the tailing singletons.
0247         core2(last_ps2:last_ps2+r2(i2)*n2(i2)*r2(i2+1)-1) = curcr2(:);
0248         last_ps2 = last_ps2 + r2(i2)*n2(i2)*r2(i2+1);
0249 %         if (i2<d2)
0250             i2 = i2+1;
0251 %         end;
0252         % Start new core2
0253         curcr2 = 1;
0254     end;
0255 end;
0256 
0257 % If we've asked for singletons
0258 while (i2<=d2)
0259     core2(last_ps2) = 1;
0260     last_ps2 = last_ps2+1;
0261     r2(i2)=1;
0262     i2 = i2+1;
0263 end;
0264 
0265 tt2 = tt_tensor;
0266 tt2.d = d2;
0267 tt2.n = n2;
0268 tt2.r = r2;
0269 tt2.core = core2;
0270 tt2.ps = cumsum([1; r2(1:d2).*n2.*r2(2:d2+1)]);
0271 
0272 
0273 tt2.n(1) = tt2.n(1)/rl;
0274 tt2.n(d2) = tt2.n(d2)/rr;
0275 tt2.r(1) = rl;
0276 tt2.r(d2+1) = rr;
0277 
0278 if (ismatrix)
0279     tt2 = tt_matrix(tt2, sz_n, sz_m);
0280 end;
0281 
0282 end

Generated on Wed 08-Feb-2012 18:20:24 by m2html © 2005