function res = testnn_nnLayerNormalizationLayer()
% testnn_nnLayerNormalizationLayer - tests the nnLayerNormalizationLayer class
%
% Syntax:
%    res = testnn_nnLayerNormalizationLayer()
%
% Inputs:
%    -
%
% Outputs:
%    res - boolean 
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: -

% Authors:       Rayen Mhadhbi, Tobias Ladner
% Written:       08-July-2024
% Last update:   22-September-2024
% Last revision: ---

% ------------------------------ BEGIN CODE -------------------------------

resvec = [];

% Create a layer normalization layer
beta = rand(16,1);
gamma = rand(16,1);
emb_size = 16;
epsilon = 1e-6;
layer = nnLayerNormalizationLayer(beta, gamma, epsilon, emb_size, "layernorm");

% numeric
num_words = 20;
input = rand(num_words, emb_size);
output = layer.evaluate(input);

% Expected result (manually computed or use a reliable source)
mean_val = mean(input, 2);
%variance = var(input, 0, 2);
exp_output = (input - mean_val);
exp_output = exp_output * diag(gamma) + beta';

resvec(end+1) = isequal(round(output, 6), round(exp_output, 6));

% test polyZono

layer = nnLayerNormalizationLayer(beta, gamma, epsilon, emb_size, "layernorm");
c = reshape(input, [], 1);
noise = 0.001;
X = polyZonotope(c, eye(20*16)* noise); % add small perturbations

Y = layer.evaluate(X);
expected_output = reshape(exp_output,[],1);
resvec(end+1) = contains(zonotope(Y),expected_output);

% check set-based evaluation ----------------------------------------------

% polyZonotope ---
beta = rand(2,1);
gamma = rand(2,1);
layer = nnLayerNormalizationLayer(beta, gamma, epsilon, 2, "layernorm");
C = [1 2; 3 4];
% input
c = reshape(C, [], 1);
X = polyZonotope(c,  0.001 * eye(4));

% output
Y = layer.evaluate(X);

% compute for samples
N = 100;
xs = X.randPoint(N);
ys = zeros(4,N);
for i = 1:N
    x = reshape(xs(:, i), 2, 2);
    y = layer.evaluate(x);
    ys(:, i) = reshape(y,[], 1);
end

% check containment
resvec(end+1) = all(contains(zonotope(Y), ys));

% interval ---
X = reshape(interval(X),size(C));

% output
Y = layer.evaluate(X);

% check containment
Y = reshape(Y,[],1);
resvec(end+1) = all(contains(Y, ys));

% gather results
res = all(resvec);

end

% ------------------------------ END OF CODE ------------------------------
