function res = test_nn_nnSelfAttentionLayer()
% test_nn_nnSelfAttentionLayer - tests the SelfAttention layer
%
% Syntax:
%    res = test_nn_nnSelfAttentionLayer()
%
% Inputs:
%    -
%
% Outputs:
%    res - boolean 
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: -

% Authors:       Rayen Mhadhbi
% Written:       22-June-2024
% Last update:   ---
% Last revision: ---

% ------------------------------ BEGIN CODE -------------------------------

resvec = [];

% initialize weight matrices
W_Q = rand(2, 2); 
W_K = rand(2, 2); 
W_V = rand(2, 2);

% test name
customName = 'SelfAttention';
layer = nnSelfAttentionLayer(W_Q, W_K, W_V,customName);
resvec(end+1) = strcmp(layer.name,customName);

% initialize self-attention layer
layer = nnSelfAttentionLayer(W_Q, W_K, W_V,'SelfAttention');

% test numeric
C = [1, 2; 3, 4];
expected_output = layer.evaluate(C);

Q = C * W_Q;
K = C * W_K;
V = C * W_V;


scores = (Q * K') / sqrt(size(W_K, 2));
r = zeros(size(scores, 1), size(scores, 2));
for i = 1:size(scores,1)
scores(i, :) = scores(i, :) - max(scores(i, :));
r(i, :) = exp(scores(i, :)) ./ sum(exp(scores(i, :)));
end

expected_output_manual = r * V;
resvec(end+1) = isequal(expected_output, expected_output_manual);


% initialize self-attention layer
layer = nnSelfAttentionLayer(W_Q, W_K, W_V,'SelfAttention');

c = reshape(C, [], 1); % reshape input matrix from above as column vector
X = polyZonotope(c, eye(4)* 0.01); % add small perturbations

Y = layer.evaluate(X);
expected_output = reshape(expected_output,[],1);
resvec(end+1) = contains(zonotope(Y),expected_output);

% check set-based evaluation ----------------------------------------------

layer = nnSelfAttentionLayer(W_Q, W_K, W_V,'SelfAttention');

% polyZonotope ---
c = reshape(C, [], 1);
X = polyZonotope(c,  0.01 * eye(4));
N = 1000;
xs = X.randPoint(N);

% output
Y = layer.evaluate(X);
ys = zeros(4,N);
for i = 1:N
    ys(:, i) = reshape(layer.evaluate(reshape(xs(:, i),2,2)),4,[]);
end

% check containment
resvec(end+1) = all(contains(zonotope(Y), ys));

% zonotope ---
X = zonotope(X);

% output
Y = layer.evaluate(X);

% check containment
resvec(end+1) = all(contains(Y, ys));

% interval ---
X = reshape(interval(X),2,2);

% output
Y = reshape(layer.evaluate(X),4,1);

% check containment
resvec(end+1) = all(contains(Y, ys));

% gather results
res = all(resvec);

end

% ------------------------------ END OF CODE ------------------------------
