#!/usr/bin/env python3
"""
Combined OOD-Detection for SMC vs. HMC on filtered MNIST,
with MAP and Deep-Ensemble (DE) OOD-detection,
using entropy-based meta-classification, and for each replicate
one combined Meta-ROC plot that includes ALL methods (MAP, DE,
SMC_P1, SMC_P8, HMC_P1, HMC_P8).
"""

import os, random, pickle
import numpy as np
import pyro
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.func import functional_call
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from scipy.io import loadmat
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    roc_auc_score, roc_curve, confusion_matrix,
    precision_score, recall_score, f1_score, precision_recall_curve
)
from matplotlib.legend_handler import HandlerTuple
from matplotlib.patches import Patch

# =============================================================================
# 1) Parameters
# =============================================================================
device         = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
R_map          = 5
ensemble_size  = 10
N_tr, N_val    = 1000, 200
v              = 0.1
var_conv, var_fc, sdb = v, v, np.sqrt(v)
DATA_PATH      = 'map_de_models.pkl'

R              = 5
d              = 6320
M              = 10
thin, burnin   = 160, 160
select_P       = [1, 8]
N_particles    = 10
N_test         = 2000
N_ood_per      = 500
eps            = 1e-12

# =============================================================================
# 2) Model & helpers
# =============================================================================
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1,4,3,padding=1)
        self.pool = nn.MaxPool2d(2,2)
        self.fc   = nn.Linear(4*14*14,8)
        nn.init.normal_(self.conv.weight, mean=0, std=np.sqrt(var_conv))
        nn.init.normal_(self.conv.bias,   mean=0, std=sdb)
        nn.init.normal_(self.fc.weight,   mean=0, std=np.sqrt(var_fc))
        nn.init.normal_(self.fc.bias,     mean=0, std=sdb)
    def forward(self, x):
        x = F.relu(self.conv(x))
        x = self.pool(x)
        return self.fc(x.view(x.size(0), -1))

def unflatten_params(flat, net):
    params, ptr = {}, 0
    for name, p in net.named_parameters():
        n = p.numel()
        params[name] = flat[ptr:ptr+n].view(p.shape)
        ptr += n
    return params

def softmax_np(x):
    e = np.exp(x - np.max(x,axis=-1,keepdims=True))
    return e/np.sum(e,axis=-1,keepdims=True)

def predict_with_particles(x, particles, Net):
    net = Net().to(device); net.eval()
    outs = []
    for flat in particles:
        params = unflatten_params(torch.tensor(flat,device=device), net)
        outs.append(functional_call(net, params, x.to(device)).detach().cpu().numpy())
    return np.stack(outs, axis=0)

def flatten_net(net):
    return torch.cat([p.view(-1) for p in net.parameters()])

class FilteredDataset(torch.utils.data.Dataset):
    def __init__(self, ds, allowed):
        self.data = [(img,lbl) for img,lbl in ds if lbl in allowed]
    def __len__(self): return len(self.data)
    def __getitem__(self, idx): return self.data[idx]

def binary_meta_labels(y_true, y_pred):
    id_correct = (y_true == 0) & (y_pred == 0)
    return np.where(id_correct, 0, 1)

# Compute p_incorrect = 1 − max_class_prob for ID and OOD and concatenate
def get_p_incorrect(particles):
    # In-domain
    logits_id = predict_with_particles(x_id, particles, SimpleCNN)
    mean_id   = np.mean(softmax_np(logits_id), axis=0)
    p_inc_id  = 1 - np.max(mean_id, axis=1)
    # OOD
    logits_ood = predict_with_particles(ood_imgs, particles, SimpleCNN)
    mean_ood   = np.mean(softmax_np(logits_ood), axis=0)
    p_inc_ood  = 1 - np.max(mean_ood, axis=1)
    return np.concatenate([p_inc_id, p_inc_ood])

# =============================================================================
# Plot‐helper and color definitions (must be defined before the loop)
# =============================================================================
def plot_meta(label, y_true, scores, color):
    fpr, tpr, _ = roc_curve(y_true, scores, pos_label=1)
    plt.plot(fpr, tpr, label=label, color=color)

color_map = {
    'MAP':    'blue',
    'DE':     'orange',
    'SMC_P1': 'red',
    'SMC_P8': 'darkred',
    'HMC_P1': 'green',
    'HMC_P8': 'olive'
}

# =============================================================================
# 3) Load or train MAP & DE
# =============================================================================
try:
    with open(DATA_PATH,'rb') as f:
        saved = pickle.load(f)
    replicate_map_particles = saved['replicate_map_particles']
    replicate_de_flats      = saved['replicate_de_flats']
    print("✅ Loaded MAP + DE")
except FileNotFoundError:
    # --- Prepare data for training MAP/DE ---
    transform    = transforms.ToTensor()
    full_train   = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    filtered_all = FilteredDataset(full_train, list(range(8)))
    filtered_all = Subset(filtered_all, range(N_tr+N_val))
    train_ds     = Subset(filtered_all, range(N_tr))
    val_ds       = Subset(filtered_all, range(N_tr, N_tr+N_val))
    train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
    val_loader   = DataLoader(val_ds, batch_size=64, shuffle=False)

    # --- MAP Training ---
    replicate_map_particles = []
    for r in range(1, R_map+1):
        torch.manual_seed(r); np.random.seed(r); random.seed(r)
        model = SimpleCNN().to(device)
        opt   = optim.Adam(model.parameters(), lr=1e-3)
        crit  = nn.CrossEntropyLoss()
        window, best = [], np.inf
        for epoch in range(1,1001):
            model.train()
            for xb,yb in train_loader:
                opt.zero_grad()
                out = model(xb.to(device))
                ce  = crit(out, yb.to(device))
                reg = sum(p.pow(2).sum() for p in model.parameters())/(2*v)
                (ce + reg/len(train_ds)).backward(); opt.step()
            model.eval()
            with torch.no_grad():
                losses = [crit(model(xb.to(device)), yb.to(device)).item()
                          for xb,yb in val_loader]
            window.append(np.mean(losses))
            if len(window)>5 and np.mean(window[-5:])>=best:
                break
            best = min(best, np.mean(window[-5:]))
        replicate_map_particles.append(flatten_net(model).cpu().numpy())

    # --- Deep Ensemble Training ---
    replicate_de_flats = []
    for r in range(1, R_map+1):
        de_models = []
        for m in range(ensemble_size):
            seed = r*1000 + m
            torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
            model = SimpleCNN().to(device)
            opt   = optim.Adam(model.parameters(), lr=1e-3)
            crit  = nn.CrossEntropyLoss()
            window, best = [], np.inf
            for epoch in range(1,1001):
                model.train()
                for xb,yb in train_loader:
                    opt.zero_grad()
                    out = model(xb.to(device))
                    ce  = crit(out, yb.to(device))
                    reg = sum(p.pow(2).sum() for p in model.parameters())/(2*v)
                    (ce + reg/len(train_ds)).backward(); opt.step()
                model.eval()
                with torch.no_grad():
                    losses = [crit(model(xb.to(device)), yb.to(device)).item()
                              for xb,yb in val_loader]
                window.append(np.mean(losses))
                if len(window)>5 and np.mean(window[-5:])>=best:
                    break
                best = min(best, np.mean(window[-5:]))
            de_models.append(flatten_net(model).cpu().numpy())
        replicate_de_flats.append(np.stack(de_models))

    # --- Save trained particles ---
    with open(DATA_PATH,'wb') as f:
        pickle.dump({
            'replicate_map_particles': replicate_map_particles,
            'replicate_de_flats':      replicate_de_flats
        }, f)
    print("💾 Saved MAP + DE")

# =============================================================================
# 4) Prepare ID & OOD data
# =============================================================================
transform = transforms.ToTensor()
full_val  = torchvision.datasets.MNIST(root='.', train=False, download=True, transform=transform)
indices_id = [i for i,(_,lbl) in enumerate(full_val) if lbl<8][:N_test]
loader     = DataLoader(Subset(full_val,indices_id), batch_size=len(indices_id), shuffle=False)
x_id, y_id = next(iter(loader)); y_id = y_id.numpy()

def pick(d,n):
    return next(iter(DataLoader(
        Subset(full_val, [i for i,(img,lbl) in enumerate(full_val) if lbl==d][:n]),
        batch_size=n, shuffle=False
    )))[0]

d8, d9 = pick(8,N_ood_per), pick(9,N_ood_per)
random.seed(2); np.random.seed(2); torch.manual_seed(2); pyro.set_rng_seed(2)
pert  = torch.clamp(x_id[:N_ood_per] + 0.5*torch.randn_like(x_id[:N_ood_per]),0,1)
noise = torch.rand(N_ood_per,1,28,28)
ood_imgs = torch.cat([d8, d9, pert, noise], dim=0).to(device)

# =============================================================================
# 5) Metrics & helper
# =============================================================================
def detection_metrics(predict_fn, particles, return_scores=False):
    logits_id  = predict_with_particles(x_id, particles, SimpleCNN)
    mean_id    = np.mean(softmax_np(logits_id),axis=0)
    H_id       = -np.sum(mean_id * np.log(mean_id + eps), axis=1)

    logits_ood = predict_with_particles(ood_imgs, particles, SimpleCNN)
    mean_ood   = np.mean(softmax_np(logits_ood),axis=0)
    H_ood      = -np.sum(mean_ood * np.log(mean_ood + eps), axis=1)

    y_true = np.concatenate([np.zeros_like(H_id), np.ones_like(H_ood)])
    scores = np.concatenate([H_id, H_ood])

    prec, rec, thr = precision_recall_curve(y_true, scores)
    f1_arr = 2*prec*rec/(prec+rec+1e-12)
    idx    = np.nanargmax(f1_arr[:-1])
    best_f1, best_thr = f1_arr[idx], thr[idx]
    auroc  = roc_auc_score(binary_meta_labels(y_true, np.zeros_like(y_true)), scores)

    if return_scores:
        return best_f1, auroc, best_thr, y_true, scores
    return best_f1, auroc

# =============================================================================
# 6) Main loop: per-replicate Meta-ROC + other plots
# =============================================================================
# --- before replicates ---
metrics_storage = {
    'MAP':    {'prec':[], 'rec':[], 'f1':[], 'auroc':[]},
    'DE':     {'prec':[], 'rec':[], 'f1':[], 'auroc':[]},
    'SMC_P1': {'prec':[], 'rec':[], 'f1':[], 'auroc':[]},
    'SMC_P8': {'prec':[], 'rec':[], 'f1':[], 'auroc':[]},
    'HMC_P1': {'prec':[], 'rec':[], 'f1':[], 'auroc':[]},
    'HMC_P8': {'prec':[], 'rec':[], 'f1':[], 'auroc':[]},
}

# storage for default‐threshold metrics (threshold=0.5)
metrics_default = {
    'MAP':    {'prec':[], 'rec':[], 'f1':[], 'auroc':[]},
    'DE':     {'prec':[], 'rec':[], 'f1':[], 'auroc':[]},
    'SMC_P1': {'prec':[], 'rec':[], 'f1':[], 'auroc':[]},
    'SMC_P8': {'prec':[], 'rec':[], 'f1':[], 'auroc':[]},
    'HMC_P1': {'prec':[], 'rec':[], 'f1':[], 'auroc':[]},
    'HMC_P8': {'prec':[], 'rec':[], 'f1':[], 'auroc':[]},
}

# Initialize storage before the replicate loop (put this up near metrics_storage)
roc_curves = {
    'MAP':    [],
    'DE':     [],
    'SMC_P1': [],
    'SMC_P8': [],
    'HMC_P1': [],
    'HMC_P8': []
}

ths = np.linspace(0, 1, 200)
accuracy_curves = {
    'MAP':    [],
    'DE':     [],
    'SMC_P1': [],
    'SMC_P8': [],
    'HMC_P1': [],
    'HMC_P8': []
}

# storage for per‐replicate optimal thresholds
thr_storage = {
    'MAP':    [],
    'DE':     [],
    'SMC_P1': [],
    'SMC_P8': [],
    'HMC_P1': [],
    'HMC_P8': []
}

confusion_rates = {
    'MAP':    [],
    'DE':     [],
    'SMC_P1': [],
    'SMC_P8': [],
    'HMC_P1': [],
    'HMC_P8': []
}

# storage for default‐threshold (0.5) normalized confusion rates
confusion_rates_default = {
    'MAP':    [], 
    'DE':     [],
    'SMC_P1': [], 
    'SMC_P8': [],
    'HMC_P1': [], 
    'HMC_P8': []
}



for r in range(R):
    print(f"\n===== replicate {r+1}/{R} =====")

    # --- inside your for r in range(R): loop, replace the entire SMC/HMC + Meta-ROC + p_inc section with this ---

    # 6a) Collect SMC/HMC parameter‐particles, y_true and scores
    y_true_collect    = {m: {} for m in ['SMC','HMC']}
    scores_collect    = {m: {} for m in ['SMC','HMC']}
    particles_collect = {m: {} for m in ['SMC','HMC']}

    for method, prefix in [('SMC','psmc'), ('HMC','hmc')]:
        key = f"{prefix}_single_x"
        for P in select_P:
            parts = []
            total = P if method=='SMC' else N_particles * P
            for i in range(total):
                if method=='SMC':
                    fn = f"BayesianNN_MNIST_{prefix}_SimpleNN_results_d{d}_train{N_tr}_val{N_val}_N{N_particles}_M{M}_node{r*total + i + 1}.mat"
                else:
                    fn = f"BayesianNN_MNIST_{prefix}_results_d{d}_train{N_tr}_val{N_val}_thin{thin}_burnin{burnin}_node{r*total + i + 1}.mat"
                if os.path.exists(fn):
                    parts.append(loadmat(fn)[key])
            if not parts:
                continue
            parts = np.vstack(parts)                   # shape (#particles, param_dim)
            particles_collect[method][P] = parts       # save for p_incorrect
            _, _, _, y_r, s_r = detection_metrics(predict_with_particles, parts, return_scores=True)
            y_true_collect[method][P] = y_r
            scores_collect[method][P] = s_r

    # 6b) MAP & DE
    _, _, _, y_map, s_map = detection_metrics(predict_with_particles, [replicate_map_particles[r]], return_scores=True)
    _, _, _, y_de,  s_de  = detection_metrics(predict_with_particles, replicate_de_flats[r], return_scores=True)

    # 6c) Combined Meta-ROC (unchanged)
    meta_list = [
        ('MAP', y_map, s_map),
        ('DE',  y_de,  s_de),
    ]
    for m in ['SMC','HMC']:
        for P in select_P:
            if P in y_true_collect[m]:
                meta_list.append((f"{m}_P{P}", y_true_collect[m][P], scores_collect[m][P]))
    
    for name, y, s in meta_list:
        fpr, tpr, _ = roc_curve(y, s, pos_label=1)
        roc_curves[name].append((fpr, tpr))

    # plt.figure(figsize=(8,6))
    # for name, y, s in meta_list:
    #     plot_meta(name, y, s, color_map[name])
    # plt.plot([0,1],[0,1],'k--',lw=1)
    # plt.xlabel('FPR'); plt.ylabel('TPR')
    # plt.title(f'Meta-ROC, Replicate {r+1}')
    # plt.legend(loc='lower right'); plt.grid(True); plt.tight_layout()
    # plt.savefig(f'meta_ROC_rep{r+1}.png',dpi=300); plt.show()

    # 6d) Combined Accuracy vs Threshold at p(correct) & per‐method confusion

    # Build p_inc and invert to p_corr = p(correct)
    p_inc = {
        'MAP': get_p_incorrect([replicate_map_particles[r]]),
        'DE':  get_p_incorrect(replicate_de_flats[r])
    }
    for method in ['SMC','HMC']:
        for P, parts in particles_collect[method].items():
            p_inc[f"{method}_P{P}"] = get_p_incorrect(parts)

    # Build acc_list
    acc_list = [
        ('MAP', y_map, p_inc['MAP']),
        ('DE',  y_de,  p_inc['DE'])
    ]
    for method in ['SMC','HMC']:
        for P in select_P:
            name = f"{method}_P{P}"
            if name in p_inc:
                acc_list.append((name, y_true_collect[method][P], p_inc[name]))
    
        # --- collect this replicate’s accuracy vs threshold at p(correct) ---
    for name, y_true_full, p_inc_scores in acc_list:
        p_corr = 1 - p_inc_scores
        y_mt   = binary_meta_labels(y_true_full, np.zeros_like(y_true_full))

        # compute accuracy at each threshold
        accs = []
        for thr in ths:
            # predict “incorrect” when p_corr < thr
            y_pred = (p_corr < thr).astype(int)
            tp = np.sum((y_pred==1) & (y_mt==1))
            tn = np.sum((y_pred==0) & (y_mt==0))
            accs.append((tp+tn) / len(y_mt))

        accuracy_curves[name].append(accs)

    # Plot combined accuracy vs threshold on p(correct)
    #plt.figure(figsize=(8,6))
    best_thr_dict = {}
    for name, y_true_full, p_inc_scores in acc_list:
        # invert to get p_correct
        p_corr = 1 - p_inc_scores
        y_mt   = binary_meta_labels(y_true_full, np.zeros_like(y_true_full))

        # find best threshold by F1 on incorrect as before
        best_f1, best_thr_corr = -1, 0
        for thr in ths:
            # classify incorrect when p_corr < thr
            y_pred = (p_corr < thr).astype(int)
            f1m    = f1_score(y_mt, y_pred, pos_label=1)
            if f1m > best_f1:
                best_f1, best_thr_corr = f1m, thr
        best_thr_dict[name] = best_thr_corr

        # after best_thr_dict[name] = best_thr_corr
        thr_storage[name].append(best_thr_corr)

        # compute accuracy = (TP_incorrect + TN_correct)/N
        accs = []
        for thr in ths:
            y_pred = (p_corr < thr).astype(int)
            tp = np.sum((y_pred==1) & (y_mt==1))
            tn = np.sum((y_pred==0) & (y_mt==0))
            accs.append((tp+tn)/len(y_mt))
        #plt.plot(ths, accs, label=f"{name} (thr={best_thr_corr:.2f})")

    #plt.xlabel('Threshold on p(correct)')
    #plt.ylabel('Accuracy')
    #plt.title(f'Accuracy vs p(correct), Replicate {r+1}')
    #plt.legend(loc='lower left')
    #plt.grid(True)
    #plt.tight_layout()
    #plt.savefig(f'accuracy_vs_pcorrect_rep{r+1}.png', dpi=300)
    #plt.show()

    # Per‐method confusion at each best threshold
        # Per‐method confusion at each best threshold
    for name, y_true_full, p_inc_scores in acc_list:
        p_corr   = 1 - p_inc_scores
        y_mt     = binary_meta_labels(y_true_full, np.zeros_like(y_true_full))
        thr_corr = best_thr_dict[name]
        y_pred   = (p_corr < thr_corr).astype(int)

        # --- compute & print metrics ---
        prec  = precision_score(y_mt, y_pred, pos_label=1)
        rec   = recall_score( y_mt, y_pred, pos_label=1)
        f1    = f1_score(   y_mt, y_pred, pos_label=1)
        auroc = roc_auc_score(y_mt, p_inc_scores)  # scoring on p_inc
        print(f"Optimal:{name} → Prec={prec:.3f}, Rec={rec:.3f}, F1={f1:.3f}, AUROC={auroc:.3f}")
        # store metrics for this replicate
        storage = metrics_storage[name]
        storage['prec'].append(prec)
        storage['rec'].append(rec)
        storage['f1'].append(f1)
        storage['auroc'].append(auroc)

        # also compute default‐threshold (0.5) metrics for this replicate
        y_pred_default = (p_inc_scores >= 0.5).astype(int)
        prec_d = precision_score(y_mt, y_pred_default, pos_label=1)
        rec_d  = recall_score(   y_mt, y_pred_default, pos_label=1)
        f1_d   = f1_score(      y_mt, y_pred_default, pos_label=1)
        auroc_d = roc_auc_score(y_mt, p_inc_scores)
        print(f"Default:{name} → Prec={prec_d:.3f}, Rec={rec_d:.3f}, F1={f1_d:.3f}, AUROC={auroc_d:.3f}")
        # store them
        md = metrics_default[name]
        md['prec'].append(prec_d)
        md['rec'].append(rec_d)
        md['f1'].append(f1_d)
        md['auroc'].append(auroc_d)



        # --- plot confusion ---
        cm = confusion_matrix(y_mt, y_pred, labels=[0,1])
        # normalize to proportions over total examples
        cm_rate = cm.astype(float) / cm.sum()
        confusion_rates[name].append(cm_rate)

        # fig, ax = plt.subplots(figsize=(4,4))
        # if name == 'MAP':
        #     sns.heatmap(
        #         cm, annot=True, fmt='d', ax=ax,
        #         xticklabels=['Correct','Incorrect'],
        #         yticklabels=['Correct','Incorrect'],
        #         cmap='Blues', cbar=True
        #     )
        #     for spine in ax.spines.values():
        #         spine.set_edgecolor('orange')
        #         spine.set_linewidth(2)
        # else:
        #     sns.heatmap(
        #         cm, annot=True, fmt='d', cbar=False, ax=ax,
        #         xticklabels=['Correct','Incorrect'],
        #         yticklabels=['Correct','Incorrect']
        #     )

        # ax.set_title(f'{name} Confusion, Rep {r+1}')
        # ax.set_xlabel('Pred'); ax.set_ylabel('True')
        # fig.tight_layout()
        # fig.savefig(f'{name}_cm_pcorrect_rep{r+1}.png', dpi=200)
        # plt.show()

                # Default‐threshold confusion matrix (threshold=0.5)
        y_pred_def = (p_inc_scores >= 0.5).astype(int)
        cm_def = confusion_matrix(y_mt, y_pred_def, labels=[0,1])
        # normalize and store default confusion rates
        cm_def_rate = cm_def.astype(float) / cm_def.sum()
        confusion_rates_default[name].append(cm_def_rate)

        # fig_def, ax_def = plt.subplots(figsize=(4,4))
        # sns.heatmap(
        #     cm_def, annot=True, fmt='d', cbar=False, ax=ax_def,
        #     xticklabels=['Correct','Incorrect'],
        #     yticklabels=['Correct','Incorrect']
        # )
        # ax_def.set_title(f'{name} Confusion @ default thr=0.5, Rep {r+1}')
        # ax_def.set_xlabel('Pred'); ax_def.set_ylabel('True')
        # fig_def.tight_layout()
        # fig_def.savefig(f'{name}_cm_defaultthr_rep{r+1}.png', dpi=200)
        # plt.show()


# --- after all replicates are done ---
print("\n=== Average over replicates ===")
for name, vals in metrics_storage.items():
    p_mean, p_std     = np.mean(vals['prec']),  np.std(vals['prec'])
    r_mean, r_std     = np.mean(vals['rec']),   np.std(vals['rec'])
    f1_mean, f1_std   = np.mean(vals['f1']),    np.std(vals['f1'])
    auroc_mean, auroc_std = np.mean(vals['auroc']), np.std(vals['auroc'])
    print(f"Optimal {name}:")
    print(f"  Precision = {p_mean:.3f} ± {p_std:.3f}")
    print(f"  Recall    = {r_mean:.3f} ± {r_std:.3f}")
    print(f"  F1        = {f1_mean:.3f} ± {f1_std:.3f}")
    print(f"  AUROC     = {auroc_mean:.3f} ± {auroc_std:.3f}")

# --- after your existing optimal‐threshold averages ---
print("\n=== Average over replicates (default thr=0.5) ===")
for name, vals in metrics_default.items():
    p_mean, p_std         = np.mean(vals['prec']),  np.std(vals['prec'])
    r_mean, r_std         = np.mean(vals['rec']),   np.std(vals['rec'])
    f1_mean, f1_std       = np.mean(vals['f1']),    np.std(vals['f1'])
    auroc_mean, auroc_std = np.mean(vals['auroc']), np.std(vals['auroc'])
    print(f"Default {name}:")
    print(f"  Precision = {p_mean:.3f} ± {p_std:.3f}")
    print(f"  Recall    = {r_mean:.3f} ± {r_std:.3f}")
    print(f"  F1        = {f1_mean:.3f} ± {f1_std:.3f}")
    print(f"  AUROC     = {auroc_mean:.3f} ± {auroc_std:.3f}")


# Friendly labels for legends
label_map = {
    'MAP':    'MAP',
    'DE':     'DE',
    'SMC_P1': 'SMC (P=1)',
    'SMC_P8': 'SMC (P=8)',
    'HMC_P1': 'HMC (P=1)',
    'HMC_P8': 'HMC (P=8)',
}

# =============================================================================
# 7) Mean ROC across replicates (with AUC in legend)
# =============================================================================
# plt.figure(figsize=(8,6))
# fpr_grid = np.linspace(0, 1, 200)
# for name, curves in roc_curves.items():
#     # Interpolate tprs onto common FPR grid
#     tprs_interp = []
#     for fpr, tpr in curves:
#         tprs_interp.append(np.interp(fpr_grid, fpr, tpr))
#     tprs_interp = np.vstack(tprs_interp)

#     mean_tpr = np.mean(tprs_interp, axis=0)
#     std_tpr  = np.std(tprs_interp, axis=0)

#     # Compute average AUROC from metrics_storage
#     mean_auroc = np.mean(metrics_storage[name]['auroc'])

#     # Plot
#     plt.plot(fpr_grid, mean_tpr,
#              label=f"{label_map[name]} (AUC={mean_auroc:.3f})")
#     plt.fill_between(fpr_grid,
#                      np.maximum(mean_tpr - std_tpr, 0),
#                      np.minimum(mean_tpr + std_tpr, 1),
#                      alpha=0.2)

# plt.plot([0,1], [0,1], 'k--', lw=1)
# plt.xlabel('False Positive Rate')
# plt.ylabel('True Positive Rate')
# plt.title('Mean Meta-ROC across replicates')
# plt.legend(loc='lower right')
# plt.grid(True)
# plt.tight_layout()
# plt.savefig('mean_meta_ROC_all_methods.png', dpi=300)
# plt.show()

plt.figure(figsize=(8,6))
fpr_grid = np.linspace(0,1,200)

handles = []
labels  = []

for name, curves in roc_curves.items():
    # interpolate
    tprs = [np.interp(fpr_grid, fpr, tpr) for fpr, tpr in curves]
    arr = np.vstack(tprs)
    mean_tpr = arr.mean(axis=0)
    std_tpr  = arr.std(axis=0)

    # plot line
    line, = plt.plot(
        fpr_grid, mean_tpr,
        color=color_map[name], 
        label=None  # we'll use tuple legend
    )
    # plot shading
    band = plt.fill_between(
        fpr_grid,
        np.maximum(mean_tpr-std_tpr, 0),
        np.minimum(mean_tpr+std_tpr, 1),
        facecolor=line.get_color(),
        alpha=0.2,
        label=None
    )

    # create a tuple handle for the legend
    handles.append((line, band))
    labels.append(f"{label_map[name]} (AUC={np.mean(metrics_storage[name]['auroc']):.3f})")

# diagonal chance line
diag, = plt.plot([0,1],[0,1],'k--', lw=1)
handles.append(diag)
labels.append('Chance')

plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
#plt.title('Mean Meta-ROC across replicates')
plt.grid(True)
plt.tight_layout()

# draw legend with HandlerTuple for the (line,band) tuples
plt.legend(handles, labels, handler_map={tuple: HandlerTuple(ndivide=None)}, loc='lower right')

plt.savefig('mean_meta_ROC_all_methods.png', dpi=300)
plt.show()



# =============================================================================
# 8) Mean Accuracy vs Threshold at p(correct) over replicates
# =============================================================================
# plt.figure(figsize=(8,6))
# for name, all_acc in accuracy_curves.items():
#     # all_acc is a list of R arrays of length len(ths)
#     arr = np.vstack(all_acc)                # shape (R, len(ths))
#     mean_acc = arr.mean(axis=0)
#     std_acc  = arr.std(axis=0)

#     plt.plot(ths, mean_acc, label=label_map[name])
#     plt.fill_between(ths,
#                      np.maximum(mean_acc - std_acc, 0),
#                      np.minimum(mean_acc + std_acc, 1),
#                      alpha=0.2)

# plt.xlabel('Threshold on p(correct)')
# plt.ylabel('Accuracy')
# plt.title('Mean Accuracy vs p(correct) over replicates')
# plt.legend(loc='lower left')
# plt.grid(True)
# plt.tight_layout()
# plt.savefig('mean_accuracy_vs_pcorrect.png', dpi=300)
# plt.show()

plt.figure(figsize=(8,6))

acc_handles = []
acc_labels  = []

for name, all_acc in accuracy_curves.items():
    arr       = np.vstack(all_acc)               # shape (R, len(ths))
    mean_acc  = arr.mean(axis=0)
    std_acc   = arr.std(axis=0)

    # plot mean accuracy curve
    line, = plt.plot(
        ths, mean_acc,
        color=color_map[name],
        label=None
    )
    # plot shading band
    band = plt.fill_between(
        ths,
        np.maximum(mean_acc-std_acc, 0),
        np.minimum(mean_acc+std_acc, 1),
        facecolor=line.get_color(),
        alpha=0.2,
        label=None
    )

    # one handle for the tuple (line, band)
    acc_handles.append((line, band))
    acc_labels.append(label_map[name])

plt.xlabel('Threshold on p(correct)')
plt.ylabel('Accuracy')
#plt.title('Mean Accuracy vs p(correct) over replicates')
plt.grid(True)
plt.tight_layout()

# draw legend: each tuple (curve+shade) as one entry
plt.legend(
    acc_handles,
    acc_labels,
    handler_map={tuple: HandlerTuple(ndivide=None)},
    loc='lower left'
)

plt.savefig('mean_accuracy_vs_pcorrect.png', dpi=300)
plt.show()


# =============================================================================
# 9) Average optimal thresholds over replicates
# =============================================================================
print("\n=== Average optimal p(correct) thresholds ===")
for name, thr_list in thr_storage.items():
    arr = np.array(thr_list)
    mean_thr = arr.mean()
    std_thr  = arr.std()
    print(f"{name}: threshold = {mean_thr:.3f} ± {std_thr:.3f}")


# =============================================================================
# 10) Plot averaged confusion‐matrix proportions with SE (updated colors)
# =============================================================================
import math

for name, mats in confusion_rates.items():
    # stack and average
    arr       = np.stack(mats, axis=0)                  # shape (R,2,2)
    mean_rate = arr.mean(axis=0)
    std_rate  = arr.std(axis=0, ddof=1)
    se_rate   = std_rate / math.sqrt(arr.shape[0])

    # prepare annotations “mean±SE”
    annot = np.empty(mean_rate.shape, dtype=object)
    for i in range(2):
        for j in range(2):
            annot[i, j] = f"{mean_rate[i, j]:.3f}±{se_rate[i, j]:.3f}"

    fig, ax = plt.subplots(figsize=(4,4))
    # use Blues colormap and show colorbar
    sns.heatmap(
        mean_rate,
        annot=annot,
        fmt="",
        cmap='Oranges',
        cbar=True,
        ax=ax,
        xticklabels=['Correct','Incorrect'],
        yticklabels=['Correct','Incorrect']
    )
    # draw orange border around entire matrix
    for spine in ax.spines.values():
        spine.set_edgecolor('orange')
        spine.set_linewidth(2)

    #ax.set_title(f'Avg. Confusion Rates ±SE: {name}')
    ax.set_xlabel('Pred'); ax.set_ylabel('True')
    fig.tight_layout()
    fig.savefig(f'avg_cm_rates_se_{name}.png', dpi=200)
    plt.show()



import math

# =============================================================================
# 11) Averaged default‐threshold confusion proportions with SE
# =============================================================================
for name, mats in confusion_rates_default.items():
    arr      = np.stack(mats, axis=0)           # shape (R,2,2)
    mean_rate = arr.mean(axis=0)
    std_rate  = arr.std(axis=0, ddof=1)
    se_rate   = std_rate / math.sqrt(arr.shape[0])

    # prepare “mean±SE” annotations
    annot = np.empty(mean_rate.shape, dtype=object)
    for i in range(2):
        for j in range(2):
            annot[i,j] = f"{mean_rate[i,j]:.3f}±{se_rate[i,j]:.3f}"

    fig, ax = plt.subplots(figsize=(4,4))
    sns.heatmap(
        mean_rate,
        annot=annot, fmt="",
        cmap='Blues', cbar=True,
        xticklabels=['Correct','Incorrect'],
        yticklabels=['Correct','Incorrect'],
        ax=ax
    )
    # orange border to match optimal style
    for spine in ax.spines.values():
        spine.set_edgecolor('orange')
        spine.set_linewidth(2)

    #ax.set_title(f'Avg. Confusion @ default thr=0.5: {name}')
    ax.set_xlabel('Pred'); ax.set_ylabel('True')
    fig.tight_layout()
    fig.savefig(f'avg_cm_defaultthr_se_{name}.png', dpi=200)
    plt.show()



# --- Generate LaTeX tables for both default and optimal thresholds ---
# helper to format mean±std
def mean_se(a):
    arr = np.array(a)
    return f"{arr.mean():.3f}±{arr.std(ddof=1)/np.sqrt(len(arr)):.3f}"

methods = ['MAP','DE','SMC_P1','SMC_P8','HMC_P1','HMC_P8']

# DEFAULT table
print(r"\begin{table}[H]")
print(r"  \centering")
print(r"  \caption{Performance under the default decision threshold (0.5).}")
print(r"  \label{tab:perf_default_avg}")
print(r"  \begin{tabular}{ll|c|c|c|c}")
print(r"    \toprule")
print(r"    \(P\) & Method & Precision & Recall & F1 & AUC-ROC \\")
print(r"    \midrule")
for name in methods:
    P = name.split('_')[1] if '_' in name else '–'
    lab = label_map[name]
    md = metrics_default[name]
    print(f"    {P} & {lab} & {mean_se(md['prec'])} & {mean_se(md['rec'])} & {mean_se(md['f1'])} & {mean_se(md['auroc'])} \\\\")
print(r"    \bottomrule")
print(r"  \end{tabular}")
print(r"\end{table}")
print()

# OPTIMAL table
print(r"\begin{table}[H]")
print(r"  \centering")
print(r"  \caption{Performance at the optimal \(F_1\) decision threshold.}")
print(r"  \label{tab:perf_optimal_avg}")
print(r"  \begin{tabular}{ll|c|c|c|c}")
print(r"    \toprule")
print(r"    \(P\) & Method & Precision & Recall & F1 & AUC-ROC \\")
print(r"    \midrule")
for name in methods:
    P = name.split('_')[1] if '_' in name else '–'
    lab = label_map[name]
    ms = metrics_storage[name]
    print(f"    {P} & {lab} & {mean_se(ms['prec'])} & {mean_se(ms['rec'])} & {mean_se(ms['f1'])} & {mean_se(ms['auroc'])} \\\\")
print(r"    \bottomrule")
print(r"  \end{tabular}")
print(r"\end{table}")




