function res = verifyTransformer_snd_dataset(name)
% verifyTransformer - Verifies a transformer neural network using different approaches.
%
% Syntax:
%    res = verifyTransformer_snd_dataset()
%
% Inputs:
%   -
%
% Outputs:
%    res - boolean
%
% References:
%    [1] Vaswani et al. "Attention is all you need", 2017
%

% Authors:       Rayen Mhadhbi, Tobias Ladner
% Written:       16-September-2024
% Last update:   20-September-2024 (TL, set up evaluation)
% Last revision: ---

% ------------------------------ BEGIN CODE -------------------------------

% parse input
if nargin < 1 || isempty(name)
    name = char(datetime,'yyyyMMdd-hhmmss');
end

% limit number of cores
maxNumCompThreads(4);

% init
resultsPath = sprintf('./results/%s', name);
mkdir(resultsPath)
diary(sprintf('%s/diary.txt', resultsPath));

% load networks and dataset -----------------------------------------------

disp('Loading networks and dataset..')

% load transformer neural networks ---
verbose = false;
model1 = neuralNetwork.readTransformerNetwork('/Users/rayenmhadhbi/PycharmProjects/BachelorArbeit/yelp_model/transformer_model_yelp_medium.json', verbose);
models = {model1};

% data sets ---
data1 = {[ 101, 1045, 2018, 1037, 2307, 3325, 2012, 1996, 4825,  102, 0]
};
datas = {data1};

% set options
options = struct();
options.nn.num_generators = 500;
options.nn.add_approx_error_to_GI = true;
options = nnHelper.validateNNoptions(options);
    

%  approaches
approaches = {"interval", "polyZonotope"};

% Run verification --------------------------------------------------------

disp('Run verification..')

for m = 1:numel(models)
    disp('---------------------------------------------------------------')
    fprintf('Model: %i\n\n',m)

    % get model
    model = models{m};
    
    % compute bounds of the embeddings and radius
    boundsEmbed = aux_computeEmbeddingBounds(model);
    boundsEmbedRadius = rad(boundsEmbed);

    % init model results
    model_results = struct;
    
    % iterate over approaches and call aux_verifyTransformer for each
    for a = 1:length(approaches)
        % get approach
        approach = approaches{a};
        options.nn.approach_transformer = approach;
        fprintf('Approach: %s\n', approach)
        
        % determine maximal verifiable noise radius
        disp('- Maximal verifiable noise radius:')
        resultsNoise = aux_verifyTransformer_NoiseRange(model1, data1, boundsEmbedRadius, options);
        model_results.noise.(approach) = resultsNoise;
        fprintf('  mean: %.4e, std: %.4e\n\n',mean([resultsNoise.maxVerifiableNoise]),std([resultsNoise.maxVerifiableNoise]))
        
        % save results
        save(sprintf('%s/model%i-results.mat',resultsPath,m), 'model_results')
    end

end

diary off
res = true;
end


% Auxiliary functions -----------------------------------------------------

function bounddsEmbed = aux_computeEmbeddingBounds(model)
% helper function to compute bounds of the embedding matrix
inf = min(model.layers{1, 1}.token_emb(:), [], 1);
sup = max(model.layers{1, 1}.token_emb(:), [], 1);
bounddsEmbed = interval(inf, sup);
end

function results = aux_verifyTransformer_NoiseRange(nn_model, sentences, boundsEmbedRadius, options)
% verify the transformer for different noise radii

% init results
results = struct;

% Define table headers
TableHeaders = {'Sentence', 'Step', 'Noise Radius [%]', 'Verified', 'Computation Time', 'Output Radius'};
ResultsTable = [];

% Iterate over all tokenized sentences
for s = 1:numel(sentences)
    input = sentences{s};

    % Initialize output arrays
    stepsBinarySearch = 10;
    noiseArray = zeros(1,stepsBinarySearch);
    isVerifiedArray = false(1,stepsBinarySearch);
    timesArray = zeros(1,stepsBinarySearch);
    radiiArray = zeros(1,stepsBinarySearch);

    % Binary search to find maximal noise radius
    noise_range = [0, 1];
    for n = 1:stepsBinarySearch
        noise = mean(noise_range);

        % Perform verification for the current sentence and noise
        [isVerified, time, radius] = aux_verifyInstance(nn_model, input, noise * boundsEmbedRadius, options);

        % Do binary search
        if isVerified
            noise_range = [noise, noise_range(2)];
        else
            noise_range = [noise_range(1), noise];
        end

        % Save results
        noiseArray(n) = noise;
        isVerifiedArray(n) = isVerified;
        radiiArray(n) = radius;
        timesArray(n) = time;

        % Add results to table
        ResultsTable = [ResultsTable; {s, n, noise * 100, isVerified, time, radius}];
    end

    % Store in results (pre-allocate)
    results(s).maxVerifiableNoise = noiseArray(end);
    results(s).noiseArray = noiseArray;
    results(s).isVerifiedArray = isVerifiedArray;
    results(s).timesArray = timesArray;
    results(s).radiiArray = radiiArray;
    results(s).boundsEmbedRadius = boundsEmbedRadius;
end

% Display the results table
ResultsTable = cell2table(ResultsTable, 'VariableNames', TableHeaders);
disp(ResultsTable);

end

function [isVerified, time, radius] = aux_verifyInstance(nn_model, input, noise, options)
% helper function to verify one instance

% get correct label
pred = nn_model.evaluate(input);
[~, label] = max(pred);

% init input set
embedding = nn_model.layers{1, 1}.evaluate(input);
% init set
c = reshape(embedding, [], 1);
G = noise * eye(numel(c));
X = polyZonotope(c, G);

% convert input set based on the approach
if strcmp(options.nn.approach_transformer, 'polyZonotope')
    % S = S
elseif ismember(options.nn.approach_transformer, {'zonotope', 'zonotope_precise'})
    X = zonotope(X);
elseif strcmp(options.nn.approach_transformer, 'interval')
    X = interval(X);
    X = reshape(X, length(input), 16);
else
    throw(CORAerror('CORA:wrongValue', sprintf('Unknown approach: %s', options.nn.transformer_approach), {'polyZonotope', 'zonotope', 'interval', 'zonotope_precise'}));
end

% start timer
tic;

% evaluate input set
Y = nn_model.evaluate(X, options, 2:length(nn_model.layers));

% apply verification trick
M = eye(2);
M(:, label) = M(:, label) - 1;
M = sum(M, 1); % as output is 2d, transform into 1D <= 0 check
Y = M * Y;

% check specification
Y = interval(Y);
isVerified = Y.sup <= 0;

% stop timer
time = toc;
radius = rad(Y);

end

% ------------------------------ END OF CODE ------------------------------


