function res = example_neuralNetwork_transformer_medical_queries_zono_03()
% example_neuralNetwork_transformer_medical_queries_zono_03 - example for the verification of a 
%    transformer neural network from 
%
% Syntax:
%    res = example_neuralNetwork_transformer_medical_queries_zono_03()
%
% Inputs:
%    -
%
% Outputs:
%    res - boolean
%
% References:
%    [1] Vaswani et al. "Attention is all you need", 2017
%
% Authors:       Rayen Mhadhbi
% Written:       01-Sep-2024
% Last update:   ---
% Last revision: ---
% ------------------------------ BEGIN CODE -------------------------------

% load nn
verbose = false;
nn_cora = neuralNetwork.readTransformerNetwork('/Users/rayenmhadhbi/PycharmProjects/BachelorArbeit/transformer_model_in_medical_queries_big.json', verbose);

% tokenized sentences
tokenized_sentences = {
    %[101, 2339, 2079, 3056, 3096, 3688, 2191, 2033, 5060, 1037, 2919, 5437, 2006, 2026, ...
    %3096, 1029,  102, 0];
    %[101, 6034, 1998, 21446, 6181, 3255, 2043, 5777, 102, 0];
    [101, 2129, 2000, 3066, 2007, 5458, 2159, 1998, 8300, 4167, 102, 0];
    %[101, 5505, 3314, 1012, 2048, 7435, 2360, 1045, 2031, 5410, 5505, 2021, ...
     %2123, 1005, 1056, 2113, 2054, 1005, 1055, 2039, 1012, 2151, 2825, 12369, ...
     %1029, 102, 0];
    %[101, 6387, 22390, 14978, 2015, 2008, 3402, 4487, 18719, 17585, 2043, 1045, ...
     %25430, 4509, 2677, 28556, 102, 0];
     %[101, 10634, 3108, 3255, 2043, 14457, 1999, 3056, 7826, 1998, 2043, 5505, ...
     %1998, 2107, 102, 0];
     %[101, 4157, 4593, 12943, 17643, 16952, 2015, 2026, 10089, 1011, 2003, 2045, ...
     %2151, 3246, 1997, 5948, 2009, 2153, 1029, 102, 0];
     %[101, 7223, 3255, 2044, 5059, 1010, 5457, 2006, 2129, 2000, 7438, 1998, 2129, 2000, ...
     %4468, 2582, 4544, 102, 0];
     %[101, 2054, 2024, 2691, 2966, 3471, 3141, 2000, 22935, 4295, 1029, 102, 0];
     %[101, 2054, 2785, 1997, 3450, 2064, 2393, 24251, 10699, 15074, 2015, 2005, ...
     %8040, 4048, 6844, 8458, 7389, 6558, 1029, 102, 0];
};


total_best_noise = 0;
total_numVerified = 0;

% iterate over all tokenized sentences
for i = 1:length(tokenized_sentences)
    input = tokenized_sentences{i};
    [best_noise, numVerified] = aux_verifyInstance(nn_cora, input);
    
    % accumulate results
    total_best_noise = total_best_noise + best_noise;
    total_numVerified = total_numVerified + numVerified;
end

% compute averages
avg_best_noise = total_best_noise / length(tokenized_sentences);
avg_numVerified = total_numVerified / length(tokenized_sentences);

fprintf('Average Largest Noise: %e\n', avg_best_noise);
fprintf('Average Number of Verified Entries: %f\n', avg_numVerified);

res = true;
end

function [best_noise, numVerified] = aux_verifyInstance(nn_cora, input)
% helper function to compute the largest noise for which verification is true
% and the number of verified entries for a given input

% make predictions
pred_1 = nn_cora.layers{1, 1}.evaluate(input);
pred_2 = nn_cora.layers{2, 1}.evaluate(pred_1);
pred_3 = nn_cora.layers{3, 1}.evaluate(pred_2);
pred_4 = nn_cora.layers{4, 1}.evaluate(pred_3);
pred_5 = nn_cora.layers{5, 1}.evaluate(pred_4);
pred_6 = nn_cora.layers{6, 1}.evaluate(pred_5);

[~, label] = max(pred_6);

% define embedding noise range
noise_range = logspace(log10(0.1), log10(0.25), 10);  
radii = zeros(size(noise_range));
isVerifiedArray = false(size(noise_range));

% iterate over noise range
for i = 1:length(noise_range)
    noise = 0.2;
    embedding = nn_cora.layers{1, 1}.evaluate(input);
    c = reshape(embedding, [], 1); 
    dim = length(c);
    G = noise * eye(dim);
    z = zonotope(c, G);
    options = struct();
    options.nn.num_generators = 1000;
    options.nn.add_approx_error_to_GI = true;
    options.nn.approach_transformer = "zonotope";
    options = nnHelper.validateNNoptions(options);

    z = nn_cora.layers{2, 1}.evaluate(z, options);
    z = nn_cora.layers{3, 1}.evaluate(z, options);
    z = nn_cora.layers{4, 1}.evaluate(z, options);
    z = nn_cora.layers{5, 1}.evaluate(z, options);
    %z = nn_cora.layers{6, 1}.evaluate(z, options);

    int = interval(z);
    inf = reshape(int.inf, 2, 1);
    sup = reshape(int.sup, 2, 1);
    int = interval(inf, sup);
    
    M = [1 -1; 0 0];
    inf = M * int.inf;
    sup = M * int.sup; 
    r = radius(M* int);
    radii(i) = r;
   

    label_pred = interval(inf(label), sup(label));
    other_pred = interval(min(sup(2-label+1),inf(2-label+1)), max(sup(2-label+1),inf(2-label+1)));

    isVerified = all(other_pred.sup < label_pred.inf);
    isVerifiedArray(i) = isVerified;  
end

numVerified = sum(isVerifiedArray);
fprintf('Number of verified entries: %d\n', numVerified);

% binary search to find the largest noise value for which isVerifiedArray is true
left = 1;
right = length(noise_range);
best_noise = noise_range(left);

while left <= right
    mid = floor((left + right) / 2);
    if isVerifiedArray(mid)
        best_noise = noise_range(mid);
        left = mid + 1;
    else
        right = mid - 1;
    end
end

fprintf('Largest noise for which verification is true (refined): %e\n', best_noise);


end

% ------------------------------ END OF CODE ------------------------------
