function xai_run_evaluation(name)
% xai_run_evaluation - runs the evaluation
%
% Syntax:
%    xai_run_evaluation()
%
% Inputs:
%    -
%
% Outputs:
%    -
%
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
%
% See also: -

% Authors:       Anonymous
% Written:       17-July-2024
% Last update:   ---
% Last revision: ---

% ------------------------------ BEGIN CODE -------------------------------

% parse input
if nargin < 1 || isempty(name)
    name = char(datetime,'yyyyMMdd-hhmm');
end

% SETTINGS ---

verbose = true;
basepath = '.';

% init paths
datapath = [basepath '/data'];
resultspath = sprintf('%s/results/%s',basepath,name);
mkdir(resultspath, 'plots')
diary([resultspath '/diary.txt'])

% Load run configs ---

runConfigs = aux_loadRunConfigs();

% Load data ---

data = struct;
datasets = unique(runConfigs(:,1));
for i=1:numel(datasets)
    dataset = datasets{i};
    data.(dataset) = xai_readData(dataset,datapath);
end

% init seed (random, but same for all runs to choose same images)
rng('shuffle')
seed = randi(1000);

% save general result
resAllRuns = struct;
resAllRuns.configs = runConfigs;
resAllRuns.date = string(datetime,'yyyyMMdd-hhmm');
resRunAll.seed = seed;
save([resultspath '/resAllRuns.mat'],'resAllRuns');

% Go through all configs --

for i = 1:size(runConfigs,1)
    fprintf('\n--------------- RUN #%i/%i ---------------\n\n',i,size(runConfigs,1));

    resRun = struct;
    
    % read config
    resRun.config = runConfigs(i,:);
    dataset = resRun.config{1};
    network = resRun.config{2};
    inputFormat = resRun.config{3};
    epsilon = resRun.config{4};
    method = resRun.config{5};
    bucketType = resRun.config{6};
    refineSteps = resRun.config{7};
    numThreads = resRun.config{8};
    N = resRun.config{9};

    % set seed
    resRun.seed = resRunAll.seed;
    rng(resRun.seed);
    resRun.seeds = randi(1000,1,N);

    % restrict number of CPUs used
    maxNumCompThreads(numThreads);

    % init results
    resRun.idxFreedFeats = cell(1,N);
    resRun.featOrders = cell(1,N);
    resRun.timesPerFeat = cell(1,N);
    resRun.time = nan(1,N);
    resRun.MEs = {};

    % read network
    nn = xai_readNetwork(network,inputFormat,verbose);

    % find correctly classified images
    [xs,labels,imageIDs] = aux_findCorrectClassifiedImages(N,nn,data,dataset);
    resRun.imageIDs = imageIDs;

    % iterate over all images
    for j=1:N

        try
            fprintf('\n--------------- Image #%i/%i ---------------\n\n',j,N)
            rng(resRun.seeds(j));    

            % read data
            x = xs(:,j);
            label = labels(j);
            nn.reset();
    
            % create explanation
            [idxFreedFeat,featOrder,timesPerFeat] = nn.explain(x,label,epsilon, ...
                'Method',method,'Verbose',verbose, ...
                'InputSize',data.(dataset).inputSize, ...
                'BucketType',bucketType,'RefinementSteps',refineSteps);

            % save results
            resRun.idxFreedFeats{j} = idxFreedFeat;
            resRun.featOrders{j} = featOrder;
            resRun.timesPerFeat{j} = timesPerFeat;
            resRun.time(j) = sum(timesPerFeat);
            disp('Finished with creating explanation.')
            fprintf('Elapsed time since calling nn.explain(): %.4f.\n',resRun.time(j))

            % save figure
            figpath = sprintf('%s/plots/%i_%i',resultspath,i,j);
            savefig(figpath)
            saveas(gcf, [figpath '.png']);
            close

        catch ME
            disp(ME.message)
            resRun.MEs{end+1} = ME;
        end

        % save run result
        save(sprintf([resultspath '/resRun_%i.mat'],i),'resRun');
    end
end

fprintf('\n Visualize ...\n\n')
xai_evaluate_cumulative_time(resultspath);

fprintf('\n--------------- DONE ---------------\n\n')

diary off

end


% Auxiliary functions -----------------------------------------------------

function runConfigs = aux_loadRunConfigs()

% default configs
refineStepsAll = 0.1:0.1:1;
numThreadsMNIST = 1;
numThreadsCIFAR = 4;
numThreadsGTSRB = 8;
numImagesMNIST = 50;
numImagesCIFAR = 50;
numImagesGTSRB = 50;

runConfigs = {
    'MNIST','mnist_sigmoid_6_200.onnx','BCSS', 0.01, 'abstract+refine', 'static',refineStepsAll, numThreadsMNIST, numImagesMNIST;
    'MNIST','mnist_sigmoid_6_200.onnx','BCSS', 0.01, 'standard', 'static',refineStepsAll, numThreadsMNIST, numImagesMNIST;
    'MNIST','mnist_sigmoid_6_200.onnx','BCSS', 0.01, 'abstract+refine', 'static',0.1, numThreadsMNIST, numImagesMNIST;
    'MNIST','mnist_sigmoid_6_200.onnx','BCSS', 0.01, 'abstract+refine', 'static',0.2, numThreadsMNIST, numImagesMNIST;
    'MNIST','mnist_sigmoid_6_200.onnx','BCSS', 0.01, 'abstract+refine', 'static',0.3, numThreadsMNIST, numImagesMNIST;
    'MNIST','mnist_sigmoid_6_200.onnx','BCSS', 0.01, 'abstract+refine', 'static',0.4, numThreadsMNIST, numImagesMNIST;
    'MNIST','mnist_sigmoid_6_200.onnx','BCSS', 0.01, 'abstract+refine', 'static',0.5, numThreadsMNIST, numImagesMNIST;
    'MNIST','mnist_sigmoid_6_200.onnx','BCSS', 0.01, 'abstract+refine', 'static',0.6, numThreadsMNIST, numImagesMNIST;
    'MNIST','mnist_sigmoid_6_200.onnx','BCSS', 0.01, 'abstract+refine', 'static',0.7, numThreadsMNIST, numImagesMNIST;
    'MNIST','mnist_sigmoid_6_200.onnx','BCSS', 0.01, 'abstract+refine', 'static',0.8, numThreadsMNIST, numImagesMNIST;
    'MNIST','mnist_sigmoid_6_200.onnx','BCSS', 0.01, 'abstract+refine', 'static',0.9, numThreadsMNIST, numImagesMNIST;
    % 'GTSRB','gtsrb_cnn_avgp_fc_sigmoid_84_48.onnx','BCSS', 0.01, 'abstract+refine', 'dynamic',refineStepsAll, numThreadsGTSRB, numImagesGTSRB;
    % 'GTSRB','gtsrb_cnn_avgp_fc_sigmoid_84_48.onnx','BCSS', 0.01, 'standard', 'dynamic',refineStepsAll, numThreadsGTSRB, numImagesGTSRB;
    % 'GTSRB','gtsrb_cnn_avgp_fc_sigmoid_84_48.onnx','BCSS', 0.01, 'abstract+refine', 'dynamic',0.1, numThreadsGTSRB, numImagesGTSRB;
    % 'GTSRB','gtsrb_cnn_avgp_fc_sigmoid_84_48.onnx','BCSS', 0.01, 'abstract+refine', 'dynamic',0.2, numThreadsGTSRB, numImagesGTSRB;
    % 'GTSRB','gtsrb_cnn_avgp_fc_sigmoid_84_48.onnx','BCSS', 0.01, 'abstract+refine', 'dynamic',0.3, numThreadsGTSRB, numImagesGTSRB;
    % 'GTSRB','gtsrb_cnn_avgp_fc_sigmoid_84_48.onnx','BCSS', 0.01, 'abstract+refine', 'dynamic',0.4, numThreadsGTSRB, numImagesGTSRB;
    % 'GTSRB','gtsrb_cnn_avgp_fc_sigmoid_84_48.onnx','BCSS', 0.01, 'abstract+refine', 'dynamic',0.5, numThreadsGTSRB, numImagesGTSRB;
    % 'GTSRB','gtsrb_cnn_avgp_fc_sigmoid_84_48.onnx','BCSS', 0.01, 'abstract+refine', 'dynamic',0.6, numThreadsGTSRB, numImagesGTSRB;
    % 'GTSRB','gtsrb_cnn_avgp_fc_sigmoid_84_48.onnx','BCSS', 0.01, 'abstract+refine', 'dynamic',0.7, numThreadsGTSRB, numImagesGTSRB;
    % 'GTSRB','gtsrb_cnn_avgp_fc_sigmoid_84_48.onnx','BCSS', 0.01, 'abstract+refine', 'dynamic',0.8, numThreadsGTSRB, numImagesGTSRB;
    % 'GTSRB','gtsrb_cnn_avgp_fc_sigmoid_84_48.onnx','BCSS', 0.01, 'abstract+refine', 'dynamic',0.9, numThreadsGTSRB, numImagesGTSRB;
    % 'CIFAR','cifar10_large.onnx','BCSS', 0.001 * 255, 'abstract+refine', 'dynamic',refineStepsAll, numThreadsCIFAR, numImagesCIFAR;
    % 'CIFAR','cifar10_large.onnx','BCSS', 0.001 * 255, 'standard', 'dynamic',refineStepsAll, numThreadsCIFAR, numImagesCIFAR;
    % 'CIFAR','cifar10_large.onnx','BCSS', 0.001 * 255, 'abstract+refine', 'dynamic',0.1, numThreadsCIFAR, numImagesCIFAR;
    % 'CIFAR','cifar10_large.onnx','BCSS', 0.001 * 255, 'abstract+refine', 'dynamic',0.2, numThreadsCIFAR, numImagesCIFAR;
    % 'CIFAR','cifar10_large.onnx','BCSS', 0.001 * 255, 'abstract+refine', 'dynamic',0.3, numThreadsCIFAR, numImagesCIFAR;
    % 'CIFAR','cifar10_large.onnx','BCSS', 0.001 * 255, 'abstract+refine', 'dynamic',0.4, numThreadsCIFAR, numImagesCIFAR;
    % 'CIFAR','cifar10_large.onnx','BCSS', 0.001 * 255, 'abstract+refine', 'dynamic',0.5, numThreadsCIFAR, numImagesCIFAR;
    % 'CIFAR','cifar10_large.onnx','BCSS', 0.001 * 255, 'abstract+refine', 'dynamic',0.6, numThreadsCIFAR, numImagesCIFAR;
    % 'CIFAR','cifar10_large.onnx','BCSS', 0.001 * 255, 'abstract+refine', 'dynamic',0.7, numThreadsCIFAR, numImagesCIFAR;
    % 'CIFAR','cifar10_large.onnx','BCSS', 0.001 * 255, 'abstract+refine', 'dynamic',0.8, numThreadsCIFAR, numImagesCIFAR;
    % 'CIFAR','cifar10_large.onnx','BCSS', 0.001 * 255, 'abstract+refine', 'dynamic',0.9, numThreadsCIFAR, numImagesCIFAR;
};

end

function [xs,labels,imageIDs] = aux_findCorrectClassifiedImages(N,nn,data,dataset)

disp('Looking for correctly classified images..')
    
xs = [];
labels = [];
imageIDs = [];
tries = 0;

while numel(labels) < N
    tries = tries + 1;

    % randomly choose image
    n = tries; randi(numel(data.(dataset).Y));

    % read data
    label = data.(dataset).Y(n);
    x = xai_readInputImage(data,dataset,n);
    
    % evaluate
    y = nn.evaluate(x);
    [~,pred] = max(y);

    % check with label
    if label == pred
        xs = [xs,x];
        labels = [labels,label];
        imageIDs = [imageIDs,n];
    end
end

% sanity check: compute accuracy
acc = N/tries;
fprintf('Accurarcy of the network: %.2f%%\n',acc * 100)
if acc < 1/max(data.(dataset).Y)
    throw(CORAerror('CORA:specialError','Accuracy too low. Check network and input.'))
end

end

% ------------------------------ END OF CODE ------------------------------
