classdef nnTNNResidualConnectionLayer < nnLayer
% nnTNNResidualConnectionLayer - class for residual connection of
% attention and normalization layers in transformers
%
% Syntax:
%    obj = nnTNNResidualConnectionLayer(W_Q, W_K, W_V, W_O, beta, gamma, epsilon, varargin1, varargin2, name)
%
% Inputs:
%    name - name of the layer, defaults to type
%    W_Q - query weight matrix for all attention heads stacked hor
%    W_K - key weight matrix for all attention heads stacked hor
%    W_V - value weight matrix for all attention heads stacked hor
%    W_O - linear projection matrix
%    beta - shift parameter
%    gamma - scale parameter
%    epsilon - small constant for numerical stability
%    varargin1 -
%    varargin2 -
%
% Outputs:
%    obj - generated object
%
% References:
%    [1] Vaswani et al. "Attention is all you need", 2017
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: neuralNetwork

% ------------------------------ BEGIN CODE -------------------------------

properties
    layerMHA
    layernorm
end


methods
    % constructor
    function obj = nnTNNResidualConnectionLayer(W_Q, W_K, W_V, W_O, num_heads, beta, gamma, epsilon, emb_dim, name_1, name_2,varargin)
        % parse input
        [name] = setDefaultValues({[]}, varargin);

        % call super class constructor
        obj@nnLayer(name);
        obj.layerMHA = nnMultiHeadAttentionLayer(W_Q, W_K, W_V, W_O, num_heads,name_1);
        obj.layernorm = nnLayerNormalizationLayer(beta, gamma, epsilon, emb_dim, name_2);
    end

    function outputSize = getOutputSize(obj, inputSize)
        outputSize = 0;
    end

    function [nin, nout] = getNumNeurons(obj)
        if isempty(obj.inputSize)
            nin = [];
            nout = [];
        else
            % we can only compute the number of neurons if the input
            % size was set.
            nin = prod(obj.inputSize);
            outputSize = getOutputSize(obj, obj.inputSize);
            nout = prod(outputSize);
        end
    end

end

% evaluate ----------------------------------------------------------------

methods (Access = {?nnLayer, ?neuralNetwork})
    
    % numeric
    function [r, obj] = evaluateNumeric(obj, input, options)
       r = obj.layerMHA.evaluateNumeric(input, options);
       r = input + r;
       r = obj.layernorm.evaluateNumeric(r, options);
    end

    % sensitivity
    function S = evaluateSensitivity(obj, S, x, options)
        S = [];
    end
    
    % interval
    function bounds = evaluateInterval(obj, bounds, options)
        att_bounds = obj.layerMHA.evaluateInterval(bounds, options);
        bounds = bounds + att_bounds;
        bounds = obj.layernorm.evaluateInterval(bounds, options);
    end

    % zonotope/polyZonotope
    function [c, G, GI, E, id, id_, ind, ind_] = evaluatePolyZonotope(obj, c, G, GI, E, id, id_, ind, ind_, options)
        if strcmp(options.nn.transformer_approach, "zonotope")
            assert(isempty(G),'zonotope approach should only uses independent generators')
            % use identifier vector for independent generators as labels
            % for parallel computations
            id_pre = id_ + (1:size(GI,2));
            id_ = max([id_,id]);

            % order reduction pre
            [c, GI, id_pre] = nnHelper.reduceLabelledZono(c, GI, id_pre, options.nn.num_generators);

            % compute multi-head attention
            [c_MHA, ~, GI_MHA, ~, id_MHA, ~, ~, ~] = obj.layerMHA.evaluatePolyZonotope(c, G, GI, E, id_pre, id_, ind, ind_, options);

            % residual connection
            c = c+c_MHA;
            [GI,id] = nnHelper.mergeLabelledGenerators(GI,id_pre,GI_MHA,id_MHA);

            % order reduction post
            [c, GI, ~] = nnHelper.reduceLabelledZono(c, GI, id, options.nn.num_generators);

            % reset id
            id = [];
            id_ = 1;
            
        else
            % save input polyZonotope for residual connection
            PZ = polyZonotope(c, G, GI, E, id);

            % order reduction pre
            [c, G, GI, E, id, id_, ind, ind_] = aux_orderReductionPolyZono(c, G, GI, E, id, id_, options);
            
            % compute multi-head attention
            [c, G, GI, E, id, id_, ~, ~] = obj.layerMHA.evaluatePolyZonotope(c, G, GI, E, id, id_, ind, ind_, options);
            PZ_MHA = polyZonotope(c, G, GI, E, id);
          
            % residual connection
            res = exactPlus(PZ,PZ_MHA);
            c = res.c; G = res.G; GI = res.GI; E = res.E; id = res.id;

            % order reduction post
            [c, G, GI, E, id, id_, ind, ind_] = aux_orderReductionPolyZono(c, G, GI, E, id, id_, options);

        end
    
        % apply layer normalization
        [c, G, GI, E, id, id_, ind, ind_] = obj.layernorm.evaluatePolyZonotope(c, G, GI, E, id, id_, ind, ind_, options);
    end  
end

end


% Auxiliary functions -----------------------------------------------------

function [c,G,GI,E,id,id_,ind,ind_] = aux_orderReductionPolyZono(c,G,GI,E,id,id_,options)
    % reduce polyZonotope
    [c, G, GI, E, id, d] = nnHelper.reducePolyZono(c, G, GI, E, id, options.nn.num_generators);

    % add approx error
    D = diag(d);
    D = D(:,d > 0);
    G = [G,GI,D];
    numIndGens = size(GI,2)+size(D,2);
    GI = zeros(numel(c),0);
    E = blkdiag(E,eye(numIndGens));
    id = [id; id_ + (1:numIndGens)'];

    % set properties
    id_ = max([id_;id]);
    ind = find(prod(ones(size(E))-mod(E, 2), 1) == 1);
    ind_ = setdiff(1:size(E, 2), ind);

end

% ------------------------------ END OF CODE ------------------------------
