function res = test_nn_transformer_verify_instance()
% test_nn_transformer_verify_instance - verifies one instance of a
%    transformer
%
%
% Syntax:
%    res = test_nn_transformer_verify_instance
%
% Inputs:
%    -
%
% Outputs:
%    res - boolean
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: -

% Authors:       Tobias Ladner
% Written:       22-September-2024
% Last update:   ---
% Last revision: ---

% ------------------------------ BEGIN CODE -------------------------------

resvec = [];

% load model
model = neuralNetwork.readTransformerNetwork([CORAROOT '/models/Cora/nn/transformer_model_in_medical_queries_big.json']);
input = [101, 10694, 4629, 3255, 1999, 5110, 10120, 6740, 2043, 10917, 1012, 102, 0];

% define options
options = struct();
options.nn.num_generators = 1000;
options.nn.add_approx_error_to_GI = true;

% compute embedding
embedding = model.evaluate(input,options,1);
resvec(end+1) = all(withinTol(embedding,[ -0.0005640182644129 -0.0241730203852057 -0.0460567697882652 -0.0042961268918589 -0.0068041794002056 0.0494054399896413 0.0314212534576654 0.0420499704778194 ; -0.0557963581522927 -0.0541303241625428 0.0278755305334926 -0.0382628673687577 0.0897296927869320 -0.0056197056546807 -0.0309314662590623 0.0139696069527417 ; -0.0420412803068757 -0.0255224527791142 0.0010005291551352 -0.0066663213074207 0.0302781821228564 0.0019108429551125 -0.0688383989036083 0.0012397654354572 ; -0.0531686100875959 0.0356216728687286 -0.0509245367720723 0.0256451014429331 0.1039550444111228 0.0740018077194691 -0.0912788007408381 0.0483412109315395 ; -0.0180325889959931 -0.0343617522157729 0.0169891938567162 -0.0192955452948809 0.0083953022840433 -0.0290503608994186 -0.0501143960282207 -0.0444871392101049 ; -0.0503410790115595 0.0554020237177610 -0.0293949730694294 0.0389996282756329 0.0378459361381829 0.0528018325567245 -0.0026217079721391 -0.0112485052159172 ; -0.0057230121456087 0.0342804538086057 0.0007017161697149 0.0240233130753040 0.0690761324949563 0.0236580765340477 -0.0302294231951237 0.0282487608492374 ; -0.0110494444379583 -0.0331779867410660 0.0157203482303885 0.0013946415856481 0.0521200671792030 0.0218681474216282 -0.0285612624138594 0.0088050374761224 ; -0.0454647669102997 0.0403748604003340 -0.0588305182754993 -0.0232503125444055 0.0577112454921007 0.0587209477089345 -0.0238700862973928 0.0616108868271112 ; 0.0048491209745407 0.0142202544957399 0.0161546575836837 -0.0218190317973495 0.0475376211106777 -0.0564418155699968 -0.0753316860646009 -0.0204363486263901 ; -0.0051322062499821 -0.0466150781139731 -0.0625990647822618 0.0343972258269787 0.0445417924784124 -0.0001154011115432 0.0061667473055422 0.0336610209196806 ; -0.0422777254134417 -0.0855807550251484 -0.0319487771485001 -0.0357087114825845 -0.0645918846130371 0.0199477560818195 -0.0894165337085724 -0.0704615004360676 ; -0.0000155051238835 -0.0072098672389984 -0.0120314471423626 -0.0124147878959775 0.0041543846018612 -0.0154295649845153 -0.0046574100852013 -0.0059186853468418 ]), 'all');

% get correct label
pred = model.evaluate(input);
[~, label] = max(pred);
resvec(end+1) = all(withinTol(pred,[ 0.3717860500147145 ; 0.6282139499852853 ]));
resvec(end+1) = label == 2;


% set-based evaluation-----------------------------------------------------

% polyZonotope ---
c = reshape(embedding, [], 1);
G = 1e-3 * eye(numel(c));
X = polyZonotope(c, G);

% evaluate input set (without softmax)
idxLayer = 2:length(model.layers)-1;
Y = model.evaluate(X, options, idxLayer);

% compute for samples
N = 100;
xs = [c,X.randPoint(N),X.randPoint(N,'extreme')];
ys = zeros(2,N);
for i = 1:size(xs,2)
    x = reshape(xs(:, i), size(embedding));
    y = model.evaluate(x,options,idxLayer);
    ys(:, i) = reshape(y,[], 1);
end

% check containment
resvec(end+1) = all(contains(interval(Y), ys));

% zonotope ---
X = zonotope(X);

% evaluate input set (without softmax)
idxLayer = 2:length(model.layers)-1;
Y = model.evaluate(X, options, idxLayer);

% check containment
resvec(end+1) = all(contains(interval(Y), ys));

% interval ---
X = interval(X);
X = reshape(X, size(embedding));

% evaluate input set (without softmax)
idxLayer = 2:length(model.layers)-1;
Y = model.evaluate(X, options, idxLayer)';

% check containment
resvec(end+1) = all(contains(Y, ys));

% gather results
res = all(resvec);

end

% ------------------------------ END OF CODE ------------------------------
