classdef nnSelfAttentionLayer < nnLayer
% nnSelfAttentionLayer - class for self attention layers
%
% Syntax:
%    obj = nnSelfAttentionLayer(W_Q, W_K, W_V,name)
%
% Inputs:
%    W_Q -  queries weight matrix
%    W_K -  keys weight matrix
%    W_V -  values weight matrix
%    name - name of the layer
%
% Outputs:
%    obj - generated object
%
% References:
%    [1] Vaswani et al. "Attention is all you need", 2017
%    [2] Bonaert et al. "Fast and Precise Certification of Transformers", 2021
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: neuralNetwork

% ------------------------------ BEGIN CODE -------------------------------

properties
    d_K, d_V
    W_Q, W_K, W_V
    softmaxLayer
end

methods
 % constructor
 function obj = nnSelfAttentionLayer(W_Q, W_K, W_V,varargin)
        if nargin < 3
          throw(CORAerror('CORA:notEnoughInputArgs', 3));
        end
        [name] = setDefaultValues({[]}, varargin);
        inputArgsCheck({ ...
            {W_Q, 'att', 'numeric'}; ...
            {W_K, 'att', 'numeric'}; ...
            {W_V, 'att', 'numeric'}; ...
        })

        % call super class constructor
        obj@nnLayer(name)
        
        obj.W_Q = W_Q;
        obj.W_K = W_K;
        obj.W_V = W_V;
        obj.d_K = size(W_K, 2);
        obj.d_V = size(W_V, 2);
        obj.softmaxLayer = nnSoftmaxLayer();
    end

    %get i-th derivative (not needed)
    function df_i = getDf(obj, i)
        function r = deriv(x)
            r = [];
        end
        df_i = @(x) deriv(x);
    end
    
    function der1 = getDerBounds(obj, l, u) 
        % df_l and df_u as lower and upper bound for the derivative
        der1 = interval(0, 1);
    end
    function outputSize = getOutputSize(obj, inputSize)
        outputSize = 0;
    end

    function [nin, nout] = getNumNeurons(obj)
        if isempty(obj.inputSize)
            nin = [];
            nout = [];
        else
            % we can only compute the number of neurons if the input
            % size was set.
            nin = prod(obj.inputSize);
            outputSize = getOutputSize(obj, obj.inputSize);
            nout = prod(outputSize);
        end
    end

end

% evaluate ----------------------------------------------------------------

methods (Access = {?nnLayer, ?neuralNetwork})
    
    % numeric
    function [Z, obj] = evaluateNumeric(obj, X, options)
        % compute query, key and value matrices as linear projections
        Q = X * obj.W_Q;
        K = X * obj.W_K;
        V = X * obj.W_V;
    
        % compute the scaled dot-product attention
        scores = (Q * K') / sqrt(obj.d_K);

        % apply softmax to scores matrix row by row
        for i = 1:size(scores,1)
            scores(i, :) = obj.softmaxLayer.evaluateNumeric((scores(i, :)'), options);
        end
        Z = scores * V;
    end

    % sensitivity
    function S = evaluateSensitivity(obj, S, x, options)
        S = [];
    end
    
    % interval
    function bounds = evaluateInterval(obj, H, options)

        % compute K,Q,V
        Q = H * obj.W_Q;
        K = H * obj.W_K;
        V = H * obj.W_V;

        % compute bounds of multiplication Q*K' / sqrt(d_k)
        bounds_scores = Q * K';
        bounds_scores = bounds_scores / sqrt(obj.d_K);
        
        % compute softmax
        bounds_softmax = 0 * bounds_scores;
        for i=1:size(bounds_softmax,1)
            % compute rowwise
            bounds_softmax_i = bounds_scores(i,:)';
            res = obj.softmaxLayer.evaluateInterval(bounds_softmax_i,options);
            bounds_softmax(i,:) = res';
        end
        
        % compute multiplication with V
        bounds = bounds_softmax * V;
    end
    
    % zonotope/polyZonotope
    function [c, G, GI, E, id, id_, ind, ind_] = evaluatePolyZonotope(obj, c, G, GI, E, id, id_, ind, ind_, options)
        
        if strcmp(options.nn.transformer_approach, "zonotope")
            [c, GI, id] = aux_evaluateZonotope(obj, c, GI, id, id_, options);
            return
        end

        % compute for polyZonotope

        % compute Q, K, V ---

        % compute query polyZonotope
        c_Q = nnHelper.affineMap(c, obj.W_Q, 0);
        G_Q = nnHelper.affineMap(G, obj.W_Q, 0);
        GI_Q = nnHelper.affineMap(GI, obj.W_Q, 0);
        PZ_Q = polyZonotope(c_Q, G_Q, GI_Q, E, id);
        
        % compute transpose key polyZonotpe
        c_Kt = nnHelper.transpose(nnHelper.affineMap(c, obj.W_K, 0), obj.d_K);
        G_Kt = nnHelper.transpose(nnHelper.affineMap(G, obj.W_K, 0), obj.d_K);
        GI_Kt = nnHelper.transpose(nnHelper.affineMap(GI, obj.W_K, 0), obj.d_K);
        PZ_Kt = polyZonotope(c_Kt, G_Kt, GI_Kt, E, id);

        % compute value polyZonotope
        c_V = nnHelper.affineMap(c, obj.W_V, 0);
        G_V = nnHelper.affineMap(G, obj.W_V, 0);
        GI_V = nnHelper.affineMap(GI, obj.W_V, 0);
        PZ_V = polyZonotope(c_V, G_V, GI_V, E, id);

        % extract dimension
        n = size(c_Q,1)/obj.d_K;

        % compute Q*K' / sqrt(d_k) ---
         
        % compute attention scores
        error_order = options.nn.transformer_error_order;
        PZ_scores = nnHelper.quadMapPolyZono(PZ_Q,PZ_Kt,n,obj.d_K,n,error_order);
        order_in = size(G,2)/numel(c);
        % PZ_scores = reduce(PZ_scores,'girard',order_in+3);

        % scale attention scores
        PZ_scores = PZ_scores * (1 / sqrt(obj.d_K));
        
        % prepare propagation
        c_scores = PZ_scores.c;
        if isempty(PZ_scores.G)
           G_scores = zeros(n^2,0);
           E_scores = zeros(0,0);
           id_scores = [];
        else
           G_scores = PZ_scores.G;
           E_scores = PZ_scores.E;
           id_scores = PZ_scores.id;
        end
        if isempty(PZ_scores.GI)
            GI_scores = zeros(n^2,0);
        else
            GI_scores = PZ_scores.GI;
        end
        id__scores = max([id_;id_scores]);
        ind = find(prod(ones(size(E_scores))-mod(E_scores, 2), 1) == 1);
        ind_ = setdiff(1:size(E_scores, 2), ind);

        % compute softmax(...) ---        

        % extract needed dimensions
        h = size(G_scores,2);
        q = size(GI_scores,2);
         
        % reshape to respective matrix
        c_scores = reshape(c_scores,n,n); 
        G_scores = reshape(G_scores,n,n,h);
        GI_scores = reshape(GI_scores,n,n,q);
        
        % preallocate results
        c_sm_approx = zeros(n,n);
        G_sm_approx = zeros(n,n,h);
        GI_sm_approx = zeros(n,n,q);
        GI_sm_error = zeros(n,n,n,n);
        id_ = max([id_scores;id_]);

        % apply softmax to each row in the center matrix and store results
        for i = 1:n
            % adapt options for softmax
            options_softmax = options;
            options_softmax.nn.add_approx_error_to_GI = true;

            % compute rowwise
            c_i = (c_scores(i, :))';
            G_i = reshape(G_scores(i, :, :), n, h);
            GI_i = reshape(GI_scores(i, :, :), n, q);

            % compute softmax
            [c_res_i, G_sm_i, GI_sm_i, ~, ~, ~] = ...
            obj.softmaxLayer.evaluatePolyZonotope(c_i, G_i, GI_i, E_scores, id_scores, id__scores, ind, ind_, options_softmax);

            % transpose back to columns
            c_res_i = c_res_i';
            G_sm_i = reshape(G_sm_i,n,h);
            GI_sm_i = reshape(GI_sm_i,n,q+n);

            % save results (as rows)
            c_sm_approx(i,:) = c_res_i;
            G_sm_approx(i,:,:) = G_sm_i;
            GI_sm_approx(i,:,:) = GI_sm_i(:,1:q);
            GI_sm_error(i,:,i,:) = GI_sm_i(:,(q+1):end); % make approx errors independent
        end

        % reshape result back to polynomial zonotope
        c_sm_approx = reshape(c_sm_approx, [], 1);
        G_sm_approx = reshape(G_sm_approx, [], h);
        GI_sm_approx = reshape(GI_sm_approx, [], q);
        GI_sm_error = reshape(GI_sm_error, [], n*n);
        E_sm_approx = E_scores;
        id_sm_approx = id_scores;

        % apply order reduction
        d_red = zeros(size(c_sm_approx));
        % [c_sm_approx, G_sm_approx, GI_sm_approx, E_sm_approx, id_sm_approx, d_red] = nnHelper.reducePolyZono(c_sm_approx, G_sm_approx, GI_sm_approx, E_sm_approx, id_sm_approx, options.nn.num_generators);

        % init softmax approximation
        PZ_sm_approx = polyZonotope(c_sm_approx,G_sm_approx,[],E_sm_approx,id_sm_approx);

        % make softmax approx error independent
        c_sm_error = zeros(size(c_sm_approx));
        PZ_sm_error = polyZonotope(c_sm_error, [GI_sm_approx GI_sm_error+diag(d_red)]);
        % update ids to make independent
        PZ_sm_error = PZ_sm_error.replaceId(id_+PZ_sm_error.id);

        % compute softmax
        PZ_sm = exactPlus(PZ_sm_approx, PZ_sm_error);

        % compute softmax(...) * V ---

        PZ_out = nnHelper.quadMapPolyZono(PZ_sm,PZ_V,n,n,obj.d_V,error_order);
        % PZ_out = reduce(PZ_out,'girard',order_in+2);

        % assign output ---
        c = PZ_out.c;
        G = PZ_out.G;
        GI = PZ_out.GI;
        E = PZ_out.E;
        id = PZ_out.id;
        id_ = max(id);
        ind = find(prod(ones(size(E))-mod(E, 2), 1) == 1); 
        ind_ = setdiff(1:size(E, 2), ind);
    end
      
end

end


% Auxiliary functions -----------------------------------------------------

function [c, G, id] = aux_evaluateZonotope(obj, c, G, id, id_, options)
    % compute for zonotope [2]

    newIds = isempty(id);
    if newIds
        id = id_ + (1:size(G,2));
        id_ = max([id_,id]);
    end

     % compute Q, K, V ---

    % compute query polyZonotope
    c_Q = nnHelper.affineMap(c, obj.W_Q, 0);
    G_Q = nnHelper.affineMap(G, obj.W_Q, 0);
    
    % compute transpose key polyZonotpe
    c_Kt = nnHelper.transpose(nnHelper.affineMap(c, obj.W_K, 0), obj.d_K);
    G_Kt = nnHelper.transpose(nnHelper.affineMap(G, obj.W_K, 0), obj.d_K);

    % compute value polyZonotope
    c_V = nnHelper.affineMap(c, obj.W_V, 0);
    G_V = nnHelper.affineMap(G, obj.W_V, 0);

    % extract dimension
    n = size(c_Q,1)/obj.d_K;

    % compute Q*K' / sqrt(d_k) ---
     
    % compute attention scores
    [c_scores,G_scores,id_scores] = nnHelper.quadMapZonotope(c_Q,G_Q,id,c_Kt,G_Kt,id,n,obj.d_K,n);
    
    % assign ids to new generators (will be the same across all heads)
    % id_scores = [id_scores, id_+1:(size(G_scores,2)-numel(id_scores))];
    % id_ = max([id_,id_scores]);

    % scale attention scores
    c_scores = c_scores * (1 / sqrt(obj.d_K));
    G_scores = G_scores * (1 / sqrt(obj.d_K));

    % compute softmax(...) ---        

    % extract needed dimensions
    h = size(G_scores,2);
     
    % reshape to respective matrix
    c_scores = reshape(c_scores,n,n); 
    G_scores = reshape(G_scores,n,n,h);
    
    % preallocate results
    c_sm = zeros(n,n);
    G_sm_approx = zeros(n,n,h);
    G_sm_error = zeros(n,n,n,n);

    % apply softmax to each row in the center matrix and store results
    for i = 1:n
        % adapt options for softmax
        options_softmax = options;
        options_softmax.nn.add_approx_error_to_GI = true;

        % compute rowwise
        c_i = (c_scores(i, :))';
        G_i = reshape(G_scores(i, :, :), n, h);

        % compute softmax
        [c_res_i, ~, G_sm_i, ~, ~, ~] = ...
        obj.softmaxLayer.evaluatePolyZonotope(c_i, zeros(n,0), G_i, [], [], 1, [], [], options_softmax);

        % transpose back to columns
        c_res_i = c_res_i';
        G_sm_i = reshape(G_sm_i,n,h+n);

        % save results (as rows)
        c_sm(i,:) = c_res_i;
        G_sm_approx(i,:,:) = G_sm_i(:,1:h);
        G_sm_error(i,:,i,:) = G_sm_i(:,(h+1):end); % make approx errors independent
    end

    % reshape result back to polynomial zonotope
    c_sm = reshape(c_sm, [], 1);
    G_sm_approx = reshape(G_sm_approx, [], h);
    G_sm_error = reshape(G_sm_error, [], n*n);
    G_sm = [G_sm_approx G_sm_error];

    % apply order reduction
    [c_sm,G_sm,id_sm] = nnHelper.reduceLabelledZono(c_sm,G_sm,id_scores,options.nn.num_generators);

    % compute softmax(...) * V ---

    [c,G,id] = nnHelper.quadMapZonotope(c_sm,G_sm,id_sm,c_V,G_V,id,n,n,obj.d_V);   

    % delete ids if necessary
    if newIds
        id = [];
    end
end


% ------------------------------ END OF CODE ------------------------------
