function res = example_neuralNetwork_transformer()
% example_neuralNetwork_transformer - example for the verification of a 
%    transformer neural network from 
%
% Syntax:
%    res = example_neuralNetwork_transformer()
%
% Inputs:
%    -
%
% Outputs:
%    res - boolean
%
% References:
%    [1] Vaswani et al. "Attention is all you need", 2017
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: -

% Authors:       Rayen Mhadhbi
% Written:       05-July-2024
% Last update:   ---
% Last revision: ---

% ------------------------------ BEGIN CODE -------------------------------

% load nn
verbose = false;
nn_cora = neuralNetwork.readTransformerNetwork('/Users/rayenmhadhbi/PycharmProjects/BachelorArbeit/transformer_model_in_medical_queries.json', verbose);
display(nn_cora)
mn = min(nn_cora.layers{1, 1}.token_emb(:));
mx = max(nn_cora.layers{1, 1}.token_emb(:));
display(mn);
display(mx);


% Define the input array
input = [101, 1045, 2031, 9016, 1998, 14978, 1010, 2026, 6650, 3480, 1998, 1045,...
   2342, 1037, 3460, 2000, 2393, 2033, 1012, 102];


disp(size(input,2));

% make predictions
pred_cora = nn_cora.evaluate(input);


disp("Sanity Check:")
disp("    (Label | Prediction)")
disp([(0:1)', pred_cora'])
[~, label] = max(pred_cora);
fprintf("Most likely label is: %d\n", label-1);

% check set-based prediction -----------------------------------------

disp("Making a set-based prediction...")

% noise 1 percent of embedding values interval
noise = 0.003;
embedding = nn_cora.layers{1, 1}.evaluate(input);
% construct input set
c = reshape(embedding, [], 1); 
dim = length(c);
G = noise*eye(dim);
pZ = polyZonotope(c, []);

options = struct();
options.nn.num_generators = dim;
options = nnHelper.validateNNoptions(options);

pred = nn_cora.layers{2, 1}.evaluate(pZ, options);
pred = nn_cora.layers{3, 1}.evaluate(pred, options);
pred = nn_cora.layers{4, 1}.evaluate(pred, options);
pred = nn_cora.layers{5, 1}.evaluate(pred, options);
pred = nn_cora.layers{6, 1}.evaluate(pred, options);

% apply verification trick for binary classification
M = [1 -1; 0 0];
pred = M * pred;

pred = interval(pred);  

% check results
label_pred = project(pred, label);
other_pred = project(pred, setdiff(1:2, label));

% sentence + noise is verified if lower bound of the true label is larger
% than the upper bound of all other labels.
isVerified = all(other_pred.sup < label_pred.inf);

if isVerified
    fprintf("VERIFIED with noise=%d.\n", noise)
else
    fprintf("Unable to verify sentence embedding with noise=%d.\n", noise)
end

res = isVerified;

% ------------------------------ END OF CODE ------------------------------
