classdef nnFFNResidualConnectionLayer < nnLayer
% nnFFNResidualConnectionLayer - class for residual connection of
% feed forward and normalization layers in transformers
%
% Syntax:
%    obj = nnFFNResidualConnectionLayer(name,W_1, b_1, W_2, b_2, beta, gamma, epsilon, varargin)
%
% Inputs:
%    name - name of the layer, defaults to type
%    W_1 - weight matrix of first linear layer
%    b_1 - bias of first linear layer
%    W_2 - weight matrix of second linear layer
%    b_2 - bias of second linear layer
%    beta - shift parameter
%    gamma - scale parameter
%    epsilon - small constant for numerical stability
%    varargin - 
%
% Outputs:
%    obj - generated object
%
% References:
%    [1] Vaswani et al. "Attention is all you need", 2017
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: neuralNetwork

% ------------------------------ BEGIN CODE -------------------------------

properties
    linearlayer_fst
    linearlayer_snd
    relulayer
    layernorm
end


methods
    % constructor
    function obj = nnFFNResidualConnectionLayer(name,W_1, b_1, W_2, b_2, beta, gamma, epsilon, emb_dim, name_1, name_2, name_3)
        if nargin < 10
            name = [];
        end
        % call super class constructor
        obj@nnLayer(name)
        obj.linearlayer_fst = nnTNNLinearLayer(W_1, b_1, name_2);
        obj.relulayer = nnReLULayer();
        obj.linearlayer_snd = nnTNNLinearLayer(W_2, b_2, name_3);

        obj.layernorm = nnLayerNormalizationLayer(beta, gamma, epsilon, emb_dim, name_1);
    end

    function outputSize = getOutputSize(obj, inputSize)
        outputSize = 0;
    end

    function [nin, nout] = getNumNeurons(obj)
        if isempty(obj.inputSize)
            nin = [];
            nout = [];
        else
            % we can only compute the number of neurons if the input
            % size was set.
            nin = prod(obj.inputSize);
            outputSize = getOutputSize(obj, obj.inputSize);
            nout = prod(outputSize);
        end
    end

end

% evaluate ----------------------------------------------------------------

methods (Access = {?nnLayer, ?neuralNetwork})

    % sensitivity
    function S = evaluateSensitivity(obj, S, x, options)
        S = [];
    end

    % numeric
    function [r, obj] = evaluateNumeric(obj, input, options)
       lin_out = obj.linearlayer_fst.evaluateNumeric(input, options);
       relu_out = obj.relulayer.evaluateNumeric(lin_out, options);
       ffn_out = obj.linearlayer_snd.evaluateNumeric(relu_out, options);
       r = obj.layernorm.evaluateNumeric(input + ffn_out, options);
    end
    % interval
    function bounds = evaluateInterval(obj, bounds, options)
       
        bounds_lin = obj.linearlayer_fst.evaluateInterval(bounds, options);
        bounds_relu = obj.relulayer.evaluateInterval(bounds_lin, options);   
        bounds_linsnd = obj.linearlayer_snd.evaluateInterval(bounds_relu, options);
        bounds = obj.layernorm.evaluateInterval(bounds + bounds_linsnd, options);
    end

    % zonotope/polyZonotope
    function [c, G, GI, E, id, id_, ind, ind_] = evaluatePolyZonotope(obj, c, G, GI, E, id, id_, ind, ind_, options)
        
        % save input
        c_res = c;
        G_res = G;
        GI_res = GI;
        % E_res = E;
        % id_res = id;

        % prevent order reduction within relu  
        options.nn.num_generators = [];
        options.nn.max_gens_post = [];
        % add approx error to GI
        options.nn.add_approx_error_to_GI = true;

        % compute parallel layers
        [c, G, GI, E, id, id_, ind, ind_] = obj.linearlayer_fst.evaluatePolyZonotope(c, G, GI, E, id, id_, ind, ind_, options);
        [c, G, GI, E, id, id_, ind, ind_] = obj.relulayer.evaluatePolyZonotope(c, G, GI, E, id, id_, ind, ind_, options);
        [c, G, GI, E, id, id_, ind, ind_] = obj.linearlayer_snd.evaluatePolyZonotope(c, G, GI, E, id, id_, ind, ind_, options);
       
        % combine with input (same generator order; approx error at the end of GI)
        c = c_res + c;
        G = G_res + G;
        q = size(GI_res,2);
        GI = [GI_res+GI(:,1:q) GI(:,q+1:end)];

        % compute layer normalization
        [c, G, GI, E, id, id_, ind, ind_] = obj.layernorm.evaluatePolyZonotope(c, G, GI, E, id, id_, ind, ind_, options);
        
    end
    
end

end

% ------------------------------ END OF CODE ------------------------------
