
function nn_neuron_merge_mnist_automatic_reduction_comparison

% from https://github.com/daniel-e/mnist_octave/raw/master/mnist.mat
load('mnist.mat','testX','testY');

% normalize
testX = double(testX) / 255;
testY = categorical(testY);

%% 

disp("Loading models ...")
modelfile = "mnist_sigmoid_6_200.onnx";
nn = neuralNetwork.readONNXNetwork(modelfile,false,"BCSS","BC");
% nn.setInputSize([28,28], true)

%%

% 769
seed = randi(1000)
rng(769)

acc = 0;
accDLT = 0;
ver = 0;
verRed = 0;

verf = 0;
images = randi(length(testY),1,1);
delta = 0.01;

doVerify = true;
method = 'static';
tol = 1e-2;

disp("---")
fprintf("N=%d, delta=%.4f\n", numel(images),delta)

X01 = interval(zeros(784,1),ones(784,1));
Xdelta = interval(-delta,delta);

for i=(images)
    fprintf("\nImage: %d \n", i);
    xi = testX(i,:)';
    ti = testY(i);

    % compute DLT
    xi = reshape(reshape(xi,28,28)',1,28,28);
    
    pred = -1;
    % pred = predict(dltoolbox_net, dlarray(xi));
    [~, y_pred] = max(pred);
    
    correct = (double(ti) == y_pred);
    accDLT = accDLT + correct;

    % compute CORA
    xi = reshape(xi,[],1);

    pred = nn.evaluate(xi);
    [~, y_pred] = max(pred);

    correct = (double(ti) == y_pred);
    acc = acc + correct;

    % continue

    % verify
    if doVerify && correct

        % init input set with perturbation
        Xi = xi + Xdelta;  % add uncertainty to image
        Xi = Xi & X01;     % clip at [0,1] 
        Xi = zonotope(Xi); % convert to zonotope

        % original network
        tic;
        Yi = nn.evaluate(Xi);
        toc;

        res = aux_isVerified(Yi,ti);
        ver = ver + res;

        if res
            % reduced network
            % Xi = interval(Xi);
            
            figure; hold on;
            colorOrder = turbo(10);

            for redRate=1:10
            % tic;
            [nn_red, Yi] = nn.computeReducedNetwork(Xi, ...
                "BucketType","static",'ReductionRate',redRate/10, ...
                "Verbose",false,'InputCompression',false);
            % toc;

            color = colorOrder(redRate,:);
            Ii = interval(Yi);
            vis = {};
            for d=1:10
                plot([Ii.inf(d),Ii.sup(d)],[d,d],'Color',color,'DisplayName',sprintf("%2.0f%%",redRate*10),vis{:})
                vis = {'HandleVisibility','off'};
            end
            yticks(1:10);
            yticklabels(categorical(0:9))
            ylabel("Label")
            ylim([0,11])
            xlabel("Prediction bounds")
            legend('Location','eastoutside')

            fprintf("Reduction Rate: %2.0f%% - Is Verified? %i\n",redRate*10,aux_isVerified(Yi,ti))
            
            end
            
        end
    end
end

end

function res = aux_isVerified(Yi,ti)
        % argmax trick
        I = eye(10);
        I(:, double(ti)) = - 1;
        res = all(interval(I * Yi).sup <= 0);
end

