0001 function [c] = mtimes(a,b,varargin)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029 if (isa(a,'tt_matrix') && isa(b,'double') && numel(b) == 1 || (isa(b,'tt_matrix') && isa(a,'double') && numel(a) == 1) )
0030 c=tt_matrix;
0031 if ( isa(b,'double') )
0032 c.n=a.n; c.m=a.m; c.tt=(a.tt)*b;
0033 else
0034 c.n=b.n; c.m=b.m; c.tt=a*(b.tt);
0035 end
0036 elseif ( isa(a,'tt_matrix') && isa(b,'tt_tensor') && nargin == 2)
0037
0038
0039
0040
0041 c=tt_tensor;
0042 n=a.n; m=a.m; crm=a.tt.core; ttm=a.tt; psm=ttm.ps; d=ttm.d; rm=ttm.r;
0043 rv=b.r; crv=b.core; psv=b.ps;
0044 rp=rm.*rv;
0045 psp=cumsum([1;n.*rp(1:d).*rp(2:d+1)]);
0046 crp=zeros(psp(d+1)-1,1);
0047 c.d=d;
0048 c.r=rp;
0049 c.n=n;
0050 c.ps=psp;
0051 for i=1:d
0052 mcur=crm(psm(i):psm(i+1)-1);
0053 vcur=crv(psv(i):psv(i+1)-1);
0054 mcur=reshape(mcur,[rm(i)*n(i),m(i),rm(i+1)]);
0055 mcur=permute(mcur,[1,3,2]); mcur=reshape(mcur,[rm(i)*n(i)*rm(i+1),m(i)]);
0056 vcur=reshape(vcur,[rv(i),m(i),rv(i+1)]);
0057 vcur=permute(vcur,[2,1,3]); vcur=reshape(vcur,[m(i),rv(i)*rv(i+1)]);
0058 pcur=mcur*vcur;
0059 pcur=reshape(pcur,[rm(i),n(i),rm(i+1),rv(i),rv(i+1)]);
0060 pcur=permute(pcur,[1,4,2,3,5]);
0061 crp(psp(i):psp(i+1)-1)=pcur(:);
0062 end
0063 c.core=crp;
0064
0065
0066 elseif ( isa(a,'tt_tensor') && isa(b,'tt_matrix') && nargin == 2)
0067 fprintf('vector-by-matrix not implemented yet \n');
0068 elseif ( isa(a,'tt_matrix') && isa(b,'double') && nargin == 2 )
0069
0070 n=a.n; m=a.m; tt=a.tt; cra=tt.core; d=tt.d; ps=tt.ps; r=tt.r;
0071 rb=size(b,2);
0072 c=reshape(b,[m',rb]);
0073
0074 for k=1:d
0075
0076
0077 cr=cra(ps(k):ps(k+1)-1);
0078 cr=reshape(cr,[r(k),n(k),m(k),r(k+1)]);
0079 cr=permute(cr,[2,4,1,3]); cr=reshape(cr,[n(k)*r(k+1),r(k)*m(k)]);
0080 M=numel(c);
0081 c=reshape(c,[r(k)*m(k),M/(r(k)*m(k))]);
0082 c=cr*c; c=reshape(c,[n(k),numel(c)/n(k)]);
0083 c=permute(c,[2,1]);
0084 end
0085 c=c(:); c=reshape(c,[rb,numel(c)/rb]);
0086 c=c.';
0087 elseif ( isa(a,'tt_matrix') && isa(b,'tt_matrix') && nargin == 2)
0088
0089
0090
0091 c = tt_matrix;
0092 d = a.tt.d;
0093 n = a.n; m = b.m; p = a.m;
0094 c.n = n; c.m = m;
0095 tt = tt_tensor;
0096 ra = a.tt.r;
0097 rb = b.tt.r;
0098 tt.d = d;
0099 tt.r = ra.*rb;
0100 tt.n = n.*m;
0101 tt.ps = cumsum([1; tt.r(1:d).*n.*m.*tt.r(2:d+1)]);
0102 tt.core = zeros(tt.ps(d+1)-1, 1);
0103 for i=1:d
0104 cra = a.tt{i};
0105 cra = reshape(cra, ra(i), n(i), p(i), ra(i+1));
0106 crb = b.tt{i};
0107 crb = reshape(crb, rb(i), p(i), m(i), rb(i+1));
0108 cra = permute(cra, [1, 2, 4, 3]);
0109 cra = reshape(cra, ra(i)*n(i)*ra(i+1), p(i));
0110 crb = permute(crb, [2, 1, 3, 4]);
0111 crb = reshape(crb, p(i), rb(i)*m(i)*rb(i+1));
0112 crc = cra*crb;
0113 crc = reshape(crc, ra(i), n(i), ra(i+1), rb(i), m(i), rb(i+1));
0114 crc = permute(crc, [1, 4, 2, 5, 3, 6]);
0115 tt.core(tt.ps(i):tt.ps(i+1)-1) = crc(:);
0116 end;
0117 c.tt = tt;
0118
0119 elseif ( isa(a,'tt_matrix') && isa(b,'tt_tensor') && nargin > 3)
0120 c=mvk(a,b,varargin);
0121 fprintf('Krylov matrix-by-vector not implemented yet \n');
0122
0123 elseif ( isa(a,'tt_matrix') && isa(b,'tt_matrix') && nargin == 3)
0124 fprintf('Krylov matrix-by-matrix not implemented yet \n');
0125
0126
0127 end