Computes an approximate low-rank solution for 2D case (Method 2) [x]=ALS_SOLVE_RX(MAT, RHS, [TOL], [MAXIT],[X0],[DRX], [NSWP]) Finds a solution to 2D TTM matrix MAT using the ALS to a 2D TT tensor with rank rx, but the RHS and X are represented as full vectors. TT-Toolbox 2.2, 2009-2012 This is TT Toolbox, written by Ivan Oseledets et al. Institute of Numerical Mathematics, Moscow, Russia webpage: http://spring.inm.ras.ru/osel For all questions, bugs and suggestions please mail ivan.oseledets@gmail.com ---------------------------
0001 function [x]=als_solve_rx_2(mat, rhs, tol, maxit, x0, drx, nswp) 0002 %Computes an approximate low-rank solution for 2D case (Method 2) 0003 % [x]=ALS_SOLVE_RX(MAT, RHS, [TOL], [MAXIT],[X0],[DRX], [NSWP]) 0004 % Finds a solution to 2D TTM matrix MAT using the ALS to a 2D TT tensor 0005 % with rank rx, but the RHS and X are represented as full vectors. 0006 % 0007 % 0008 % TT-Toolbox 2.2, 2009-2012 0009 % 0010 %This is TT Toolbox, written by Ivan Oseledets et al. 0011 %Institute of Numerical Mathematics, Moscow, Russia 0012 %webpage: http://spring.inm.ras.ru/osel 0013 % 0014 %For all questions, bugs and suggestions please mail 0015 %ivan.oseledets@gmail.com 0016 %--------------------------- 0017 0018 0019 nrmf = norm(rhs); 0020 0021 if (nargin<3)||(isempty(tol)) 0022 tol=1e-12; 0023 end; 0024 if (nargin<4)||(isempty(maxit)) 0025 maxit=2; 0026 end; 0027 if (nargin<5)||(isempty(x0)) 0028 x0 = zeros(size(rhs)); 0029 end; 0030 if (nargin<6)||(isempty(drx)) 0031 drx=1; 0032 end; 0033 if (nargin<7)||(isempty(nswp)) 0034 nswp=4; 0035 end; 0036 0037 0038 x=x0; 0039 cur_rhs = rhs - tt_mat_full_vec(mat,x); 0040 0041 err = norm(cur_rhs)/nrmf; 0042 if (err<tol) 0043 return; 0044 end; 0045 0046 spunct = 0; 0047 err_old = err; 0048 for i=1:maxit 0049 cur_x = als_solve_rx(mat, cur_rhs, tol, drx, nswp); 0050 % cur_rhs = cur_rhs - tt_mat_full_vec(mat, cur_x); 0051 0052 x = x+cur_x; 0053 cur_rhs = rhs - tt_mat_full_vec(mat,x); 0054 0055 err = norm(cur_rhs)/nrmf; 0056 conv_fact = err_old/err; 0057 0058 fprintf('als_solve_full: iter = %d, resid = %3.3e, conv_fact=%3.3f\n', i, err, conv_fact); 0059 0060 if (conv_fact<1.2) 0061 spunct = spunct+1; 0062 % if (rx<15) 0063 % rx=rx+1; 0064 % end; 0065 end; 0066 if (spunct>3) 0067 break; % shit happened - we've stagnated 0068 end; 0069 err_old = err; 0070 if (err<tol) 0071 break; 0072 end; 0073 0074 % if (rx==10) 0075 % break; 0076 % end; 0077 % if (conv_fact<1.01) 0078 % break; 0079 % end; 0080 end; 0081 0082 end