Home > tt2 > solve > als_solve_rx.m

als_solve_rx

PURPOSE ^

Computes an approximate low-rank solution for 2D case

SYNOPSIS ^

function [x]=als_solve_rx(mat, rhs, tol, drx, nswp, addswp)

DESCRIPTION ^

Computes an approximate low-rank solution for 2D case
   [x]=ALS_SOLVE_RX(MAT, RHS, [TOL], [RX], [NSWP])
   Finds a solution to 2D TTM matrix MAT using the ALS to a 2D TT tensor
   with rank rx, but the RHS and X are represented as full vectors.
   TOL is the tolerance for ||x_{i+1}-x_i||/||x_i||,
   DRX is the random kick rank,
   NSWP - number of ALS sweeps.
   default values:
   tol: 1e-12
   rx: 1
   nswp: 10


 TT-Toolbox 2.2, 2009-2012

This is TT Toolbox, written by Ivan Oseledets et al.
Institute of Numerical Mathematics, Moscow, Russia
webpage: http://spring.inm.ras.ru/osel

For all questions, bugs and suggestions please mail
ivan.oseledets@gmail.com
---------------------------

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [x]=als_solve_rx(mat, rhs, tol, drx, nswp, addswp)
0002 %Computes an approximate low-rank solution for 2D case
0003 %   [x]=ALS_SOLVE_RX(MAT, RHS, [TOL], [RX], [NSWP])
0004 %   Finds a solution to 2D TTM matrix MAT using the ALS to a 2D TT tensor
0005 %   with rank rx, but the RHS and X are represented as full vectors.
0006 %   TOL is the tolerance for ||x_{i+1}-x_i||/||x_i||,
0007 %   DRX is the random kick rank,
0008 %   NSWP - number of ALS sweeps.
0009 %   default values:
0010 %   tol: 1e-12
0011 %   rx: 1
0012 %   nswp: 10
0013 %
0014 %
0015 % TT-Toolbox 2.2, 2009-2012
0016 %
0017 %This is TT Toolbox, written by Ivan Oseledets et al.
0018 %Institute of Numerical Mathematics, Moscow, Russia
0019 %webpage: http://spring.inm.ras.ru/osel
0020 %
0021 %For all questions, bugs and suggestions please mail
0022 %ivan.oseledets@gmail.com
0023 %---------------------------
0024 
0025 
0026 a1 = mat{1}; a2 = mat{2};
0027 n1 = size(a1,1); m1 = size(a1,2);
0028 n2 = size(a2,1); m2 = size(a2,2);
0029 ra = size(a1,3);
0030 
0031 % tol2 = 1e-3;
0032 
0033 rhs = reshape(rhs, n1, n2);
0034 
0035 if (nargin<3)||(isempty(tol))
0036     tol=1e-12;
0037 end;
0038 if (nargin<4)||(isempty(drx))
0039     drx=1;
0040 end;
0041 if (nargin<5)||(isempty(nswp))
0042     nswp=10;
0043 end;
0044 if (nargin<6)||(isempty(addswp))
0045     addswp=2;
0046 end;
0047 
0048 if (drx>m1)||(drx>m2)
0049     drx = min(m1,m2);
0050 end;
0051 
0052 rx=1;
0053 
0054 curx = cell(2,1);
0055 curx{1}=rand(m1,rx);
0056 curx{2}=rand(m2,rx);
0057 x = curx{1}*(curx{2}.');
0058 derr = 2;
0059 sp = 0;
0060 resid_old = 1;
0061 for swp=1:nswp
0062     
0063 
0064     % QR 2-1
0065     [q,rv]=qr(curx{2},0); % m2,rx' - rx',rx
0066     rx = size(q,2);
0067     curx{2}=q;
0068 %     curx{1}=curx{1}*(rv.');
0069     
0070     % compute phi
0071     a2 = permute(mat{2}, [1 3 2]);
0072     a2 = reshape(a2, n2*ra, m2);
0073     phi = a2*curx{2}; % size n2*ra, rx
0074     phi = reshape(phi, n2, ra*rx);
0075     phi = (phi')*phi; % size ra*rx, ra*rx <-- for cplx should also work
0076     phi = reshape(phi, ra, rx, ra, rx);
0077     phi = reshape(permute(phi, [1 3 2 4]), ra*ra, rx*rx);
0078 %     phi = reshape(permute(phi, [3 1 2 4]), ra, ra*rx*rx);
0079 %     a2 = reshape(mat{2}, n2, m2*ra);
0080 %     phi = (a2.')*phi; % size m2*ra, ra*rx
0081 %     phi = reshape(phi, m2, ra*ra*rx);
0082 %     phi = (curx{2}.')*phi; % size rx, ra*ra*rx
0083 %     phi = reshape(permute(reshape(phi, rx, ra, ra, rx), [2 3 1 4]), ra*ra, rx*rx);
0084 %     % And the projection of the matrix
0085     a1 = reshape(permute(mat{1}, [3 2 1]), ra*m1, n1);
0086     a1 = conj(a1)*reshape(mat{1}, n1, m1*ra); % size ra*m1, m1*ra <-- conjugate!
0087     a1 = reshape(a1, ra, m1, m1, ra);
0088     a1 = reshape(permute(a1, [1 4 2 3]), ra*ra, m1*m1);
0089 %     a1 = reshape(permute(mat{1}, [3 2 1]), ra, m1*n1);
0090     a1 = (phi.')*a1; % size rx*rx, m1*n1
0091     a1 = reshape(a1, rx, rx, m1, m1);
0092     a1 = reshape(permute(a1, [2 4 1 3]), rx*m1, rx*m1);
0093 %     a1 = reshape(mat{1}, n1*m1, ra)*phi; % size n1*m1, ra*rx*rx
0094 %     a1 = reshape(a1, n1, m1, ra, rx, rx);
0095 %     a1 = reshape(permute(a1, [1 3 2 4 5]), n1*ra, m1*rx*rx);
0096 %     a1 = conj(reshape(permute(mat{1}, [2 1 3]), m1, n1*ra))*a1; % size m1, m1*rx*rx
0097 %     a1 = reshape(a1, m1, m1, rx, rx);
0098 %     a1 = reshape(permute(a1, [3 1 4 2]), rx*m1, rx*m1);
0099     
0100     %rhs:
0101     
0102     rhs1 = rhs*conj(reshape(mat{2}, n2, m2*ra)); % size n1, m2*ra <-- conjugate
0103     rhs1 = reshape(rhs1, n1, m2, ra);
0104     rhs1 = reshape(permute(rhs1, [1 3 2]), n1*ra, m2);
0105     
0106     rhs1 = conj(reshape(permute(mat{1}, [2 1 3]), m1, n1*ra))*rhs1; % size m1, m2
0107 %     rhs1 = rhs;
0108     rhs1 = rhs1*conj(curx{2}); % size m1,rx
0109     rhs1 = reshape(rhs1.', rx*m1, 1);
0110     
0111     curx{1}=a1 \ rhs1; % new first block
0112 %     cond_a1 = cond(a1)
0113     curx{1}=reshape(curx{1}, rx, m1).';
0114     
0115     % Now, let's try the kickass by rank drx:
0116     if (mod(swp,addswp)==0)
0117 %     if (sp>5)
0118         curx{1}=[curx{1}, randn(m1,drx)];
0119 %         sp=0;
0120     end;
0121 %     rx=rx+1;
0122     
0123     % Now, let's compute the second block
0124     [q,rv]=qr(curx{1},0); % m1,rx' - rx',rx
0125     rx = size(q,2);
0126     curx{1}=q;
0127     
0128     % compute phi
0129     a1 = permute(mat{1}, [1 3 2]);
0130     a1 = reshape(a1, n1*ra, m1);
0131     phi = a1*q; % size n1*ra, rx
0132     phi = reshape(phi, n1, ra*rx);
0133     phi = (phi')*phi; % size ra*rx, ra*rx
0134     phi = reshape(phi, ra, rx, ra, rx);
0135     phi = reshape(permute(phi, [1 3 2 4]), ra*ra, rx*rx);    
0136 %     a1 = reshape(mat{1}, n1, m1*ra);
0137 %     phi = (a1.')*phi; % size m1*ra, ra*rx
0138 %     phi = reshape(phi, m1, ra*ra*rx);
0139 %     phi = (curx{1}.')*phi; % size rx, ra*ra*rx
0140 %     phi = reshape(permute(reshape(phi, rx, ra, ra, rx), [2 3 1 4]), ra*ra, rx*rx);
0141     % And the projection of the matrix
0142     a2 = reshape(permute(mat{2}, [3 2 1]), ra*m2, n2);
0143     a2 = conj(a2)*reshape(mat{2}, n2, m2*ra); % size ra*m2, m2*ra
0144     a2 = reshape(a2, ra, m2, m2, ra);
0145     a2 = reshape(permute(a2, [1 4 2 3]), ra*ra, m2*m2);
0146 %     a2 = reshape(permute(mat{2}, [3 2 1]), ra, m2*n2);
0147     a2 = (phi.')*a2; % size rx*rx, m2*n2
0148     a2 = reshape(a2, rx, rx, m2, m2);
0149     a2 = reshape(permute(a2, [2 4 1 3]), rx*m2, rx*m2);
0150     
0151     %rhs:
0152     rhs2 = rhs*conj(reshape(mat{2}, n2, m2*ra)); % size n1, m2*ra
0153     rhs2 = reshape(rhs2, n1, m2, ra);
0154     rhs2 = reshape(permute(rhs2, [1 3 2]), n1*ra, m2);
0155     
0156     rhs2 = conj(reshape(permute(mat{1}, [2 1 3]), m1, n1*ra))*rhs2; % size m1, m2
0157 %     rhs2 = rhs;
0158     rhs2 = (curx{1}')*rhs2; % size rx,m2
0159     rhs2 = reshape(rhs2, rx*m2, 1);
0160     
0161     curx{2}=a2 \ rhs2; % new first block
0162     curx{2}=reshape(curx{2}, rx, m2).';
0163     
0164     x_new = curx{1}*(curx{2}.'); % size m1,m2
0165     derr = norm(x_new(:)-x(:))/norm(x(:));
0166     
0167     x = x_new;
0168     
0169     resid = norm(tt_mat_full_vec(mat, x(:))-rhs(:))/norm(rhs(:));
0170     conv_fact = resid_old/resid;
0171     if (conv_fact-1<1e-4)
0172         sp=sp+1;
0173     end;
0174     
0175     fprintf('als_solve: swp=%d, derr=%3.3e, rx=%d, resid=%3.3e, conv-1=%3.5e\n', swp, derr, rx, resid, conv_fact-1);
0176     if (derr<tol)
0177         break;
0178     end;
0179     resid_old = resid;
0180 end;
0181 
0182 x = x(:);
0183 
0184 end

Generated on Wed 08-Feb-2012 18:20:24 by m2html © 2005