%% This script helps to find the final performance metrics for repeated 10fold CVs.
close all; clear;
s = '.';
result =[];
ths = [50]; % burn in epochs.

fold_auc =[];
iter =1;
L = [];
for i = [1:10]
    M_auc=[];
    M_acc=[];
for xx = 1

    %% Val set
    loc =[];
    for j = 0:9
        ss = load(sprintf('../cross_validation/data_%d/setting/model%d/vis.mat',i,j));        
        [~,loc(j+1)] = min(ss.valtotal_loss(ths(xx)+1:length(ss.traintotal_loss))); %identify epoch that gives minimum loss.
        loc(j+1) = loc(j+1) + ths(xx);
    end
    class_pred = [];
    class_true = [];
    for j = 0:9 
          ss = load(sprintf('../cross_validation/data_%d/setting/model%d/vis.mat',i,j));
          t = ss.valclass_pred(1,loc(j+1),:);
          class_pred = [class_pred; t(:)];
                
          t = ss.valclass_true(1,loc(j+1),:);
          class_true = [class_true; t(:)];
     end
     % AUC calculation       
     [X,Y,T,result_auc(1, 1),opt] = perfcurve(class_true,class_pred,1);

     threshold = T((X==opt(1))&(Y==opt(2))); % identify threshold that gives minimun class error over validation set.
     perf = classperf(class_true,class_pred>threshold,'Positive',[1],'Negative',[0]);
     result_acc(1, 1) = perf.CorrectRate;
          
    
    %% Evaluation over test data.
    result =[];
    class_pred = [];
    class_true = [];
    for j = 0:9 
         ss = load(sprintf('../cross_validation/data_%d/setting/model%d/vis.mat',i,j));
         t = ss.testclass_pred(1,loc(j+1),:);
         class_pred = [class_pred; t(:)];
                
         t = ss.testclass_true(1,loc(j+1),:);
         class_true = [class_true; t(:)];      
    end
            
    [X,Y,~,result_auc(1, 1)] = perfcurve(class_true,class_pred,1);

    perf = classperf(class_true,class_pred>threshold,'Positive',[1],'Negative',[0]);
    result_specificity(1, 1) = perf.Specificity;
    result_sensitivity(1, 1) = perf.Sensitivity;      
    result_accuracy(1, 1)    = perf.CorrectRate;
    
    
    M_auc(xx)  = max(result_auc,[],2);
    M_acc(xx)  = max(result_acc,[],2);
    M_sens(xx) = max(result_sensitivity,[],2);
    M_spec(xx) = max(result_specificity,[],2);
    

end

AUC(i,:) = M_auc;
ACC(i,:) = M_acc;
SENS(i,:)= M_sens;
SPEC(i,:)= M_spec;

% This contains the epoch numbers which are later used.
L(i,:) = loc-1;
end
save('guide_importance/locations.mat','L');
save('shap_importance/locations.mat','L');






