function res = compare_pZ_enum_transformer()
% compare_pZ_enum_transformer - compares transformer verification times for 
% exhaustive enumeration against pZ approach and generates a plot
%
% Syntax:
%    res = compare_pZ_enum_transformer()
%
% Inputs:
%    -
%
% Outputs:
%    res - boolean
%
% References:
%    [1] Vaswani et al. "Attention is all you need", 2017
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: -

% Authors:       Rayen Mhadhbi
% Written:       22-September-2024
% Last update:   ---
% Last revision: ---

% ------------------------------ BEGIN CODE -------------------------------

% load nn
model = neuralNetwork.readTransformerNetwork([CORAROOT, '/models/Cora/nn/transformer_model_in_medical_queries_small_no_var.json'], false);
inf = min(model.layers{1, 1}.token_emb(:), [], 1);
sup = max(model.layers{1, 1}.token_emb(:), [], 1);
boundsEmbed = interval(inf, sup);
boundsEmbedRadius = rad(boundsEmbed);

% largest verified noise radius for this sentence example according to verifyTransformer 
noise = 0.014 * boundsEmbedRadius;

% token ids of the sentence used for comparison
input = [101, 10634, 3108, 3255, 2043, 14457, 1999, 3056, 7826, 1998, 2043, 5505, 1998, 2107, 102, 0]; 

% set options 
options = struct();
options.nn.num_generators = 1000;
options.nn.add_approx_error_to_GI = true;

% get correct label
pred = model.evaluate(input);
[~, label] = max(pred);

% init input set
embedding = model.evaluate(input,options,1);
% init set
c = reshape(embedding, [], 1);
G = noise * eye(numel(c));
X = polyZonotope(c, G);

% set number of repetitions
num_trials = 5;

% load synonym token arrays 
tokens = aux_read_synonym_sentence_tokens([CORAROOT, '/examples/nn/transformer_verification/synonym_token_ids.json']);
tokenmap = jsondecode(fileread([CORAROOT, '/examples/nn/transformer_verification/token_mapping.json']));

% print synonym sentence
table = CORAtable('latex',{'Token','\#Synonyms','Synonyms'},{'s','i','s'});
table.printHeading();

num_synonym_sets = size(tokens, 1);
% loop through different numbers of synonym sentences for comparison 
numSynonyms = zeros(1,size(tokens,2));
for i=1:numel(input)
    tokens_t = unique(tokens(:,i));
    synonyms = {};
    for t = tokens_t'
        synonyms = [synonyms {tokenmap.(sprintf('x%i',t))}];
    end
    table.printContentRow({tokenmap.(sprintf('x%i',input(i))), numel(synonyms), strjoin(synonyms,', ')})
    numSynonyms(i) = numel(tokens_t);
end
table.printBottom;

numSynonymSentences = prod(numSynonyms);
disp('Number of synonyms per word:')
disp(numSynonyms);
fprintf('Number of synonym sentences: %i\n', numSynonymSentences);

numWords = 1:17;
all_times = cell(1, numel(numWords));

for s = 1:numel(numWords)

    times_trials = zeros(1,num_trials);
    for i=1:num_trials
        % Run the enumeration method 10 times
        % for i = 1:num_trials
            % Compute time for enumeration with this many synonyms
        times_trials(i) = aux_evaluate_enumeration(model, input, 2^numWords(s));
    end
    all_times{s} = times_trials;
    disp(fprintf('Number of synonyms: %i -> Time for enumeration: %.2f\n', numWords(s), mean(times_trials)));
     
    % end
end

% compute mean and standard deviation of times
enum_times_mean = cellfun(@mean,all_times);
enum_times_std = cellfun(@std,all_times);


pZ_times = zeros(1, num_trials);

% rerun the pZ evaluation for n = num_trials trials
for i = 1:num_trials
    % start timer
    tic;
    
    % evaluate the input set (without softmax)
    idxLayer = 2:length(model.layers) - 1;
    Y = model.evaluate(X, options, idxLayer);
    
    % apply verification
    M = eye(2);
    M(:, label) = M(:, label) - 1;
    M = sum(M, 1); % As output is 2D, transform into 1D <= 0 check
    Y = M * Y;
    
    % Check specification
    Y = interval(Y);
    isVerified = Y.sup <= 0;
    
    % Record verification time for pZ approach
    pZ_times(i) = toc;
end

% compute mean and standard deviation for pZ approach
pZ_mean_time = mean(pZ_times);
pZ_std_time = std(pZ_times);
fprintf('Time to compute input set using polynomial zonotopes: %.3f',pZ_mean_time);

generatePlot = true;
saveData = true; 

saveFilename = 'verification_comparison_results.mat';

% save results
if saveData
    save(saveFilename, 'enum_times_mean', 'enum_times_std', 'pZ_mean_time', 'pZ_std_time', 'num_synonym_sets');
end

if generatePlot
    % plot results
    figure; subplot(1,2,2); hold on; box on;

    % plot for enumeration approach with dark red line
    plot(numWords, enum_times_mean, 'Color', CORAcolor('CORA:color1'), 'DisplayName', 'Enumeration');
    
    % shaded area for enumeration standard deviation
    fill([numWords, fliplr(numWords)], ...
        [enum_times_mean-enum_times_std, fliplr(enum_times_mean + enum_times_std)], ...
        CORAcolor('CORA:color1'), 'EdgeColor', 'none', 'FaceAlpha', 0.2, 'HandleVisibility','off');
    
    % plot for pZ approach 
    plot(numWords, repmat(pZ_mean_time, 1, numel(numWords)), ...
        'Color', CORAcolor('CORA:color2'), 'DisplayName', 'Our approach');
    
    % shaded area for pZ standard deviation
    fill([numWords, fliplr(numWords)], ...
        [repmat(pZ_mean_time - pZ_std_time, 1, numel(numWords)), ...
        fliplr(repmat(pZ_mean_time + pZ_std_time, 1, numel(numWords)))], ...
        CORAcolor('CORA:color2'), 'EdgeColor', 'none', 'FaceAlpha', 0.2, 'HandleVisibility','off');
    
    xlabel('Number of synonym words');
    ylabel('Verification time [s]');
    title('Comparison of pZ and Enumeration Approaches');
    xlim([min(numWords),max(numWords)]); ylim([0,round(max(enum_times_mean)*0.75,-1)]);
    legend('Location','northwest');
end

res = true;

end


% Auxiliary functions -----------------------------------------------------

function tokens = aux_read_synonym_sentence_tokens(json_file)
    json_data = jsondecode(fileread(json_file));
    % extract token arrays 
    tokens = json_data.token_ids;  
end


function time = aux_evaluate_enumeration(model, input, numSynonyms)
    time = 0;

    % iterate over all token arrays and evaluate the model
    for i = 1:numSynonyms
        % start timer for each token array
        tic

        % evaluate the model on the current token array
        model.evaluate(input);    

        % save time
        time = time+toc;
    end
    
  end
  
% ------------------------------ END OF CODE ------------------------------
