function res = test_nn_nnTransformerEmbeddingLayer()
% test_nn_nnTransformerEmbeddingLayer - tests nnTransformerEmbeddingLayer
%
% Syntax:
%    res = test_nn_nnTransformerEmbeddingLayer()
%
% Inputs:
%    -
%
% Outputs:
%    res - true/false 
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: -

% Authors:       Tobias Ladner
% Written:       22-September-2024
% Last update:   ---
% Last revision: ---

% ------------------------------ BEGIN CODE -------------------------------

resvec = [];

% init layer
token_emb = [1 4 7 10; 2 5 8 11; 3 6 9 12];
pos_emb = [1 2 3 4; 5 6 7 8; 9 10 11 12];
layer = nnTransformerEmbeddingLayer(token_emb, pos_emb);

input = 2;
x = layer.evaluate(input);
resvec(end+1) = all(x == [4 8 12 16], "all");

input = [1,0];
x = layer.evaluate(input);
resvec(end+1) = all(x == [ 3 7 11 15 ; 6 10 14 18 ], "all");

% set-based evaluation ----------------------------------------------------

% should fail (doesn't make sense)
try
    layer.evaluate(interval(input));
    resvec(end+1) = false;
catch
    resvec(end+1) = true;
end

try
    layer.evaluate(zonotope(input));
    resvec(end+1) = false;
catch
    resvec(end+1) = true;
end

try
    layer.evaluate(polyZonotope(input));
    resvec(end+1) = false;
catch
    resvec(end+1) = true;
end

res = all(resvec);

end

% ------------------------------ END OF CODE ------------------------------
