function res = verifyTransformer(name)
% verifyTransformer - Verifies a transformer neural network using different approaches.
%
% Syntax:
%    res = verifyTransformer()
%
% Inputs:
%   -
%
% Outputs:
%    res - boolean
%
% References:
%    [1] Vaswani et al. "Attention is all you need", 2017
%

% ------------------------------ BEGIN CODE -------------------------------

% parse input
if nargin < 1 || isempty(name)
    name = char(datetime,'yyyyMMdd-HHmmss');
end

% limit number of cores
maxNumCompThreads(8);

% init
resultsPath = sprintf('./results/%s', name);
mkdir(resultsPath)
diary(sprintf('%s/diary.txt', resultsPath));

% load networks and dataset -----------------------------------------------

disp('Loading networks and dataset..')
configs = aux_loadModelDataConfigs();

% set options
options = struct();
options.nn.num_generators = 4000;
options.nn.add_approx_error_to_GI = true;
  
%  approaches '<set>_<error-order>'
approaches = {'polyZonotope_1','interval_0','zonotope_1','polyZonotope_2','polyZonotope_5','polyZonotope_7','polyZonotope_10','polyZonotope_15','polyZonotope_20','polyZonotope_50','polyZonotope_100','polyZonotope_200','polyZonotope_500','polyZonotope_1000'};

% Run verification --------------------------------------------------------

disp('Run verification..')

for c = 1:size(configs,1)
    disp('===============================================================')
    dataname = configs{c,1};
    data = configs{c,2};
    modelname = configs{c,3};
    fprintf('Data: %s, model: %s\n\n',dataname, modelname)
    resultsfile = sprintf('%s/results-data-%s-model-%s.mat',resultsPath,dataname,replace(modelname,'.json',''));

    % load model
    model = neuralNetwork.readTransformerNetwork([CORAROOT, '/models/Cora/nn/' modelname], false);

    % compute bounds of the embeddings and radius
    boundsEmbed = aux_computeEmbeddingBounds(model);
    boundsEmbedRadius = rad(boundsEmbed);

    % init model results
    model_results = struct;
    
    % iterate over approaches and call aux_verifyTransformer for each
    for a = 1:length(approaches)
        % get approach
        approach = approaches{a};
        idx_ = strfind(approach,'_');
        options.nn.transformer_approach = approach(1:(idx_-1));
        options.nn.transformer_error_order = str2double(approach((idx_+1):end));
        fprintf('Approach: %s\n', approach)
        
        % determine maximal verifiable noise radius
        disp('- Maximal verifiable noise radius:')
        resultsNoise = aux_verifyTransformer_NoiseRange(model, data, boundsEmbedRadius, options);
        model_results.noise.(approach) = resultsNoise;
        fprintf('%s - mean: %.5e, std: %.5e\n\n',approach,mean([resultsNoise.maxVerifiableNoise]),std([resultsNoise.maxVerifiableNoise]))
        
        % save results
        save(resultsfile, 'model_results')
    end

    % evaluate all approaches
    aux_evaluateResults(resultsfile,model,data,approaches);
end

diary off
res = true;
end


% Auxiliary functions -----------------------------------------------------

function configs = aux_loadModelDataConfigs()
% load transformer neural networks ---
% sort by dataset, #blocks, #heads
model_medical_1h_1b = 'transformer_model_in_medical_queries_1h_1b.json';
model_medical_2h_2b = 'transformer_model_in_medical_queries_2h_2b.json';
model_yelp_2h_1b = 'transformer_model_yelp_2h_1b.json';
model_yelp_2h_2b = 'transformer_model_yelp_2h_2b.json';
model_yelp_4h_3b = 'transformer_model_yelp_4h_3b.json';
%model_yelp_2h_4b = 'transformer_model_yelp_2h_4b.json';

% data sets ---
dataMedical = {; ... % medical queries
    [101, 2339, 2079, 3056, 3096, 3688, 2191, 2033, 5060, 1037, 2919, 5437, 2006, 2026, ...
    3096, 1029, 102, 0]; ...
    [101, 6034, 1998, 21446, 6181, 3255, 2043, 5777, 102, 0]; ...
    [101, 2129, 2000, 3066, 2007, 5458, 2159, 1998, 8300, 4167, 102, 0]; ...
    [101, 5505, 3314, 1012, 2048, 7435, 2360, 1045, 2031, 5410, 5505, 2021, ...
    2123, 1005, 1056, 2113, 2054, 1005, 1055, 2039, 1012, 2151, 2825, 12369, ...
    1029, 102, 0]; ...
    [101, 6387, 22390, 14978, 2015, 2008, 3402, 4487, 18719, 17585, 2043, 1045, ...
    25430, 4509, 2677, 28556, 102, 0]; ...
    [101, 10634, 3108, 3255, 2043, 14457, 1999, 3056, 7826, 1998, 2043, 5505, ...
    1998, 2107, 102, 0]; ...
    [101, 4157, 4593, 12943, 17643, 16952, 2015, 2026, 10089, 1011, 2003, 2045, ...
    2151, 3246, 1997, 5948, 2009, 2153, 1029, 102, 0]; ...
    [101, 7223, 3255, 2044, 5059, 1010, 5457, 2006, 2129, 2000, 7438, 1998, 2129, 2000, ...
    4468, 2582, 4544, 102, 0]; ...
    [101, 2054, 2024, 2691, 2966, 3471, 3141, 2000, 22935, 4295, 1029, 102, 0]; ...
    [101, 2054, 2785, 1997, 3450, 2064, 2393, 24251, 10699, 15074, 2015, 2005, ...
    8040, 4048, 6844, 8458, 7389, 6558, 1029, 102, 0]; ...
    [101, 5458, 1010, 11480, 1010, 2058, 20192, 3064, 1998, 4390, 5777, 1011, 2054, 2003, 3308, 2007, 2033, 102, 0];
    [101, 6898, 2038, 5573, 14447, 4654, 26775, 14194, 15370, 2067, 3255, 1012, 2025, 2469, 2054, 2000, 2079, 1012, 102, 0];
    [101, 8777, 1013, 5729, 2896, 2067, 20398, 1998, 3110, 2200, 16342, 2094, 1998, 5305, 1012, 2151, 4784, 2006, 2054, 2009, 2071, 2022, 1029, 102, 0];
    [101, 4895, 7913, 4383, 6911, 1011, 2755, 5397, 1010, 2054, 2024, 1996, 12763, 1029, 102, 0];
    [101, 2054, 7870, 3426, 17630, 27427, 14088, 13521, 2006, 1996, 2227, 1029, 102, 0],
    [101, 2339, 2572, 1045, 2383, 4629, 1998, 10634, 20398, 1999, 2026, 2896, 2187, 13878, 1029, 102, 0];
    [101, 2129, 2515, 10327, 5788, 2146, 2438, 1999, 1996, 4308, 2000, 2191, 2149, 5305, 1029, 102, 0];
    [101, 10694, 4629, 3255, 1999, 5110, 10120, 6740, 2043, 10917, 1012, 102, 0];
    [101, 6547, 7355, 1013, 2116, 8030, 1013, 2053, 11616, 2342, 6998, 102, 0];
    [101, 2128, 10085, 10841, 18807, 21419, 3255, 2003, 2067, 2007, 1037, 14096, 999, 2054, 2000, 2079, 1029, 102, 0]
};

dataYelp = {... % yelp
    [101, 1996, 2326, 2001, 4248, 1998, 1996, 2833, 2001, 12090, 102, 0]; ...
    [101, 1045, 5632, 2026, 2051, 2012, 2023, 26931, 2210, 3962, 102, 0]; ...
    [101, 1996, 3095, 2001, 5379, 1010, 1998, 1996, 7224, 2001, 18066, 102, 0]; ...
    [101, 1996, 12183, 2018, 1037, 2843, 1997, 7047, 1010, 2021, 1996, 7597, 2020, 2152, 1998, 1996, 2833, 2246, 2919, 102, 0]; ...
    [101, 2026, 5440, 9841, 2001, 1996, 23621, 28005, 2121, 1010, 2009, 2001, 2200, 11937, 21756, 102, 0]; ...
    [101, 1045, 2052, 16755, 2023, 4825, 2005, 1037, 10017, 4596, 2007, 2178, 2711, 102, 0]; ...
    [101, 1996, 8810, 2020, 2062, 2084, 2438, 2057, 8823, 2127, 2057, 2020, 8510, 2057, 2097, 2272, 2067, 2000, 2009, 102, 0]; ...
    [101, 1996, 4825, 9557, 2204, 1010, 1998, 1996, 3095, 2001, 12382, 1998, 2785, 1998, 3281, 2012, 2149, 102, 0]; ...
    [101, 1996, 18064, 2015, 2020, 11757, 11937, 21756, 1010, 2926, 1996, 7967, 9850, 2009, 2001, 6014, 102, 0]; ...
    [101, 1996, 3524, 2051, 2001, 2205, 2146, 1010, 1045, 6283, 2296, 2117, 2045, 102, 0]; ...
    [101, 1996, 2833, 2001, 3147, 1998, 1996, 2326, 2001, 4030, 2057, 4741, 2005, 2048, 2847, 102, 0]; ...
    [101, 1996, 4044, 2001, 20810, 1998, 8796, 1998, 1996, 5437, 2001, 11808, 2503, 102, 0]; ...
    [101, 1996, 8241, 2001, 12726, 1998, 4895, 21572, 7959, 28231, 2389, 2002, 2106, 2025, 2130, 2868, 2012, 2149, 102, 0]; ...
    [101, 1045, 2052, 2196, 2272, 2067, 2000, 2023, 2173, 2009, 2003, 1037, 10103, 102, 0]; ...
    [101, 1996, 8974, 2020, 2058, 18098, 6610, 2094, 1998, 12595, 2214, 1998, 2919, 102, 0]; ...
    [101, 2023, 4825, 2003, 2058, 9250, 1998, 2058, 18098, 6610, 2094, 1010, 2053, 2028, 4122, 2000, 5247, 2023, 2172, 102, 0]; ...
    [101, 1996, 3403, 2051, 2001, 2460, 1998, 2057, 2288, 2366, 2256, 2833, 5901, 1045, 2293, 2009, 102, 0]; ...
    [101, 1996, 25545, 2001, 25963, 1998, 14477, 4779, 26884, 1010, 1045, 2507, 2009, 5717, 3340, 102, 0]; ...
    [101, 1996, 2326, 2001, 6581, 1010, 2027, 2020, 2200, 3835, 1998, 7570, 13102, 6590, 3468, 102, 0]; ...
    [101, 1996, 2173, 2001, 2205, 10789, 1998, 20810, 2027, 2209, 9202, 2189, 102, 0]
};

% set up config
configs = {
    'medical', dataMedical, model_medical_1h_1b;
    % 'medical', dataMedical, model_medical_2h_2b;
    % 'yelp', dataYelp, model_yelp_2h_1b;
    % 'yelp', dataYelp, model_yelp_2h_2b;
    % 'yelp', dataYelp, model_yelp_4h_3b;
    % 'yelp', dataYelp, model_yelp_2h_4b;
};

end

function boundsEmbed = aux_computeEmbeddingBounds(model)
% helper function to compute bounds of the embedding matrix
inf = min(model.layers{1, 1}.token_emb(:), [], 1);
sup = max(model.layers{1, 1}.token_emb(:), [], 1);
boundsEmbed = interval(inf, sup);
end

function results = aux_verifyTransformer_NoiseRange(nn_model, sentences, boundsEmbedRadius, options)
% verify the transformer for different noise radii

% init results
results = struct;

% show intermediate results in table
table = CORAtable("minimalistic", ...
    {'Sentence','Length','Step','Noise Radius','Radius [%]','Verified','Computation Time','(Output Radius,','Upper Bound)'}, ...
    {'i','i','i','.4e','.4e','i','.3f','.4f','.4f'});
table.printHeading();

% iterate over all tokenized sentences
for s = 1:numel(sentences)
    input = sentences{s};

    % init output arrays
    stepsBinarySearch = 15;
    noiseArray = zeros(1,stepsBinarySearch);
    pertRadiusArray = zeros(1,stepsBinarySearch);
    isVerifiedArray = false(1,stepsBinarySearch);
    timesArray = zeros(1,stepsBinarySearch);
    radiiArray = zeros(1,stepsBinarySearch);
    ubArray = zeros(1,stepsBinarySearch);

    try
        % find maximal noise radius
        noise_range = [0,1];
        n = 0;
        while n < stepsBinarySearch || ~any(isVerifiedArray)
            n = n+1 ;
    
            % compute perturbation radius
            noise = mean(noise_range);
            pertRadius = noise*boundsEmbedRadius;
    
            % perform verification for the current sentence and noise
            [isVerified,time,ub,radius] = aux_verifyInstance(nn_model, input, pertRadius, options);
    
            % do binary search
            if isVerified
                % increase radius
                noise_range = [noise,noise_range(2)];
            else
                % decrease radius
                noise_range = [noise_range(1),noise];
            end
    
            % save results
            noiseArray(n) = noise;
            pertRadiusArray(n) = pertRadius;
            isVerifiedArray(n) = isVerified;
            radiiArray(n) = radius;
            timesArray(n) = time;
            ubArray(n) = ub;
            table.printContentRow({s,numel(input),n,pertRadius,noise,isVerified,time,radius,ub}) 
        end
    catch ME
        % save error message
        fprintf('# ERROR: %s\n', ME.message)
        results(numel(sentences)+1-s).ME = ME;
    end

    % verbose output
    if s < numel(sentences)
        table.printMidBoundaryRow();
    end
        
    % store in results (pre-allocate)
    results(numel(sentences)+1-s).maxVerifiableNoise = max(noiseArray(isVerifiedArray));
    results(numel(sentences)+1-s).boundsEmbedRadius = boundsEmbedRadius;
    results(numel(sentences)+1-s).noiseArray = noiseArray;
    results(numel(sentences)+1-s).pertRadiusArray = pertRadiusArray;
    results(numel(sentences)+1-s).isVerifiedArray = isVerifiedArray;
    results(numel(sentences)+1-s).timesArray = timesArray;
    results(numel(sentences)+1-s).radiiArray = radiiArray;
    results(numel(sentences)+1-s).ubArray = ubArray;
end
table.printBottom();


end

function [isVerified, time, ub, radius] = aux_verifyInstance(model, input, noise, options)
% helper function to verify one instance

% get correct label
pred = model.evaluate(input);
[~, label] = max(pred);

% init input set
embedding = model.evaluate(input,options,1);
% init set
c = reshape(embedding, [], 1);
G = noise * eye(numel(c));
X = polyZonotope(c, G);

% convert input set based on the approach
if contains(options.nn.transformer_approach, 'polyZonotope')
    % S = S
elseif ismember(options.nn.transformer_approach, 'zonotope')
    X = zonotope(X);
elseif strcmp(options.nn.transformer_approach, 'interval')
    X = interval(X);
    X = reshape(X, length(input), []);
else
    throw(CORAerror('CORA:wrongValue', sprintf('Unknown approach: %s', options.nn.transformer_approach), {'polyZonotope', 'zonotope', 'interval', 'zonotope_precise'}));
end

% start timer
tic;

% evaluate input set (without softmax)
idxLayer = 2:length(model.layers)-1;
Y = model.evaluate(X, options, idxLayer);
if isa(Y,'interval')
    Y = Y';
end

% apply verification trick
M = eye(2);
M(:, label) = M(:, label) - 1;
M = sum(M, 1); % as output is 2d, transform into 1D <= 0 check
Y = M * Y;

% check specification
ub = supportFunc(Y,1);
isVerified = ub <= 0;

% stop timer
time = toc;

% gather outputs
radius = ub - center(Y);

end

function aux_evaluateResults(resultsfile,model,data,approaches)
    % load results file
    load(resultsfile,'model_results');
    % load('.\results\model5\results-data-yelp-model-transformer_model_yelp_smallest_no_var.mat','model_results')
    

    % read out noise evaluation
    baselineApproach = 'zonotope_1';
    noisestruct = model_results.noise;
    numsentences = numel(noisestruct.(baselineApproach));

    % show table
    table = CORAtable('latex',{'Approach','Time [s]','Verified Radius','Verified Volume [\%]'},{'s','s','s','s'});
    table.printHeading();

    % iterate through all sentences
    for a = 1:numel(approaches)
        try
            approach = approaches{a};
    
            % init
            radii_approach = zeros(1,numsentences);
            time_approach = zeros(1,numsentences);
            radii_baseline = zeros(1,numsentences);
            baselineStruct = noisestruct.(baselineApproach);
    
            % iterate through all sentences
            for s=1:numsentences
                sentencestruct = noisestruct.(approach);
                radii_approach(s) = sentencestruct(s).maxVerifiableNoise;
                time_approach(s) = mean(sentencestruct(s).timesArray);
                radii_baseline(s) = baselineStruct(s).maxVerifiableNoise;
            end
    
            % remove nans
            idx = radii_approach > 0;
            radii_approach = radii_approach(idx);
            time_approach = time_approach(idx);
            radii_baseline = radii_baseline(idx);
        
            % compute volume comparison
            rel_volume = (radii_approach ./ radii_baseline).^(cellfun(@numel, data)' * size(model.layers{1, 1}.token_emb,2)) * 100;
            
            % print table
            table.printContentRow({approach,sprintf('$%.2f\\pm%.2f$',mean(time_approach),std(time_approach)),sprintf('$%.4e\\pm%.4e$',mean(radii_approach),std(radii_approach)),sprintf('$%.2f\\pm%.2f$',mean(rel_volume(idx)),std(rel_volume(idx)))})
        catch ME

        end
    end
    table.printBottom();

    % min/max sentence length
    sentencelengths = cellfun(@numel, data)';
    fprintf('Sentence length - min: %i, max: %i, mean: %.2f, std: %.2f.\n', min(sentencelengths), max(sentencelengths), mean(sentencelengths), std(sentencelengths))

    
end

% ------------------------------ END OF CODE ------------------------------
