function [res,evResult] = gnn_eval_graph_level_rmedge(varargin)
% evaluate graph-level predictions with uncertain structure

% settings
tol = 1e-14;
[NUM_CHECKS,MODEL,DELTA,PERT_EDGES,DO_ENUMERATION] = setDefaultValues({ ...
    2, ...      % NUM_CHECKS
    './gnn-verification/experiments/models/Enzymes_34035', ...
    0.01, ...  % DELTA
    0.01, ...      % NUM_PERT_EDGES
    false, ...  % DO_ENUMERATION
},varargin);

date = datestr(datetime());
seed = randi(1000,1);
fprintf('Seed: %i\n', seed)
rng(seed)

% load model and data
fprintf('Reading network and data in: %s\n', MODEL)
nn = neuralNetwork.readGNNetwork([MODEL filesep 'model_export.json']);
data = neuralNetwork.readGNNdata([MODEL filesep 'data_export.json']);

% result storage
resvec = false(1,NUM_CHECKS);
resVerifiedSet = false(1,NUM_CHECKS);
resVerifiedEnum = false(1,NUM_CHECKS);
resViolated = false(1,NUM_CHECKS);

counterSuccessful = 1;
counterSkipped = 0;
counterFailed = 0;

timeSet = nan(1,NUM_CHECKS);
timeEnum = nan(1,NUM_CHECKS);

numNodes = nan(1,NUM_CHECKS);
numEdges = nan(1,NUM_CHECKS);

% for trouble-shooting
MEs = {};
failedIdx = [];

% sample NUM_CHECKS data points
if contains(MODEL,'Cora_')
    % only one graph
    MAX_RUNS = 20*NUM_CHECKS;
    idxdataperm = ones(1,MAX_RUNS);
else
    data = data(data{:,'output_label'} == data{:,'target_label'},:); % correctly classified
    idxdataperm = randperm(height(data));
    data = data(idxdataperm,:);
    MAX_RUNS = height(data);
end

% get run seeds
seeds = randi(1000,[1,MAX_RUNS]);

for i=1:MAX_RUNS
    fprintf('Successful: %i/%i. Skipped %i. Failed: %i; \t Running data %i/%i (%i) with seed=%i.. \n', ...
        counterSuccessful-1,NUM_CHECKS,counterSkipped,counterFailed,i,MAX_RUNS,idxdataperm(i),seeds(i))
    rng(seeds(i))

    try
        % read data -------------------------------------------------------
        
        [nn_red,G,x_vec,y_ref,target] = aux_constructGraph(MODEL, data, i, nn);

        numNodes(counterSuccessful) = G.numnodes;
        numEdges(counterSuccessful) = G.numedges;

        % reset neural network
        nn_red.resetGNN();
        options.nn = struct;

        % check reference output ------------------------------------------

        % propagate input through network
        options.nn.graph = G;
        y_pred = nn_red.evaluate(x_vec,options);

        % compare with reference output
        resvec(counterSuccessful) = compareMatrices(y_pred,y_ref,tol);

        % reset gnn
        nn_red.resetGNN();

        % check set output ------------------------------------------------

        % perturb edges

        if floor(PERT_EDGES) == PERT_EDGES
            numPertEdges = PERT_EDGES;
        else
            numPertEdges = ceil(G.numedges * PERT_EDGES);
        end
        
        [G,idxPertEdges] = aux_perturb_graph_rmedge(G,numPertEdges);
        if numel(idxPertEdges) < numPertEdges
            % graph probably too small to perturb numPertEdges edges..
            counterSkipped = counterSkipped + 1;
            continue
        end

        % store in options
        options.nn.graph = G;
        options.nn.idx_pert_edges = idxPertEdges;      

        % other nn options
        % options.nn.poly_method = 'singh';
        options.nn.num_generators = 10000;

        % make uncertain
        if DELTA == 0
            X = polyZonotope(x_vec);
        else
            X = compact(polyZonotope(x_vec,DELTA*diag(ones(size(x_vec)))));
        end

        % propagate set through network
        tic;
        Y = nn_red.evaluate(X,options);

        % check verification
        resVerifiedSet(counterSuccessful) = aux_isVerified(Y,target);
        timeSet(counterSuccessful) = toc;

        % check enumeration -----------------------------------------------

        if DO_ENUMERATION
            % init
            numGraphs = 2^numPertEdges;
            resVerifiedEnum_i = false(1,numGraphs);
            resViolatedEnum_i = false(1,numGraphs);
    
            tic;
    
            for j=0:numGraphs-1
                idxrmedges = idxPertEdges(dec2bin(j,numPertEdges) == '1');
                G_pert = G.rmedge(idxrmedges);
    
                % reset parameter
                nn_red.resetGNN();
                options.nn.graph = G_pert;
                options.nn.idx_pert_edges = [];
    
                % evaluate
                Y = nn_red.evaluate(X,options);
    
                % check verification
                resVerifiedEnum_i(j+1) = aux_isVerified(Y,target);
        
                % check violation
                if DELTA==0
                    % can only be computed for point-wise evaluation
                    resViolatedEnum_i(j+1) = ~aux_isVerified(Y,target) && DELTA == 0;
                end
            end
    
            resVerifiedEnum(counterSuccessful) = all(resVerifiedEnum_i);
            resViolated(counterSuccessful) = any(resViolatedEnum_i);
            timeEnum(counterSuccessful) = toc;

        end

        % reset gnn
        nn_red.resetGNN();

        % done ------------------------------------------------------------

        counterSuccessful = counterSuccessful + 1;

        % check if done
        if counterSuccessful > NUM_CHECKS
            break
        end

    catch ME
        counterFailed = counterFailed + 1;
        MEs{end+1} = ME;
        failedIdx(end+1) = idxdataperm(i);
        % check error messages. OutOfDomain is ok for node with no edge 
        % Should no longer happen due to minimum spanning tree, though
        % if strcmp(ME.identifier,'CORA:outOfDomain') && contains(ME.message,'nnInvSqrtRootLayer')
        %     fprintf('\n\nOutOfDomain in InverseSquareRoot.')
        % elseif strcmp(ME.identifier,'CORA:outOfDomain') && contains(ME.message,'nnGCNLayer')
        %     fprintf('\n\nOutOfDomain in nnGCNLayer.')
        % else
        %     % keyboard
        % end
    end
end

disp(' ')
fprintf("Sanity check: %.2f\n", mean(resvec));
fprintf('Successful: %i/%i. Skipped %i. Failed: %i.\n',counterSuccessful-1,NUM_CHECKS,counterSkipped,counterFailed)
fprintf('Verified: %.2f - Violated: %.2f\n', mean(resVerifiedSet),mean(resViolated))

% gather results
evResult = struct;
evResult.NUM_CHECKS = NUM_CHECKS;
evResult.MODEL = MODEL;
evResult.DELTA = DELTA;
evResult.NUM_PERT_EDGES = PERT_EDGES;

evResult.date = date;
evResult.seed = seed;
evResult.seeds = seeds;
evResult.idxdata = idxdataperm;

evResult.counterSuccessful = counterSuccessful-1;
evResult.counterSkipped = counterSkipped;
evResult.counterFailed = counterFailed;
evResult.MEs = MEs;
evResult.failedIdx = failedIdx;

evResult.res = all(resvec);
evResult.resvec = resvec;
evResult.resVerifiedSet = resVerifiedSet;
evResult.resVerifiedEnum = resVerifiedEnum;
evResult.resViolated = resViolated;

evResult.timeSet = timeSet;
evResult.timeEnum = timeEnum;

evResult.numNodes = numNodes;
evResult.numEdges = numEdges;

res = all(resvec);

end


% Auxiliary functions -----------------------------------------------------

function G = aux_create_graph(numNodes, adj_list)
    adj_list = adj_list + 1;
    G = graph(adj_list(1,:), adj_list(2,:));

    % add missing nodes (nodes with no edges)
    missNodes = numNodes - height(G.Nodes);
    G = addnode(G,missNodes);

    % add self loops
    G = G.addedge(1:numNodes,1:numNodes);
    G = simplify(G,"keepselfloops"); % removes multi-edges
end

function res = aux_isVerified(Y,target_label)
    % verify using argmax trick

    % subtract correct label from all others
    W = eye(length(Y.c));
    W(:,target_label) = W(:,target_label) - 1;

    % transform output
    Y = W*Y;

    % ceck bounds
    I = interval(Y);

    % verified if all <= 0
    res = all(I.sup <= 0);

end

function [G,idxPertEdges] = aux_perturb_graph_rmedge(G,numPertEdges)
    
    % keep minimal spanning tree
    [~,n0] = max(degree(G));
    G_msp = G.minspantree( ...
        'Root',n0, ... % bfs from node with highest degree
        'Type','forest'); % in case graph is disconnected
    msp_edges = G_msp.Edges.EndNodes;
    G_pert = G.rmedge(msp_edges(:,1),msp_edges(:,2));
    
    % do not perturb self loops
    G_pert = G_pert.rmedge(1:G.numnodes,1:G.numnodes);

    % randomly select perturbed edges
    idxPertEdges_pert = randsample(G_pert.numedges,min(numPertEdges,G_pert.numedges));
    pertEdges = G_pert.Edges.EndNodes(idxPertEdges_pert,:);

    % get indices in original graph
    idxPertEdges = G.findedge(pertEdges(:,1),pertEdges(:,2));
    % G = G; % keep G as is        
end

function [nn_red,G,x_vec,y_ref,target_label] = aux_constructGraph(model, data, i, nn)

    if contains(model,'Cora_')
        % selects a subgraph from G

        Xorg = data.input{1};
        x_vec = reshape(Xorg,[],1);
        Yorg = data.output{1};
        target = data.target_label{1};
        G = aux_create_graph(size(Xorg,1),data{1,'edge_index'}{1});
        
        % chooses a random node from candidate nodes (correctly classified)
        % and selects all khop neighbors
    
        % find correctly classififed nodes
        [~,pred_label] = max(Yorg,[],2);
        idx_correct = pred_label == target+1;
        correct_nodes = find(idx_correct);
    
        % randomly select a correctly predicted node within the main subgraph
        while true
            n0 = randsample(correct_nodes,1);
            neighbors = G.bfsearch(n0);
            if numel(neighbors) > 0.5*G.numnodes
                % within main subgraph
    
                % find khop neighbors
                numMPsteps = nn.getNumMessagePassingSteps();
                khopNeighors = G.nearest(n0,numMPsteps+1);
                subNodes = [n0;khopNeighors];
    
                % create subgraph with n0 in first position
                G = G.subgraph(subNodes);
                x_vec = reshape(Xorg(subNodes,:),[],1);
                y_ref = Yorg(n0,:)';
                target_label = target(n0) +1;

                % reduce network
                nn_red = nn.reduceGNNForNode(1,G);   
                break
            end
            
            % better luck next time..
        end

    else
        % use entire graph
        nn_red = nn;
        Xorg = data{i,'input'}{1};
        x_vec = reshape(Xorg,[],1);
        y_ref = data{i,'output'}';
        G = aux_create_graph(size(Xorg,1),data{i,'edge_index'}{1});
        target_label = data{i,'target_label'}+1;
    end
end
