function res = testnn_nnTNNGlobalAveragePoolingLayer()
% testnn_nnTNNGlobalAveragePoolingLayer - tests constructor and methods of nnTNNGlobalAveragePoolingLayer
%
% Syntax:
%    res = testnn_nnTNNGlobalAveragePoolingLayer()
%
% Inputs:
%    -
%
% Outputs:
%    res - true/false 
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: -

% Authors:       Rayen Mhadhbi
% Written:       05-July-2024
% Last update:   ---
% Last revision: ---

% ------------------------------ BEGIN CODE -------------------------------

resvec = [];

% Create a global pooling layer
emb_size = 16;
layer = nnTNNGlobalAveragePoolingLayer(emb_size, "global average pooling layer");

% numeric
num_words = 20;
input = rand(num_words, emb_size);
output = layer.evaluate(input);

% expected result 
exp_output = mean(input);

resvec(end+1) = isequal(round(output, 6), round(exp_output, 6));

% test polyZono
layer = nnTNNGlobalAveragePoolingLayer(emb_size, "global average pooling layer");
c = reshape(input, [], 1);
noise = 0.01;
X = polyZonotope(c, eye(20*16)* noise); % add small perturbations
G = reshape(X.G, 20, 16, 320);
res = zeros(1, 16, 320);
for i = 1:320
    res(:,:,i) = mean(G(:,:,i));
end
res = reshape(res, 16*1, 320);
Y = layer.evaluate(X);
expected_output = reshape(exp_output,[],1);
resvec(end+1) = contains(Y,expected_output,'approx');

% check set-based evaluation ---
layer = nnTNNGlobalAveragePoolingLayer(emb_size, "global average pooling layer");
C = rand(20, 16);
% input
c = reshape(C, [], 1);
X = polyZonotope(c,  0.01 * eye(20*16));
N = 100;
xs = X.randPoint(N);

% output
Y = layer.evaluate(X);
ys = zeros(16,N);
xs(:, 1)

for i = 1:N
    k = reshape(xs(:, i), 20, 16);
    out = layer.evaluate(k);
    ys(:, i) = out';
end

% check containment
resvec(end+1) = all(contains(zonotope(Y), ys));

% gather results
res = all(resvec);

% print debugging information
fprintf('Values of resvec:\n');
for i = 1:length(resvec)
    fprintf('Test %d: %d\n', i, resvec(i));
end

end


% ------------------------------ END OF CODE ------------------------------
