function completed = gnn_main()
% nn_neuron_merge_main - runs all scripts accompanying the
% "Formal Verification of Graph Neural Networks with 
% Uncertain Node Features and Uncertain Graph Structure" Paper
%
% results and plots will be saved to ./results
%
% Syntax:
%    completed = gnn_main()
%
% Inputs:
%    -
%
% Outputs:
%    completed - boolean

% ------------------------------ BEGIN CODE -------------------------------

% STARTUP -----------------------------------------------------------------

rng(1)
warning off

% set up paths
basepath = '.';
datapath = sprintf("%s/experiments", basepath);
resultpath = sprintf("%s/results/%s", basepath, datestr(datetime,'yymmdd-hhMMss'));
mkdir(resultpath)
evalpath = sprintf("%s/evaluation", resultpath);
mkdir(evalpath)
plotpath = sprintf("%s/plots", resultpath);
mkdir(plotpath)

% for smooth images
set(0, 'defaultFigureRenderer', 'painters')

resultstxt = sprintf("%s/results.txt", resultpath);
delete(resultstxt)
diary(resultstxt)
disp("--------------------------------------------------------")
disp('Repeatability Evaluation')
disp("Paper: Formal Verification of Graph Neural Networks with Uncertain Node Features and Uncertain Graph Structure")
disp("Transactions of Machine Learning Research")
fprintf("Date: %s\n", datestr(datetime()))
disp(CORAVERSION)
disp("--------------------------------------------------------")
disp(" ")

% RUN SCRIPTS ----------------------------------------------------------
scrips = {; ...
    % main evaluation
    @() gnn_example_quadMap, "example_quadMap";
    @() gnn_example_invsqrtroot, "example_invsqrtroot";
    @() gnn_example_2_message_passing, "example_2_message_passing";
    @() aux_evaluate_gnn(datapath,evalpath), "evaluation";
    @() gnn_eval_plots(evalpath), "plots";
    };

n = size(scrips, 1);
fprintf("Running %d scripts.. \n", n);
disp(" ")

for i = 1:n
    disp("--------------------------------------------------------")
    script = scrips{i, 1};
    name = scrips{i, 2};

    try
        % run script
        fprintf("Running '%s' ...\n", name)
        script();

        disp(" ")
        fprintf("'%s' was run successfully!\n", name)
        fprintf("Saving plots to '%s'..\n", plotpath)

        % save plots
        h = findobj('type', 'figure');
        m = length(h);

        for j = 1:m
            savefig(sprintf("%s/%s_%d.%s", plotpath, name, j, 'fig'));
            saveas(gcf, sprintf("%s/%s_%d.%s", plotpath, name, j, 'png'));
            % saveas(gcf, sprintf("%s/%s_%d.%s", plotpath, name, j, 'eps'), 'epsc');
            close(gcf)
        end

    catch ME
        % error handling
        disp(" ")
        fprintf("An ERROR occured during execution of '%s':\n", name);
        disp(ME.getReport())
        disp("Continuing with next script..")
    end

    disp(" ")
end

% -------------------------------------------------------------------------
disp("--------------------------------------------------------")
disp(" ")
completed = 1;
disp("Completed!")
fprintf("Date: %s\n", datestr(datetime()))
diary off;

% system('shutdown')

end


% Auxiliary functions -----------------------------------------------------

function aux_evaluate_gnn(datapath, evalpath)

    % settings
    NUM_CHECKS = 20;
    models = {
        'models/Enzymes_34035',
        'models/PROTEINS_314496',
    };

    % run 1 (time saving)
    deltas = 0.001;
    pert_edges = 0:9;
    do_enumeration = true;
    aux_evalute_gnn_details(datapath,evalpath,NUM_CHECKS,models,deltas,pert_edges,do_enumeration);

    % ---

    NUM_CHECKS = 50;
    do_enumeration = false;
    pert_edges = [0,0.001,0.005,0.01,0.05];

    % run 2 (enzymes and proteins)
    deltas = 0.001;
    aux_evalute_gnn_details(datapath,evalpath,NUM_CHECKS,models,deltas,pert_edges,do_enumeration);

    % run 3 (cora)
    deltas = 0;
    models = {
        'models/Cora_332314', ... % 2 MP steps
        'models/Cora_717338', ... % 3 MP steps
    };
    aux_evalute_gnn_details(datapath,evalpath,NUM_CHECKS,models,deltas,pert_edges,do_enumeration);
   
end

function aux_evalute_gnn_details(datapath,evalpath,NUM_CHECKS,models,deltas,pert_edges,do_enumeration)
    % run loops
    for model_i = 1:numel(models)
        for delta = deltas
            for pert_edge = pert_edges
                model = models{model_i};
                disp('---')
                fprintf('pert_edge: %.3f, delta=%.3f, model=%s, do_enumeration=%i\n',pert_edge,delta,model,do_enumeration)
                
                % run one instance
                modelpath = sprintf('%s/%s',datapath,model);
                [~,evResult] = gnn_eval_graph_level_rmedge( ...
                    NUM_CHECKS,modelpath,delta,pert_edge,do_enumeration ...
                );

                % save results
                evalmodelpath = sprintf('%s/%s',evalpath,model);
                mkdir(evalmodelpath)
                save(...
                    sprintf( ...
                        '%s/evResult-edge%.3f-d%.3f-enum%i.mat', ...
                        evalmodelpath,pert_edge,delta,do_enumeration), ...
                    'evResult' ...
                );
                disp(' ')
            end
        end
    end
end

% ------------------------------ END OF CODE ------------------------------
