classdef nnSoftmaxLayer < nnActivationLayer
% nnSoftmaxLayer - class for softmax layers
%
% Syntax:
%    obj = nnSoftmaxLayer(name)
%
% Inputs:
%    name - name of the layer, defaults to type
%
% Outputs:
%    obj - generated object
%
% References:
%    [1] Bonaert et al. "Fast and Precise Certification of Transformers",
%    2021
%    [2] Dennis et al. "Convex Bounds on the Softmax Function with
%    Applications to Robustness Verification", 2023
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: neuralNetwork

% Authors:       Rayen Mhadhbi, Tobias Ladner
% Written:       20-June-2024
% Last update:   23-September-2024 (TL, optimizations as order=1)
% Last revision: ---

% ------------------------------ BEGIN CODE -------------------------------

properties
    expLayer
    recipLayer
end

methods
    % constructor
    function obj = nnSoftmaxLayer(name)
        if nargin < 2
            name = [];
        end
        % call super class constructor
        obj@nnActivationLayer(name)

        obj.expLayer = nnExpLayer();
        obj.recipLayer = nnReciprocalLayer();
    end

    %get i-th derivative
    function df_i = getDf(obj, i)
        function r = deriv(x)
            sx = exp(x-max(x)) ./ sum(exp(x-max(x)));
            sx = permute(sx, [1, 3, 2]);
            % compute Jacobian of softmax
            J = pagemtimes(-sx, 'none', sx, 'transpose') + sx .* eye(size(x, 1));
            r = reshape(pagemtimes(J, permute(x, [1, 3, 2])), size(x));
        end

        df_i = @(x) deriv(x);
    end


    function der1 = getDerBounds(obj, l, u)
        % df_l and df_u as lower and upper bound for the derivative
        der1 = interval(0, 1);
    end

end

% evaluate ----------------------------------------------------------------

methods (Access = {?nnLayer, ?neuralNetwork})

    % numeric
    function [r, obj] = evaluateNumeric(obj, input, options)
        % avoid numerical issues see [2, Chp. 4]
        if size(input, 1) == 1
            input = input';
        end
        input = input - max(input,[],1);
        z = exp(input);
        n = sum(exp(input), 1);
        r = z ./ n;
    end

    % sensitivity
    function S = evaluateSensitivity(obj, S, x, options)
        sx = permute(obj.evaluateNumeric(x, options), [1, 3, 2]);
        % compute Jacobian of softmax
        J = pagemtimes(-sx, 'none', sx, 'transpose') + sx .* eye(size(x, 1));
        S = S * J;
    end

    % interval
    function bounds = evaluateInterval(obj, bounds, options)
        % interval arithmetic bounds
        lb = bounds.inf;
        ub = bounds.sup; 
       
        lb = lb(:);
        ub = ub(:);
        
        % Compute diagonal elements of lb and ub
        lb_diag = diag(lb);
        ub_diag = diag(ub); 
        
        % Compute non-diagonal parts for softmax
        lbs = lb_diag + ub - ub_diag; 
        ubs = ub_diag + lb - lb_diag;  
        
        % Apply the softmax function
        ubs_softmax = obj.evaluateNumeric(ubs, options);
        lbs_softmax = obj.evaluateNumeric(lbs, options);
        
        inf = diag(lbs_softmax);
        sup = diag(ubs_softmax);
        
        bounds = interval(inf(:), sup(:));
            
    end

    % zonotope/polyZonotope
    function [c, G, GI, E, id, id_, ind, ind_] = evaluatePolyZonotope(obj, c, G, GI, E, id, id_, ind, ind_, options)
        % computing y_i = 1 / \sum_j exp(x_j - x_i) as in [1, Sec. 3.1]

        % check order
        order = max(obj.order);
        if order > 1
            throw(CORAerror("CORA:wrongValue", 'nnModifiedSoftmaxLayer only supports order 1.'))
        end

        % avoid numerical issues
        c = c - max(c);

        % initialization
        num_neurons = size(c, 1);
        c_ = zeros(num_neurons, 1);
        G_ = zeros(size(G));
        GI_ = zeros(size(GI));

        % cell array to store additional generators
        d = zeros(1, num_neurons);

        % main loop over all neurons in the current layer

        for i = 1:num_neurons
            options.nn.neuron_i = i;
            [c_(i), G_i, GI_i, d(i)] = ...
                obj.evaluatePolyZonotopeNeuronSoftmax(i, c(i), G(i, :), GI(i, :), E, 1, ind, ind_, c, G, GI, options);

            % stack the dependent generators vertically
            G_(i, 1:length(G_i)) = G_i;

            % stack the independent generators vertically
            GI_(i, 1:length(GI_i)) = GI_i;

        end

        % update properties
        c = c_;
        G = G_;
        GI = GI_;

        % add approximation error
        D = diag(d);
        D = D(:, ~all(D == 0,1));
        if options.nn.add_approx_error_to_GI
            GI = [GI, D];
        else
            G = [G, D];
            sz = size(D, 2);
            E = blkdiag(E, eye(sz));
            id = [id; 1 + (1:sz)' * id_];
            id_ = max(id);
            ind = find(prod(ones(size(E))-mod(E, 2), 1) == 1);
            ind_ = setdiff(1:size(E, 2), ind);
        end
     
    end

    % zonotope/polyZonotope neuron
    function [c_out, G_out, GI_out, d] = evaluatePolyZonotopeNeuronSoftmax(obj, i, c_i, G_i, GI_i, E, order, ind, ind_, c_all, G_all, GI_all, options)

        num_neurons = size(G_all, 1);

        % initializationn
        c_exp = zeros(num_neurons, 1);
        G_exp = zeros(size(G_all));
        GI_exp = zeros(size(GI_all));
        d_exp = zeros(num_neurons, 1);

        % set poly_method to singh for enclosure with nonnegative output
        options.nn.poly_method = 'singh';
        options.nn.num_generators = [];
        options.nn.max_gens_post = [];

        % Loop over all neurons for the exponential calculation
        for j = 1:num_neurons
            if j ~= i
                c_diff = c_all(j) - c_i;
                G_diff = G_all(j, :) - G_i;
                GI_diff = GI_all(j, :) - GI_i;

                % Evaluate exponential of the difference using expLayer
                [c_exp(j), G_exp_j, GI_exp_j, d_exp(j)] = ...
                    obj.expLayer.evaluatePolyZonotopeNeuron(c_diff, G_diff, GI_diff, E, E, order, ind, ind_, options);

                % stack the generators vertically
                G_exp(j, 1:length(G_exp_j)) = G_exp_j;
                GI_exp(j, 1:length(GI_exp_j)) = GI_exp_j;
            else
                % exp(0) = 1, will be added later
            end
        end

        % sum dimensions
        c_sum = sum(c_exp, 1) + 1; % +1 due to exp(0) = 1 for i=j
        G_sum = sum(G_exp, 1);
        GI_sum = sum(GI_exp, 1);
        d_sum = sum(d_exp,1);

        % temporarily add approx error
        GI_sum = [GI_sum, d_sum];
 
        options.nn.tnn_layer = "reciprocal";
        % Evaluate reciprocal of the sum using recipLayer
        [c_out, G_out, GI_out, d_out] = ...
            obj.recipLayer.evaluatePolyZonotopeNeuron(c_sum, G_sum, GI_sum, E, E, order, ind, ind_, options);

        % separate approx error obtained from exp again
        d = abs(GI_out(:,end))+d_out;
        GI_out = GI_out(:,1:end-1);

    end
end

end

% ------------------------------ END OF CODE ------------------------------
