addpath('rntk/')
%% figure 3
%% relu
%% different sigmaw
clear param
L = 1;
m = 1;
T = 100;
eps = 0.0001;
M = [sqrt(2),1.36,1.47];
r = 10000;
der = zeros(r,T,length(M));
for k = 1:length(M)
    param.sigmaw = M(k);
    param.sigmau = 1;
    param.sigmav = 1;
    param.sigmab = 1*ones(1,L);
    param.sigmah = 0;
    param.nonlinearity = 1 % 1 is relu, 2 is erf;
    for j = 1:r
        j
        x = randn(m,T);
        xp = randn(m,T);
        y = newRNTK(x,xp,param);
        grad = zeros(m,1);
        for i = 1:T
            for h = 1:m
                temp = x;
                temp(h,i) = temp(h,i) + eps;
                y1 = newRNTK(temp,xp,param);
                grad(h) = (y1 - y)/(eps);
             end
              der(j,i,k) = norm(grad);
        end
    end
    temp = mean(der(:,:,k));
    der(:,:,k) = der(:,:,k)/max(temp);       
end
plot(mean(der(:,:,1)), 'Linewidth', 3)
hold on
plot(mean(der(:,:,2)),'Linewidth', 3)
hold on
plot(mean(der(:,:,3)),'Linewidth', 3)
%% different sigmau
clear param
m = 1;
T = 100;
eps = 0.0001;
M = [0.1,1,10];
r = 10000;
der = zeros(r,T,length(M));
for k = 1:length(M)
    param.sigmaw = sqrt(2);
    param.sigmau = M(k);
    param.sigmav = 1;
    param.sigmab = 0;
    param.sigmah = 0;
    param.nonlinearity = 1;
    for j = 1:r
        j
        x = randn(m,T);
        xp = randn(m,T)+1;
        y = newRNTK(x,xp,param);
        grad = zeros(m,1);
        for i = 1:T
            for h = 1:m
                temp = x;
                temp(h,i) = temp(h,i) + eps;
                y1 = newRNTK(temp,xp,param);
                grad(h) = (y1 - y)/(eps);
             end
              der(j,i,k) = norm(grad);
        end
    end
    temp = mean(der(:,:,k));
    der(:,:,k) = der(:,:,k)/max(temp);       
end
plot(mean(der(:,:,1)), 'Linewidth', 3)
hold on
plot(mean(der(:,:,2)),'Linewidth', 3)
hold on
plot(mean(der(:,:,3)),'Linewidth', 3)
%% different sigmab
clear param
m = 1;
T = 100;
eps = 0.001;
M = [0.5,1,2];
r = 10000;
param = [0, 1,0,1,0,0,1,1,sqrt(2)];
der = zeros(r,T,length(M));
for k = 1:length(M)
    param.sigmaw = sqrt(2);
    param.sigmau = 1;
    param.sigmav = 1;
    param.sigmab = M(k);
    param.sigmah = 0;
    param.nonlinearity = 1;
    for j = 1:r
        j
        x = randn(m,T);
        xp = randn(m,T);
        y = newRNTK(x,xp,param);
        grad = zeros(m,1);
        grad2 = zeros(m,1);
        for i = 1:T
            for h = 1:m
                temp = x;
                temp(h,i) = temp(h,i) + eps;
                y1 = newRNTK(temp,xp,param);
                grad(h) = (y1 - y)/(eps);
             end
              der(j,i,k) = norm(grad);
        end
    end
    temp = mean(der(:,:,k));
    der(:,:,k) = der(:,:,k)/max(temp);         
end
plot(mean(der(:,:,1)), 'Linewidth', 3)
hold on
plot(mean(der(:,:,2)),'Linewidth', 3)
hold on
plot(mean(der(:,:,3)),'Linewidth', 3)
%% different sigmah
clear param
m = 1;
T = 100;
eps = 0.001;
M = [0.2,1,5];
r = 100;
param2 = [0, 1,0,1,0,0,1,1,sqrt(2)];
der = zeros(r,T,length(M));
for k = 1:length(M)
    param.sigmaw = sqrt(2);
    param.sigmau = 1;
    param.sigmav = 1;
    param.sigmab = 0;
    param.sigmah = M(k);
    param.nonlinearity = 1;
    for j = 1:r
        j
        x = randn(m,T);
        xp = randn(m,T);
        y = newRNTK(x,xp,param);
        grad = zeros(m,1);
        grad2 = zeros(m,1);
        for i = 1:T
            for h = 1:m
                temp = x;
                temp(h,i) = temp(h,i) + eps;
                y1 = newRNTK(temp,xp,param);
                grad(h) = (y1 - y)/(eps);
             end
              der(j,i,k) = norm(grad);
        end
    end
    temp = mean(der(:,:,k));
    der(:,:,k) = der(:,:,k)/max(temp);         
end
plot(mean(der(:,:,1)), 'Linewidth', 3)
hold on
plot(mean(der(:,:,2)),'Linewidth', 3)
hold on
plot(mean(der(:,:,3)),'Linewidth', 3)
lgd = legend({'$\sigma_w^{(1)} = \sqrt{2}$', '$\sigma_w^{(1)} = 1.47 $', '$\sigma_w^{(1)} = 1.36$'}, 'Interpreter','latex');
lgd.FontSize = 20;
%% erf
%% different sigmau
clear param
L = 1;
m = 1;
T = 100;
eps = 0.0001;
M = [0.005,0.01,0.05];
r = 10000;
der = zeros(r,T,length(M));
for k = 1:length(M)
    param.sigmaw = 1;
    param.sigmau = M(k);
    param.sigmav = 1;
    param.sigmab = 0.05*ones(1,L);
    param.sigmah = 0;
    param.nonlinearity = 2;
    for j = 1:r
        j
        x = randn(m,T);
        xp = randn(m,T);
        [~,y] = newRNTK(x,xp,param);
        grad = zeros(m,1);
        for i = 1:T
            for h = 1:m
                temp = x;
                temp(h,i) = temp(h,i) + eps;
                [~,y1] = newRNTK(temp,xp,param);
                grad(h) = (y1 - y)/(eps);
             end
              der(j,i,k) = norm(grad);
        end
    end
    temp = mean(der(:,:,k));
    der(:,:,k) = der(:,:,k)/max(temp);       
end
plot(mean(der(:,:,1)), 'Linewidth', 3)
hold on
plot(mean(der(:,:,2)),'Linewidth', 3)
hold on
plot(mean(der(:,:,3)),'Linewidth', 3)
%% different sigmaw
clear param
L = 1;
m = 1;
T = 100;
eps = 0.0001;
M = [0.5,1,2];
r = 10000;
der = zeros(r,T,length(M));
for k = 1:length(M)
    param.sigmaw = M(k);
    param.sigmau = 0.01;
    param.sigmav = 1;
    param.sigmab = 0.05*ones(1,L);
    param.sigmah = 0;
    param.nonlinearity = 2;
    for j = 1:r
        j
        x = randn(m,T);
        x = x./sqrt(sum(x.^2,1));
        xp = randn(m,T);
        xp = xp./sqrt(sum(xp.^2,1));
        [~,y] = newRNTK(x,xp,param);
        grad = zeros(m,1);
        for i = 1:T
            for h = 1:m
                temp = x;
                temp(h,i) = temp(h,i) + eps;
                [~,y1] = newRNTK(temp,xp,param);
                grad(h) = (y1 - y)/(eps);
             end
              der(j,i,k) = norm(grad);
        end
    end
    temp = mean(der(:,:,k));
    der(:,:,k) = der(:,:,k)/max(temp);       
end
plot(mean(der(:,:,1)), 'Linewidth', 3)
hold on
plot(mean(der(:,:,2)),'Linewidth', 3)
hold on
plot(mean(der(:,:,3)),'Linewidth', 3)
%% different sigmab
clear param
L = 1;
m = 1;
T = 100;
eps = 0.0001;
M = [0,0.05,0.1];
r = 10000;
der = zeros(r,T,length(M));
for k = 1:length(M)
    param.sigmaw = 1;
    param.sigmau = 0.01;
    param.sigmav = 1;
    param.sigmab = M(k)*ones(1,L);
    param.sigmah = 0;
    param.nonlinearity = 2;
    for j = 1:r
        j
        x = randn(m,T);
        x = x./sqrt(sum(x.^2,1));
        xp = randn(m,T);
        xp = xp./sqrt(sum(xp.^2,1));
        [~,y] = newRNTK(x,xp,param);
        grad = zeros(m,1);
        for i = 1:T
            for h = 1:m
                temp = x;
                temp(h,i) = temp(h,i) + eps;
                [~,y1] = newRNTK(temp,xp,param);
                grad(h) = (y1 - y)/(eps);
             end
              der(j,i,k) = norm(grad);
        end
    end
    temp = mean(der(:,:,k));
    der(:,:,k) = der(:,:,k)/max(temp);       
end
plot(mean(der(:,:,1)), 'Linewidth', 3)
hold on
plot(mean(der(:,:,2)),'Linewidth', 3)
hold on
plot(mean(der(:,:,3)),'Linewidth', 3)
%% different sigmah
clear param
L = 1;
m = 1;
T = 100;
eps = 0.0001;
flag = 1;
M = [0,0.1,1];
r = 10000;
der = zeros(r,T,length(M));
for k = 1:length(M)
    param.sigmaw = 1;
    param.sigmau = 0.01;
    param.sigmav = 1;
    param.sigmab = 0.05*ones(1,L);
    param.sigmah = M(k);
    param.nonlinearity = 2;
    for j = 1:r
        j
        x = randn(m,T);
        x = x./sqrt(sum(x.^2,1));
        xp = randn(m,T);
        xp = xp./sqrt(sum(xp.^2,1));
        [~,y] = newRNTK(x,xp,param);
        grad = zeros(m,1);
        for i = 1:T
            for h = 1:m
                temp = x;
                temp(h,i) = temp(h,i) + eps;
                [~,y1] = newRNTK(temp,xp,param);
                grad(h) = (y1 - y)/(eps);
             end
              der(j,i,k) = norm(grad);
        end
    end
    temp = mean(der(:,:,k));
    der(:,:,k) = der(:,:,k)/max(temp);       
end
plot(mean(der(:,:,1)), 'Linewidth', 3)
hold on
plot(mean(der(:,:,2)),'Linewidth', 3)
hold on
plot(mean(der(:,:,3)),'Linewidth', 3)