function [nn_red, pZ] = computeReducedNetwork(obj, pZ, varargin)
% computeReducedNetwork - computes a reduces network by merging similar
%   neurons based on the given input
%
% Syntax:
%   [nn_red, pZ] = neuralNetwork.computeReducedNetwork(obj, pZ)
%   [nn_red, pZ] = neuralNetwork.computeReducedNetwork(obj, pZ, verbose, method, tols)
%
% Inputs:
%    obj - neuralNetwork
%    pZ - polyZonotope
%    verbose - whether additional informations should be printed
%    method - method to create merge buckets: 'static', 'dynamic' 
%    tols - tolerance vector
%
% Outputs:
%    nn_red - reduced neuralNetwork
%    pZ - output polyZonotope
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: -

% Authors:       Tobias Ladner
% Written:       20-December-2022
% Last update:   09-January-2023
% Last revision: ---

% ------------------------------ BEGIN CODE -------------------------------

if nargin < 2
    throw(CORAerror("CORA:notEnoughInputArgs", 2))
elseif nargin > 5
    throw(CORAerror("CORA:tooManyInputArgs", 5))
end

[verbose, method, tols] = setDefaultValues( ...
    {false, 'static', [1e-3, 1e-2]}, varargin);

inputArgsCheck({ ...
    {pZ, 'att', 'polyZonotope'}, ...
    {verbose, 'att', 'logical'}, ...
    {method, 'str', {'static', 'dynamic'}}, ...
    {tols, 'att', 'numeric', 'row'} ...
});
tols = sort(tols); % process smaller tolerances first

% bring network in normal form
nn = obj.getNormalForm();
% save number of neurons of original network
numNeurons = nn.getNumNeurons();

% parameters 
evParams = struct();
evParams.poly_method = "regression";
evParams.bound_approx = true;
evParams.num_generators = 50000;
evParams.propagate_bounds = false;
evParams.do_pre_order_reduction = false;
evParams.remove_Grest = false;
evParams.add_approx_error_to_Grest = true;

% compute bounds of input
bounds = interval(pZ);

doPlot = false;

% assuming alternating linear and nonlinear layers
for k = 2:2:(length(nn.layers)-1)
    % extract layers
    nnLinIn = nn.layers{k-1};
    nnAct = nn.layers{k};
    nnLinOut = nn.layers{k+1};
    
    % propagate bounds to next activation i
    evParams.reuse_bounds = false;
    bounds_pre = nn.evaluate(bounds, evParams, k-1:k);
    n = dim(bounds_pre);
    
    if doPlot
        figure; hold on;
        histogram(bounds_pre.inf, 100)
        histogram(bounds_pre.sup, 100)
    end

    % determine buckets
    buckets = [];
    if strcmp(method, 'static')
        buckets = nnAct.getMergeBuckets();
    elseif strcmp(method, 'dynamic')
        buckets = center(bounds_pre)';
    end

    idx = false(1, n);
    bounds_pre_inf = bounds_pre.inf;
    bounds_pre_sup = bounds_pre.sup;
    M_merged = zeros(0, length(buckets));

    for t = 1:length(tols)
        % bucket bounds
        bInf = buckets-tols(t);
        bSup = buckets+tols(t);

        % compute containment and filter only first belonging
        M_merged_t = (bInf <= bounds_pre_inf) & (bounds_pre_sup <= bSup);
        M_merged_t = (cumsum(M_merged_t, 2) == 1) & (M_merged_t == 1);
        
        % only select buckets with more than one containments
        idx_b = sum(M_merged_t, 1) > 1;
        M_merged_t = M_merged_t(:, idx_b);
        
        % store results
        M_merged_t = M_merged_t';
        M_merged(end+1:end+size(M_merged_t, 1), ~idx) = M_merged_t;

        % remove chosen neurons
        idx_n = any(M_merged_t, 1);
        if any(idx_n)
            if all(idx_n)
                idx(~idx) = true;
                break
            end

            bounds_pre_inf = bounds_pre_inf(~idx_n);
            bounds_pre_sup = bounds_pre_sup(~idx_n);
            idx(~idx) = idx_n;
        end
    end

    % count number of merged neurons
    num_merged = size(M_merged, 1);

    % init merge matrix
    M_unmerged = diag(sparse(~idx)); % keep un-merged neurons
    M_unmerged = M_unmerged(any(M_unmerged, 2), :); % delete zero rows
    
    % merge 'input' weight matrix
    W1 = nnLinIn.W;
    b1 = nnLinIn.b;
    W1m = M_unmerged * W1;
    b1m = M_unmerged * b1;
    
    % merge 'output' weight matrix
    W2 = nnLinOut.W;
    b2 = nnLinOut.b;
    W2m = W2 * M_unmerged'; % sum
    b2m = b2;  % unchanged!

    % init linear layers \widehat{L}_{k-1}, \widehat{L}_{k+1}
    nnLinInNew = nnLinearLayer(full(W1m), full(b1m));
    nnLinOutNew = nnLinearLayer(full(W2m), b2m);

    % compute approx error
    if num_merged > 0
        % select bounds of merged neurons
        approx_error = bounds_pre .* idx';
        % propagate forward
        nnLinOutNew.d = W2*approx_error;
    end

    if ~isempty(nnLinOut.d)
        % add approx error from previous reduction
        if isempty(nnLinOutNew.d)
            nnLinOutNew.d = nnLinOut.d;
        else
            nnLinOutNew.d = nnLinOutNew.d + nnLinOut.d;
        end
    end

    % keep old approx error in L_{k+1} for unmerged dimensions
    if ~isempty(nnLinIn.d)
        nnLinInNew.d = nnLinIn.d(~idx);
    end
     
    % update layers
    nn.layers{k-1} = nnLinInNew;     % \widehat{L}_{k-1}
    nn.layers{k} = nnAct;            % \widehat{L}_{k}
    nn.layers{k+1} = nnLinOutNew;    % \widehat{L}_{k+1}
    
    % update bounds for next iteration
    evParams.reuse_bounds = true;
    bounds = nn.evaluate(bounds, evParams, k-1:k);
    
    % evaluate pZ
    pZ = nn.evaluate(pZ, evParams, k-1:k);
    bounds = and_(bounds, interval(pZ), 'exact'); % update bounds
end

% propagate through output layers
pZ = nn.evaluate(pZ, evParams, k+1:length(nn.layers));

% compute reduction rate
numNeuronsRed = nn.getNumNeurons();
rate = sum(numNeuronsRed(2:2:end-1))/sum(numNeurons(2:2:end-1));

if verbose
    % display resulting number of neurons
    disp([numNeurons; numNeuronsRed])
    fprintf("Remaining neurons: %.2f%%\n", rate*100)
end

% sanity check
% N = 500;
% xs = pZ.randPoint(N);
% ys = nn.evaluate(xs);
% bounds = nn.evaluate(bounds, evParams, 13);
% 
% res = false(1, N);
% for i=1:N
%     res(i) = bounds.contains(ys(:, i));
% end
% disp(all(res))

if doPlot
    figure; hold on;
    histogram(bounds.inf, 100)
    histogram(bounds.sup, 100)
end

nn_red = nn;

end

% Auxiliary functions -----------------------------------------------------


% ------------------------------ END OF CODE ------------------------------
