classdef nnTNNGlobalAveragePoolingLayer < nnLayer
% nnTNNGlobalAveragePoolingLayer - class for global average pooling layer
% for transformers
%
% Syntax:
%    obj = nnTNNGlobalAveragePoolingLayer(emb_dim,name)
%
% Inputs:
%    emb_dim - word embedding size
%    name - name of the layer, defaults to type
%
% Outputs:
%    obj - generated object
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: -

% ------------------------------ BEGIN CODE -------------------------------

properties 
    emb_dim
end

methods
    % constructor
    function obj = nnTNNGlobalAveragePoolingLayer(emb_dim, name)
        % parse input
        if nargin > 2
            throw(CORAerror('CORA:tooManyInputArgs', 2))
        end
       
        % call super class constructor
        obj@nnLayer(name)
        obj.emb_dim = emb_dim;
    end

    function outputSize = getOutputSize(obj, inputSize)
        outputSize = 0;
    end

    function [nin, nout] = getNumNeurons(obj)
        if isempty(obj.inputSize)
            nin = [];
            nout = [];
        else
            % we can only compute the number of neurons if the input
            % size was set.
            nin = prod(obj.inputSize);
            outputSize = getOutputSize(obj, obj.inputSize);
            nout = prod(outputSize);
        end
    end

end

% evaluate ----------------------------------------------------------------

methods(Access = {?nnLayer, ?neuralNetwork})

    % numeric
    function r = evaluateNumeric(obj, input, options)
         seq_len = size(input, 1);
        % construct matrix
         M = ones(1,seq_len) * 1/seq_len;

        % propagate
        r = aux_vecleftmtimes(obj,M,input);
    end

    % sensitivity
    function S = evaluateSensitivity(obj, S, x, options)
 
        % init
        seq_len = numel(x)/obj.emb_dim;

        % construct matrix
        M = kron(speye(obj.emb_dim),ones(1, seq_len));
        M = M * 1/seq_len;

        % output
        S = S * M;
    end
    
    % interval
    function bounds = evaluateInterval(obj, bounds, options)
        inf = obj.evaluateNumeric(bounds.inf, options);
        sup = obj.evaluateNumeric(bounds.sup, options);
        bounds = interval(inf, sup);
    end

    % zonotope/polyZonotope
    function [c, G, GI, E, id, id_, ind, ind_] = evaluatePolyZonotope(obj, c, G, GI, E, id, id_, ind, ind_, options)
        
        seq_len = size(c,1)/obj.emb_dim;
        % construct matrix
        M = ones(1,seq_len) * 1/seq_len;

        % propagate
        c = aux_vecleftmtimes(obj,M,c);
        G = aux_vecleftmtimes(obj,M,G);

        if ~isempty(GI)
          GI = aux_vecleftmtimes(obj,M,GI);
        end
       
    end
end

% Auxiliary functions -----------------------------------------------------

methods (Access = protected)

    function v = aux_vecleftmtimes(obj,A,v)
     
        % init dimensions
        [n,k] = size(A);
        [km,h] = size(v);
        m = km/k;
        
        % reshape vectorized matrix back to matrix
        M = reshape(v,k,m,h);
        
        % do matrix multiplication (including broadcasting)
        M = pagemtimes(A,M);
        
        % reshape back to vector
        v = reshape(M,n*m,h);
        
    end
end

end

% ------------------------------ END OF CODE ------------------------------
