classdef nnLayerNormalizationLayer < nnLayer
% nnLayerNormalizationLayer - class for layer normalization
%    this layer computes the modified version according to [2, Appendix E]
%
% Syntax:
%    obj = nnLayerNormalizationLayer(beta, gamma, epsilon, varargin)
%
% Inputs:
%   beta - shift parameter
%   gamma - scale parameter
%   epsilon - small constant for numerical stability
%   varargin - 
%
% Outputs:
%    obj - generated object
%
% References:
%    [1] Vaswani et al. "Attention is all you need", 2017
%    [2] Shi et al. "Robustness Verification for Transformer", 2020
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: neuralNetwork

% Authors:       Rayen Mhadhbi
% Written:       28-June-2024
% Last update:   ---
% Last revision: ---

% ------------------------------ BEGIN CODE -------------------------------

properties

    beta   
    gamma   
    epsilon
    emb_dim
end

methods
    % Constructor
    function obj = nnLayerNormalizationLayer(beta, gamma, epsilon, emb_dim, name)
        % call superclass constructor
        obj@nnLayer(name);
        
        obj.beta = double(beta);
        obj.gamma = double(gamma);
        obj.epsilon = epsilon;
        obj.emb_dim = emb_dim;
    end
    function outputSize = getOutputSize(obj, inputSize)
        outputSize = inputSize;
    end

    function [nin, nout] = getNumNeurons(obj)
        if isempty(obj.inputSize)
            nin = [];
            nout = [];
        else
            % we can only compute the number of neurons if the input
            % size was set.
            nin = prod(obj.inputSize);
            outputSize = getOutputSize(obj, obj.inputSize);
            nout = prod(outputSize);
        end
    end
end

% evaluate ----------------------------------------------------------------

methods (Access = {?nnLayer, ?neuralNetwork})

   % numeric
    function outputs = evaluateNumeric(obj, inputs, options)
        mean_val = mean(inputs, 2);        
        outputs = (inputs - mean_val); % ./ sqrt(variance + obj.epsilon);
        outputs = outputs * diag(obj.gamma) + obj.beta';
    end

   % interval
    function bounds = evaluateInterval(obj, bounds, options)
        bounds_mean = bounds * (ones(obj.emb_dim,1) / obj.emb_dim);
        bounds = (bounds - bounds_mean) * diag(obj.gamma) + obj.beta';        
    end

    % polyZonotope
    function [c, G, GI, E, id, id_, ind, ind_] = evaluatePolyZonotope(obj, c, G, GI, E, id, id_, ind, ind_, options)
        % apply layer normalization
        num_words = size(c, 1) / obj.emb_dim;

        % compute for center
        c = reshape(c, num_words, obj.emb_dim);
        c = (c - mean(c,2));
        c = reshape(c,[],1);
        c = nnHelper.affineMap(c,diag(obj.gamma),obj.beta);

        % compute for dependent generators
        G = reshape(G, num_words, obj.emb_dim, size(G,2));
        G = (G - mean(G,2));
        G = reshape(G,num_words*obj.emb_dim,[]);
        G = nnHelper.affineMap(G,diag(obj.gamma),0);

        % compute for dependent generators
        GI = reshape(GI, num_words, obj.emb_dim, size(GI,2));
        GI = (GI - mean(GI,2));
        GI = reshape(GI,num_words*obj.emb_dim,[]);
        GI = nnHelper.affineMap(GI,diag(obj.gamma),0);
     
    end
end

end

% ------------------------------ END OF CODE ------------------------------
