function res = test_nn_nnTNNGlobalAveragePoolingLayer()
% test_nn_nnTNNGlobalAveragePoolingLayer - tests the nnTNNGlobalAveragePoolingLayer
%
% Syntax:
%    res = test_nn_nnTNNGlobalAveragePoolingLayer()
%
% Inputs:
%    -
%
% Outputs:
%    res - true/false 
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: -

% Authors:       Rayen Mhadhbi
% Written:       05-July-2024
% Last update:   ---
% Last revision: ---

% ------------------------------ BEGIN CODE -------------------------------

resvec = [];

% Create a global pooling layer
emb_size = 16;
layer = nnTNNGlobalAveragePoolingLayer(emb_size, "global average pooling layer");

% numeric
num_words = 20;
input = rand(num_words, emb_size);
output = layer.evaluate(input);

% expected result 
exp_output = mean(input);

resvec(end+1) = isequal(round(output, 6), round(exp_output, 6));

% test polyZono
layer = nnTNNGlobalAveragePoolingLayer(emb_size, "global average pooling layer");
c = reshape(input, [], 1);
noise = 0.01;
X = polyZonotope(c, eye(20*16)* noise); % add small perturbations
G = reshape(X.G, 20, 16, 320);
res = zeros(1, 16, 320);
for i = 1:320
    res(:,:,i) = mean(G(:,:,i));
end
res = reshape(res, 16*1, 320);
Y = layer.evaluate(X);
expected_output = reshape(exp_output,[],1);
resvec(end+1) = contains(Y,expected_output,'approx');

% check set-based evaluation ----------------------------------------------

layer = nnTNNGlobalAveragePoolingLayer(emb_size, "global average pooling layer");
C = rand(20, 16);

% polyZonotope ---
c = reshape(C, [], 1);
X = polyZonotope(c,  0.01 * eye(20*16));
N = 100;
xs = X.randPoint(N);

% output
Y = layer.evaluate(X);
ys = zeros(16,N);

for i = 1:N
    k = reshape(xs(:, i), 20, 16);
    out = layer.evaluate(k);
    ys(:, i) = out';
end

% check containment
resvec(end+1) = all(contains(zonotope(Y), ys));

% interval ---
X = interval(X);
X = reshape(X,size(C));

% output
Y = layer.evaluate(X);
Y = reshape(Y,[],1);

% check containment
resvec(end+1) = all(contains(Y, ys));

% gather results
res = all(resvec);

end


% ------------------------------ END OF CODE ------------------------------
