function [res, x_, y_] = verify(nn, x, r, A, b, safeSet, varargin)
% verify - automated verification for specification on neural networks.
%
% Syntax:
%    [res, z] = nn.verify(x, r, A, b, options)
%
% Inputs:
%    nn - object of class neuralNetwork
%    x, r - center and radius of the initial set (can already be a batch)
%    A, b - specification, prove A*y + b <= 0
%    safeSet - bool, safe-set or unsafe-set
%    options - evaluation options
%    timeout - timeout in seconds
%    verbose - print verbose output
%    plotDims - 2x2 plot dimensions; empty for no plotting; 
%           plotDims(1,:) for input and plotDims(2,:) for output; sets 
%           are stored in res.Xs, res.uXs
%
% Outputs:
%    res - result: true if specification is satisfied, false if not, empty if unknown
%    x_ - counterexample in terms of an initial point violating the specs
%    y_ - output for x_
%
% References:
%    [1] VNN-COMP'24
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: -

% Authors:       ---
% Written:       23-November-2021
% Last update:   30-November-2022 (TL, removed neuralNetworkOld, adaptive)
%                25-July-2023 (TL, input parsing, improvements)
%                23-November-2023 (TL, verbose, bug fix)
%                14-June-2024 (LK, rewritten with efficient splitting)
%                20-January-2025 (LK, constraint zonotope splitting)
% Last revision: ---

% ------------------------------ BEGIN CODE -------------------------------

% Check number of input arguments.
narginchk(6,10);

% Validate parameters.
[options, timeout, verbose, plotDims] = ...
    setDefaultValues({struct, 100, false, []}, varargin);
plotting = ~isempty(plotDims);

% Validate parameters.
inputArgsCheck({ ...
    {nn, 'att','neuralNetwork'}; ...
    {x, 'att',{'numeric','gpuArray'}}; ...
    {r, 'att',{'numeric','gpuArray'}}; ...
    {A, 'att',{'numeric','gpuArray'}}; ...
    {b, 'att',{'numeric','gpuArray'}}; ...
    {options,'att','struct'}; ...
    {timeout,'att','numeric','scalar'}; ...
    {verbose,'att','logical'}; ...
    {plotting,'att','logical'}; ...
})
options = nnHelper.validateNNoptions(options,true);

nSplits = options.nn.num_splits; % Number of input splits per dimension.
nDims = options.nn.num_dimensions; % Number of input dimension to splits.
nNeur = options.nn.num_neuron_splits; % Number of neurons to split.
% Add x >= 0 constraints to tighten the bounds.
nReluTightConstr = options.nn.num_relu_tighten_constraints; 

% Extract parameters.
bSz = options.nn.train.mini_batch_size;

% Obtain number of input dimensions.
n0 = size(x,1);
% Limit the number of dimensions to split.
nDims = min(nDims,n0);
% Check the maximum number of input generators.
numInitGens = min(n0,options.nn.train.num_init_gens);
% Obtain the number of approximation error generators per layer.
numApproxErrGens = options.nn.train.num_approx_err;
% Obtain the maximum number of approximation errors in an activation layer.
nk_max = max(cellfun(@(li) ...
    isa(li,'nnActivationLayer')*prod(li.getOutputSize(li.inputSize)), ...
    nn.layers) ...
);
% We always have to use the approximation during set propagation to ensure
% soundness.
options.nn.use_approx_error = true;
% Ensure the interval-center flag is set, if there are less generators than
% input dimensions.
options.nn.interval_center = ...
    (numApproxErrGens < nk_max) | (numInitGens < n0);

% To speed up computations and reduce gpu memory, we only use single 
% precision.
inputDataClass = single(1);
% Check if a gpu is used during training.
useGpu = options.nn.train.use_gpu;
if useGpu
    % Training data is also moved to gpu.
    inputDataClass = gpuArray(inputDataClass);
end
% (potentially) move weights of the network to gpu.
nn.castWeights(inputDataClass);

% Specify indices of layers for propagation.
idxLayer = 1:length(nn.layers);

% In each layer, store ids of active generators and identity matrices 
% for fast adding of approximation errors.
numGen = nn.prepareForZonoBatchEval(x,options,idxLayer);
% Allocate generators for initial perturbation set.
batchG = zeros([n0 numGen bSz],'like',inputDataClass);

% Initialize queue.
xs = x;
rs = r;
% Initialize result.
res.str = 'UNKNOWN';
x_ = [];
y_ = [];

% Initialize iteration stats.
numVerified = 0;
% Initialize iteration counter.
iter = 1;

if plotting
    % Initialize cell arrays to store intermediate results for plotting.
    res.Xs = {};
    res.Ys = {};
    res.xs_ = {};
    res.ys_ = {};

    if strcmp(options.nn.split_refinement_method,'zonotack')
        % Unsafe sets are only required for the 'zonotack'.
        res.uXs = {};
        res.uYs = {};
    end
    
    % Compute samples.
    sampX = randPoint(interval(x - r,x + r),1000);
    sampY = gather(double(nn.evaluate(sampX)));

    % Create a new figure.
    fig = figure;
    % Initialize plot.
    [fig,hx0,hspec] = aux_initPlot(fig,plotDims, ...
        sampX,sampY,x,r,A,b,safeSet);
    drawnow;
end

if verbose
    % Setup table.
    table = CORAtable('double', ...
        {'Iteration','#Queue','#Verified','Avg. radius','Unknown Vol.'}, ...
        {'d','d','d','.5f','.5f'});
    table.printHeader();
end

tic

% Obtain initial split parameters.
initSplits = options.nn.init_split(1);
initDims = options.nn.init_split(2);
% Apply initial split.
if initSplits > 1
    % TODO: handle sensitivity.
    for i=1:initDims
        [xs,rs,~] = aux_split(xs,rs,ones(size(rs)),initSplits);
    end
end

% Main splitting loop.
while size(xs,2) > 0

    time = toc;
    if time > timeout
        % Time is up.
        res.time = time;
        break;
    end

    if verbose
        % Compute iteration stats.
        queueLen = size(xs,2);
        avgRad = mean(rs,'all');
        unknVol = sum(prod(2*rs,1),'all');
        % Print new table row.
        table.printContentRow({iter,queueLen,numVerified,avgRad,unknVol});
    end

    % Pop next batch from the queue.
    [xi,ri,xs,rs] = aux_pop(xs,rs,bSz);
    % Move the batch to the GPU.
    xi = cast(xi,'like',inputDataClass);
    ri = cast(ri,'like',inputDataClass);

    % Obtain the current batch size.
    [~,cbSz] = size(xi);

    % Compute the sensitivity (store sensitivity for neuron-splitting).
    % The sensitivity is used for selecting input generators, neuron
    % -splitting, and FGSM attacks.
    storeSensitivity =  (nNeur > 0) ...
        | strcmp(options.nn.approx_error_order,'sensitivity*length');
    [S,~] = nn.calcSensitivity(xi,options,storeSensitivity);
    % The sensitivity should not be lower than 1e-3, otherwise it is too 
    % low to be effective for the (neuron-) splitting heuristic.
    S = max(abs(S),1e-3);
    sens = permute(sum(S),[2 1 3]);
    sens = sens(:,:);
    % TODO: investigate a more efficient implementation of the sensitivity
    % computation using backpropagation.

    % 1. Verification -----------------------------------------------------
    % 1.1. Use batch-evaluation of zonotopes.

    % Construct input zonotope. 
    [cxi,Gxi,~] = aux_constructInputZonotope(xi,ri,sens,batchG, ...
        numInitGens,options);

    if nNeur > 0 || nReluTightConstr > 0
        % Store inputs by enabling backpropagation. The inputs are needed 
        % for the neuron splitting.
        options.nn.train.backprop = true;
    end
    % Compute output enclosure.
    [yi,Gyi] = nn.evaluateZonotopeBatch_(cxi,Gxi,options,idxLayer);
    % Disable backpropagation.
    options.nn.train.backprop = false;
    % Obtain number of output dimensions.
    [nK,~] = size(yi);
    if options.nn.interval_center
        % Compute the center and the radius of the center-interval.
        yic = reshape(1/2*(yi(:,2,:) + yi(:,1,:)),[nK cbSz]);
        % Compute approximation error.
        yid = 1/2*(yi(:,2,:) - yi(:,1,:));
    else
        % The center is just a vector.
        yic = yi;
        % There are no approximation errors stored in the center.
        yid = zeros([nK 1 cbSz],'like',yi);
    end
    % 2.2. Compute logit difference.
    ld_yi = A*yic + b;
    % Compute the logit difference of the input generators.
    ld_Gyi = pagemtimes(A,Gyi(:,1:numInitGens,:));
    % Compute logit difference of the approximation errors.
    ld_Gyi_err = sum(abs(pagemtimes(A,Gyi(:,(numInitGens+1):end,:))),2) ...
         + sum(abs(A.*permute(yid,[2 1 3])),2);
    % Compute the radius of the logit difference.
    ld_ri = sum(abs(ld_Gyi),2) + ld_Gyi_err;
    % 2.3. Check specification.
    if safeSet
        checkSpecs = any(ld_yi + ld_ri(:,:) > 0,1);
    else
        checkSpecs = all(ld_yi - ld_ri(:,:) < 0,1);
    end
    unknown = checkSpecs;
    % Update counter for verified patches.
    numVerified = numVerified + sum(~unknown,'all');

    if plotting
        % % Reset the figure.
        % clf(fig);
        % [fig,hx0,hspec] = aux_initPlot(fig,plotDims,x,r,A,b,safeSet);
        % Store input sets.
        if options.nn.interval_center
            xid = 1/2*(cxi(:,2,:) - cxi(:,1,:));
            res.Xs{end+1} = struct( ...
                'c',gather(xi),'G',gather(cat(2,Gxi,xid.*eye(n0))),...
                'verified',gather(~unknown) ...
            );
            % Store the output set.
            res.Ys{end+1} = struct( ...
                'c',gather(yic),'G',gather(cat(2,Gyi,yid.*eye(nK))) ...
            );
        else
            res.Xs{end+1} = struct('c',gather(cxi),'G',gather(Gxi),...
                'verified',gather(~unknown));
            % Store the output set.
            res.Ys{end+1} = struct('c',gather(yic),'G',gather(Gyi));
        end
        % Plot current input sets and propagated output sets.
        [fig,hxi,hx,hxv,hy,hyv] = aux_plotInputAndOutputSets(fig, ...
            plotDims,x,r,res);
        drawnow;
    end

    % Only keep un-verified patches.
    xi(:,~unknown) = [];
    ri(:,~unknown) = [];
    sens(:,~unknown) = [];
    % Update the current batch size.
    [~,cbSz] = size(xi);

    if options.nn.interval_center
        cxi(:,:,~unknown) = [];
    else
        cxi(:,~unknown) = [];
    end
    Gxi(:,:,~unknown) = [];
    yic(:,~unknown) = [];
    Gyi(:,:,~unknown) = [];
    yid(:,:,~unknown) = [];
    ld_yi(:,~unknown) = [];
    ld_Gyi(:,:,~unknown) = [];
    ld_Gyi_err(:,:,~unknown) = [];

    % 2. Falsification ----------------------------------------------------

    % 2.1. Compute adversarial examples.
    switch options.nn.falsification_method
        case 'fgsm'
            % Try to falsification with a FGSM attack.
            
            % sens = permute(pagemtimes(A,S),[2 1 3]);
            % if safeSet
            %     % Maximize A*yi + b.
            %     [~,constrIdx] = max(sum(sens,1),[],2);
            % else
            %     % Minimize A*yi + b.
            %     [~,constrIdx] = min(sum(sens,1),[],2);
            % end
            % sens = sens(:,constrIdx);
            % Compute adversarial attacks.
            xi_ = xi + ri.*sign(sens);
        case 'center'
            % Use the center for falsification.
            xi_ = xi;
        case 'zonotack'
            % TODO: more clever attacks based on the output set.

            % Obtain number of constraints.
            [p,~] = size(A);

            % Compute the vertex that minimizes the distance to each 
            % halfspace.
            beta_ = permute(sign(ld_Gyi),[2 4 3 1]);
            if ~safeSet
                beta_ = -beta_;
            end
            % Put multiple candidates into the batch.
            beta = reshape(beta_,[numInitGens 1 cbSz*p]);
            % Compute attack.
            delta = pagemtimes(repmat(Gxi(:,1:numInitGens,:),1,1,p),beta);
            delta = reshape(delta,[n0 cbSz*p]);
            % Compute candidates for falsification.
            xi_ = repmat(xi,1,p) + delta;
            % beta = mean(beta_,4);
            % xi_ = xi + reshape(pagemtimes(Gxi(:,1:n0,:),beta),[n0 cbSz]);
        otherwise
            % Invalid option.
            throw(CORAerror('CORA:wrongFieldValue', ...
                'options.nn.falsification_method', ...
                    {'fgsm','center','zonotack'}));
    end

    % 2.2. Check adversarial examples.
    yi_ = nn.evaluate_(xi_,options,idxLayer);
    checkSpecsVal = max(A*yi_ + b,[],1);
    if safeSet
        checkSpecs = (checkSpecsVal >= 0);
    else
        checkSpecs = (checkSpecsVal <= 0);
    end

    if plotting
        % Plot counterexample candidates.
        res.xs_{end+1} = gather(xi_);
        res.ys_{end+1} = gather(yi_);
        % Only show current counterexamples; delete previous ones.
        if exist('hxs_','var')
            delete(hxs_);
        end
        if exist('hys_','var')
            delete(hys_);
        end
        % Plot current counterexamples.
        % [fig,hxs_,hys_] = ...
        %   aux_plotCounterExampleCandidates(fig,plotDims,res);
        % drawnow;
    end

    % Check if the batch was extended with multiple candidates.
    if size(checkSpecs,2) > cbSz
        checkSpecs_ = reshape(checkSpecs,1,cbSz,[]);
        % Find the worst candidate.
        [checkSpecs,idx_] = max(checkSpecs_,[],3);
        % Compute the batch index of the worst candiate.
        idx = sub2ind(size(checkSpecs_,[2 3]),1:cbSz,idx_);
        % Take the worst candidate.
        checkSpecsVal = checkSpecsVal(:,idx);
        xi_ = xi_(:,idx);
        yi_ = yi_(:,idx);
    end
    
    if any(checkSpecs)
        % Found a counterexample.
        res.str = 'COUNTEREXAMPLE';
        idNzEntry = find(checkSpecs);
        id = idNzEntry(1);
        x_ = gather(xi_(:,id));
        % Gathering weights from gpu. There is are precision error when 
        % using single gpuArray.
        nn.castWeights(single(1));
        y_ = nn.evaluate_(x_,options,idxLayer); % yi_(:,id);
        break;
    end

    % 3. Refine input sets. -----------------------------------------------

    switch options.nn.split_refinement_method
        case 'naive'
            % The sets are not refined.
        case 'zonotack'
            % Compute constraints.
            A_ = ld_Gyi;
            b_ = -ld_yi;

            % Compute approximation errors.
            apprErr = ld_Gyi_err(:,:);

            % 4.2. Project the unsafe outputs to input space. 
            if options.nn.interval_center
                uXi = struct( ...
                    'c',reshape(1/2*(cxi(:,2,:) + cxi(:,1,:)),[n0 cbSz]), ...
                    'r',1/2*(cxi(:,2,:) - cxi(:,1,:)) ...
                );
            else
                uXi = struct('c',cxi,'r',zeros([n0 1 cbSz]));
            end
            uXi.G = Gxi(:,1:numInitGens,:);
            uXi.A = A_;
            % Offset by approximation errors.
            if safeSet
                uXi.b = b_ - apprErr;
            else
                uXi.b = b_ + apprErr;
            end

            if nReluTightConstr > 0
                % Compute tightening constraints for unstable relu neurons.
                [At,bt] = aux_reluTightenConstraints(nn,numInitGens, ...
                    unknown,nReluTightConstr);
                % Append constraints.
                uXi.A = [uXi.A; At];
                uXi.b = [uXi.b; bt];
            end
            
            if nNeur > 0
                % Create split constraints for neurons within the network.
                [Anri,bnri,apprErri,nrIdx] = ...
                    aux_splitNeurons(nn,numInitGens,unknown,nNeur);

                % Compute number of new splits.
                newSplits = 2^size(Anri,1);
                % Construct a matrix with all combination of -1 and 1 of 
                % length nNeur as columns. 
                cf = permute(2*(dec2bin(0:(newSplits-1)) - '0') - 1,[2 3 1]);
                % Construct constraint matrix for new constraints.
                Anr = permute(Anri,[1 2 4 3]).*cf;
                Anr = Anr(:,:,:);
                bnr = permute(bnri.*cf,[1 3 2]); % + apprErri.*cf;
                bnr = bnr(:,:);
                % Duplicate sets for splitting.
                uXi.c = repelem(uXi.c,1,newSplits);
                uXi.G = repelem(uXi.G,1,1,newSplits);
                uXi.r = repelem(uXi.r,1,1,newSplits);
                uXi.A = [repelem(uXi.A,1,1,newSplits); Anr];
                uXi.b = [repelem(uXi.b,1,newSplits); bnr];

                sens = repmat(sens,1,newSplits);
                checkSpecsVal = repmat(checkSpecsVal,1,newSplits);
            end
          
            % Compute the bounds of the unsafe inputs.
            [l_,u_] = aux_boundsOfConZonotope(uXi,safeSet, ...
                options.nn.exact_conzonotope_bounds);

            % Remove empty sets.
            emptyIdx = any(isnan(l_),1) | any(isnan(u_),1);
            l_(:,emptyIdx) = [];
            u_(:,emptyIdx) = [];
            sens(:,emptyIdx) = [];
            checkSpecsVal(:,emptyIdx) = [];

            % Update input sets.
            xi = 1/2*(u_ + l_);
            ri = 1/2*(u_ - l_);

            if plotting
                % Add a slack variable to convert between equality and 
                % inequality constraints.
                uYi = struct('c',yic,'G',Gyi,'r',yid, ...
                    'A',pagemtimes(A,Gyi),'b',b_);
                % Store constraint zonotope.
                res.uYs{end+1} = aux_2ConZonoWithEqConst(uYi,apprErr);

                % Store input constraint zonotope.
                res.uXs{end+1} = aux_2ConZonoWithEqConst(uXi,0);
            end
        otherwise
            % Invalid option.
            throw(CORAerror('CORA:wrongFieldValue', ...
                'options.nn.split_refinement_method', ...
                    {'naive','zonotack'}));
    end

    % Order remaining sets by their criticality.
    [~,idx] = sort(checkSpecsVal,'descend');
    % Split input sets.
    xi = xi(:,idx);
    ri = ri(:,idx);
    sens = sens(:,idx);

    % We simply split the input sets.
    xis = xi;
    ris = ri;
    for i=1:nDims
        [xis,ris,sens] = aux_split(xis,ris,sens,nSplits);
    end
    
    % Compute bounds of the new input sets.
    lis = xis - ris;
    uis = xis + ris;

    % % Find containing intervals.
    tol = 1e-6;
    if nNeur > 0
        % The splitted sets can be contained if neuron-splitting is
        % enabled.
        isContained = reshape(any( ...
            all(permute(lis - tol,[1 3 2]) <= lis,1) ... lower bounds are larger
            & all(uis <= permute(uis + tol,[1 3 2]),1) ... upper bounds are smaller
            & 1:size(xis,2) > permute(1:size(xis,2),[1 3 2]) ... not the same
                ,3),[1 size(xis,2)]);
        % Remove contained intervals.
        xis(:,isContained) = [];
        ris(:,isContained) = [];
    else
        % No sets are contained.
        isContained = [];
    end

    % Add new splits to the queue.
    xs = [xis xs];
    rs = [ris rs];

    if plotting
        % Delete previously contained sets.
        if exist('hx_','var')
            cellfun(@(hxi_) delete(hx_{i}), hx_);
        end
        % Plot the unsafe output sets and the new input sets.
        [fig,huy,huy_,hux,hx,hx_] = ...
            aux_plotUnsafeOutputAndNewInputSets(fig,plotDims,res, ...
                lis,uis,isContained,nSplits^nDims);
        drawnow;
    end

    % To save memory, we clear all variables that are no longer used.
    batchVars = {'xi','ri','xGi','yi','Gyi','ld_yi','ld_Gyi','ld_ri'};
    clear(batchVars{:});
     
    % Increment iteration counter.
    iter = iter + 1;

    if iter > options.nn.max_verif_iter
        break;
    end
end

if size(xs,2) == 0 && ~strcmp(res.str,'COUNTEREXAMPLE')
    % Verified all patches.
    res.str = 'VERIFIED';
    x_ = [];
    y_ = [];
end

% Store time.
res.time = toc;

if verbose
    % Compute final stats.
    queueLen = size(xs,2);
    if ~isempty(rs)
        avgRad = mean(rs,'all');
        unknVol = sum(prod(2*rs,1),'all');
    else
        avgRad = 0;
        unknVol = 0;
    end
    % Print new table row.
    table.printContentRow({iter,queueLen,numVerified,avgRad,unknVol});
    % Print table footer.
    table.printFooter();
    % Print the result.
    fprintf('--- Result: %s (time: %.3f [s])\n',res.str,res.time);
end

end


% Auxiliary functions -----------------------------------------------------

function [cxi,Gxi,dimIdx] = aux_constructInputZonotope(xi,ri,sens, ...
    batchG,numInitGens,options)
    % Obtain the number of input dimensions and the batch size.
    [n0,bSz] = size(xi);

    % Initialize the generator matrix.
    Gxi = batchG(:,:,1:bSz);

    if numInitGens >= n0
        % We create a generator for each input dimension.
        dimIdx = repmat((1:n0)',1,bSz);
    else
        % Find the input pixels that affect the output the most.
        [~,dimIdx] = sort(sens.*ri,'descend');
        % Select the most important input dimensions and add a generator
        % for each of them.
        dimIdx = dimIdx(1:numInitGens,:);
    end
    % Compute indices for non-zero entries.
    gIdx = sub2ind(size(Gxi),dimIdx, ...
        repmat((1:numInitGens)',1,bSz),repelem(1:bSz,numInitGens,1));
    % Set non-zero generator entries.
    Gxi(gIdx) = ri(sub2ind(size(ri),dimIdx,repelem(1:bSz,numInitGens,1)));
    % Sum generators to compute remaining set.
    ri_ = (ri - reshape(sum(Gxi,2),[n0 bSz]));

    % Construct the center.
    if options.nn.interval_center
        % Put remaining set into the interval center.
        cxi = permute(cat(3,xi - ri_,xi + ri_),[1 3 2]);
    else
        % The center is just a vector.
        cxi = xi;
    end
end

function [xi,ri,xs,rs] = aux_pop(xs,rs,bSz)
    % Construct indices to pop the first bSz elements from xs.
    idx = 1:min(bSz,size(xs,2));
    % Pop centers.
    xi = xs(:,idx);
    xs(:,idx) = [];
    % Pop radii.
    ri = rs(:,idx);
    rs(:,idx) = [];
end

function [xis,ris,sens] = aux_split(xi,ri,sens,nSplits)
    % Split one input dimension into nSplits pieces.
    [n,bSz] = size(xi);
    % Split each input in the batch into nSplits parts; use radius*sens 
    % as the splitting heuristic.
    % 1. Find the input dimension with the largest heuristic.
    [~,sortDims] = sort(abs(sens.*ri),1,'descend');
    dimIds = sortDims(1,:); 
    % Construct indices to use sub2ind to compute the offsets.
    splitsIdx = repmat(1:nSplits,1,bSz);
    bSzIdx = repelem((1:bSz)',nSplits);

    dim = dimIds(1,:);
    linIdx = sub2ind([n bSz nSplits], ...
        repelem(dim,nSplits),bSzIdx(:)',splitsIdx(:)');
    % 2. Split the selected dimension.
    xi_ = xi;
    ri_ = ri;
    % Shift to the lower bound.
    dimIdx = sub2ind([n bSz],dim,1:bSz);
    xi_(dimIdx) = xi_(dimIdx) - ri(dimIdx);
    % Reduce raidus.
    ri_(dimIdx) = ri_(dimIdx)/nSplits;
   
    xis = repmat(xi_,1,1,nSplits);
    ris = repmat(ri_,1,1,nSplits);
    % Offset the center.
    xis(linIdx(:)) = xis(linIdx(:)) + (2*splitsIdx(:) - 1).*ris(linIdx(:));
    
    % Flatten.
    xis = xis(:,:);
    ris = ris(:,:);
    sens = repmat(sens,1,nSplits);
end

function cZeq = aux_2ConZonoWithEqConst(cZineq,apprErr)
    % Extract parameters of the constraint zonotope.
    c = double(gather(cZineq.c));
    G = double(gather(cZineq.G));
    r = double(gather(cZineq.r));
    A = double(gather(cZineq.A));
    b = double(gather(cZineq.b));

    % We convert the inequality constraints to equality constraints by 
    % adding a slack variable.

    % Obtain number of dimensions, generators, and batch size.
    [n,q,bSz] = size(G);
    % Obtain number of constraints.
    [p,~] = size(A);

    cZeq.c = c;
    % Add the radius to the generators.
    if any(r ~= 0,'all')
        G = cat(2,G,r.*eye(n));
        A = cat(2,A,zeros([p n bSz]));
    end
    % Add a slack variable.
    cZeq.G = cat(2,G,zeros([n p bSz]));
    % Compute scale for the slack variable.
    s = 1/2*(sum(abs(A),2) + permute(b,[1 3 2]));
    cZeq.A = cat(2,A,eye(p).*s);
    % Compensate for the slack variable.
    cZeq.b = b - s(:,:);
    % Set the approximation errors.
    cZeq.apprErr = double(gather(apprErr));
end

function [l,u] = aux_boundsOfConZonotope(cZs,safeSet,exactBounds)
    % Input arguments represent a constraint zonotope with inequality
    % constraints.

    % Extract parameters of the constraint zonotope.
    c = cZs.c;
    G = cZs.G;
    r = cZs.r;
    A = cZs.A;
    b = cZs.b;

    % Obtain number of dimensions.
    [n,bSz_] = size(c);
    % Obtain number of generators.
    [~,q,~] = size(G);
    % Obtain number of constraints.
    [p_,~,~] = size(A);

    if safeSet
        % The unsafe set is the union of all constraints. Thus, we 
        % have to create a new set for each constraint.
        % Update number of constraints and batch size.
        p = 1;
        bSz = bSz_*p_;
        % Move constraints into the batch.
        A = reshape(permute(-A,[4 2 3 1]),[1 q bSz]);
        b = reshape(permute(-b,[3 2 1]),[1 bSz]);
        % Duplicate sets.
        c = repmat(c,1,p_);
        r = repmat(r,1,1,p_);
        G = repmat(G,1,1,p_);
    else
        % The number of constraints and batch size are not changed.
        p = p_;
        bSz = bSz_;
    end

    if ~exactBounds
        % Efficient approximation by isolating the i-th variable. ---------
        % We compute a box-approximation of the valid factor for the 
        % constraint zonotope, 
        % i.e., [\underline{\beta},\overline{\beta}] 
        %   \supseteq \{\beta \in [-1,1]^q \mid A\,\beta\leq b\}.
        % We view each constraint separately and use the tightest bounds.
        % For each constraint A_{(i,\cdot)}\,\beta\leq b_{(i)}, we isolate 
        % each factor \beta_{(j)} and extract bounds:
        % A_{(i,\cdot)}\,\beta\leq b_{(i)} 
        %   \implies A_{(i,j)}\,\beta_{(j)} \leq 
        %       b_{(i)} - \sum_{k=1,...,q, k\neq j} A_{(i,k)}\,\beta_{(k)}
        % Based on the sign of A_{(i,j)} we can either tighten the lower or
        % upper bound of \beta_{(j)}.
    
        % Initialize bounds of the factors.
        bl = -ones([q bSz],'like',c);
        bu = ones([q bSz],'like',c);
    
        % Permute the dimension of the constraints for easier handling.
        A_ = permute(A,[2 1 3]);
        b_ = permute(b,[3 1 2]);
        % Reshape factor bounds for easier multiplication.
        bl_ = permute(bl,[1 3 2]);
        bu_ = permute(bu,[1 3 2]);
        % Extract a mask for the sign of the coefficient of the i-th 
        % variable in the j-th constraint.
        nMsk = (A_ < 0);
        pMsk = (A_ > 0);
        % Decompose the matrix into positive and negative entries.
        An = A_.*nMsk;
        Ap = A_.*pMsk;
        % Do summation with matrix multiplication: sum all but the i-th entry.
        sM = (1 - eye(q,'like',c));
    
        tighterBnds = 1;
        while tighterBnds
            % Scale the matrix entries with the current bounds.
            ABnd = Ap.*bl_ + An.*bu_;
            % Isolate the i-th variable of the j-th constraint.
            sABnd = pagemtimes(sM,ABnd);
            % Compute right-hand side of the inequalities.
            rh = min(max((b_ - sABnd)./A_,bl_),bu_);
            % Update the bounds.
            bl_ = max(nMsk.*rh + (~nMsk).*bl_,[],2);
            bu_ = min(pMsk.*rh + (~pMsk).*bu_,[],2);
            % Check if the bounds could be tightened.
            tighterBnds = any((bl < bl_(:,:) | bu_(:,:) < bu) ...
                & bl_(:,:) <= bu_(:,:),'all');
            bl = bl_(:,:);
            bu = bu_(:,:);
        end
    
        % Map bounds of the factors to bounds of the constraint zonotope. 
        % We use interval arithmetic for that.
        % First, split the positive and the negative generators entries.
        Gneg = G.*(G < 0);
        Gpos = G.*(G > 0);
        % Map bounds of the factors to bounds of the constraint zonotope.
        l = c - r(:,:) + reshape(pagemtimes(Gpos,bl_) ...
            + pagemtimes(Gneg,bu_),[n bSz]);
        u = c + r(:,:) + reshape(pagemtimes(Gpos,bu_) ...
            + pagemtimes(Gneg,bl_),[n bSz]);
    
        % Identify empty sets.
        emptyIdx = any(bl_ > bu_,1);
        l(:,emptyIdx) = NaN;
        u(:,emptyIdx) = NaN;

    else % ----------------------------------------------------------------

        % Slow implementation with exact bounds for validation.
        
        % Initialize result.
        l = NaN(n,bSz);
        u = NaN(n,bSz);

        for i=1:bSz
            % Obtain parameters of the i-th batch entry.
            ci = double(gather(c(:,i)));
            Gi = double(gather(G(:,:,i)));
            Ai = double(gather(A(:,:,i)));
            bi = double(gather(b(:,i)));
            % Loop over the dimensions.
            for j=1:n
                % Construct linear program.
                prob = struct('Aineq',Ai,'bineq',bi, ...
                    'lb',-ones(q,1),'ub',ones(q,1));
                % Find the lower bound for the j-th dimension.
                prob.f = Gi(j,:)';
                [bl,~,efl] = CORAlinprog(prob);
                if efl > 0
                    % Solution found.
                    l(j,i) = ci(j) + Gi(j,:)*bl;
                    % Find the upper bound for the j-th dimension.
                    prob.f = -Gi(j,:)';
                    [bu,~,efu] = CORAlinprog(prob);
                    if efu > 0
                        % Solution found.
                        u(j,i) = ci(j) + Gi(j,:)*bu;
                    end
                end
            end
        end

        % % Update the constraint zonotope, because of safeSet the
        % % constraints might have been transformed.
        % cZs = struct('c',c,'G',G,'r',r,'A',A,'b',b);
        % % Convert the inequality constraints to equality constraints.
        % cZseq = aux_2ConZonoWithEqConst(cZs,0);
        % % Extract the parameters.
        % c = double(gather(cZseq.c));
        % G = double(gather(cZseq.G));
        % A = double(gather(cZseq.A));
        % b = double(gather(cZseq.b));
        % % Loop over the batch and compute the bounds of each constraint
        % % zonotope.
        % for i=1:bSz
        %     % Obtain parameters of the i-th batch entry.
        %     ci = c(:,i);
        %     Gi = G(:,:,i);
        %     Ai = A(:,:,i);
        %     bi = b(:,i);
        %     % Instantiate constraint zonotope and add tolerance; compute
        %     % the bounds.
        %     cZiBnds = interval(conZonotope(ci,Gi,Ai,bi));
        %     % Check if the constrain zonotope is empty.
        %     if representsa(cZiBnds,'emptySet')
        %         l(:,i) = NaN;
        %         u(:,i) = NaN;
        %     else
        %         % Extract the computed bounds.
        %         l(:,i) = cZiBnds.inf;
        %         u(:,i) = cZiBnds.sup;
        %     end
        % end
    end

    % ---------------------------------------------------------------------
        
    if safeSet
        % Unify sets if a safe set is specified.
        l = min(reshape(l,[n bSz_ p_]),[],3);
        u = max(reshape(u,[n bSz_ p_]),[],3);
    end
end

function [Anri,bnri,apprErri,nrIdx] = ...
    aux_splitNeurons(nn,numInitGens,unknown,numSplits)
% Assume: input was propagated and stored including sensitivity.
% Output: 
% - Anri, bnri: individual constraints for neuron splits nrConst
%   e.g. A(i,:)*beta <= b(i) and -A(i,:)*beta >= -b(i)
% - nrIdx: indices of neuron splits

% Compute batch size.
bSz = nnz(unknown);

% Initialize constraints.
Anri = [];
bnri = [];
apprErri = [];
% Initial heuristics.
h = [];
% Initialize indices of neuron split.
nrIdx = struct('layerIdx',[],'dimIdx',[]);

% Iterate through the layers and find max heuristics and propagate
% constrains.
for i=1:length(nn.layers)
    % Obtain i-th layer.
    layeri = nn.layers{i};
    % if ~isa(layeri,'nnActivationLayer')
    %     continue;
    % end
    % Obtain the i-th input.
    ci = layeri.backprop.store.inc(:,unknown);
    Gi = layeri.backprop.store.inG(:,:,unknown);
    % Obtain approximation error ids.
    % approxErrGenIds = layeri.backprop.store.approxErrGenIds;
    % Obtain number of hidden neurons.
    [nk,~] = size(ci);
    % Compute splitting heuristic.
    r = sum(abs(Gi),2);
    % newApprErr = permute(sum(abs(Gyi(:,approxErrGenIds,:)),1),[2 1 3]);
    apprErr = sum(abs(Gi(:,(numInitGens+1):end,:)),2); % + newApprErr;
    Si = max(abs(layeri.sensitivity(:,:,unknown)),1e-3);
    sens = permute(sum(Si,1),[2 1 3]);
    % Compute heuristic, e.g., {abs(r.*sens),apprErr.*r, apprErr}.
    % hi = (abs(permute(ci,[1 3 2])) < abs(r)).*r.*sens;
    ci_ = permute(ci,[1 3 2]);
    hi = (abs(ci_) < r).*min(-ci_ + r,ci_ + r);
    % Append new constraints.
    Anri = cat(2,Anri,permute(Gi(:,1:numInitGens,:),[2 1 3]));
    % Split at the center.
    bnri = [bnri; zeros([nk bSz],'like',Gi)];
    apprErri = [apprErri; apprErr(:,:)];
    % Append heuristic and sort.
    [h,idx] = sort([h; hi(:,:)],1,'descend');
    % Only keep the constraints for the top neurons.
    numSplits_ = min(numSplits,size(h,1));
    h = h(1:numSplits_,:);
    % Extract constraints.
    Anri = reshape(Anri(:,idx(1:numSplits_,:)),[numInitGens numSplits_ bSz]);
    bnri = reshape(bnri(idx(1:numSplits_,:)),[numSplits_ bSz]);
    apprErri = reshape(apprErri(idx(1:numSplits_,:)),[numSplits_ bSz]);

    % Update indices.
    nrIdx.layerIdx = [nrIdx.layerIdx; repelem(i,nk,bSz)];
    nrIdx.layerIdx = reshape(nrIdx.layerIdx(idx(1:numSplits_,:)),[numSplits_ bSz]);
    nrIdx.dimIdx = [nrIdx.dimIdx; repmat((1:nk)',1,bSz)];
    nrIdx.dimIdx = reshape(nrIdx.dimIdx(idx(1:numSplits_,:)),[numSplits_ bSz]);
end
% Transpose constraint matrix.
Anri = permute(Anri,[2 1 3]);

end

function [At,bt] = ...
    aux_reluTightenConstraints(nn,numInitGens,unknown,numConstr)
% Assume: input was propagated and stored.
% Output: 
% - At, bt: constraints for unstable neurons, i.e., x >= 0

% Compute batch size.
bSz = nnz(unknown);

% Initialize constraints.
% (i) ReLU(x) >= 0
At0 = [];
bt0 = [];
% (ii) ReLU(x) >= x
At1 = [];
bt1 = [];
% Initial heuristics.
h = [];
% Find indices of relu layers.
idxLayer = 1:length(nn.layers);
idxLayer = idxLayer( ...
    arrayfun(@(i) isa(nn.layers{i},'nnReLULayer'),idxLayer));

% Iterate through the layers and find maximal unstable neurons.
for i=idxLayer
    % Obtain i-th layer.
    layeri = nn.layers{i};
    % Obtain the i-th input.
    ci = layeri.backprop.store.inc(:,unknown);
    Gi = layeri.backprop.store.inG(:,:,unknown);
    % Compute splitting heuristic.
    r = sum(abs(Gi),2);
    hi = r - permute(abs(ci),[1 3 2]);
    % Append heuristic and sort.
    [h,idx] = sort([h; hi(:,:)],1,'descend');
    % Only keep the constraints for the top neurons.
    numConstr_ = min(numConstr,size(h,1));
    h = h(1:numConstr_,:);
    % Obtain the output of the relu enclosure; use the input of the next
    % layer.
    if i < length(nn.layers)
        % Obtain (i+1)-th layer.
        layerip1 = nn.layers{i+1};
        co = layerip1.backprop.store.inc(:,unknown);
        Go = layerip1.backprop.store.inG(:,:,unknown);
        % Append new constraints.
        % ReLU(x) >= 0 <--> -Go*\beta <= co + apprErr
        At0 = cat(2,At0,permute(-Go(:,1:numInitGens,:),[2 1 3]));
        apprErro = sum(abs(Go(:,(numInitGens+1):end,:)),2);
        bt0 = [bt0; co + apprErro(:,:)];
        % ReLU(x) >= x <--> (Gi-Go)*\beta <= co - ci + apprErr'
        % Compute difference of generator matrices.
        Gd = Gi - Go;
        At1 = cat(2,At1,permute(Gd(:,1:numInitGens,:),[2 1 3]));
        apprErrd = sum(abs(Gd(:,(numInitGens+1):end,:)),2);
        bt1 = [bt1; co - ci + apprErrd(:,:)];
        % Obtain the indices for the relevant constraints.
        cIdx = idx(1:numConstr_,:);
        % Select the relevant constraints.
        At0 = reshape(At0(:,cIdx),[numInitGens numConstr_ bSz]);
        bt0 = reshape(bt0(cIdx),[numConstr_ bSz]);
        At1 = reshape(At1(:,cIdx),[numInitGens numConstr_ bSz]);
        bt1 = reshape(bt1(cIdx),[numConstr_ bSz]);
    end
end
% Transpose constraint matrix.
At = permute(cat(2,At0,At1),[2 1 3]);
bt = [bt0; bt1];

end

function [fig,hx0,hspec] = aux_initPlot(fig,plotDims,xs,ys,x0,r0,A,b,safeSet)
    % Plot the initial input set.
    subplot(1,2,1); hold on;
    title('Input Space')
    % Plot the initial input set.
    % plotPoints(xs,plotDims(1,:),'.k');
    hx0 = plot(interval(x0 - r0,x0 + r0),plotDims(1,:), ...
        'DisplayName','Input Set', ...
        'EdgeColor',CORAcolor('CORA:simulations'),'LineWidth',2);

    % Construct the halfspace specification.
    spec = polytope(A,-b);

    % Plot the specification.
    subplot(1,2,2); hold on;
    title('Output Space')
    if safeSet
        safeSetStr = 'safe';
    else
        safeSetStr = 'unsafe';
    end

    % plotPoints(ys,plotDims(2,:),'.k');
    hspec = plot(spec,plotDims(2,:),...
        'DisplayName',sprintf('Specification (%s)',safeSetStr), ...
        'FaceColor',CORAcolor(sprintf('CORA:%s',safeSetStr)),'FaceAlpha',0.2, ...
        'EdgeColor',CORAcolor(sprintf('CORA:%s',safeSetStr)),'LineWidth',2);
end

function [fig,hxi,hx,hxv,hy,hyv] = aux_plotInputAndOutputSets(fig, ...
    plotDims,x0,r0,res)
    % Obtain number of dimensions.
    [n,~] = size(x0);
    % Small interval to avoid plotting errors.
    pI = 1e-8*interval(-ones(n,1),ones(n,1));

    % Plot the input sets.
    subplot(1,2,1); hold on;
    % Plot the initial input set.
    hxi = plot(interval(x0 - r0,x0 + r0),plotDims(1,:), ...
        'DisplayName','Input Set', ...
        'EdgeColor',CORAcolor('CORA:simulations'),'LineWidth',2);
    % Store plot handles for potential deletion.
    hx = {};
    hxv = {};
    for j=1:size(res.Xs{end}.c,2)
        Xij = zonotope(res.Xs{end}.c(:,j),res.Xs{end}.G(:,:,j)) + pI;
        if res.Xs{end}.verified(j)
            hxv{end+1} = plot(Xij,plotDims(1,:), ...
                'DisplayName','Input Set (verified)', ...
                'FaceColor',CORAcolor('CORA:safe'),'FaceAlpha',0.5, ...
                'EdgeColor',CORAcolor('CORA:safe'),'LineWidth',2);
        else
            hx{end+1} = plot(Xij,plotDims(1,:), ...
                'DisplayName','Input Set', ...
                ... 'FaceColor',CORAcolor('CORA:reachSet'),'FaceAlpha',0.2, ...
                'EdgeColor',CORAcolor('CORA:reachSet'),'LineWidth',2);
        end
    end
    % Plot the output sets.
    subplot(1,2,2); hold on;
    % Store plot handles for potential deletion.
    hy = {};
    hyv = {};
    for j=1:size(res.Ys{end}.c,2)
        Yij = zonotope(res.Ys{end}.c(:,j),res.Ys{end}.G(:,:,j)) + pI;
        if res.Xs{end}.verified(j)
            hyv{end+1} = plot(Yij,plotDims(2,:),'DisplayName','Output Set', ...
                ...'FaceColor',CORAcolor('CORA:reachSet'),'FaceAlpha',0.2, ...
                'EdgeColor',CORAcolor('CORA:safe'),'LineWidth',2);
        else
            hy{end+1} = plot(Yij,plotDims(2,:),'DisplayName','Output Set', ...
                ...'FaceColor',CORAcolor('CORA:reachSet'),'FaceAlpha',0.2, ...
                'EdgeColor',CORAcolor('CORA:reachSet'),'LineWidth',2);
        end
    end
end

function [fig,hxs_,hys_] = aux_plotCounterExampleCandidates(fig, ...
    plotDims,res)
    % Plot inputs.
    subplot(1,2,1); hold on;
    hxs_ = plotPoints(res.xs_{end},plotDims(1,:),'or', ...
        'DisplayName','Counterexample Candidate');
    % Plot outputs.
    subplot(1,2,2); hold on;
    hys_ = plotPoints(res.ys_{end},plotDims(2,:),'or', ...
        'DisplayName','Counterexample Candidate');
end

function [fig,huy,huy_,hux,hx,hx_] = ...
    aux_plotUnsafeOutputAndNewInputSets(fig,plotDims,res,lis,uis, ...
        isContained,splitsPerUnsafeSet)
    % Obtain number of dimensions.
    [n,~] = size(lis);
    % Small interval to avoid plotting errors.
    pI = 1e-8*interval(-ones(n,1),ones(n,1));

    % Store plot handles for potential deletion.
    huy = {};
    huy_ = {};
    if isfield(res,'uYs')
        % Plot unsafe output constraint zonotope.
        subplot(1,2,2); hold on;
        for j=1:size(res.uYs{end}.c,2)
            % Plot with approximation error.
            % uYij_ = conZonotope( ...
            %     res.uYs{end}.c(:,j),res.uYs{end}.G(:,:,j),...
            %     res.uYs{end}.A(:,:,j),res.uYs{end}.b(:,j) ...
            %         + res.uYs{end}.apprErr(:,j)) + pI;
            % huy_{end+1} = plot(uYij_,plotDims(2,:),'--', ...
            %     'DisplayName','Output Set (unsafe, w. Approx. Err.)', ...
            %     'EdgeColor',CORAcolor('CORA:highlight1'),'LineWidth',1, ...
            %     'FaceColor',CORAcolor('CORA:reachSet'),'FaceAlpha',0.2 ...
            %     );
            % Plot without approximation error.
            uYij = conZonotope( ...
                res.uYs{end}.c(:,j),res.uYs{end}.G(:,:,j),...
                res.uYs{end}.A(:,:,j),res.uYs{end}.b(:,j)) + pI;
            huy{end+1} = plot(uYij,plotDims(2,:), ...
                'DisplayName','Output Set (unsafe)', ...
                'FaceColor',CORAcolor('CORA:highlight1'),'FaceAlpha',0.2, ...
                'EdgeColor',CORAcolor('CORA:highlight1'),'LineWidth',2);
        end
    end
    % Plot new input sets.
    subplot(1,2,1); hold on;
    % Store plot handles for potential deletion.
    hux = {};
    hx = {};
    hx_ = {};
    for j=1:size(lis,2)
        if isfield(res,'uXs') && mod(j-1,splitsPerUnsafeSet) == 0
            j_ = (j-1)/splitsPerUnsafeSet + 1;
            % Plot unsafe input constraint zonotope.
            uXij = conZonotope( ...
                res.uXs{end}.c(:,j_),res.uXs{end}.G(:,:,j_), ...
                res.uXs{end}.A(:,:,j_),res.uXs{end}.b(:,j_)) + pI;
            hux{end+1} = plot(uXij,plotDims(1,:), ...
                'DisplayName','Input Set (unsafe)', ...
                'EdgeColor',CORAcolor('CORA:highlight1'),'LineWidth',1, ...
                'FaceColor',CORAcolor('CORA:reachSet'),'FaceAlpha',0.2 ...
                );
        end
        % Obtain new input set.
        Xij = interval(lis(:,j),uis(:,j)) + pI;
        if isempty(isContained) || ~isContained(j)
            hx{end+1} = plot(Xij,plotDims(1,:), ...
                'DisplayName','Input Set', ...
                'EdgeColor',CORAcolor('CORA:simulations'),'LineWidth',2);
        else
            hx_{end+1} = plot(Xij,plotDims(1,:),'--', ...
                'DisplayName','Input Set', ...
                'EdgeColor',CORAcolor('CORA:simulations'),'LineWidth',1);
        end
    end
end

% ------------------------------ END OF CODE ------------------------------
