0001
0002
0003 function [x,RESVEC,rw,rx] = tt_gmres(A, b, tol, maxout, maxin, eps_x, eps_z, M1, M2, M3, x0, verbose, varargin)
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025 [atype,afun,afcnstr] = tt_iterchk(A);
0026
0027
0028 max_swp=12;
0029 max_sp_factor = 1.5;
0030 max_zrank_factor = 4;
0031 derr_tol_for_sp = 1.0;
0032
0033 use_err_trunc = 1;
0034 err_trunc_power = 0.8;
0035 compute_real_res = 0;
0036
0037
0038
0039 t0 = tic;
0040
0041 existsM1=1;
0042 existsM2=1;
0043 existsM3=1;
0044 existsx0=1;
0045 if ((nargin<8)||(isempty(M1)))
0046 existsM1=0;
0047 end;
0048 if ((nargin<9)||(isempty(M2)))
0049 existsM2=0;
0050 end;
0051 if ((nargin<10)||(isempty(M3)))
0052 existsM3=0;
0053 end;
0054 if ((nargin<11)||(isempty(x0)))
0055 existsx0=0;
0056 end;
0057
0058 if ((nargin<12)||(isempty(verbose)))
0059 verbose=1;
0060 end;
0061
0062 pre_b = b;
0063 max_rank=[];
0064
0065 if (compute_real_res==1)
0066 pre_b2 = b;
0067 end;
0068
0069 if (existsM1)
0070 [m1type,m1fun,m1fcnstr] = tt_iterchk(M1);
0071 if (compute_real_res==1)
0072 pre_b2 = tt_iterapp('mldivide',m1fun,m1type,m1fcnstr,pre_b2,min(eps_z, 0.5), max_rank, max_swp,varargin{:});
0073 end;
0074 end;
0075 if (existsM2)
0076 [m2type,m2fun,m2fcnstr] = tt_iterchk(M2);
0077 if (compute_real_res==1)
0078 pre_b2 = tt_iterapp('mldivide',m2fun,m2type,m2fcnstr,pre_b2,min(eps_z, 0.5), max_rank, max_swp,varargin{:});
0079 end;
0080 end;
0081 if (existsM3)
0082 [m3type,m3fun,m3fcnstr] = tt_iterchk(M3);
0083 if (compute_real_res==1)
0084 pre_b2 = tt_iterapp('mldivide',m3fun,m3type,m3fcnstr,pre_b2,min(eps_z, 0.5), max_rank, max_swp,varargin{:});
0085 end;
0086 end;
0087
0088
0089 norm_f = tt_dot2(pre_b,pre_b)*0.5;
0090 mod_norm_f = sign(norm_f)*mod(abs(norm_f), 10);
0091 order_norm_f = norm_f - mod_norm_f;
0092 cur_norm_f = exp(mod_norm_f);
0093
0094 if (existsx0)
0095 x = x0;
0096 else
0097
0098
0099 x = tt_zeros(max(size(b)), tt_size(b));
0100 end;
0101
0102
0103
0104 H = zeros(maxin+1, maxin);
0105 v = cell(maxin, 1);
0106
0107 if (nargout>1)
0108 RESVEC = ones(maxout, maxin);
0109 end;
0110 if (nargout>2)
0111 rw = zeros(maxout,maxin);
0112 end;
0113 if (nargout>2)
0114 rx = zeros(maxout,maxin);
0115 end;
0116
0117
0118 err=2;
0119 max_err=Inf;
0120 old_err = 1;
0121 stagpoints=0;
0122
0123 for nitout=1:maxout
0124 max_rank=[];
0125 err_for_trunc=1;
0126 Ax = tt_iterapp('mtimes',afun,atype,afcnstr,x,min(eps_z/err_for_trunc, 0.5), max_rank, max_swp,varargin{:});
0127 r = tt_axpy2(0,1, pre_b, 0,-1, Ax, min(eps_z/err_for_trunc, 0.5), max_rank);
0128 if (existsM1)
0129 r = tt_iterapp('mldivide',m1fun,m1type,m1fcnstr,r,min(eps_z/err_for_trunc, 0.5), max_rank, max_swp,varargin{:});
0130
0131 if (existsM2)
0132 r = tt_iterapp('mldivide',m2fun,m2type,m2fcnstr,r,min(eps_z/err_for_trunc, 0.5), max_rank, max_swp,varargin{:});
0133
0134 end;
0135 if (existsM3)
0136 r = tt_iterapp('mldivide',m3fun,m3type,m3fcnstr,r,min(eps_z/err_for_trunc, 0.5), max_rank, max_swp,varargin{:});
0137
0138 end;
0139 end;
0140
0141
0142
0143
0144 beta = tt_dot2(r,r)*0.5;
0145 mod_beta = sign(beta)*mod(abs(beta), 10);
0146 order_beta = beta - mod_beta;
0147 cur_beta = exp(mod_beta);
0148 if (verbose==1)
0149 real_beta = exp(beta);
0150 fprintf(' cur_beta = %g, real_beta = %g\n', cur_beta, real_beta);
0151 end;
0152
0153 if (nitout==1)
0154 cur_normb = cur_beta;
0155 order_normb = order_beta;
0156
0157 end;
0158 v{1} = tt_scal2(r, -beta, 1);
0159
0160 for j=1:maxin
0161 max_w_rank = 0;
0162 w = tt_iterapp('mtimes',afun,atype,afcnstr,v{j},min(eps_z/err_for_trunc, 0.5), max_rank, max_swp,varargin{:});
0163 max_w_rank = max([max_w_rank; tt_ranks(w)]);
0164
0165
0166 if (existsM1)
0167 w = tt_iterapp('mldivide',m1fun,m1type,m1fcnstr,w,min(eps_z/err_for_trunc, 0.5), max_rank, max_swp,varargin{:});
0168 max_w_rank = max([max_w_rank; tt_ranks(w)]);
0169
0170 if (existsM2)
0171 w = tt_iterapp('mldivide',m2fun,m2type,m2fcnstr,w,min(eps_z/err_for_trunc, 0.5), max_rank, max_swp,varargin{:});
0172 max_w_rank = max([max_w_rank; tt_ranks(w)]);
0173
0174
0175 end;
0176 if (existsM3)
0177 w = tt_iterapp('mldivide',m3fun,m3type,m3fcnstr,w,min(eps_z/err_for_trunc, 0.5), max_rank, max_swp,varargin{:});
0178 max_w_rank = max([max_w_rank; tt_ranks(w)]);
0179
0180 end;
0181 end;
0182
0183 max_wrank = max(tt_ranks(w));
0184
0185
0186 for i=1:j
0187 H(i,j)=tt_dot(w, v{i});
0188
0189 max_w_rank = max([max_w_rank; tt_ranks(w)]);
0190 w = tt_axpy2(0,1, w, log(abs(H(i,j))+1e-308), -1*sign(H(i,j)), v{i}, eps_z, max_rank);
0191
0192
0193 end;
0194
0195
0196
0197
0198
0199
0200
0201
0202 if (nargout>2)
0203 rw(nitout,j)=max_w_rank;
0204 end;
0205
0206 H(j+1,j) = sqrt(tt_dot(w, w));
0207 if (j<maxin)
0208 v{j+1}=tt_scal2(w, -log(H(j+1,j)), 1);
0209
0210 end;
0211
0212 [UH,SH,VH]=svd(H(1:j+1, 1:j), 0);
0213 SH = diag(SH);
0214 sigma_min_H = min(SH);
0215 sigma_max_H = max(SH);
0216 if (verbose==1)
0217 fprintf(' min(sigma(H)) = %g\n', sigma_min_H);
0218 end;
0219 SH(1:numel(find(SH>1e-100)))=1./SH(1:numel(find(SH>1e-100)));
0220 SH = diag(SH);
0221 y = cur_beta*VH*SH*(UH(1,:)');
0222
0223
0224
0225
0226
0227 err = log(norm(H(1:j+1, 1:j)*y-[cur_beta zeros(1,j)]', 'fro')/cur_norm_f+1e-308);
0228 err = err+order_beta-order_norm_f;
0229 err = exp(err);
0230 if (use_err_trunc==1)
0231
0232 err_for_trunc = log(norm(H(1:j+1, 1:j)*y-[cur_beta zeros(1,j)]', 'fro')/cur_beta+1e-308);
0233 err_for_trunc = exp(err_for_trunc);
0234 err_for_trunc = (err_for_trunc*maxin*sigma_max_H/sigma_min_H)^(err_trunc_power);
0235 err_for_trunc = min(err_for_trunc, 1);
0236
0237 end;
0238
0239
0240
0241
0242 if (nargout>1)
0243 RESVEC(nitout,j)=err;
0244 end;
0245
0246
0247 x_new = x;
0248 max_x_rank = 0;
0249
0250 for i=j:-1:1
0251
0252
0253 x_new = tt_axpy2(0,1, x_new, log(abs(y(i))+1e-308)+order_beta, sign(y(i)), v{i}, eps_x);
0254
0255 max_x_rank = max([max_x_rank; tt_ranks(x_new)]);
0256 end;
0257
0258 if (nargout>3)
0259 rx(nitout,j)=max_x_rank;
0260 end;
0261
0262
0263
0264
0265
0266 max_xrank = max(tt_ranks(x_new));
0267 max_rank = max_zrank_factor*max_xrank;
0268
0269 if (compute_real_res==1)
0270 res = tt_iterapp('mtimes',afun,atype,afcnstr,x_new,eps_z, max_rank, max_swp,varargin{:});
0271 if (existsM1)
0272 res = tt_iterapp('mldivide',m1fun,m1type,m1fcnstr,res,eps_z, max_rank, max_swp,varargin{:});
0273 if (existsM2)
0274 res = tt_iterapp('mldivide',m2fun,m2type,m2fcnstr,res,eps_z, max_rank, max_swp,varargin{:});
0275 end;
0276 if (existsM3)
0277 res = tt_iterapp('mldivide',m3fun,m3type,m3fcnstr,res,eps_z, max_rank, max_swp,varargin{:});
0278 end;
0279 end;
0280 res = tt_dist3(res, pre_b2)/sqrt(tt_dot(pre_b2,pre_b2));
0281 end;
0282
0283 derr = old_err/err;
0284 old_err = err;
0285
0286
0287
0288
0289 if (derr<derr_tol_for_sp) stagpoints=stagpoints+1; end;
0290 if (verbose==1)
0291
0292 if (compute_real_res==1)
0293 fprintf('iter = [%d,%d], derr = %3.2f, resid=%3.2e, real_res=%3.2e, rank_w=%d, rank_x=%d, sp=%d, time=%g\n', nitout, j, derr, err, res, max_wrank, max_xrank, stagpoints, toc(t0));
0294 else
0295 fprintf('iter = [%d,%d], derr = %3.2f, resid=%3.2e, rank_w=%d, rank_x=%d, sp=%d, time=%g\n', nitout, j, derr, err, max_w_rank, max_x_rank, stagpoints, toc(t0));
0296 end;
0297 end;
0298 if (err<max_err)
0299 x_good = x_new;
0300 max_err=err;
0301 end;
0302 if (err<tol) break; end;
0303 if (stagpoints>=maxin*max_sp_factor) break; end;
0304 end;
0305 x = x_good;
0306
0307 if (err<tol) break; end;
0308 if (stagpoints>=maxin*max_sp_factor) break; end;
0309 end;
0310
0311 if (nargout>1)
0312 RESVEC=RESVEC(1:nitout,:);
0313 if (nitout==1)
0314 RESVEC=RESVEC(:,1:j);
0315 end;
0316 end;
0317 if (nargout>2)
0318 rw=rw(1:nitout,:);
0319 if (nitout==1)
0320 rw=rw(:,1:j);
0321 end;
0322 end;
0323 if (nargout>3)
0324 rx=rx(1:nitout,:);
0325 if (nitout==1)
0326 rx=rx(:,1:j);
0327 end;
0328 end;
0329
0330 end