addpath('rntk/')
m = 1;
n = 100;
T = 2;
Tp = 2;
sigmaw = sqrt(2);
sigmau = 1;
sigmav = 1;
sigmab = 0;

param.sigmaw = sigmaw;
param.sigmau = sigmau;
param.sigmav = sigmav;
param.sigmab = sigmab;
param.sigmah = 0;
param.L = 1;
param.nonlinearity = 1;
%%
w = randn(n,n);
u = randn(n,m);
b = randn(n,1);
v = randn(n,1);
erntk = zeros(1,50);
arntk = zeros(1,50);
alpha = 2*pi*(1:50)/50;
xp = [0.2,0.5];
for qqq = 1:50
    tic
x = [cos(alpha(qqq)),sin(alpha(qqq))];
% x
g = zeros(n,T);
h = zeros(n,T);
d = zeros(n,T);
g(:,1) = (sigmau/sqrt(m))*u*x(:,1) + sigmab*b;
h(:,1) = max(0,g(:,1));
d(:,1) = heaviside(g(:,1));
for t = 2:T

    g(:,t) = (sigmaw/sqrt(n))*w*h(:,t-1) + (sigmau/sqrt(m))*u*x(:,t) + sigmab*b;
    h(:,t) = max(0,g(:,t));
    d(:,t) = heaviside(g(:,t));
end
% xp
gp = zeros(n,Tp);
hp = zeros(n,Tp);
dp = zeros(n,Tp);
gp(:,1) = (sigmau/sqrt(m))*u*xp(:,1) + sigmab*b;
hp(:,1) = max(0,gp(:,1));
dp(:,1) = heaviside(gp(:,1));
for t = 2:Tp
    gp(:,t) = (sigmaw/sqrt(n))*w*hp(:,t-1) + (sigmau/sqrt(m))*u*xp(:,t) + sigmab*b;
    hp(:,t) = max(0,gp(:,t));
    dp(:,t) = heaviside(gp(:,t));
end
f = (sigmav/sqrt(n))*v'*h(:,T);
fp = (sigmav/sqrt(n))*v'*hp(:,Tp);
% q
q = ones(n,T);
q(:,T) = (sigmav/sqrt(n))*diag(d(:,T))*v;
for t = (T-1):-1:1
    q(:,t) = (sigmaw/sqrt(n))*diag(d(:,t))*w'*q(:,t+1);
end
% qp
qp = ones(n,Tp);
qp(:,Tp) = (sigmav/sqrt(n))*diag(dp(:,Tp))*v;
for t = (Tp-1):-1:1
    qp(:,t) = (sigmaw/sqrt(n))*diag(dp(:,t))*w'*qp(:,t+1);
end
% gw
gw = zeros(n,n);
for t = 2:T
   gw = gw + (sigmaw/sqrt(n))*q(:,t)*h(:,t-1)';
end
gu = zeros(n,m);
for t = 1:T
    gu = gu + (sigmau/sqrt(m))*q(:,t)*x(:,t)';
end
gb = zeros(n,1);
for i = 1:T
    gb = gb + sigmab*q(:,t);
end
gv = (sigmav/sqrt(n))*h(:,T);
% gpw
gpw = zeros(n,n);
for t = 2:Tp
   gpw = gpw + (sigmaw/sqrt(n))*qp(:,t)*hp(:,t-1)';
end
gpu = zeros(n,m);
for t = 1:Tp
    gpu = gpu + (sigmau/sqrt(m))*qp(:,t)*xp(:,t)';
end
gpb = zeros(n,1);
for i = 1:Tp
    gpb = gpb + sigmab*qp(:,t);
end
gpv = (sigmav/sqrt(n))*hp(:,Tp);
erntk(qqq) = trace(gw'*gpw) + trace(gu'*gpu) + trace(gb'*gpb) + trace(gv'*gpv) ;
arntk(qqq) = newRNTK(x,xp,param);
qqq
toc
end
%%
plot(erntk,'Linewidth',2)
hold on
plot(arntk,'Linewidth',2)

