%Figure 4.b
clear
clc
addpath('rntk/')
addpath('NTK/')
addpath('Poly/')
addpath('RBF/')
addpath('utill/')
%%
kfold = 5;
fixed_length = 10;
n_train = 20;
n_test = 5000;
data_length = 1000;
noise_std_all = [0.01,0.05,0.1,0.5];
nonoise_data = sin((1:data_length)*2*pi/data_length);
repeat = 1;
T_var = 10;
allresults = cell(4,length(noise_std_all),repeat);
%%
counter = 0;
time = 0;
for a = 1:length(noise_std_all)
for r = 1:repeat
tic
noise_std = noise_std_all(a);
max_variable_length = T_var;
data = nonoise_data + noise_std*randn(1,data_length);
train_length = fixed_length + randi([0 max_variable_length],1,n_train);
test_length = fixed_length + randi([0 max_variable_length],1,n_test);
id = crossvalind('Kfold',n_train,kfold);
%%
data_position_train = randi([1,data_length - max_variable_length - fixed_length - 1],1,n_train);
data_position_test = randi([1,data_length - max_variable_length - fixed_length - 1],1,n_test);
x_train = zeros(n_train,max_variable_length + fixed_length);
nn = max_variable_length + fixed_length;
y_train = zeros(n_train,1);
for i = 1:n_train
    x_train(i,(nn - train_length(i)+1): nn ) = data(data_position_train(i):(data_position_train(i) + train_length(i) - 1));
    y_train(i) = data(data_position_train(i) + train_length(i));
end

x_test = zeros(n_test,max_variable_length + fixed_length);
y_test = zeros(n_test,1);
y_test_nonoise = zeros(n_test,1);
for i = 1:n_test
    x_test(i,(nn - test_length(i)+1): nn ) = data(data_position_test(i):(data_position_test(i) + test_length(i) -1));
    y_test(i) = data(data_position_test(i) + test_length(i));
    y_test_nonoise(i) = nonoise_data(data_position_test(i) + test_length(i));
end
%%
rntkresult = RNTKbestresult_D(x_train,y_train,train_length,id);
kernel_train = gramRNTKdifferentlength(x_train,x_train,rntkresult.param,train_length,train_length);
kernel_test = gramRNTKdifferentlength(x_test,x_train,rntkresult.param,test_length,train_length);
b = pinv(kernel_train + rntkresult.lambda*eye(size(kernel_train)))*y_train;
y_hat_test = kernel_test*b;
mse_noise_rntk = mean(  (y_hat_test - y_test).^2 );
snr_noise_rntk = snr(y_test,y_test - y_hat_test);
rntkresult.snr_noise = snr_noise_rntk;
rntkresult.mse_noise = mse_noise_rntk;

mse_nonoise_rntk = mean(  (y_hat_test - y_test_nonoise).^2 );
snr_nonoise_rntk = snr(y_test_nonoise,y_test_nonoise - y_hat_test);
rntkresult.snr_nonoise = snr_nonoise_rntk;
rntkresult.mse_nonoise = mse_nonoise_rntk;

allresults{1,a,r} = rntkresult;
%%
ntkresult = NTKbestresult_D(x_train,y_train,id);
kernel_train = gramNTK(x_train,x_train,ntkresult.param);
kernel_test = gramNTK(x_test,x_train,ntkresult.param);
b = pinv(kernel_train + ntkresult.lambda*eye(size(kernel_train)))*y_train;
y_hat_test = kernel_test*b;

mse_noise_ntk = mean(  (y_hat_test - y_test).^2 );
snr_noise_ntk = snr(y_test,y_test - y_hat_test);
ntkresult.snr_noise = snr_noise_ntk;
ntkresult.mse_noise = mse_noise_ntk;

mse_nonoise_ntk = mean(  (y_hat_test - y_test_nonoise).^2 );
snr_nonoise_ntk = snr(y_test_nonoise,y_test_nonoise - y_hat_test);
ntkresult.snr_nonoise = snr_nonoise_ntk;
ntkresult.mse_nonoise = mse_nonoise_ntk;

allresults{2,a,r} = ntkresult;
%%
rbfresult = RBFbestresult_D(x_train,y_train,id);
kernel_train = gramRBF(x_train,x_train,rbfresult.alpha);
kernel_test = gramRBF(x_test,x_train,rbfresult.alpha);
b = pinv(kernel_train + rbfresult.lambda*eye(size(kernel_train)))*y_train;
y_hat_test = kernel_test*b;
mse_noise_rbf = mean(  (y_hat_test - y_test).^2 );
snr_noise_rbf = snr(y_test,y_test - y_hat_test);

rbfresult.snr_noise = snr_noise_rbf;
rbfresult.mse_noise = mse_noise_rbf;

mse_nonoise_rbf = mean(  (y_hat_test - y_test_nonoise).^2 );
snr_nonoise_rbf = snr(y_test_nonoise,y_test_nonoise - y_hat_test);

rbfresult.snr_nonoise = snr_nonoise_rbf;
rbfresult.mse_nonoise = mse_nonoise_rbf;

allresults{3,a,r} = rbfresult;
%%
polyresult = POLYbestresult_D(x_train,y_train,id);
kernel_train = gramPOLY(x_train,x_train,polyresult.d,polyresult.r);
kernel_test = gramPOLY(x_test,x_train,polyresult.d,polyresult.r);
b = pinv(kernel_train + polyresult.lambda*eye(size(kernel_train)))*y_train;
y_hat_test = kernel_test*b;

mse_noise_poly = mean(  (y_hat_test - y_test).^2 );
snr_noise_poly = snr(y_test,y_test - y_hat_test);
polyresult.snr_noise = snr_noise_poly;
polyresult.mse_noise = mse_noise_poly;

mse_nonoise_poly = mean(  (y_hat_test - y_test_nonoise).^2 );
snr_nonoise_poly = snr(y_test_nonoise,y_test_nonoise - y_hat_test);

polyresult.snr_nonoise = snr_nonoise_poly;
polyresult.mse_nonoise = mse_nonoise_poly;


counter = counter+1;
time(counter) = toc;
allresults{4,a,r} = polyresult;
percent_complete = 100*((a-1)*repeat + r)/(repeat*length(noise_std_all));
remaining_time = (100 - percent_complete)*(  sum(time)/percent_complete );
fprintf('%.2f percent complete. Estimated remaining time: %.2f minutes \n', percent_complete,remaining_time/60)
end
end
%%
SNRs = zeros(4,length(noise_std_all));
for i = 1:length(noise_std_all)
    for j = 1:4
        for r = 1:repeat
            SNRs(j,i) = SNRs(j,i) + (allresults{j,i,r}.snr_nonoise)/repeat;
        end
    end
end
%%
plot(noise_std_all,SNRs','LineWidth',3.5,'Marker','*')
LEG = legend('RNTK','NTK','RBF','Polynomial','NumColumns',1,'Interpreter','latex', 'Location','best');
LEG.FontSize = 35;
ylim([(min(min(SNRs))- 0.5) (max(max(SNRs))+1)])
set(gca,'FontSize',20)
ax = gca;
outerpos = ax.OuterPosition;
ti = ax.TightInset; 
left = outerpos(1) + ti(1);
bottom = outerpos(2) + ti(2);
ax_width = outerpos(3) - ti(1) - ti(3);
ax_height = outerpos(4) - ti(2) - ti(4);
ax.Position = [left bottom ax_width ax_height];

