import os, random, pickle, json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.func import functional_call
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.io import loadmat
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    precision_score, recall_score, f1_score, roc_auc_score, roc_curve, confusion_matrix
)
from sklearn.preprocessing import StandardScaler
from sentence_transformers import SentenceTransformer
from sklearn.metrics import log_loss
#from decimal import Decimal, getcontext
from matplotlib.legend_handler import HandlerTuple
import pyro
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision import models
from PIL import Image
import tarfile, urllib.request

# --------------------------
# 1) Parameters and device
# --------------------------
device      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
R           = 5
ensemble_size = 10
select_P    = [1, 8]
N_particles = 10
N_id        = 9000
# 10000
# use full 10k ID for meta-train and meta-eval
TRAIN_META  = N_id
# OOD per split
n_close, n_corrupt, n_far = 3000,3000,3000
DATA_PATH   = 'map_de_models.pkl'
OOD_CACHE_TRAIN = 'ood_embeddings_train.pt'
OOD_CACHE_EVAL  = 'ood_embeddings_eval.pt'
eps         = 1e-12
# for BMC file naming
D_DIM       = 20490
BURNIN      = 200
THIN        = 200
M_SMC       = 4
sigma_w = np.sqrt(0.2)   # standard deviation for weights (used in SimpleMLP)
sigma_b = np.sqrt(0.2)   # standard deviation for biases

# --------------------------
# 2) Model definition
# --------------------------
class SimpleMLP(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=0, num_classes=10):
        super(SimpleMLP, self).__init__()
        if hidden_dim is None or hidden_dim == 0:
            # Logistic regression: no hidden layer.
            self.fc = nn.Linear(input_dim, num_classes)
        else:
            # Single-hidden-layer MLP.
            self.fc1 = nn.Linear(input_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, num_classes)
    def forward(self, x):
        if hasattr(self, 'fc'):
            return self.fc(x)
        else:
            x = self.fc1(x)
            x = F.relu(x)
            return self.fc2(x)

# --------------------------
# 3) Utilities: unflatten & predict
# --------------------------
def unflatten(flat, net):
    ptr = 0; params = {}
    for name, p in net.named_parameters():
        num = p.numel()
        params[name] = flat[ptr:ptr+num].view(p.shape)
        ptr += num
    return params

def softmax_np(x):
    e = np.exp(x - x.max(axis=-1, keepdims=True))
    return e / e.sum(axis=-1, keepdims=True)

def predict_particles(x, particles, NetClass):
    net = NetClass().to(device).eval(); outs = []
    for flat in particles:
        params = unflatten(torch.tensor(flat, device=device), net)
        logits = functional_call(net, params, x)
        outs.append(F.softmax(logits, dim=-1).cpu().numpy())
    return np.stack(outs, axis=0)

def flatten_net(net):
    return torch.cat([p.view(-1) for p in net.parameters()])#.detach()

# --------------------------
#      Filtered CIFAR-10 Dataset Class
# --------------------------
class FilteredCIFAR10(Dataset):
    def __init__(self, root, train, transform, download, allowed_labels):
        self.dataset = datasets.CIFAR10(root=root, train=train, transform=transform, download=download)
        self.allowed_labels = allowed_labels
        self.data = [(img, label) for img, label in self.dataset if label in allowed_labels]
    def __getitem__(self, idx):
        return self.data[idx]
    def __len__(self):
        return len(self.data)
    
# --------------------------
#   ResNet-50 Embedding Extraction (In-Domain: Labels 0-9)
# --------------------------
def create_resnet50_embedded_cifar10_dataset(
    train_cache_path="cifar10_train_embeddings.pt",
    test_cache_path="cifar10_test_embeddings.pt",
    allowed_labels=list(range(10))
):
    if os.path.exists(train_cache_path) and os.path.exists(test_cache_path):
        print("Loading cached ResNet-50 embeddings for whole CIFAR-10...")
        X_train, y_train = torch.load(train_cache_path)
        X_test, y_test = torch.load(test_cache_path)
        return X_train, y_train, X_test, y_test

    print("Cached embeddings not found. Computing ResNet-50 embeddings for whole CIFAR-10...")
    transform = transforms.Compose([
         transforms.Resize(224),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
    ])
    train_dataset = FilteredCIFAR10(root='./data', train=True, transform=transform, download=True, allowed_labels=allowed_labels)
    test_dataset  = FilteredCIFAR10(root='./data', train=False, transform=transform, download=True, allowed_labels=allowed_labels)
    N_tr = len(train_dataset)
    N_val = len(test_dataset)
    train_dataset = Subset(train_dataset, list(range(N_tr)))
    test_dataset = Subset(test_dataset, list(range(N_val)))
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=2)
    test_loader  = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
    resnet50 = models.resnet50(pretrained=True)
    # Remove the final fc layer to obtain 2048-dim features.
    feature_extractor = nn.Sequential(*list(resnet50.children())[:-1])
    feature_extractor.eval()
    feature_extractor.to(device)
    X_train_list, y_train_list = [], []
    with torch.no_grad():
        for inputs, targets in train_loader:
            inputs = inputs.to(device)
            features = feature_extractor(inputs)
            features = features.view(features.size(0), -1)
            X_train_list.append(features.cpu())
            y_train_list.append(targets)
    X_train = torch.cat(X_train_list, dim=0)
    y_train = torch.cat(y_train_list, dim=0)
    X_test_list, y_test_list = [], []
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            features = feature_extractor(inputs)
            features = features.view(features.size(0), -1)
            X_test_list.append(features.cpu())
            y_test_list.append(targets)
    X_test = torch.cat(X_test_list, dim=0)
    y_test = torch.cat(y_test_list, dim=0)
    torch.save((X_train, y_train), train_cache_path)
    torch.save((X_test, y_test), test_cache_path)
    print("ResNet-50 embeddings computed and saved.")
    return X_train, y_train, X_test, y_test

# --------------------------
# 4) Extended 6D feature extractor
# --------------------------
def compute_features(probs):
    mean_p = probs.mean(axis=0)
    H_tot = -(mean_p * np.log(mean_p + eps)).sum(axis=1)
    H_each = -(probs * np.log(probs + eps)).sum(axis=2)
    H_epi = H_tot - H_each.mean(axis=0)
    p_max = probs.max(axis=2)
    sorted_p = np.sort(probs, axis=2)
    diff = sorted_p[:, :, -1] - sorted_p[:, :, -2]
    mean_max = p_max.mean(axis=0)
    var_max = p_max.var(axis=0)
    mean_diff = diff.mean(axis=0)
    var_diff = diff.var(axis=0)
    preds = mean_p.argmax(axis=1)
    return H_tot, H_epi, mean_max, mean_diff, var_max, var_diff, preds

# --------------------------
# 5) Meta-labels: OOD or misclassified
# --------------------------
def make_labels(preds, true, is_id):
    return np.where(~is_id, 1, (preds != true).astype(int))

# --------------------------
# 6) Load MAP & DE particles
# --------------------------
if os.path.exists(DATA_PATH):
    data = pickle.load(open(DATA_PATH,'rb'))
    all_map_parts = data['replicate_map_particles']
    all_de_parts  = data['replicate_de_flats']
    raw_de = data['replicate_de_flats']            # this is a flat list of length R*ensemble_size
    # now reshape into a list of length R, each a list of ensemble_size arrays
    all_de_parts = [
        raw_de[i*ensemble_size:(i+1)*ensemble_size]
        for i in range(len(all_map_parts))
    ]
else:
    # Load the full in-domain embeddings (for labels 0-9) and split:
    allowed_labels = list(range(10))
    train_cache = "cifar10_train_embeddings.pt"
    test_cache  = "cifar10_test_embeddings.pt"
    X_all_train, y_all_train, X_all_val, y_all_val = create_resnet50_embedded_cifar10_dataset(
            train_cache_path=train_cache,
            test_cache_path=test_cache,
            allowed_labels=allowed_labels
    )
    # Shuffle and split into 40 000 train / 10 000 early‐stop validation
    n_total = X_all_train.size(0)        # should be 50 000 for CIFAR-10
    print(f"Total CIFAR-10 train embeddings: {n_total}")

    # Split
    n_train_es = 50000
    n_val_es   = n_total - n_train_es    # = 25 000

    X_train_es = X_all_train[:n_train_es]
    y_train_es = y_all_train[:n_train_es]
    X_val_es   = X_all_train[n_train_es:]
    y_val_es   = y_all_train[n_train_es:]

    # Create training and early-stop validation subsets.
    X_train_subset, y_train_subset = X_train_es, y_train_es
    X_val_subset,   y_val_subset   = X_val_es,   y_val_es

    # Create TensorDatasets and DataLoaders.
    train_dataset = TensorDataset(X_train_subset, y_train_subset)
    val_dataset   = TensorDataset(X_val_subset, y_val_subset)
    batch_size = 128
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    #  MAP
    map_particles = []
    for r in range(1, R+1):
        print(f"\n===== Starting replicate {r} =====\n")
        # Set replicate seed.
        torch.manual_seed(r)
        np.random.seed(r)
        random.seed(r)
        pyro.set_rng_seed(r)
        
        ############################################
        # MAP Training (Logistic Regression with SimpleMLP)
        ############################################
        model_mlp = SimpleMLP(input_dim=X_train_subset.shape[1], hidden_dim=0, num_classes=10).to(device)
        optimizer_mlp = optim.Adam(model_mlp.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()

        train_losses = []
        val_losses = []
        moving_avg_window = 10
        best_moving_avg = float('inf')
        patience = 5
        no_improve_count = 0
        max_epochs = 200

        for epoch in range(max_epochs):
            model_mlp.train()
            running_loss = 0.0
            for inputs, labels in train_loader:                   # iterate mini-batches
                #x_batch, y_batch = x_batch.to(device), y_batch.to(device)

                optimizer_mlp.zero_grad()
                outputs = model_mlp(inputs)
                ce_loss = criterion(outputs, labels)

                # MAP regularizer on this batch
                reg_loss = (torch.sum(model_mlp.fc.weight**2) / (2 * sigma_w**2) +
                            torch.sum(model_mlp.fc.bias**2)   / (2 * sigma_b**2))
                loss = ce_loss + reg_loss / len(train_loader.dataset)             # divide reg by total train size
                loss.backward()
                optimizer_mlp.step()
                running_loss += loss.item()
            train_loss = running_loss / len(train_loader)
            print(f"Replicate {r}, MAP Epoch {epoch+1:04d}: Train Loss = {loss.item():.4f}")

            # model_mlp.eval()
            # val_running_loss = 0.0
            # with torch.no_grad():
            #     for inputs, labels in val_loader:
            #         outputs_val = model_mlp(inputs)
            #         ce_loss_val = criterion(outputs_val, labels)
            #         reg_loss_val = (torch.sum(model_mlp.fc.weight**2) / (2 * sigma_w**2) +
            #                         torch.sum(model_mlp.fc.bias**2)   / (2 * sigma_b**2))
            #         val_running_loss = (ce_loss_val + reg_loss_val / len(train_loader.dataset)).item()
            # val_loss = val_running_loss / len(val_loader)
            # val_losses.append(val_loss)

            # print(f"Replicate {r}, MAP Epoch {epoch+1:04d}: Train Loss = {loss.item():.4f}, Val Loss = {val_loss:.4f}")
            # if epoch >= moving_avg_window - 1:
            #     moving_avg = np.mean(val_losses[-moving_avg_window:])
            #     if moving_avg < best_moving_avg:
            #         best_moving_avg = moving_avg
            #         no_improve_count = 0
            #     else:
            #         no_improve_count += 1
            #     if no_improve_count >= patience:
            #         print(f"Replicate {r}, MAP: Early stopping at epoch {epoch+1}")
            #         break
        map_particles.append(flatten_net(model_mlp).cpu().numpy())

    #  DE
    de_flats = []
    for r in range(1, R+1):
        print(f"\n===== Starting replicate {r} =====\n")
        # Set replicate seed.
        torch.manual_seed(r)
        np.random.seed(r)
        random.seed(r)
        pyro.set_rng_seed(r)
        
        # (The training and validation subsets remain fixed.)
        
        # Deep Ensembles: train ensemble_size models for replicate r.
        ensemble_models = []   # to store ensemble members
        ensemble_epochs = []   # to store number of epochs trained per member
        replicate_de_time = []
        
        max_epochs = 200
        moving_avg_window = 10
        patience = 5
        #best_val_loss = float('inf')
        
        for m in range(ensemble_size):
            member_seed = r * 1000 + m  
            torch.manual_seed(member_seed)
            np.random.seed(member_seed)
            random.seed(member_seed)
            pyro.set_rng_seed(member_seed)
            
            model_de = SimpleMLP(input_dim=X_train_subset.shape[1], hidden_dim=0, num_classes=10).to(device)
            optimizer_de = optim.Adam(model_de.parameters(), lr=0.001)
            criterion_de = nn.CrossEntropyLoss()
            
            train_losses_de = []
            val_losses_de = []
            best_moving_avg_de = float('inf')
            no_improve_count_de = 0
            
            for epoch in range(max_epochs):
                model_de.train()
                running_loss = 0.0
                for features, labels in train_loader:
                    features = features.to(device)
                    labels = labels.to(device)
                    optimizer_de.zero_grad()
                    outputs = model_de(features)
                    ce_loss = criterion_de(outputs, labels)
                    # Regularization using the Gaussian prior.
                    reg_loss = (torch.sum(model_de.fc.weight**2) / (2 * sigma_w**2) +
                                torch.sum(model_de.fc.bias**2)   / (2 * sigma_b**2))
                    loss = ce_loss + reg_loss / len(train_loader.dataset)
                    loss.backward()
                    optimizer_de.step()
                    running_loss += loss.item()
                train_loss = running_loss / len(train_loader)
                train_losses_de.append(train_loss)
                print(f"Replicate {r}, DE Model (seed={member_seed}) Epoch {epoch+1}: Train Loss = {train_loss:.4f}")
                
                # Evaluate on validation set.
                # model_de.eval()
                # val_running_loss = 0.0
                # with torch.no_grad():
                #     for features, labels in val_loader:
                #         features = features.to(device)
                #         labels = labels.to(device)
                #         outputs = model_de(features)
                #         ce_loss = criterion_de(outputs, labels)
                #         reg_loss = (torch.sum(model_de.fc.weight**2) / (2 * sigma_w**2) +
                #                     torch.sum(model_de.fc.bias**2)   / (2 * sigma_b**2))
                #         loss = ce_loss + reg_loss / len(train_loader.dataset)
                #         val_running_loss += loss.item()
                # val_loss = val_running_loss / len(val_loader)
                # val_losses_de.append(val_loss)
                
                # print(f"Replicate {r}, DE Model (seed={member_seed}) Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
                
                # if epoch >= moving_avg_window - 1:
                #     moving_avg = np.mean(val_losses_de[-moving_avg_window:])
                #     if moving_avg < best_moving_avg_de:
                #         best_moving_avg_de = moving_avg
                #         no_improve_count_de = 0
                #     else:
                #         no_improve_count_de += 1
                #     if no_improve_count_de >= patience:
                #         print(f"Replicate {r}, DE Model (seed={member_seed}) Early stopping at epoch {epoch+1}")
                #         break
            de_flats.append(flatten_net(model_de).cpu().numpy()) # .detach()?

    #  save and return
    pickle.dump({
        'replicate_map_particles': map_particles,
        'replicate_de_flats'     : de_flats
    }, open(DATA_PATH,'wb'))

# --------------------------
# 7) Load embeddings: ID and two OOD splits
# --------------------------
# Load full ID dataset
X_full, y_full = torch.load('cifar10_train_embeddings.pt' if False else 'cifar10_test_embeddings.pt', map_location='cpu')
# separate train vs test ID
X_id_full = X_full[:N_id].to(device)
y_id_full = y_full[:N_id].numpy().flatten()
# meta-train uses original training ID: assume 'cifar_embeddings_trainBig.pt'
X_train_full, y_train_full = torch.load('cifar10_train_embeddings.pt', map_location='cpu')
X_id_train = X_train_full[:N_id].to(device)
y_id_train = y_train_full[:N_id].numpy().flatten()
# meta-eval uses test ID
X_id_eval = torch.load('cifar10_test_embeddings.pt', map_location='cpu')[0][:N_id].to(device)
y_id_eval = torch.load('cifar10_test_embeddings.pt', map_location='cpu')[1][:N_id].numpy().flatten()

# --------------------------
# Helpers for loading/creating OOD SBERT embeddings
# --------------------------
# --- 1) CIFAR-100 “close” OOD: classes not in CIFAR-10 ---
def sample_cifar100_not_in_cifar10(seed, n1, root='./data'):
    # Download both datasets
    cifar10 = datasets.CIFAR10(root=root, train=False, download=True)
    cifar100 = datasets.CIFAR100(root=root, train=False, download=True)
    
    # Which CIFAR-100 fine-class names are NOT in CIFAR-10?
    cif10_set = set(cifar10.classes)
    cif100_names = cifar100.classes  # list of 100 names
    ood_names = [name for name in cif100_names if name not in cif10_set]
    
    # Filter indices whose label name is in our OOD set
    idxs = [
        idx for idx, label in enumerate(cifar100.targets)
        if cif100_names[label] in ood_names
    ]
    random.seed(seed)
    chosen = random.sample(idxs, n1)
    
    # Return as PIL Images
    return [Image.fromarray(cifar100.data[i]) for i in chosen]

# --- 2) CIFAR-10-C “corrupt” OOD: all corruption types/severities ---
def sample_cifar10c(seed, n2, root='./data'):
    url = 'https://zenodo.org/record/2535967/files/CIFAR-10-C.tar'
    tar_path = os.path.join(root, 'CIFAR-10-C.tar')
    extract_dir = os.path.join(root, 'CIFAR-10-C')
    
    # Download if needed
    if not os.path.exists(tar_path):
        os.makedirs(root, exist_ok=True)
        print("Downloading CIFAR-10-C (≈180 MB)...")
        urllib.request.urlretrieve(url, tar_path)
    
    # Extract if needed
    if not os.path.isdir(extract_dir):
        print("Extracting CIFAR-10-C...")
        with tarfile.open(tar_path) as tar:
            tar.extractall(path=root)
    
    # Load every corruption .npy (skip the labels file)
    all_imgs = []
    for fname in os.listdir(extract_dir):
        if fname.endswith('.npy') and 'labels' not in fname:
            arr = np.load(os.path.join(extract_dir, fname))
            # arr shape: [10000, 32, 32, 3]
            all_imgs.append(arr)
    # Concatenate: shape [#types × 10000, 32,32,3]
    all_imgs = np.vstack(all_imgs)
    
    # Sample and convert
    random.seed(seed)
    chosen = random.sample(range(len(all_imgs)), n2)
    return [Image.fromarray(all_imgs[i].astype(np.uint8)) for i in chosen]

# --- 3) SVHN “far” OOD: street-view digits ---
def sample_svhn(seed, n3, root='./data'):
    # transform to PIL so dataset[i][0] is an Image
    #transform = transforms.ToPILImage()
    #svhn = datasets.SVHN(root=root, split='test', download=True, transform=transform)
    svhn = datasets.SVHN(root=root, split='test', download=True, transform=None)
    
    random.seed(seed)
    chosen = random.sample(range(len(svhn)), n3)
    return [svhn[i][0] for i in chosen]

def compute_cifar_ood_embeddings(device, seed, n_close, n_corrupt, n_far):
    """
    Samples three OOD splits (close, corrupt, far), embeds them via ResNet-50,
    and returns a dict of tensors.
    """
    # 1) Sample PIL images
    close_imgs   = sample_cifar100_not_in_cifar10(seed, n_close)
    corrupt_imgs = sample_cifar10c(seed, n_corrupt)
    far_imgs     = sample_svhn(seed, n_far)

    # 2) Preprocess + extract features
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    ])
    resnet50 = models.resnet50(pretrained=True)
    feat_extractor = nn.Sequential(*list(resnet50.children())[:-1]).to(device).eval()

    def imgs_to_embs(img_list):
        tensors = torch.stack([transform(img) for img in img_list])
        loader  = DataLoader(TensorDataset(tensors), batch_size=128, shuffle=False)
        feats   = []
        with torch.no_grad():
            for (batch,) in loader:
                batch = batch.to(device)
                out   = feat_extractor(batch).view(batch.size(0), -1)
                feats.append(out.cpu())
        return torch.cat(feats, dim=0)

    return {
        'ood_close':   imgs_to_embs(close_imgs),
        'ood_corrupt': imgs_to_embs(corrupt_imgs),
        'ood_far':     imgs_to_embs(far_imgs),
    }

# Function to get or cache OOD embeddings
def get_cifar_ood(cache_path, device, seed, n_close, n_corrupt, n_far):
    if os.path.exists(cache_path):
        return torch.load(cache_path, map_location='cpu')
    od = compute_cifar_ood_embeddings(device, seed, n_close, n_corrupt, n_far)
    torch.save(od, cache_path)
    return od


# meta-train OOD (first N_ood starting at 0)
seed1 = 42
seed2 = 2
ood_train = get_cifar_ood(OOD_CACHE_TRAIN, device, seed1, n_close, n_corrupt, n_far)
# meta-eval OOD (next N_ood starting at N_ood
ood_eval = get_cifar_ood(OOD_CACHE_EVAL, device, seed2, n_close, n_corrupt, n_far)

# combine splits
# stack all the OOD blocks into one big tensor for train:
X_ood_train = torch.cat([ood_train[k] for k in ood_train], dim=0)
y_ood_train = np.full(len(X_ood_train), -1, dtype=int)
is_id_train = np.zeros(len(X_ood_train), dtype=bool)

X_meta = torch.cat([X_id_train, X_ood_train], dim=0)
y_meta = np.concatenate([y_id_train, y_ood_train])
is_id_meta = np.concatenate([np.ones(len(y_id_train), bool), is_id_train])

X_ood_eval = torch.cat([ood_eval[k] for k in ood_eval], dim=0)
y_ood_eval = np.full(len(X_ood_eval), -1, dtype=int)
is_id_eval = np.zeros(len(X_ood_eval), dtype=bool)

X_eval = torch.cat([X_id_eval, X_ood_eval], dim=0)
y_eval = np.concatenate([y_id_eval, y_ood_eval])
is_id_eval = np.concatenate([np.ones(len(y_id_eval), bool), is_id_eval])

# --- pre-allocate the storage metrics-----------------------
# after loading all_map_parts and all_de_flats
methods_dict = ['MAP', 'DE'] + [f'SMC_P{P}' for P in select_P] + [f'HMC_P{P}' for P in select_P]

# for each method, we’ll track lists of auc, f1, precision, recall
metrics = {
    name: {'auc': [], 'f1': [], 'precision': [], 'recall': []}
    for name in methods_dict
}
# Containers for fixed‐0.5 threshold metrics
default_metrics   = { name: {'precision':[], 'recall':[], 'f1':[], 'auc':[]} for name in methods_dict }

# to store ROC curves per replicate:
roc_curves = { name: [] for name in methods_dict }

thr_grid = np.linspace(0, 1, 201)   # same as when you tuned F1
acc_curves = { name: [] for name in metrics.keys() }

conf_rates = { name: [] for name in methods_dict }
default_conf_rates = { name: [] for name in methods_dict }

# --------------------------
# 8) Compute features & labels per method
# --------------------------

for r in range(R):
    map_parts = [ all_map_parts[r] ]
    de_parts  = list(all_de_parts[r])

    methods = {'MAP': map_parts, 'DE': de_parts}

    # initialize containers for predictions
    preds_train = {}
    preds_eval  = {}

    def load_bmc(pref):
        out = {}
        key = 'psmc_single_x' if pref=='psmc' else 'hmc_single_x'
        for P in select_P:
            arr=[]
            total = P if pref=='psmc' else N_particles * P
            for r in range(R):
                for i in range(total):
                    fn = (f'BayesianNN_CIFAR_MAP_{pref}_SimpleMLP_N{N_particles}_M{M_SMC}_node{r*total+i+1}.mat' if pref=='psmc' 
                          else f'BayesianNN_CIFAR_{pref}_SimpleMLP_MAP_d{D_DIM}_thin{THIN}_burnin{BURNIN}_node{r*total+i+1}.mat')
                    if os.path.exists(fn):
                        mat=loadmat(fn); arr.append(mat[key])
            if arr: out[P]=np.vstack(arr)
        return out
    smc=load_bmc('psmc'); hmc=load_bmc('hmc')
    for P in select_P:
        if P in smc: methods[f'SMC_P{P}']=smc[P]
        if P in hmc: methods[f'HMC_P{P}']=hmc[P]

    feat_train, label_train = {},{}
    feat_eval,  label_eval  = {},{}
    for name, parts in methods.items():
        # training split features and true meta-labels
        probs_t = predict_particles(X_meta, parts, SimpleMLP)
        feats_t = compute_features(probs_t)
        # first 6 dims are meta-features, 7th is underlying preds
        feat_train[name] = np.stack(feats_t[:6], axis=1)
        label_train[name] = make_labels(feats_t[6], y_meta, is_id_meta)
        preds_train[name] = feats_t[6]

        # evaluation split
        probs_e = predict_particles(X_eval, parts, SimpleMLP)
        feats_e = compute_features(probs_e)
        feat_eval[name] = np.stack(feats_e[:6], axis=1)
        label_eval[name] = make_labels(feats_e[6], y_eval, is_id_eval)
        preds_eval[name] = feats_e[6]

    # --------------------------
    # 8.5) Standardize meta-features per method
    # --------------------------
    scalers = {}
    for name in methods:
        scaler = StandardScaler()
        feat_train[name] = scaler.fit_transform(feat_train[name])
        feat_eval[name]  = scaler.transform(feat_eval[name])
        scalers[name] = scaler

    # --------------------------
    scalers = {}
    for name in methods:
        scaler = StandardScaler()
        # fit on training features and transform both train and eval
        feat_train[name] = scaler.fit_transform(feat_train[name])
        feat_eval[name]  = scaler.transform(feat_eval[name])
        scalers[name] = scaler

    # --------------------------
    # 9) Train meta-classifiers
    # --------------------------
    clfs={}
    for name in methods:
        clf=LogisticRegression(max_iter=1000)
        clf.fit(feat_train[name], label_train[name])
        clfs[name]=clf

    # --------------------------
    # 10) Evaluate: metrics & plots
    # --------------------------
    for name, clf in clfs.items():
        y_true = label_eval[name]
        prob   = clf.predict_proba(feat_eval[name])[:,1]
        # select best threshold by F1
        best_thr, best_f1 = 0, 0
        for thr in np.linspace(0,1,101):
            pr = (prob >= thr).astype(int)
            f = f1_score(y_true, pr)
            if f > best_f1:
                best_f1, best_thr = f, thr
        prec  = precision_score(y_true, (prob>=best_thr), zero_division=0)
        rec   = recall_score(y_true, (prob>=best_thr))
        auc   = roc_auc_score(y_true, prob)
        cm = confusion_matrix(y_true, (prob>=best_thr), labels=[0,1])
        cm_rate = cm.astype(float) / cm.sum()   # normalized over total examples

        # store metrics
        metrics[name]['auc'].append(auc)
        metrics[name]['f1'].append(best_f1)
        metrics[name]['precision'].append(prec)
        metrics[name]['recall'].append(rec)

        conf_rates[name].append(cm_rate)
        # save its ROC curve
        fpr, tpr, _ = roc_curve(y_true, prob)
        roc_curves[name].append((fpr, tpr))

        # Now compute at fixed threshold = 0.5
        preds_def = (prob >= 0.5).astype(int)
        prec_def  = precision_score(y_true, preds_def, zero_division=0)
        rec_def   = recall_score(y_true, preds_def)
        f1_def    = f1_score(y_true, preds_def)
        # AUC is the same as before
        default_metrics[name]['precision'].append(prec_def)
        default_metrics[name]['recall'].append(rec_def)
        default_metrics[name]['f1'].append(f1_def)
        default_metrics[name]['auc'].append(auc)

        cm_def     = confusion_matrix(y_true, preds_def, labels=[0,1])
        cm_rate_def= cm_def.astype(float) / cm_def.sum()
        default_conf_rates[name].append(cm_rate_def)

    
    # --- for each method, compute its accuracy vs thr for this replicate ---
    for name, clf in clfs.items():
        # (a) get per-example P(abstain) from your meta-classifier
        prob_abstain = clf.predict_proba(feat_eval[name])[:,1]  # shape (N_eval,)
        p_keep       = 1.0 - prob_abstain                       # P(respond)

        # (b) true “should‐respond” mask: ID & correct classification
        #     your y_meta was 1=abstain-desired, 0=respond-desired
        y_meta = label_eval[name]
        should_respond = (y_meta == 0)   # True for those you want to answer

        # (c) compute accuracy at each threshold
        acc_r = np.zeros_like(thr_grid)
        for i, thr in enumerate(thr_grid):
            respond = (p_keep >= thr)        # boolean array
            # correct if we respond when we should, or abstain when we should not
            corr_resp = respond & should_respond
            corr_abst = (~respond) & (~should_respond)
            acc_r[i] = (corr_resp | corr_abst).mean()

        # (d) store this replicate’s accuracy curve
        acc_curves[name].append(acc_r)


# print averaged metrics
for name in methods_dict:
    print(f"=== {name} ===")
    for metric in ['auc','f1','precision','recall']:
        arr = np.array(metrics[name][metric])   # shape (R,)
        mean = arr.mean()
        se   = arr.std(ddof=1) / np.sqrt(R)     # standard error
        print(f"  {metric.upper():9s}: {mean:.4f} ± {se:.4f}")
# print averaged default‐threshold metrics
print("\n=== Default‐threshold (thr=0.5) metrics ===")
for name in methods_dict:
    print(f"--- {name} @0.5 ---")
    for metric in ['auc','f1','precision','recall']:
        arr = np.array(default_metrics[name][metric])
        mean = arr.mean()
        se   = arr.std(ddof=1)/np.sqrt(R)
        print(f"  {metric.upper():9s}: {mean:.4f} ± {se:.4f}")


# -----------plots---------
label_map = {
    'MAP':    'MAP',
    'DE':     'DE',
    'SMC_P1': 'SMC$_\parallel$ (P=1)',
    'SMC_P8': 'SMC$_\parallel$ (P=8)',
    'HMC_P1': 'HMC$_\parallel$ (P=1)',
    'HMC_P8': 'HMC$_\parallel$ (P=8)',
}
color_map = {
    'MAP':    'blue',
    'DE':     'orange',
    'SMC_P1': 'red',
    'SMC_P8': 'darkred',
    'HMC_P1': 'green',
    'HMC_P8': 'olive'
}

# confusion matrix
mean_cm = {}  # mean confusion rate
se_cm   = {}  # standard error per cell

for name, mats in conf_rates.items():
    arr = np.stack(mats, axis=0)             # shape (R,2,2)
    mean = arr.mean(axis=0)                  # (2,2)
    se   = arr.std(axis=0, ddof=1) / np.sqrt(R)  # (2,2)
    mean_cm[name] = mean
    se_cm[name]   = se

annots = {}
for name in methods_dict:
    m = mean_cm[name]
    s = se_cm[name]
    annot = np.empty(m.shape, dtype=object)
    for i in (0,1):
        for j in (0,1):
            annot[i,j] = f"{m[i,j]:.3f}±{s[i,j]:.3f}"
    annots[name] = annot

for name in methods_dict:
    fig, ax = plt.subplots(figsize=(4,4))
    sns.heatmap(
        mean_cm[name],
        annot=annots[name],
        fmt="",
        cmap="Oranges",          # or another palette
        cbar=True,
        xticklabels=['Correct','Incorrect'],
        yticklabels=['Correct','Incorrect'],
        ax=ax
    )
    # draw orange border around entire matrix
    for spine in ax.spines.values():
        spine.set_edgecolor('orange')
        spine.set_linewidth(2)

    ax.set_xlabel('Predicted')
    ax.set_ylabel('True')
    #ax.set_title(f'{name}\nMean Confusion Rate ± s.e.')
    fig.tight_layout()
    # save to disk:
    fig.savefig(f'cifar_avg_cm_rates_se_{name}.png', dpi=300, bbox_inches='tight')
    plt.show()
    #plt.close(fig)

# ————————————————
# Plot confusion matrices at thr=0.5
for name in methods_dict:
    arr = np.stack(default_conf_rates[name], axis=0)
    mean_def = arr.mean(axis=0)
    se_def   = arr.std(axis=0, ddof=1) / np.sqrt(R)
    # create annotations
    ann_def = np.empty(mean_def.shape, dtype=object)
    for i in (0,1):
        for j in (0,1):
            ann_def[i,j] = f"{mean_def[i,j]:.3f}±{se_def[i,j]:.3f}"

    fig, ax = plt.subplots(figsize=(4,4))
    sns.heatmap(
        mean_def,
        annot=ann_def,
        fmt="",
        cmap="Blues",
        xticklabels=['Correct','Incorrect'],
        yticklabels=['Correct','Incorrect'],
        ax=ax
    )
    #ax.set_title(f"{name} (thr=0.5)")
    ax.set_xlabel('Predicted')
    ax.set_ylabel('True')
    fig.tight_layout()
    fig.savefig(f'cifar_avg_cm_rates_se_default_{name}.png', dpi=300, bbox_inches='tight')
    plt.show()


# ROC plot
fpr_grid = np.linspace(0,1,200)
plt.figure(figsize=(8,6))

handles = []
labels  = []
for name in methods_dict:
    # 4.1) build an array (R × len(fpr_grid)) of interpolated TPRs
    tprs = []
    for fpr, tpr in roc_curves[name]:
        tprs.append(np.interp(fpr_grid, fpr, tpr))
    arr = np.vstack(tprs)

    mean_tpr = arr.mean(axis=0)
    se_tpr   = arr.std(axis=0, ddof=1) / np.sqrt(R)

    # 4.2) get the mean AUC for legend
    mean_auc = np.mean(metrics[name]['auc'])

    # 4.3) plot line + shaded error band
    line, =plt.plot(fpr_grid, mean_tpr, color=color_map[name],label=None)
    band = plt.fill_between(
        fpr_grid,
        np.maximum(mean_tpr - se_tpr, 0),
        np.minimum(mean_tpr + se_tpr, 1),
        facecolor=line.get_color(),
        alpha=0.2,
        label=None
    )
    # one handle for the tuple (line, band)
    handles.append((line, band))
    labels.append(label_map[name])

diag, = plt.plot([0,1],[0,1],'k--', lw=1)
handles.append(diag)
labels.append('Chance')

plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
#plt.title('Mean Meta-ROC across replicates')
plt.grid(True)
plt.tight_layout()
# draw legend with HandlerTuple for the (line,band) tuples
plt.legend(handles, labels, handler_map={tuple: HandlerTuple(ndivide=None)}, loc='lower right')
plt.savefig('cifar_mean_meta_ROC_all_methods.png',dpi=300,bbox_inches='tight',transparent=False)
plt.show()


# --------------------------
# 11) Two-level estimator accuracy vs threshold (using p_keep = 1 - p_meta_abstain)
# --------------------------
plt.figure(figsize=(8,6))
acc_handles = []
acc_labels  = []
for name, curves in acc_curves.items():
    # stack into shape (R, 101)
    arr     = np.vstack(curves)
    mean_acc = arr.mean(axis=0)                      # mean over replicates
    se_acc   = arr.std(axis=0, ddof=1) / np.sqrt(R)   # standard error

    # 3b) plot mean curve
    line, = plt.plot(thr_grid, mean_acc, color=color_map[name], label=None)
    # 3c) shade ± s.e.
    band = plt.fill_between(
        thr_grid,
        np.maximum(mean_acc - se_acc, 0),
        np.minimum(mean_acc + se_acc, 1),
        facecolor=line.get_color(),
        alpha=0.2,
        label=None
    )
    # one handle for the tuple (line, band)
    acc_handles.append((line, band))
    acc_labels.append(label_map[name])

plt.xlabel('Threshold on $p_{keep}$ (respond)')
plt.ylabel('Two-level Accuracy')
#plt.title('Mean Accuracy vs. Threshold (± s.e.)')
plt.grid(True)
plt.tight_layout()
# draw legend: each tuple (curve+shade) as one entry
plt.legend(
    acc_handles,
    acc_labels,
    handler_map={tuple: HandlerTuple(ndivide=None)},
    loc='lower right'
)
plt.savefig('cifar_mean_accuracy_vs_pcorrect.png',dpi=300,bbox_inches='tight',transparent=False)
plt.show()




# # 14) Base-model metrics on ID and OOD (cleaned)
# # --------------------------
# def compute_ece(probs, labels, n_bins=10):
#     """Compute Expected Calibration Error (ECE)."""
#     bins = np.linspace(0.0, 1.0, n_bins+1)
#     ece = 0.0
#     for i in range(n_bins):
#         mask = (probs >= bins[i]) & (probs < bins[i+1])
#         if np.any(mask):
#             acc_bin = (labels[mask] == (probs[mask] >= 0.5)).mean()
#             conf_bin = probs[mask].mean()
#             ece += np.abs(acc_bin - conf_bin) * mask.mean()
#     return ece
# # getcontext().prec = 8

# # Prepare a dict of lists to collect each replicate’s metrics
# base_metrics_reps = {
#     name: {
#       'ID_acc':   [],
#       'ID_nll':   [],
#       'ID_brier': [],
#       'ID_ece':   [],
#       'H_tot_ood':[],
#       'H_epi_ood':[]
#     }
#     for name in methods_dict
# }

# for r in range(R):
#     # 3a) pick this replicate’s particles
#     map_parts = [ all_map_parts[r] ]
#     de_parts  = list(all_de_parts[r])
    
#     methods = {'MAP': map_parts, 'DE': de_parts}

#     def load_bmc(pref):
#         out = {}
#         key = 'psmc_single_x' if pref=='psmc' else 'hmc_particles'
#         for P in select_P:
#             arr=[]
#             total = P if pref=='psmc' else N_particles * P
#             for r in range(R):
#                 for i in range(total):
#                     fn = (f'BayesianNN_CIFAR_{pref}_SimpleMLP_MAP_d{D_DIM}_'
#                         + (f'N{N_particles}_M{M_SMC}' if pref=='psmc' else 'N1_burnin'+str(BURNIN))
#                         + f'_node{r*total+i+1}.mat')
#                     if os.path.exists(fn):
#                         mat=loadmat(fn); arr.append(mat[key])
#             if arr: out[P]=np.vstack(arr)
#         return out
#     smc=load_bmc('psmc'); hmc=load_bmc('hmc')
#     for P in select_P:
#         if P in smc: methods[f'SMC_P{P}']=smc[P]
#         if P in hmc: methods[f'HMC_P{P}']=hmc[P]

#     # 3b) compute and append each metric
#     for name, parts in methods.items():
#         # — copy your existing code **verbatim** up through acc_id, nll_id, etc. —
#         probs_id     = predict_particles(X_id_eval, parts, SimpleMLP)
#         H_tot_id, H_epi_id, _, _, _, _, preds_id = compute_features(probs_id)
#         mean_id      = probs_id.mean(axis=0)
#         y_true_id    = y_id_eval
#         acc_id       = (preds_id == y_true_id).mean()
#         nll_id       = log_loss(y_true_id, mean_id, labels=[0,1])
#         brier_id     = np.mean((mean_id[:,1] - y_true_id)**2)
#         ece_id       = compute_ece(mean_id[:,1], y_true_id, n_bins=10)

#         probs_ood    = predict_particles(ood_eval, parts, SimpleMLP)
#         H_tot_ood, H_epi_ood, *_ = compute_features(probs_ood)

#         # 3c) now **append** to our storage lists
#         reps = base_metrics_reps[name]
#         reps['ID_acc'].append(acc_id)
#         reps['ID_nll'].append(nll_id)
#         reps['ID_brier'].append(brier_id)
#         reps['ID_ece'].append(ece_id)
#         reps['H_tot_ood'].append(float(H_tot_ood.mean()))
#         reps['H_epi_ood'].append(float(H_epi_ood.mean()))


# print("\n=== Base-model metrics over replicates ===")
# for name, stats in base_metrics_reps.items():
#     print(f"\n>> {name}:")
#     for metric, values in stats.items():
#         arr  = np.array(values, dtype=float)            # shape (R,)
#         mean = arr.mean()
#         se   = arr.std(ddof=1) / np.sqrt(R)             # standard error
#         print(f"   {metric:10s}: {mean:.4f} ± {se:.4f}")

# --- Generate LaTeX tables for both default and optimal thresholds ---
# helper to format mean±std
# def mean_se(a):
#     arr = np.array(a)
#     return f"{arr.mean():.3f}±{arr.std(ddof=1)/np.sqrt(len(arr)):.3f}"

