clear all;
close all;

n = 500;
d = 500;
L = 200;

H = zeros(L,d,n);
W = zeros(L,d,d);
x = randn(d,n);

for i = 1:L
    
    % random orthogonal matrix
	w = randn(d) / sqrt(n);
    w = bsxfun(@rdivide, w, sqrt(sum(w.^2, 2)));
	[U,~,V] = svd(w);        	
    w = U*V';
	W(i,:,:) = w;
    	
    % pre-activation
	h = w*x;	
        
    % Tanh
%     x = tanh(h);
    
    % Tanh-GPN
%     x = 1.4674 * tanh(h) + 0.3886;
    
    % ReLU
%     x = max(0,h);

    % ReLU-GPN
%     x = 1.4142 * max(0,h);
   
    % LeakyReLU
%     x = max(0,h) + min(0,0.01*h);

    % LeakyReLU-GPN
%     x = 1.4141 * (max(0,h) + min(0,0.01*h));

    % ELU
%     x = max(0,h) + min(0, exp(h)-1);
    
    % ELU-GPN
%     x = 1.2234 * ( max(0,h) + min(0, (exp(h)-1)) ) + 0.0742;

    % SELU
%     x = 1.0507 * ( max(0,h) + min(0, 1.673263*(exp(h)-1)) );
    
    % SELU-GPN
%     x = 0.9660 * 1.0507 * ( max(0,h) + min(0, 1.673263*(exp(h)-1)) ) + 0.2584;    
    
    % GELU   
%     x = h .* 1./(1+exp(-1.702 * h));
    
% 	GELU-GPN
    x = 1.4915 * h .* 1./(1+exp(-1.702 * h)) - 0.9097;

    
	H(i,:,:) = h;	
    
	m_x(i) = mean(mean(x,2));
	v_x(i) = mean(mean(x.^2, 1));    
	m_h(i) = mean(mean(h,2));
	v_h(i) = mean(mean(h.^2));    
end

e = d * randn(d,n);

for i = L:-1:1

    w = squeeze(W(i,:,:));
	h = squeeze(H(i,:,:));
    
    % Tanh
%     dh = (1 - tanh(h).^2);    
        
    % Tanh-GPN
%     dh = 1.4674 * (1 - tanh(h).^2);    

    % ReLU
%     dh = h;
%     dh(dh>0) = 1;
%     dh(dh<0) = 0;
    
%     ReLU-GPN
%     dh = h;
%     dh(dh>0) = 1.4142;
%     dh(dh<0) = 0;    

    % LeakyReLU
%     dh = h;
%     dh(dh>0) = 1;
%     dh(dh<0) = 0.01;
    
	% LeakyReLU-GPN
%     dh = h;
%     dh(dh>0) = 1.4141;
%     dh(dh<0) = 1.4141 * 0.01;  

    % ELU
%     dh = h;
%     dh(dh>0) = 1;
%     dh(dh<0) = exp(dh(dh<0));    

%     ELU-GPN
%     dh = h;
%     dh(dh>0) = 1.2234;
%     dh(dh<0) = 1.2234 * exp(dh(dh<0));    
     
    % SELU
%     dh = h;
%     dh(dh>0) = 1.0507;
%     dh(dh<0) = 1.0507*1.673263*exp(dh(dh<0));
    
    % SELU-GPN
%     dh = h;
%     dh(dh>0) = 0.9660 * 1.0507;
%     dh(dh<0) = 0.9660 * 1.0507*1.673263*exp(dh(dh<0));    

    % GELU    
%     dh = (1./(1+exp(-1.702*h)) - 1.702 * h .* 1./(1+exp(-1.702*h)) .* (1 - 1./(1+exp(-1.702*h))));
    
%     GELU-GPN
    dh = 1.4915 * ((1./(1+exp(-1.702*h)) + 1.702 * h .* 1./(1+exp(-1.702*h)) .* (1 - 1./(1+exp(-1.702*h)))));

	g = e .* dh;
	e = w' * g;
    x = w' * h;
    
	w_g = g * x' / n;
	
	m_g(i) = mean(mean(dh));
	v_g(i) = mean(mean(dh.^2));    
    m_w(i) = mean(w_g(:));
    v_w(i) = norm(w_g, 'fro');
end

figure

subplot(8,1,1),plot(1:L,m_x)
subplot(8,1,2),plot(1:L,v_x)
subplot(8,1,3),plot(1:L,m_h)
subplot(8,1,4),plot(1:L,v_h)
subplot(8,1,5),plot(1:L,m_g)
subplot(8,1,6),plot(1:L,v_g)
subplot(8,1,7),plot(1:L,m_w)
subplot(8,1,8),plot(1:L,v_w)
