select_index = [1 2 4 8];
d = 263562;

%% PSMC-HMC: estimator 
N = 32;
R = 5;
P = 8;
M = 5;
L = 1;

num_test = 10000;
N_class = 10;

psmc_single_pred = zeros(P,R,N,num_test,N_class);
psmc_single_x = zeros(P,R,N,d);
psmc_single_time = zeros(P,R);
Z = zeros(P,R);
K = zeros(P,R);
count = zeros(P,R);
Lsum = zeros(P,R);

for index_P = 1:P
    for index_R = 1:R
        index_node = (index_P-1)*R + index_R;
        data = load(['BayesianNN_CIFAR_psmc_SimpleMLP_scaleGaussian_d' num2str(d) '_N' num2str(N) '_M' num2str(M) '_node' num2str(index_node) '.mat']);
        psmc_single_pred(index_P,index_R,:,:,:) = data.(strcat('psmc_single_pred'));
        psmc_single_x(index_P,index_R,:,:) = data.(strcat('psmc_single_x'));
        psmc_single_time(index_P,index_R) = data.(strcat('psmc_single_time'));
        Z(index_P,index_R) = data.(strcat('Z'));
        K(index_P,index_R) = data.(strcat('K'));
        count(index_P,index_R) = data.(strcat('count'));
        Lsum(index_P,index_R) = data.(strcat('Lsum'));
    end
end
psmc_time = mean(psmc_single_time,"all");
fprintf('PSMC: time \n')
fprintf('%.4f \n', psmc_time)
fprintf('PSMC: Lsum \n')
fprintf('%.4f \n', mean(Lsum,'all'))
fprintf('PSMC: count \n')
fprintf('%.4f \n', mean(count,'all'))

data = ['BayesianNN_CIFAR10_data.mat'];
load(data,'y_test')
y_val = y_test';
smexact_sol = zeros(num_test,N_class);
for index_test = 1:num_test
    smexact_sol(index_test,y_val(index_test)+1) = 1;
end


psmc_xest = zeros(P,R,d);
psmc_smest = zeros(P,R,num_test,N_class); % PSMC estimator for softmax
for index_P = 1:P
    Zf = zeros(index_P,R);
    for index_r = 1:R
        K_max = max(Z(1:index_P,index_r) + K(1:index_P, index_r));
        Zf(:, index_r) = exp(Z(1:index_P,index_r) + K(1:index_P, index_r) - K_max);
    end
    Zf = ones(index_P,R); 
    for index_r = 1:R
        psmc_x_unnorm = 0;
        psmc_pred_unnorm = 0;
        norm_const = 0;
        for index_node = 1:index_P
            psmc_x_unnorm = psmc_x_unnorm + Zf(index_node,index_r) * squeeze(mean(psmc_single_x(index_node,index_r,:,:),3));
            psmc_pred_unnorm = psmc_pred_unnorm + Zf(index_node,index_r) * squeeze(mean(psmc_single_pred(index_node,index_r,:,:,:),3));
            norm_const = norm_const + Zf(index_node, index_r);
        end
        psmc_xest(index_P,index_r,:) = psmc_x_unnorm / norm_const;
        psmc_pred = psmc_pred_unnorm / norm_const;
        
        psmc_pred = psmc_pred - max(psmc_pred,[],2);
        psmc_pred_sum = sum(exp(psmc_pred),2);
        psmc_smest(index_P,index_r,:,:) = exp(psmc_pred)./psmc_pred_sum;
    end
end

psmc_smmse_single = zeros(P,R);
for index_P = 1:P
    for index_r = 1:R
        temp1 = -smexact_sol.*log(squeeze(psmc_smest(index_P,index_r,:,:)));
        temp1(isinf(log(squeeze(psmc_smest(index_P,index_r,:,:)))) & smexact_sol == 0) = 0;
        temp1 = sum(temp1,2);
        temp2 = -smexact_sol.*log(smexact_sol);
        temp2(isinf(log(smexact_sol)) & smexact_sol == 0) = 0;
        temp2 = sum(temp2,2);
        psmc_smmse_single(index_P,index_r) = (sum(temp1)/num_test - sum(temp2)/num_test);
    end
end
psmc_smmse = mean(psmc_smmse_single,2);
psmc_smmse_errorbar = sqrt(var(psmc_smmse_single,1,2))/sqrt(R);

%% PSMC-HMC: accuracy

prob_succ_psmc_prerel = zeros(P,R);
y_pred_psmc = zeros(P,R,num_test);
for index_r = 1:R
    for index_P = 1:P
        for index_test = 1:num_test
            y_pred_psmc(index_P,index_r,index_test) = find(psmc_smest(index_P,index_r,index_test,:) == max(psmc_smest(index_P,index_r,index_test,:)))-1;
            if  y_pred_psmc(index_P,index_r,index_test) == y_val(index_test)
                prob_succ_psmc_prerel(index_P,index_r) = prob_succ_psmc_prerel(index_P,index_r) + 1;
            end
        end
    end
end
prob_succ_psmc = mean(prob_succ_psmc_prerel,2);
prob_succ_psmc_sd = var(prob_succ_psmc_prerel/num_test*100,0,2);
prob_succ_psmc_se = var(prob_succ_psmc_prerel/num_test*100,0,2)/sqrt(R);
prob_succ_psmc = prob_succ_psmc/num_test;
% succ. prob. of PSMC
fprintf('PSMC: accuracy \n')
fprintf('%.4f \n', prob_succ_psmc(select_index))
fprintf('PSMC: sd of accuracy \n')
fprintf('%.4f \n', prob_succ_psmc_sd(select_index))
fprintf('PSMC: se of accuracy \n')
fprintf('%.4f \n', prob_succ_psmc_se(select_index))

%% PSMC-HMC: metrics
y_val_onehot = zeros(num_test, N_class);
for index_test = 1:num_test
    y_val_onehot(index_test, y_val(index_test)+1) = 1;
end

% Recall and precision and F1
recall_perclass_psmc = zeros(P,R,N_class);
precision_perclass_psmc = zeros(P,R,N_class);
F1_perclass_psmc = zeros(P,R,N_class);
for index_P = 1:P
    for index_r = 1:R
        for index_class = 1:N_class
            TP = sum((y_val+1 == index_class) & (squeeze(y_pred_psmc(index_P,index_r,:))+1 == index_class));
            FN = sum((y_val+1 == index_class) & (squeeze(y_pred_psmc(index_P,index_r,:))+1 ~= index_class));
            FP = sum((y_val+1 ~= index_class) & (squeeze(y_pred_psmc(index_P,index_r,:))+1== index_class));

            recall_perclass_psmc(index_P,index_r,index_class) = TP/(TP+FN);
            precision_perclass_psmc(index_P,index_r,index_class) = TP/(TP+FP);
            F1_perclass_psmc(index_P,index_r,index_class) = 2*recall_perclass_psmc(index_P,index_r,index_class)*precision_perclass_psmc(index_P,index_r,index_class)...
                /(recall_perclass_psmc(index_P,index_r,index_class)+precision_perclass_psmc(index_P,index_r,index_class));
        end
    end
end
recall_perclass_psmc_prerel = squeeze(mean(recall_perclass_psmc,2));
recall_psmc = squeeze(mean(recall_perclass_psmc_prerel,2));
fprintf('PSMC: recall \n')
fprintf('%.4f \n',recall_psmc(select_index))
precision_perclass_psmc_prerel = squeeze(mean(precision_perclass_psmc,2));
precision_psmc = squeeze(mean(precision_perclass_psmc_prerel,2));
fprintf('PSMC: precision \n')
fprintf('%.4f \n',precision_psmc(select_index))
F1_perclass_psmc_prerel = squeeze(mean(F1_perclass_psmc,2));
F1_psmc = squeeze(mean(F1_perclass_psmc_prerel,2));
fprintf('PSMC: F1 \n')
fprintf('%.4f \n',F1_psmc(select_index))

% AUC (One vs Rest for each class)
auc_perclass_psmc = zeros(P,R,N_class);
for index_P = 1:P
    for index_r = 1:R
        for index_class = 1:N_class
            [X,Y,~,AUC] = perfcurve(squeeze(y_val_onehot(:,index_class)),squeeze(psmc_smest(index_P,index_r,:,index_class)),1);
            auc_perclass_psmc(index_P,index_r,index_class) = AUC;
        end
    end
end
auc_psmc_prerel = squeeze(mean(auc_perclass_psmc,2));
auc_psmc = mean(auc_psmc_prerel,2);
fprintf('PSMC: AUC \n')
fprintf('%.4f \n',auc_psmc(select_index))

% AUC (One vs Rest for each class)
auc_pr_perclass_psmc = zeros(P,R,N_class);
for index_P = 1:P
    for index_r = 1:R
        for index_class = 1:N_class
            [~,~,~,AUC_PR] = perfcurve(y_val_onehot(:,index_class),squeeze(psmc_smest(index_P,index_r,:,index_class)),1, 'xCrit', 'prec', 'yCrit', 'reca');
            auc_pr_perclass_psmc(index_P,index_r,index_class) = AUC_PR;
        end
    end
end
auc_pr_psmc_prerel = squeeze(mean(auc_pr_perclass_psmc,2));
auc_pr_psmc = mean(auc_pr_psmc_prerel,2);
fprintf('PSMC: AUC-PR \n')
fprintf('%.4f \n',auc_pr_psmc(select_index))

% ENLL and NLLE
nll_psmc_prerel = zeros(P,R);
nlle_psmc_prerel = zeros(P,R);
for index_P = 1:P
    for index_r = 1:R
        temp1 = y_val_onehot .* log(squeeze(psmc_smest(index_P,index_r,:,:)));
        temp1(isinf(log(squeeze(psmc_smest(index_P,index_r,:,:)))) & y_val_onehot == 0) = 0;
        nll_psmc_prerel(index_P,index_r) = -sum(sum(temp1));
    end
end
nll_psmc = mean(nll_psmc_prerel,2);
fprintf('PSMC: NLL \n')
fprintf('%.4f \n',nll_psmc(select_index))

%% entropy

psmc_pred = psmc_single_pred - max(psmc_single_pred,[],5);
psmc_pred_sum = sum(exp(psmc_pred),5);
psmc_prob = exp(psmc_pred)./psmc_pred_sum;

entropy_total = zeros(P, num_test, R);
entropy_epistemic = zeros(P, num_test, R);

for index_P = 1:P
    for index_test = 1:num_test
        for index_r = 1:R
            psmc_prob_one = squeeze(psmc_prob(1:index_P,index_r,:,index_test,:));
            psmc_prob_one = reshape(psmc_prob_one,[],N_class);
            ensemble_pred = mean(psmc_prob_one,1);
            H_total = -sum(ensemble_pred .* log(ensemble_pred + eps));
    
            p = squeeze(psmc_prob_one);
            H_expected = -mean(sum(p .* log(p + eps),2));
    
            entropy_total(index_P, index_test, index_r) = H_total;
            entropy_epistemic(index_P, index_test, index_r) = H_total - H_expected;
        end
    end
end

avg_entropy_total = squeeze(mean(mean(entropy_total, 3),2));
avg_entropy_epistemic = squeeze(mean(mean(entropy_epistemic, 3),2));
% could also add errorbar since we have 5 realizations

fprintf('PSMC: Total Entropy (averaged over test set) \n')
fprintf('%.8f \n', avg_entropy_total(select_index))
fprintf('PSMC: Epistemic Entropy (averaged over test set) \n')
fprintf('%.8f \n', avg_entropy_epistemic(select_index))
