function res = example_neuralNetwork_transformer_medical_queries_interval_02()
% example_neuralNetwork_transformer_medical_queries_interval - example for the verification of a 
%    transformer neural network from 
%
% Syntax:
%    res = example_neuralNetwork_transformer_medical_queries_interval_02()
%
% Inputs:
%    -
%
% Outputs:
%    res - boolean
%
% References:
%    [1] Vaswani et al. "Attention is all you need", 2017
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: -

% Authors:       Rayen Mhadhbi
% Written:       05-July-2024
% Last update:   ---
% Last revision: ---

% ------------------------------ BEGIN CODE -------------------------------

% load nn
verbose = false;
nn_cora = neuralNetwork.readTransformerNetwork('/Users/rayenmhadhbi/PycharmProjects/BachelorArbeit/transformer_model_in_medical_queries_big_no_var.json', verbose);
display(nn_cora)
mn = min(nn_cora.layers{1, 1}.token_emb(:));
mx = max(nn_cora.layers{1, 1}.token_emb(:));
display(mn);
display(mx);

% define the input ( tokenized sentence : How to deal with tired eyes and awake brain)
input = [101, 2129, 2000, 3066, 2007, 5458, 2159, 1998, 8300, 4167, 102, 0];

% Make predictions
pred_1 = nn_cora.layers{1, 1}.evaluate(input);
pred_2 = nn_cora.layers{2, 1}.evaluate(pred_1);
pred_3 = nn_cora.layers{3, 1}.evaluate(pred_2);
pred_4 = nn_cora.layers{4, 1}.evaluate(pred_3);
pred_5 = nn_cora.layers{5, 1}.evaluate(pred_4);
pred_6 = nn_cora.layers{6, 1}.evaluate(pred_5);

[~, label] = max(pred_6);
fprintf("Most likely label is: %d\n", label-1);

% Define embedding noise range
noise_range = logspace(-8, log10(0.5), 100);  
radii = zeros(size(noise_range));
isVerifiedArray = false(size(noise_range));

% Iterate over noise range
for i = 1:length(noise_range)
    noise = noise_range(i);
    embedding = nn_cora.layers{1, 1}.evaluate(input);
    c = reshape(embedding, [], 1); 
    dim = length(c);
    G = noise * eye(dim);
    pZ = polyZonotope(c, G);
    int = interval(pZ);
    inf = reshape(int.inf, 12, 8);
    sup = reshape(int.sup, 12, 8);
    int = interval(inf, sup);

    options = struct();
    options.nn.num_generators = 1000;
    options.nn.add_approx_error_to_GI = true;
    options = nnHelper.validateNNoptions(options);

    int = nn_cora.layers{2, 1}.evaluate(int, options);
    cnt = contains_(int, pred_2, "exact", 0.00000001);
    int = nn_cora.layers{3, 1}.evaluate(int, options);
    cnt = contains_(int, pred_3, "exact", 0.00000001);
    int = nn_cora.layers{4, 1}.evaluate(int, options);
    cnt = contains_(int, pred_4, "exact", 0.00000001);
    int = nn_cora.layers{5, 1}.evaluate(int, options);
    cnt = contains_(int, pred_5, "exact", 0.00000001);
    int = nn_cora.layers{6, 1}.evaluate(int, options);
    cnt = contains_(int, pred_6, "exact", 0.00000001);

    M = [1 -1; 0 0];
    inf = M * int.inf;
    sup = M * int.sup; 
    r = radius(M* int);
    radii(i) = r; 

    label_pred = interval(inf(label), sup(label));
    other_pred = interval(inf(2-label+1), sup(2-label+1));

    isVerified = all(other_pred.sup < label_pred.inf);
    isVerifiedArray(i) = isVerified;  
end

% Binary search to find the largest noise value for which isVerifiedArray is true
left = 1;
right = length(noise_range);
best_noise = noise_range(left);

while left <= right
    mid = floor((left + right) / 2);
    if isVerifiedArray(mid)
        best_noise = noise_range(mid);
        left = mid + 1;
    else
        right = mid - 1;
    end
end

fprintf('Largest noise for which verification is true (initial): %e\n', best_noise);

% Refine the search around the best noise
precision = 1e-8;  % Desired precision for the largest noise value
step_size = best_noise / 10;  % Initial step size

while step_size > precision
    next_noise = best_noise + step_size;
    embedding = nn_cora.layers{1, 1}.evaluate(input);
    c = reshape(embedding, [], 1); 
    G = next_noise * eye(length(c));
    pZ = polyZonotope(c, G);
    int = interval(pZ);
    inf = reshape(int.inf, 12, 8);
    sup = reshape(int.sup, 12, 8);
    int = interval(inf, sup);

    int = nn_cora.layers{2, 1}.evaluate(int, options);
    int = nn_cora.layers{3, 1}.evaluate(int, options);
    int = nn_cora.layers{4, 1}.evaluate(int, options);
    int = nn_cora.layers{5, 1}.evaluate(int, options);
    int = nn_cora.layers{6, 1}.evaluate(int, options);

    M = [1 -1; 0 0];
    inf = M * int.inf;
    sup = M * int.sup;

    label_pred = interval(inf(label), sup(label));
    other_pred = interval(inf(2-label+1), sup(2-label+1));

    isVerified = all(other_pred.sup < label_pred.inf);
    
    if isVerified
        best_noise = next_noise;  % Update best noise if the verification is still true
    else
        step_size = step_size / 10;  % Reduce step size to refine further
    end
end

fprintf('Largest noise for which verification is true (refined): %e\n', best_noise);

numVerified = sum(isVerifiedArray);
fprintf('Number of verified entries: %d\n', numVerified);

figure;
hold on;
grid on;
xlabel('Embedding Noise');
ylabel('Radius of Intervals');
title('Certification Radii for Interval Propagation (IBP)');

% plot verified radii in one color (e.g., blue)
semilogx(noise_range(isVerifiedArray), radii(isVerifiedArray), 'bo', 'MarkerFaceColor', 'b');

% plot non-verified radii in another color (e.g., red)
semilogx(noise_range(~isVerifiedArray), radii(~isVerifiedArray), 'ro', 'MarkerFaceColor', 'r');

legend('Verified Radii', 'Non-verified Radii');
hold off;

res = true;

end
% ------------------------------ END OF CODE ------------------------------

