%Figure 4.d
clear
clc
addpath('rntk/')
addpath('NTK/')
addpath('Poly/')
addpath('RBF/')
addpath('utill/')
%%
kfold = 5;
fixed_length = 10;
data = readtable('GOOG_data.csv');
data = data{:,5};
data_train = data(1:700);
data_test = data(701:975);
data_length_train = length(data_train);
data_length_test = length(data_test);
data_train = data_train/max(data_train);
data_test = data_test/max(data_test);
T_var = 10;
number_train_data = [10,20,30,40,50];
n_test = 5000;
repeat = 1000;
allresults = cell(4,length(number_train_data),repeat);
%%
counter = 0;
time = 0;
for a = 1:length(number_train_data)
for r = 1:repeat
tic
max_variable_length = T_var;
n_train = number_train_data(a);
train_length = fixed_length + randi([0 max_variable_length],1,n_train);
test_length = fixed_length + randi([0 max_variable_length],1,n_test);
id = crossvalind('Kfold',n_train,kfold);
%%
data_position_train = randi([1,data_length_train - max_variable_length - fixed_length - 1],1,n_train);
data_position_test = randi([1,data_length_test - max_variable_length - fixed_length - 1],1,n_test);
nn = max_variable_length + fixed_length;
x_train = zeros(n_train,nn);
y_train = zeros(n_train,1);
for i = 1:n_train
    x_train(i,(nn - train_length(i)+1): nn ) = data_train(data_position_train(i):(data_position_train(i) + train_length(i) - 1));
    y_train(i) = data_train(data_position_train(i) + train_length(i));
end

x_test = zeros(n_test,max_variable_length + fixed_length);
y_test = zeros(n_test,1);
for i = 1:n_test
    x_test(i,(nn - test_length(i)+1): nn ) = data_test(data_position_test(i):(data_position_test(i) + test_length(i) -1));
    y_test(i) = data_test(data_position_test(i) + test_length(i));
end
%%
rntkresult = RNTKbestresult_D(x_train,y_train,train_length,id);
kernel_train = gramRNTKdifferentlength(x_train,x_train,rntkresult.param,train_length,train_length);
kernel_test = gramRNTKdifferentlength(x_test,x_train,rntkresult.param,test_length,train_length);
b = pinv(kernel_train + rntkresult.lambda*eye(size(kernel_train)))*y_train;
y_hat_test = kernel_test*b;
mse_noise_rntk = mean(  (y_hat_test - y_test).^2 );
snr_noise_rntk = snr(y_test,y_test - y_hat_test);
rntkresult.snr_noise = snr_noise_rntk;
rntkresult.mse_noise = mse_noise_rntk;


allresults{1,a,r} = rntkresult;
%%
ntkresult = NTKbestresult_D(x_train,y_train,id);
kernel_train = gramNTK(x_train,x_train,ntkresult.param);
kernel_test = gramNTK(x_test,x_train,ntkresult.param);
b = pinv(kernel_train + ntkresult.lambda*eye(size(kernel_train)))*y_train;
y_hat_test = kernel_test*b;

mse_noise_ntk = mean(  (y_hat_test - y_test).^2 );
snr_noise_ntk = snr(y_test,y_test - y_hat_test);
ntkresult.snr_noise = snr_noise_ntk;
ntkresult.mse_noise = mse_noise_ntk;


allresults{2,a,r} = ntkresult;
%%
rbfresult = RBFbestresult_D(x_train,y_train,id);
kernel_train = gramRBF(x_train,x_train,rbfresult.alpha);
kernel_test = gramRBF(x_test,x_train,rbfresult.alpha);
b = pinv(kernel_train + rbfresult.lambda*eye(size(kernel_train)))*y_train;
y_hat_test = kernel_test*b;
mse_noise_rbf = mean(  (y_hat_test - y_test).^2 );
snr_noise_rbf = snr(y_test,y_test - y_hat_test);
rbfresult.snr_noise = snr_noise_rbf;
rbfresult.mse_noise = mse_noise_rbf;


allresults{3,a,r} = rbfresult;
%%
polyresult = POLYbestresult_D(x_train,y_train,id);
kernel_train = gramPOLY(x_train,x_train,polyresult.d,polyresult.r);
kernel_test = gramPOLY(x_test,x_train,polyresult.d,polyresult.r);
b = pinv(kernel_train + polyresult.lambda*eye(size(kernel_train)))*y_train;
y_hat_test = kernel_test*b;

mse_noise_poly = mean(  (y_hat_test - y_test).^2 );
snr_noise_poly = snr(y_test,y_test - y_hat_test);
polyresult.snr_noise = snr_noise_poly;
polyresult.mse_noise = mse_noise_poly;
allresults{4,a,r} = polyresult;

counter = counter+1;
time(counter) = toc;
percent_complete = 100*((a-1)*repeat + r)/(repeat*length(number_train_data));
remaining_time = (100 - percent_complete)*(  sum(time)/percent_complete );
fprintf('%.2f percent complete. Estimated remaining time: %.2f minutes \n', percent_complete,remaining_time/60)
end
end
%%
SNRs = zeros(4,length(number_train_data));
for i = 1:length(number_train_data)
    for j = 1:4
        for r = 1:repeat
            SNRs(j,i) = SNRs(j,i) + (allresults{j,i,r}.snr_noise)/repeat;
        end
    end
end
%%
plot(number_train_data,SNRs','LineWidth',3.5,'Marker','*')
set(gca,'FontSize',20)
ax = gca;
outerpos = ax.OuterPosition;
ti = ax.TightInset; 
left = outerpos(1) + ti(1);
bottom = outerpos(2) + ti(2);
ax_width = outerpos(3) - ti(1) - ti(3);
ax_height = outerpos(4) - ti(2) - ti(4);
ax.Position = [left bottom ax_width ax_height];


