d = 98690;
num_test = 25000;
N_class = 2;

data = ['BayesianNN_IMDB_data.mat'];
load(data,'y_test')
y_val = y_test';
smexact_sol = zeros(num_test,N_class);
for index_test = 1:num_test
    smexact_sol(index_test,y_val(index_test)+1) = 1;
end
y_val_onehot = zeros(num_test, N_class);
for index_test = 1:num_test
    y_val_onehot(index_test, y_val(index_test)+1) = 1;
end

%% HMC
N_hmc = 1;
L = 1;
burnin = 45;
thin = 45;
P = 256;
R = 5;
select_index = [32 64 128 256];

hmc_single_pred = zeros(P*N_hmc,R,num_test,N_class);
hmc_single_x = zeros(P*N_hmc,R,d);
hmc_single_time = zeros(P*N_hmc,R);
Lsum = zeros(P,R);
for index_P = 1:P
    for index_R = 1:R
        index_node = (index_P-1)*R + index_R;
        data = load(['BayesianNN_IMDB_phmc_SimpleMLP_scaleGaussian_d' num2str(d) '_N' num2str(N_hmc) '_thin' num2str(thin) '_burnin' num2str(burnin) '_node' num2str(index_node) '.mat']);
        hmc_single_pred(index_P,index_R,:,:) = cell2mat(data.(strcat('hmc_single_pred')));
        hmc_single_x(index_P,index_R,:) = cell2mat(data.(strcat('hmc_single_x')));
        hmc_single_time(index_P,index_R) = data.(strcat('hmc_single_time'));
        Lsum(index_P,index_R) = data.(strcat('Lsum'));
    end
end
fprintf('HMC: time \n')
fprintf('%.4f \n', mean(hmc_single_time,"all"))
fprintf('HMC: Lsum \n')
fprintf('%.4f \n', mean(Lsum,'all'))

hmc_xest = zeros(P,R,d);
hmc_smest = zeros(P,R,num_test,N_class);  % HMC estimator for softmax
for index_R = 1:R
    for index_P = 1:P
        temp_hmc = hmc_single_x(1:index_P,index_R,:);
        hmc_xest(index_P,index_R,:) = squeeze(mean(temp_hmc,1));
        temp_pred_hmc = hmc_single_pred(1:index_P,index_R,:,:);
        hmc_pred = squeeze(mean(temp_pred_hmc,1));
        
        hmc_pred = hmc_pred - max(hmc_pred,[],2);
        hmc_pred_sum = sum(exp(hmc_pred),2);
        hmc_smest(index_P,index_R,:,:) = exp(hmc_pred)./hmc_pred_sum + 10^(-40);
    end
end

prob_succ_hmc_prerel = zeros(P,R);
y_pred_hmc = zeros(P,R);
for index_R = 1:R
    for index_P = 1:P
        for index_test = 1:num_test
            y_pred_hmc(index_P,index_R,index_test) = find(squeeze(hmc_smest(index_P,index_R,index_test,:)) == squeeze(max(hmc_smest(index_P,index_R,index_test,:))))-1;
            if  y_pred_hmc(index_P,index_R,index_test) == y_val(index_test)
                prob_succ_hmc_prerel(index_P,index_R) = prob_succ_hmc_prerel(index_P,index_R) + 1;
            end
        end
    end
end
prob_succ_hmc = mean(prob_succ_hmc_prerel,2);
prob_succ_hmc_sd = sqrt(var(prob_succ_hmc_prerel/num_test*100,0,2));
prob_succ_hmc_se = sqrt(var(prob_succ_hmc_prerel/num_test*100,0,2))/sqrt(R);
prob_succ_hmc = prob_succ_hmc/num_test;
% succ. prob. of HMC
fprintf('HMC: accuracy \n')
fprintf('%.4f \n', prob_succ_hmc(select_index))
fprintf('HMC: sd of accuracy \n')
fprintf('%.4f \n', prob_succ_hmc_sd(select_index))
fprintf('HMC: se of accuracy \n')
fprintf('%.4f \n', prob_succ_hmc_se(select_index))

%% HMC: metrics
% Recall and precision
recall_perclass_hmc = zeros(P,R,N_class);
precision_perclass_hmc = zeros(P,R,N_class);
F1_perclass_hmc = zeros(P,R,N_class);
for index_P = 1:P
    for index_R = 1:R
        for index_class = 1:N_class
            TP = sum((y_val+1 == index_class) & (squeeze(y_pred_hmc(index_P,index_R,:))+1 == index_class));
            FN = sum((y_val+1 == index_class) & (squeeze(y_pred_hmc(index_P,index_R,:))+1 ~= index_class));
            FP = sum((y_val+1 ~= index_class) & (squeeze(y_pred_hmc(index_P,index_R,:))+1== index_class));

            recall_perclass_hmc(index_P,index_R,index_class) = TP/(TP+FN);
            precision_perclass_hmc(index_P,index_R,index_class) = TP/(TP+FP);
            F1_perclass_hmc(index_P,index_R,index_class) = 2*recall_perclass_hmc(index_P,index_R,index_class)*precision_perclass_hmc(index_P,index_R,index_class)...
                /(recall_perclass_hmc(index_P,index_R,index_class)+precision_perclass_hmc(index_P,index_R,index_class));
        end
    end
end
recall_perclass_hmc_prerel = squeeze(mean(recall_perclass_hmc,2));
recall_hmc = squeeze(mean(recall_perclass_hmc_prerel,2));
fprintf('HMC: recall \n')
fprintf('%.4f\n',recall_hmc(select_index))
precision_perclass_hmc_prerel = squeeze(mean(precision_perclass_hmc,2));
precision_hmc = squeeze(mean(precision_perclass_hmc_prerel,2));
fprintf('HMC: precision \n')
fprintf('%.4f\n',precision_hmc(select_index))
F1_perclass_hmc_prerel = squeeze(mean(F1_perclass_hmc,2));
F1_hmc = squeeze(mean(F1_perclass_hmc_prerel,2));
fprintf('HMC: F1 \n')
fprintf('%.4f\n', F1_hmc(select_index))

% AUC (One vs Rest for each class)
auc_perclass_hmc = zeros(P,R,N_class);
for index_P = 1:P
    for index_R = 1:R
        for index_class = 1:N_class
            [~,~,~,AUC] = perfcurve(squeeze(y_val_onehot(:,index_class)),squeeze(hmc_smest(index_P,index_R,:,index_class)),1);
            auc_perclass_hmc(index_P,index_R,index_class) = AUC;
        end
    end
end
auc_hmc_prerel = squeeze(mean(auc_perclass_hmc,2));
auc_hmc = mean(auc_hmc_prerel,2);
fprintf('HMC: AUC \n')
fprintf('%.4f \n', auc_hmc(select_index))

% AUC-PR
auc_pr_perclass_hmc = zeros(P,R,N_class);
for index_P = 1:P
    for index_R = 1:R
        for index_class = 1:N_class
            [~,~,~,AUC_PR] = perfcurve(squeeze(y_val_onehot(:,index_class)),squeeze(hmc_smest(index_P,index_R,:,index_class)),1, 'xCrit', 'prec', 'yCrit', 'reca');
            auc_pr_perclass_hmc(index_P,index_R,index_class) = AUC_PR;
        end
    end
end
auc_pr_hmc_prerel = squeeze(mean(auc_pr_perclass_hmc,2));
auc_pr_hmc = mean(auc_pr_hmc_prerel,2);
fprintf('HMC: AUC-PR \n')
fprintf('%.4f \n', auc_pr_hmc(select_index))

% ENLL and NLLE
nll_hmc_prerel = zeros(P,R);
nlle_hmc_prerel = zeros(P,R);
for index_P = 1:P
    for index_R = 1:R
        temp1 = y_val_onehot .* log(squeeze(hmc_smest(index_P,index_R,:,:)));
        temp1(isinf(log(squeeze(hmc_smest(index_P,index_R,:,:)))) & y_val_onehot == 0) = 0;
        nll_hmc_prerel(index_P,index_R) = -sum(sum(temp1));
    end
end
nll_hmc = mean(nll_hmc_prerel,2);
fprintf('PSMC: NLL \n')
fprintf('%.4f \n', nll_hmc(select_index))

%% entropy

hmc_pred = hmc_single_pred - max(hmc_single_pred,[],4);
hmc_pred_sum = sum(exp(hmc_pred),4);
hmc_prob = exp(hmc_pred)./hmc_pred_sum;

entropy_total = zeros(P, num_test, R);
entropy_epistemic = zeros(P, num_test, R);

for index_P = 1:P
    for index_test = 1:num_test
        for index_r = 1:R
            hmc_prob_one = squeeze(hmc_prob(1:index_P*N_hmc,index_r,index_test,:));
            hmc_prob_one = reshape(hmc_prob_one,[],2);
            ensemble_pred = mean(hmc_prob_one,1);
            H_total = -sum(ensemble_pred .* log(ensemble_pred + eps));
    
            p = squeeze(hmc_prob_one);
            H_expected = -mean(sum(p .* log(p + eps),2));
    
            entropy_total(index_P, index_test, index_r) = H_total;
            entropy_epistemic(index_P, index_test, index_r) = H_total - H_expected;
        end
    end
end

avg_entropy_total = squeeze(mean(mean(entropy_total, 3),2));
avg_entropy_epistemic = squeeze(mean(mean(entropy_epistemic, 3),2));
% could also add errorbar since we have 5 realizations

fprintf('PHMC: Total Entropy (averaged over test set) \n')
fprintf('%.8f \n', avg_entropy_total(select_index))
fprintf('PHMC: Epistemic Entropy (averaged over test set) \n')
fprintf('%.8f \n', avg_entropy_epistemic(select_index))