classdef nnTransformerEmbeddingLayer < nnLayer
% nnTransformerEmbeddingLayer - class for transform embedding layers
%
% Syntax:
%    obj = nnTransformerEmbeddingLayer(token_emb, pos_emb)
%
% Inputs:
%    token_emb - token embedding matrix
%    pos_emb - position embedding matrix
%
% Outputs:
%    obj - generated object
%
% References:
%    [1] Vaswani et al. "Attention is all you need", 2017
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: neuralNetwork

% ------------------------------ BEGIN CODE -------------------------------

properties
    token_emb % token embedding matrix
    pos_emb % position embedding matrix
end

methods
    % constructor
    function obj = nnTransformerEmbeddingLayer(token_emb, pos_emb)
        % parse input
        inputArgsCheck({ ...
            {token_emb, 'att', 'numeric'}; ...
            {pos_emb, 'att', 'numeric'}; ...
        });

        obj.token_emb = token_emb;
        obj.pos_emb = pos_emb;
    end

    function outputSize = getOutputSize(obj, inputSize)
        outputSize = [size(obj.token_emb, 2), inputSize(2)];
    end


    function [nin, nout] = getNumNeurons(obj)
        if isempty(obj.inputSize)
            nin = [];
            nout = [];
        else
            % we can only compute the number of neurons if the input
            % size was set.
            nin = prod(obj.inputSize);
            outputSize = getOutputSize(obj, obj.inputSize);
            nout = prod(outputSize);
        end
    end

end

% evaluate ----------------------------------------------------------------

methods (Access = {?nnLayer, ?neuralNetwork})

    % numeric
    function r = evaluateNumeric(obj, input, options)
        % compute token embeddings by retrieving rows from embedding matrix
        % corresponding to the token ids
        token_embs = obj.token_emb(input + 1, :);

        % compute position embeddings
        maxlen = size(input, 2);
        positions = 0:maxlen - 1;
        pos_embs = obj.pos_emb(positions+1, :);

        % add token and position embeddings
        r = token_embs + pos_embs;
    end

    function varargout = evaluateInterval(obj, varargin)
        throw(CORAerror('CORA:nnLayerNotSupported', obj, 'There cannot be uncertainty to the input of an embedding layer.'))
    end


    function varargout = evaluatePolyZonotope(obj, varargin)
        throw(CORAerror('CORA:nnLayerNotSupported', obj, 'There cannot be uncertainty to the input of an embedding layer.'))
    end
end

end

% ------------------------------ END OF CODE ------------------------------
