function res = test_nn_nnMultiHeadAttentionLayer()
% test_nn_nnMultiHeadAttentionLayer - tests the MultiHeadAttention layer
%
% Syntax:
%    res = test_nn_nnMultiHeadAttentionLayer()
%
% Inputs:
%    -
%
% Outputs:
%    res - boolean 
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: -

% Authors:       Rayen Mhadhbi
% Written:       25-June-2024
% Last update:   ---
% Last revision: ---

% ------------------------------ BEGIN CODE -------------------------------

resvec = [];

% initialize weight matrices
W_Q_1 = [1 2; 1 0]; 
W_K_1 = [1 1; 0 3]; 
W_V_1 = [2 2; 1 1];

W_Q_2 = [4 5; 1 2]; 
W_K_2 = [0 2; 3 3]; 
W_V_2 = [0 5; 3 5];

% concatenated weight matrices to represent different heads
W_Q = [W_Q_1, W_Q_2]; 
W_K = [W_K_1, W_K_2];
W_V = [W_V_1, W_V_2];

W_O = [1 2; 1 0; 2 0; 1 1];

% test name
customName = 'MultiHeadAttention';
layer = nnMultiHeadAttentionLayer(W_Q, W_K, W_V, W_O, 2,customName);
resvec(end+1) = strcmp(layer.name,customName);

% initialize multihead-attention layer
layer = nnMultiHeadAttentionLayer(W_Q, W_K, W_V, W_O, 2,'MultiHead Attention');

% test numeric
C = [1, 2; 3, 4];
output = layer.evaluate(C);

% compute expected output

% head one output
Q = C * W_Q_1;
K = C * W_K_1;
V = C * W_V_1;

scores = (Q * K') / sqrt(size(W_K_1, 2));
r = zeros(size(scores, 1), size(scores, 2));
for i = 1:size(scores,1)
    scores(i, :) = scores(i, :) - max(scores(i, :));
    r(i, :) = exp(scores(i, :)) ./ sum(exp(scores(i, :)));
end

output_head_1 = r * V;

% head two output
Q = C * W_Q_2;
K = C * W_K_2;
V = C * W_V_2;

scores = (Q * K') / sqrt(size(W_K_2, 2));
r = zeros(size(scores, 1), size(scores, 2));
for i = 1:size(scores,1)
    scores(i, :) = scores(i, :) - max(scores(i, :));
    r(i, :) = exp(scores(i, :)) ./ sum(exp(scores(i, :)));
end

output_head_2 = r * V;

% concatenate output of both heads horizontally
concat_output = [output_head_1, output_head_2];

% multipy with projection matrix
expected_output = concat_output * W_O;

resvec(end+1) = all(withinTol(output, expected_output, 1e-5),'all');

% check set-based evaluation ----------------------------------------------

layer = nnMultiHeadAttentionLayer(W_Q, W_K, W_V, W_O, 2,'MultiHead Attention');

% polyZonotope ---
c = reshape(C, [], 1);
X = polyZonotope(c,  0.01 * eye(4));
N = 1000;
xs = X.randPoint(N);

% output
Y = layer.evaluate(X);
ys = zeros(4,N);
for i = 1:N
    ys(:, i) = reshape(layer.evaluate(reshape(xs(:, i),2,2)),4,[]);
end

% check containment
resvec(end+1) = all(contains(zonotope(Y), ys));

% zonotope ---
X = zonotope(X);

% output
Y = layer.evaluate(X);

% check containment
resvec(end+1) = all(contains(Y, ys));

% interval ---
X = reshape(interval(X),2,2);

% output
Y = reshape(layer.evaluate(X),4,1);

% check containment
resvec(end+1) = all(contains(Y, ys));


% gather results
res = all(resvec);

end

% ------------------------------ END OF CODE ------------------------------
