classdef nnMultiHeadAttentionLayer < nnLayer
% nnMultiHeadAttentionLayer - class for multihead attention layers
%
% Syntax:
%    obj = nnMultiHeadAttentionLayer(W_Q,W_K,W_V,W_O,numHeads,name)
%
% Inputs:
%    name - name of the layer, defaults to type
%    W_Q - query weight matrix for all attention heads stacked hor
%    W_K - key weight matrix for all attention heads stacked hor
%    W_V - value weight matrix for all attention heads stacked hor
%    W_O - aggregation projection matrix
%    name - name of the layer
%
% Outputs:
%    obj - generated object
%
% References:
%    [1] Vaswani et al. "Attention is all you need", 2017
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: neuralNetwork

% ------------------------------ BEGIN CODE -------------------------------

properties
    W_O
    heads
    numHeads
end

methods
    % constructor
    function obj = nnMultiHeadAttentionLayer(W_Q, W_K, W_V, W_O, numHeads, varargin)
        if nargin < 5
            throw(CORAerror("CORA:notEnoughInputArgs", 5))
        end
        [name] = setDefaultValues({[]}, varargin);
        inputArgsCheck({ ...
            {W_Q, 'att', 'numeric'}; ...
            {W_K, 'att', 'numeric'}; ...
            {W_V, 'att', 'numeric'}; ...
            {numHeads, 'att', 'numeric', 'scalar'}; ...
            })

        % call superclass constructor
        obj@nnLayer(name);

        % initialize properties
        obj.W_O = W_O;
        obj.numHeads = numHeads;

        % split the weight matrices for each head
        d_model = size(W_Q, 2); % d_model: the dimensionality of the model
        d_head = d_model / numHeads; % dimension of each head

        % initialize attention layers for each head
        obj.heads = cell(1, numHeads);
        for i = 1:numHeads
            head_WQ = W_Q(:, (i - 1)*d_head+1:i*d_head);
            head_WK = W_K(:, (i - 1)*d_head+1:i*d_head);
            head_WV = W_V(:, (i - 1)*d_head+1:i*d_head);
            obj.heads{i} = nnSelfAttentionLayer(head_WQ, head_WK, head_WV, ['Head ', num2str(i)]);
        end
    end

    function outputSize = getOutputSize(obj, inputSize)
        outputSize = 0;
    end

    function [nin, nout] = getNumNeurons(obj)
        if isempty(obj.inputSize)
            nin = [];
            nout = [];
        else
            % we can only compute the number of neurons if the input
            % size was set.
            nin = prod(obj.inputSize);
            outputSize = getOutputSize(obj, obj.inputSize);
            nout = prod(outputSize);
        end
    end
end

% evaluate ----------------------------------------------------------------

methods (Access = {?nnLayer, ?neuralNetwork})

    % evaluate numerically
    function [Z, obj] = evaluateNumeric(obj, X, options)
        % compute attention for each head
        head_outputs = cell(1, obj.numHeads);

        for i = 1:obj.numHeads
            head_output = obj.heads{i}.evaluateNumeric(X, options);
            head_outputs{i} = head_output; % concatenate head outputs
        end

        head_outputs = horzcat(head_outputs{:});
        % linear projection after concatenating head outputs
        Z = head_outputs * obj.W_O;
    end

    % sensitivity
    function S = evaluateSensitivity(obj, S, x, options)
        S = [];
    end

    % evaluate interval (used for abstract transformers like zonotopes)
    function bounds = evaluateInterval(obj, bounds, options)
        % compute attention for each head and concatenate results
        head_bounds = cell(obj.numHeads,1);

        for i = 1:obj.numHeads
            head_bounds{i} = obj.heads{i}.evaluateInterval(bounds, options);
        end
        head_bounds = [head_bounds{:}];

        % apply projection matrix W_O to concatenated results
        bounds = head_bounds * obj.W_O;
    end

    % zonotope/polyZonotope evaluation
    function [c, G, GI, E, id, id_, ind, ind_] = evaluatePolyZonotope(obj, c, G, GI, E, id, id_, ind, ind_, options)

        if strcmp(options.nn.transformer_approach, "zonotope")
            assert(isempty(G),'zonotope approach should only uses independent generators')
            newIds = isempty(id);
            if newIds
                % use identifier vector for independent generators as labels
                % for parallel computations
                id = id_ + (1:size(GI,2));
                id_ = max([id_,id]);
            end
        
            c_res = zeros(size(c));
            GI_res = zeros(numel(c),0);
            id_res = [];
            for i = 1:obj.numHeads
                % evaluate head
                [c_head, ~, GI_head, ~, id_head, ~, ~, ~] = obj.heads{i}.evaluatePolyZonotope(c, G, GI, E, id, id_, ind, ind_, options);
                
                % project output of head
                d_V = obj.heads{i}.d_V;
                W_O_proj = obj.W_O(((i-1)*d_V+1):(i*d_V), :);
                c_head = nnHelper.affineMap(c_head, W_O_proj, 0);
                GI_head = nnHelper.affineMap(GI_head, W_O_proj, 0);

                % add to previous heads
                c_res = c_res+c_head;
                [GI_res,id_res] = nnHelper.mergeLabelledGenerators(GI_res,id_res,GI_head,id_head);
            end

            % assign output
            c = c_res;
            GI = GI_res;
            id = id_res;
            
            % delete ids if they were created here
            if newIds
                id = [];
            end

        else
            % init output
            PZ_out = polyZonotope(zeros(size(c)));
    
            for i = 1:obj.numHeads
                % evaluate head
                [c_head, G_head, GI_head, E_head, id_head, ~, ~, ~] = obj.heads{i}.evaluatePolyZonotope(c, G, GI, E, id, id_, ind, ind_, options);
                
                % project output of head
                d_V = obj.heads{i}.d_V;
                W_O_proj = obj.W_O(((i-1)*d_V+1):(i*d_V), :);
                c_head = nnHelper.affineMap(c_head, W_O_proj, 0);
                G_head = nnHelper.affineMap(G_head, W_O_proj, 0);
                GI_head = nnHelper.affineMap(GI_head, W_O_proj, 0);
    
                % add to previous heads
                PZ_head = polyZonotope(c_head,G_head,GI_head,E_head,id_head);
                PZ_out = exactPlus(PZ_out,PZ_head);
                
                % used for approx error in next head
                id_ = max([id_;id_head]);
            end
    
            % read output properties
            c = PZ_out.c;
            G = PZ_out.G;
            GI = PZ_out.GI;
            E = PZ_out.E;
            id = PZ_out.id;
            id_ = max(id);
    
            % update ind, ind_
            ind = find(prod(ones(size(E))-mod(E, 2), 1) == 1);
            ind_ = setdiff(1:size(E, 2), ind);
        end
    end
end

end

% ------------------------------ END OF CODE ------------------------------
