

clear all
close all


%%
angle = [35 80]; 
mI = 41;
xx = linspace(1,2,mI);
yy = linspace(-1,0,mI);
[ma,mb] = meshgrid(xx,yy);
vx = 2+2*sin(xx*2);
vy = 0.2*sqrt((1:length(yy))*4);
ran = cos((ma.*vx+mb.*vy)*4)+sin((ma.*vx-mb.*vy)'*4);
ran = ((ran-1).*(exp(-((xx-0.5).^2+(yy'-1).^2)/5)));
[uu,ss,vv] = svd(ran,'econ');
uest = uu(:,1:4);
vest = vv(:,1:4);
Yest = uest*vest';
AMSEest = (norm(Yest - ran, 'fro').^2)/(size(Yest,1)*size(Yest,2));

tobemax = 1;
[y_optiTrue,x_optiTrue] = max(ran(:));
f = ran(:); 
dim = 2;
N0 = 30; 
Nobs = 50;
hh = 1:dim;
seedr = 2;
rng(seedr)
x = [];
for d = 1:N0
    x = [x; [xx(randi(length(xx))), yy(randi(length(yy)))]];
end
xdSpace = cell(dim,1);
xdSpace{1} = xx;
xdSpace{2} = yy;
Xdfine = cell(dim,1);
xallfine = [];
[Xdfine{1:dim}] = ndgrid(xdSpace{1:dim});
for d = 1:dim
    xallfine = [xallfine, Xdfine{d}(:)];
end
y_true = zeros(size(x,1),1);
for t = 1:N0
    [~,irow] = ismember(x(t,:), xallfine, 'rows');
    y_true(t) = f(irow); 
end


%%

R = 4;
max_iter = 1000;
n_sample = 400;
G = cell(dim,1);
G_save = cell(dim,1);
for d = 1:dim
    G{d} = 0.1*randn(length(xdSpace{d}),R);
    G_save{d} = zeros([length(xdSpace{d}), R, n_sample]);
end
ddd = 3;
tau0 = 1/1e-2;
a0 = 9;
b0 = 0.1*AMSEest;
hpri_mul = 0.2;
hpri_muv = sqrt(1);
hpri_precision = 1/1;
hpri_se = [hpri_mul, hpri_precision; hpri_muv, hpri_precision];
Fest = [];
AFvalue = [];

%%

xdidx = cell(dim,1);
for d = 1:dim    
    [x_obsu,~,IC] = unique(x(:,d));
    xdidx{d} = find(ismember(xdSpace{d}, x_obsu));
    xdidx{d} = xdidx{d}(IC);
end

if tobemax
    [val,pos] = max(y_true(:)); 
    Fest = [Fest, abs(y_optiTrue-val)];
else
    [val,pos] = min(y_true(:)); 
    Fest = [Fest, abs(y_optiTrue-val)];
end

tic
for j = 1:Nobs

Yobs = zeros(size(Xdfine{1}));
if j>1
    for d = 1:dim
        xdidx{d} = [xdidx{d}, find(x(end,d) == xdSpace{d})];
    end
end

idx = cell(dim,1);
for d = 1:dim
    hhd = hh; hhd(d) = []; hhd = [d,hhd];
    idx{d} = sub2ind(size(Yobs), xdidx{hhd});
end
Yobs(idx{1}) = y_true;

mask_matrixT = double(Yobs & ones(size(Yobs)));
num_Obser = length(find(mask_matrixT~=0));
mask_matrix = cell(dim,1);
for d = 1:dim
    mask_matrix{d} = Unfold(mask_matrixT, d);
end

Y_tensor = Yobs;
Y_matrix = reshape(Y_tensor, [size(Yobs,1),length(Yobs(:))/size(Yobs,1)]);
train_matrix = Y_matrix.*mask_matrix{1};
train_matrix(find(train_matrix == 0)) = [];
y_truemm = (y_true-mean(train_matrix(:)))./std(train_matrix(:));

ThetaG_est = cell(dim);
Sigma_G = cell(dim);
for d = 1:dim
        ThetaG_est{d} = [log(hpri_mul),log(hpri_muv)];
        Sigma_G{d} = covMatern(ddd, [ThetaG_est{d}(1);ThetaG_est{d}(2)], xdSpace{d});
end

LambdaR = randn(R,1);
tau = zeros(1,max_iter);
tau(1) = tau0;

elemf = zeros(length(xdidx{1}), R);
Gfr = cell(dim,R);
for d = 1:dim
    hhd = hh; hhd(d) = [];
    for ii = 1:length(xdidx{1})
        elemf(ii,:) = G{hhd(1)}(xdidx{hhd(1)}(ii),:);
        for dd = 2:dim-1
            elemf(ii,:) = elemf(ii,:).*G{hhd(dd)}(xdidx{hhd(dd)}(ii),:);
        end
    end
    for r = 1:R
        Gfr{d,r} = LambdaR(r).*elemf(:,r);
    end
end
elem = zeros(length(xdidx{1}), R);
Yf_obs = cell(R,1);
for ii = 1:length(xdidx{1})
    elem(ii,:) = elemf(ii,:).*G{dim}(xdidx{dim}(ii),:);
end
for r = 1:R
    elem(:,r) = LambdaR(r).*elem(:,r);
end
for r = 1:R
    Yf_obs{r} = y_truemm-sum(elem,2)+elem(:,r);
end
elemfL = zeros(length(xdidx{1}), R);
GfrL = cell(R);

Qg = cell(dim);
Lg = cell(dim);
HHg = cell(dim,R);
Hyg = cell(dim,R);
Gkronr = cell(dim,R);
LambdaGr = cell(dim,R);
for d = 1:dim
    [Qg{d}, Lg{d}] = eig(Sigma_G{d});
    for r = 1:R
        HHg_temp = zeros(size(mask_matrix{d},1), length(xdidx{1}));
        Hyg_temp = zeros(size(mask_matrix{d},1), length(xdidx{1}));
        for ii = 1:length(xdidx{1})
            [ri,~] = ind2sub(size(mask_matrix{d}), idx{d}(ii));
            HHg_temp(ri,ii) = Gfr{d,r}(ii)^2;
            Hyg_temp(ri,ii) = Yf_obs{r}(ii)*Gfr{d,r}(ii);
        end
        HHg_temp = sum(HHg_temp,2);
        HHg{d,r} = diag(HHg_temp);
        Hyg_temp = sum(Hyg_temp,2);
        Hyg{d,r} = Hyg_temp;               
        LambdaGr{d,r} = tau(1)*HHg{d,r} + Qg{d}*diag(1./diag(Lg{d}))*Qg{d}';
    end
end

LambdaG = cell(1,dim);
cholLG = cell(1,dim);
likeli_KG = zeros(dim,1);
for d = 1:dim
    LambdaG{d} = blkdiag(LambdaGr{d,1:R});
    cholLG{d} = chol(LambdaG{d});
    uu = cholLG{d}'\(cell2mat(Hyg(d,:)'));
    temp1 = 0.5*tau(1)^2*(uu'*uu);
    temp2 = -sum(log(diag(cholLG{d})))-0.5*R*sum(log(diag(Lg{d})));
    likeli_KG(d) = temp1 + temp2;
end


for iter = 1:max_iter
       
    theta_estg = cell(dim);   
    for d = 1:dim
        theta_estg{d} = ThetaG_est{d}(iter,:);
        [Theta_return,Sigma_G{d},Qg{d},Lg{d},cholLG{d},likeli_KG(d)] = hyperMatern([log(10), log(10)], theta_estg{d}, 1, likeli_KG(d), ...
            tau(iter), Hyg(d,:), HHg(d,:), [log(1e-3), log(1e-3)], [log(1e3), log(1e3)], hpri_se, xdSpace{d}, ddd, R);       
        ThetaG_est{d}(iter+1,1) = Theta_return;
        theta_estg{d}(1) = Theta_return;       
        ThetaG_est{d}(iter+1,2) = ThetaG_est{d}(iter,2);    
    end    
    
    MeanGd = tau(iter)*(cholLG{1}\(cholLG{1}'\(cell2mat(Hyg(1,:)'))));
    Gd = mvnrndpre_ch(MeanGd, cholLG{1});
    G{1} = reshape(Gd, [length(xdSpace{1}),R]);
    
    for d = 2:dim
        for r = 1:R
            hhd = hh; hhd(d) = [];
            for ii = 1:length(xdidx{1})
                elemf(ii,:) = G{hhd(1)}(xdidx{hhd(1)}(ii),:);
                for dd = 2:dim-1
                    elemf(ii,:) = elemf(ii,:).*G{hhd(dd)}(xdidx{hhd(dd)}(ii),:);
                end
            end
            Gfr{d,r} = LambdaR(r,iter).*elemf(:,r);           
            for ii = 1:length(xdidx{1})
                elem(ii,:) = elemf(ii,:).*G{d}(xdidx{d}(ii),:);
            end
            for rr = 1:R
                elem(:,rr) = LambdaR(rr,iter).*elem(:,rr);
            end
            Yf_obs{r} = y_truemm-sum(elem,2)+elem(:,r);            
            HHg_temp = zeros(size(mask_matrix{d},1), length(xdidx{1}));
            Hyg_temp = zeros(size(mask_matrix{d},1), length(xdidx{1}));
            for ii = 1:length(xdidx{1})
                [ri,~] = ind2sub(size(mask_matrix{d}), idx{d}(ii));
                HHg_temp(ri,ii) = Gfr{d,r}(ii)^2;
                Hyg_temp(ri,ii) = Yf_obs{r}(ii)*Gfr{d,r}(ii);
            end
            HHg_temp = sum(HHg_temp,2);
            HHg{d,r} = diag(HHg_temp);
            Hyg_temp = sum(Hyg_temp,2);
            Hyg{d,r} = Hyg_temp;
            LambdaGr{d,r} = tau(iter)*HHg{d,r} + Qg{d}*diag(1./diag(Lg{d}))*Qg{d}';
        end
        LambdaG{d} = blkdiag(LambdaGr{d,1:R});
        cholLG{d} = chol(LambdaG{d});       
        MeanGd = tau(iter)*(cholLG{d}\(cholLG{d}'\(cell2mat(Hyg(d,:)'))));
        Gd = mvnrndpre_ch(MeanGd, cholLG{d});
        G{d} = reshape(Gd, [length(xdSpace{d}),R]);
    end
         
    for ii = 1:length(xdidx{1})
        elem(ii,:) = G{1}(xdidx{1}(ii),:);
        for d = 2:dim
            elem(ii,:) = elem(ii,:).*G{d}(xdidx{d}(ii),:);
        end
    end  
    for r = 1:R
        elem(:,r) = LambdaR(r,iter).*elem(:,r);
    end
    y_est = sum(elem,2);


    if iter > (max_iter-n_sample)
        St_count = mod(iter-(max_iter-n_sample), n_sample);
        if St_count == 0
            St_count = n_sample;
        end        
        for d = 1:dim
            G_save{d}(:,:,St_count) = G{d};
        end

    end
    
    Err = norm(y_truemm - y_est, 'fro').^2;
    a_tau = a0+0.5*num_Obser;
    b_tau = b0+0.5*Err;
    tau(iter+1) = gamrnd(a_tau,1/b_tau);   
    
    for ii = 1:length(xdidx{1})
        elem(ii,:) = G{1}(xdidx{1}(ii),:);
        for d = 2:dim
            elem(ii,:) = elem(ii,:).*G{d}(xdidx{d}(ii),:);
        end
    end    
    LambdaL = tau(iter+1)*(elem'*elem) + eye(R);
    MeanL = tau(iter+1)*((LambdaL)\elem'*y_truemm);
    LambdaR(:,iter+1) = mvnrndpre_ch(MeanL, LambdaL);   
    
    for d = 1:dim
        hhd = hh; hhd(d) = [];
        for ii = 1:length(xdidx{1})
            elemf(ii,:) = G{hhd(1)}(xdidx{hhd(1)}(ii),:);
            for dd = 2:dim-1
                elemf(ii,:) = elemf(ii,:).*G{hhd(dd)}(xdidx{hhd(dd)}(ii),:);
            end
        end
        for r = 1:R
            Gfr{d,r} = LambdaR(r,iter+1).*elemf(:,r);
        end
    end    
    for r = 1:R
        elem(:,r) = LambdaR(r,iter+1).*elem(:,r);
    end
    for r = 1:R
        Yf_obs{r} = y_truemm-sum(elem,2)+elem(:,r);
    end
   
    for d = 1:dim
        for r = 1:R
            HHg_temp = zeros(size(mask_matrix{d},1), length(xdidx{1}));
            Hyg_temp = zeros(size(mask_matrix{d},1), length(xdidx{1}));
            for ii = 1:length(xdidx{1})
                [ri,~] = ind2sub(size(mask_matrix{d}), idx{d}(ii));
                HHg_temp(ri,ii) = Gfr{d,r}(ii)^2;
                Hyg_temp(ri,ii) = Yf_obs{r}(ii)*Gfr{d,r}(ii);
            end
            HHg_temp = sum(HHg_temp,2);
            HHg{d,r} = diag(HHg_temp);
            Hyg_temp = sum(Hyg_temp,2);
            Hyg{d,r} = Hyg_temp;        
            LambdaGr{d,r} = tau(iter+1)*HHg{d,r} + Qg{d}*diag(1./diag(Lg{d}))*Qg{d}';
        end       
        LambdaG{d} = blkdiag(LambdaGr{d,1:R});
        cholLG{d} = chol(LambdaG{d});
        uu = cholLG{d}'\(cell2mat(Hyg(d,:)'));
        temp1 = 0.5*tau(iter+1)^2*(uu'*uu);
        temp2 = -sum(log(diag(cholLG{d})))-0.5*R*sum(log(diag(Lg{d})));
        likeli_KG(d) = temp1 + temp2;
    end
    

end
toc
fprintf('OptIter = %g, Epoch = %g, tau = %g \n',...
    j,iter,tau(iter+1));


%%


Y_est_St = zeros([size(Y_matrix),n_sample]);
for nn = 1:n_sample
    Gkr = G_save{dim}(:,:,nn);
    for d = dim-1:-1:2
        Gkr = kr(Gkr, G_save{d}(:,:,nn));
    end
    for r = 1:R
        G_save{1}(:,r,nn) = LambdaR(r,max_iter-n_sample+1+nn).*G_save{1}(:,r,nn);
    end
    Y_est = (G_save{1}(:,:,nn)*Gkr').*std(train_matrix(:))+mean(train_matrix(:));
    Y_est_St(:,:,nn) = Y_est;
end


if tobemax
    acf = max(Y_est_St,[],3); 
    [AFopt,posAF] = max(acf(:));
else
    acf = min(Y_est_St,[],3); 
    [AFopt,posAF] = min(acf(:));
end
xAF = xallfine(posAF,:);
AFvalue = [AFvalue, AFopt];
x(end+1,:) = xAF;
[~,irow] = ismember(x(end,:), xallfine, 'rows');
y_true(end+1) = f(irow);
    
if tobemax
    [val,pos] = max(y_true(:)); 
    Fest = [Fest, abs(y_optiTrue-val)];
else
    [val,pos] = min(y_true(:)); 
    Fest = [Fest, abs(y_optiTrue-val)];
end


end


if tobemax 
        [ao,bo] = max(y_true); 
        str = 'Maximum';
else       
        [ao,bo] = min(y_true);
        str = 'Minimum';
end
fprintf('Bayesian Optimization\n');
fprintf('  %s (estimated):\n\ty(%.6f,%.6f) = %.6f\n',...
    str,x(bo,1),x(bo,2),ao);
fprintf('  %s (true function):\n\ty(%.6f,%.6f) = %.6f\n',...
    str,xallfine(x_optiTrue,1),xallfine(x_optiTrue,2),y_optiTrue);


%%

figure;
subplot(3,3,1)
plot(0:Nobs, Fest, 'linewidth', 1.3);
xlabel('No. of iterations'); 
ylabel('$|f^{\star}-\hat{f^{\star}}|$','Interpreter','latex');
grid on
title('BKTF-BUCB')
xlim([0,Nobs])


subplot(3,3,2)
plot(xdSpace{1},mean(G_save{1},3),'linewidth',1.2);
grid on;
title('G1')


subplot(3,3,3)
plot(xdSpace{2},mean(G_save{2},3),'linewidth',1.2);
grid on;
title('G2')



subplot(3,3,4)
plot(LambdaR(:,end),'linewidth',1.2);
grid on; 
xlim([1,R]);
xlabel('R')
ylabel('value')
title('final weights \lambda')



subplot(3,3,5)
for d = 1:dim
    plot(exp(ThetaG_est{d}(:,1)),'linewidth',1.2)
    hold on;plot(exp(ThetaG_est{d}(:,1)),'linewidth',1.2)
    hold on;
end
grid on; xlim([1,max_iter]);
xlabel('iteration')
ylabel('value')
title('lengthscale of G')


subplot(3,3,6)
plot(1./tau,'linewidth',1.2)
grid on; xlim([1,max_iter]);
xlabel('iteration')
ylabel('value')
title('variance of noise')
ylim([0,1])


subplot(3,3,7)
surf(Xdfine{1},Xdfine{2},reshape(mean(Y_est_St,3),size(Xdfine{1})));
shading interp; 
hold on; view(angle);
scatter3(x(:,1),x(:,2),y_true,20,'g','filled',...
    'MarkerEdgeColor','k');                
title(sprintf('%s, No. of observed points: %d',str,length(x)));



subplot(3,3,8)
surf(Xdfine{1},Xdfine{2},acf);
shading interp; 
hold on; view(angle);
scatter3(x(end,1),x(end,2),AFvalue(end),20,'g','filled',...
    'MarkerEdgeColor','k');                  
title('acf');



