classdef nnTNNLinearLayer < nnLayer
% nnTNNLinearLayer - class for linear layers
%
% Syntax:
%    obj = nnTNNLinearLayer(W, b, name)
%
% Inputs:
%    W - weight matrix
%    b - bias column vector
%    name - name of the layer, defaults to type
%
% Outputs:
%    obj - generated object
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: neuralNetwork

% ------------------------------ BEGIN CODE -------------------------------

properties
    W, b
end

methods
    % constructor
    function obj = nnTNNLinearLayer(W, b, name)
        % parse input
        % check dimensions
        if isscalar(b)
            b = b * ones(size(W, 1), 1);
        end
        if ~all(size(b, 1) == size(W, 1))
           throw(CORAerror('CORA:wrongInputInConstructor', ...
               'The dimensions of W and b should match.'));
        end
        if size(b, 2) ~= 1
           throw(CORAerror('CORA:wrongInputInConstructor', ...
               "Second input 'b' should be a column vector."));
        end

        % call super class constructor
        obj@nnLayer(name)

        obj.W = double(W);
        obj.b = double(b);
    end

    function outputSize = getOutputSize(obj, inputSize, graph)
        if nargin < 3
            nrNodes = 1;
        else
            nrNodes = graph.numnodes;
        end  
        nrOutFeatures = size(obj.W,1);
        outputSize = [nrNodes*nrOutFeatures, 1];
    end

    function [nin, nout] = getNumNeurons(obj)
        nin = size(obj.W, 2);
        nout = size(obj.W, 1);
    end
end

% evaluate ----------------------------------------------------------------

methods  (Access = {?nnLayer, ?neuralNetwork})
    
    % interval
    function bounds = evaluateInterval(obj, bounds, options)
        isPos = bounds.sup - bounds.inf;
        mu = (bounds.sup + bounds.inf)/2;
        r = (bounds.sup - bounds.inf)/2;
        mu = aux_affineMap(obj, mu, true);
        r = r * abs(obj.W');
        bounds = interval(mu - r,mu + r);
    end
    
    % numeric
    function r = evaluateNumeric(obj, input, options)
        r = aux_affineMap(obj, input, true);
    end

    % sensitivity
    function S = evaluateSensitivity(obj, S, x, options)
        Wv = kron(obj.W,speye(options.nn.graph.numnodes));
        S = S * Wv;
        S = full(S);
    end

    % zonotope/polyZonotope
    function [c, G, GI, E, id, id_, ind, ind_] = evaluatePolyZonotope(obj, c, G, GI, E, id, id_, ind, ind_, options)
        c = nnHelper.affineMap(c, obj.W', obj.b);
        G = nnHelper.affineMap(G, obj.W', 0);
        GI = nnHelper.affineMap(GI, obj.W', 0);
        
    end
end

% Auxiliary functions -----------------------------------------------------

methods(Access=protected)
    function output = aux_affineMap(obj, input, addBias)
            
        
        bias = repmat(obj.b', size(input,1), 1);
        output = input * obj.W' + bias;
        
    end
end

end

% ------------------------------ END OF CODE ------------------------------
