import os, random, pickle, json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.func import functional_call
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.io import loadmat
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    precision_score, recall_score, f1_score, roc_auc_score, roc_curve, confusion_matrix
)
from sklearn.preprocessing import StandardScaler
from sentence_transformers import SentenceTransformer
from sklearn.metrics import log_loss
#from decimal import Decimal, getcontext
from matplotlib.legend_handler import HandlerTuple

# --------------------------
# 1) Parameters and device
# --------------------------
device      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
R           = 5
ensemble_sz = 10
select_P    = [1, 8]
N_particles = 10
N_id        = 25000
# 10000
# use full 10k ID for meta-train and meta-eval
TRAIN_META  = N_id
# OOD per split
N_ood       = 25000
DATA_PATH   = 'map_de_models.pkl'
OOD_CACHE_TRAIN = 'ood_embeddings_train2.pt'
OOD_CACHE_EVAL  = 'ood_embeddings_eval2.pt'
eps         = 1e-12
# for BMC file naming
D_DIM       = 1538
BURNIN      = 25
M_SMC       = 1

# --------------------------
# 2) Model definition
# --------------------------
class SimpleMLP(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=0, num_classes=2):
        super().__init__()
        if hidden_dim:
            self.fc1 = nn.Linear(input_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, num_classes)
        else:
            self.fc = nn.Linear(input_dim, num_classes)
    def forward(self, x):
        if hasattr(self, 'fc'):
            return self.fc(x)
        return self.fc2(F.relu(self.fc1(x)))

# --------------------------
# 3) Utilities: unflatten & predict
# --------------------------
def unflatten(flat, net):
    ptr = 0; params = {}
    for name, p in net.named_parameters():
        num = p.numel()
        params[name] = flat[ptr:ptr+num].view(p.shape)
        ptr += num
    return params

def softmax_np(x):
    e = np.exp(x - x.max(axis=-1, keepdims=True))
    return e / e.sum(axis=-1, keepdims=True)

def predict_particles(x, particles, NetClass):
    net = NetClass().to(device).eval(); outs = []
    for flat in particles:
        params = unflatten(torch.tensor(flat, device=device), net)
        logits = functional_call(net, params, x)
        outs.append(F.softmax(logits, dim=-1).cpu().numpy())
    return np.stack(outs, axis=0)

# --------------------------
# 4) Extended 6D feature extractor
# --------------------------
def compute_features(probs):
    mean_p = probs.mean(axis=0)
    H_tot = -(mean_p * np.log(mean_p + eps)).sum(axis=1)
    H_each = -(probs * np.log(probs + eps)).sum(axis=2)
    H_epi = H_tot - H_each.mean(axis=0)
    p_max = probs.max(axis=2)
    sorted_p = np.sort(probs, axis=2)
    diff = sorted_p[:, :, -1] - sorted_p[:, :, -2]
    mean_max = p_max.mean(axis=0)
    var_max = p_max.var(axis=0)
    mean_diff = diff.mean(axis=0)
    var_diff = diff.var(axis=0)
    preds = mean_p.argmax(axis=1)
    return H_tot, H_epi, mean_max, mean_diff, var_max, var_diff, preds

# --------------------------
# 5) Meta-labels: OOD or misclassified
# --------------------------
def make_labels(preds, true, is_id):
    return np.where(~is_id, 1, (preds != true).astype(int))

# --------------------------
# 6) Load MAP & DE particles
# --------------------------
with open(DATA_PATH, 'rb') as f:
    data = pickle.load(f)
all_map_parts = data['replicate_map_particles']
all_de_parts = data['replicate_de_flats']

# --------------------------
# 7) Load embeddings: ID and two OOD splits
# --------------------------
# Load full ID dataset
X_full, y_full = torch.load('imdb_embeddings_trainBig.pt' if False else 'imdb_embeddings_testBig.pt', map_location='cpu')
# separate train vs test ID
X_id_full = X_full[:N_id].to(device)
y_id_full = y_full[:N_id].numpy().flatten()
# meta-train uses original training ID: assume 'imdb_embeddings_trainBig.pt'
X_train_full, y_train_full = torch.load('imdb_embeddings_trainBig.pt', map_location='cpu')
X_id_train = X_train_full[:N_id].to(device)
y_id_train = y_train_full[:N_id].numpy().flatten()
# meta-eval uses test ID
X_id_eval = torch.load('imdb_embeddings_testBig.pt', map_location='cpu')[0][:N_id].to(device)
y_id_eval = torch.load('imdb_embeddings_testBig.pt', map_location='cpu')[1][:N_id].numpy().flatten()

# --------------------------
# Helpers for loading/creating OOD SBERT embeddings
# --------------------------
def load_jsonl_range(filename, start, n):
    """Load n lines starting from line index 'start' in a JSONL file."""
    data = []
    with open(filename, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            if i < start:
                continue
            if i >= start + n:
                break
            data.append(json.loads(line))
    return data

# Compute OOD embeddings from JSONL sources (reviews, meta, lipsum, full)
def compute_ood_embeddings(device, n_samples=10000, start=0):
    """
    Compute n_samples of OOD embeddings starting at 'start' offset in the JSONL sources.
    """
    n_samples1=int(n_samples/5)
    rev  = load_jsonl_range('data/Appliances.jsonl', start, n_samples1)
    meta = load_jsonl_range('data/meta_Appliances.jsonl', start, n_samples1)
    texts_rev  = [e.get('text','') for e in rev]
    texts_meta = [(e.get('title','') + ' ' + ' '.join(e.get('features',[]))).strip() for e in meta]
    random.seed(42); np.random.seed(42)
    lorem = ('lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod '
             'tempor incididunt ut labore et dolore magna aliqua').split()
    texts_lip = [' '.join(random.choices(lorem, k=random.randint(1,10))) for _ in range(n_samples1)]
    full_rev  = [json.dumps(e, ensure_ascii=False) for e in rev]
    full_meta = [json.dumps(e, ensure_ascii=False) for e in meta]
    sbert = SentenceTransformer('all-mpnet-base-v2'); sbert.eval()
    encode = lambda lst: sbert.encode(lst, convert_to_tensor=True).to(device)
    return {
        'ood_reviews':      encode(texts_rev),
        'ood_meta':         encode(texts_meta),
        'ood_lipsum':       encode(texts_lip),
        'ood_full_reviews': encode(full_rev),
        'ood_full_meta':    encode(full_meta),
    }


# Function to get or cache OOD embeddings
# --------------------------
def get_ood(cache_path, start=0):
    """Load or compute OOD embeddings with offset start."""
    if os.path.exists(cache_path):
        od = torch.load(cache_path, map_location='cpu')
        return torch.cat([v.to(device) for v in od.values()], dim=0)
    od = compute_ood_embeddings(device, n_samples=N_ood, start=start)
    cpu_od = {k: v.cpu() for k,v in od.items()}
    torch.save(cpu_od, cache_path)
    return torch.cat(list(od.values()), dim=0)

# meta-train OOD (first N_ood starting at 0)
ood_train = get_ood(OOD_CACHE_TRAIN, start=0)
# meta-eval OOD (next N_ood starting at N_ood
ood_eval = get_ood(OOD_CACHE_EVAL, start=N_ood)

# combine splits
X_meta = torch.cat([X_id_train, ood_train], dim=0)
y_meta = np.concatenate([y_id_train, np.full(len(ood_train), -1)])
is_id_meta = np.concatenate([np.ones(len(y_id_train), bool), np.zeros(len(ood_train), bool)])

X_eval = torch.cat([X_id_eval, ood_eval], dim=0)
y_eval = np.concatenate([y_id_eval, np.full(len(ood_eval), -1)])
is_id_eval = np.concatenate([np.ones(len(y_id_eval), bool), np.zeros(len(ood_eval), bool)])

# --------------------------
ood_train = get_ood(OOD_CACHE_TRAIN)
# meta-eval OOD (fresh separate)
ood_eval = get_ood(OOD_CACHE_EVAL)

# combine splits
X_meta = torch.cat([X_id_train, ood_train], dim=0)
y_meta = np.concatenate([y_id_train, np.full(len(ood_train), -1)])
is_id_meta = np.concatenate([np.ones(len(y_id_train), bool), np.zeros(len(ood_train), bool)])

X_eval = torch.cat([X_id_eval, ood_eval], dim=0)
y_eval = np.concatenate([y_id_eval, np.full(len(ood_eval), -1)])
is_id_eval = np.concatenate([np.ones(len(y_id_eval), bool), np.zeros(len(ood_eval), bool)])


# --- pre-allocate the storage metrics-----------------------
# after loading all_map_parts and all_de_flats
methods_dict = ['MAP', 'DE'] + [f'SMC_P{P}' for P in select_P] + [f'HMC_P{P}' for P in select_P]

# for each method, we’ll track lists of auc, f1, precision, recall
metrics = {
    name: {'auc': [], 'f1': [], 'precision': [], 'recall': []}
    for name in methods_dict
}

# to store ROC curves per replicate:
roc_curves = { name: [] for name in methods_dict }

thr_grid = np.linspace(0, 1, 101)   # same as when you tuned F1
acc_curves = { name: [] for name in metrics.keys() }

conf_rates = { name: [] for name in methods_dict }

# --------------------------
# 8) Compute features & labels per method
# --------------------------

for r in range(R):
    map_parts = [ all_map_parts[r] ]
    de_parts  = list(all_de_parts[r])

    methods = {'MAP': map_parts, 'DE': de_parts}

    # initialize containers for predictions
    preds_train = {}
    preds_eval  = {}

    def load_bmc(pref):
        out = {}
        key = 'psmc_single_x' if pref=='psmc' else 'hmc_particles'
        for P in select_P:
            arr=[]
            total = P if pref=='psmc' else N_particles * P
            for r in range(R):
                for i in range(total):
                    fn = (f'BayesianNN_IMDB_{pref}_SimpleMLP_MAP_d{D_DIM}_'
                        + (f'N{N_particles}_M{M_SMC}' if pref=='psmc' else 'N1_burnin'+str(BURNIN))
                        + f'_node{r*total+i+1}.mat')
                    if os.path.exists(fn):
                        mat=loadmat(fn); arr.append(mat[key])
            if arr: out[P]=np.vstack(arr)
        return out
    smc=load_bmc('psmc'); hmc=load_bmc('hmc')
    for P in select_P:
        if P in smc: methods[f'SMC_P{P}']=smc[P]
        if P in hmc: methods[f'HMC_P{P}']=hmc[P]

    feat_train, label_train = {},{}
    feat_eval,  label_eval  = {},{}
    for name, parts in methods.items():
        # training split features and true meta-labels
        probs_t = predict_particles(X_meta, parts, SimpleMLP)
        feats_t = compute_features(probs_t)
        # first 6 dims are meta-features, 7th is underlying preds
        feat_train[name] = np.stack(feats_t[:6], axis=1)
        label_train[name] = make_labels(feats_t[6], y_meta, is_id_meta)
        preds_train[name] = feats_t[6]

        # evaluation split
        probs_e = predict_particles(X_eval, parts, SimpleMLP)
        feats_e = compute_features(probs_e)
        feat_eval[name] = np.stack(feats_e[:6], axis=1)
        label_eval[name] = make_labels(feats_e[6], y_eval, is_id_eval)
        preds_eval[name] = feats_e[6]

    # --------------------------
    # 8.5) Standardize meta-features per method
    # --------------------------
    scalers = {}
    for name in methods:
        scaler = StandardScaler()
        feat_train[name] = scaler.fit_transform(feat_train[name])
        feat_eval[name]  = scaler.transform(feat_eval[name])
        scalers[name] = scaler

    # --------------------------
    scalers = {}
    for name in methods:
        scaler = StandardScaler()
        # fit on training features and transform both train and eval
        feat_train[name] = scaler.fit_transform(feat_train[name])
        feat_eval[name]  = scaler.transform(feat_eval[name])
        scalers[name] = scaler

    # --------------------------
    # 9) Train meta-classifiers
    # --------------------------
    clfs={}
    for name in methods:
        clf=LogisticRegression(max_iter=1000)
        clf.fit(feat_train[name], label_train[name])
        clfs[name]=clf

    # --------------------------
    # 10) Evaluate: metrics & plots
    # --------------------------
    for name, clf in clfs.items():
        y_true = label_eval[name]
        prob   = clf.predict_proba(feat_eval[name])[:,1]
        # select best threshold by F1
        best_thr, best_f1 = 0, 0
        for thr in np.linspace(0,1,101):
            pr = (prob >= thr).astype(int)
            f = f1_score(y_true, pr)
            if f > best_f1:
                best_f1, best_thr = f, thr
        prec  = precision_score(y_true, (prob>=best_thr), zero_division=0)
        rec   = recall_score(y_true, (prob>=best_thr))
        auc   = roc_auc_score(y_true, prob)
        cm = confusion_matrix(y_true, (prob>=best_thr), labels=[0,1])
        cm_rate = cm.astype(float) / cm.sum()   # normalized over total examples

        # store metrics
        metrics[name]['auc'].append(auc)
        metrics[name]['f1'].append(best_f1)
        metrics[name]['precision'].append(prec)
        metrics[name]['recall'].append(rec)

        conf_rates[name].append(cm_rate)
        # save its ROC curve
        fpr, tpr, _ = roc_curve(y_true, prob)
        roc_curves[name].append((fpr, tpr))
    
    # --- for each method, compute its accuracy vs thr for this replicate ---
    for name, clf in clfs.items():
        # (a) get per-example P(abstain) from your meta-classifier
        prob_abstain = clf.predict_proba(feat_eval[name])[:,1]  # shape (N_eval,)
        p_keep       = 1.0 - prob_abstain                       # P(respond)

        # (b) true “should‐respond” mask: ID & correct classification
        #     your y_meta was 1=abstain-desired, 0=respond-desired
        y_meta = label_eval[name]
        should_respond = (y_meta == 0)   # True for those you want to answer

        # (c) compute accuracy at each threshold
        acc_r = np.zeros_like(thr_grid)
        for i, thr in enumerate(thr_grid):
            respond = (p_keep >= thr)        # boolean array
            # correct if we respond when we should, or abstain when we should not
            corr_resp = respond & should_respond
            corr_abst = (~respond) & (~should_respond)
            acc_r[i] = (corr_resp | corr_abst).mean()

        # (d) store this replicate’s accuracy curve
        acc_curves[name].append(acc_r)


# print averaged metrics
for name in methods_dict:
    print(f"=== {name} ===")
    for metric in ['auc','f1','precision','recall']:
        arr = np.array(metrics[name][metric])   # shape (R,)
        mean = arr.mean()
        se   = arr.std(ddof=1) / np.sqrt(R)     # standard error
        print(f"  {metric.upper():9s}: {mean:.8f} ± {se:.8f}")

# -----------plots---------
label_map = {
    'MAP':    'MAP',
    'DE':     'DE',
    'SMC_P1': 'SMC$_\parallel$ (P=1)',
    'SMC_P8': 'SMC$_\parallel$ (P=8)',
    'HMC_P1': 'HMC$_\parallel$ (P=1)',
    'HMC_P8': 'HMC$_\parallel$ (P=8)',
}
color_map = {
    'MAP':    'blue',
    'DE':     'orange',
    'SMC_P1': 'red',
    'SMC_P8': 'darkred',
    'HMC_P1': 'green',
    'HMC_P8': 'olive'
}

# confusion matrix
mean_cm = {}  # mean confusion rate
se_cm   = {}  # standard error per cell

for name, mats in conf_rates.items():
    arr = np.stack(mats, axis=0)             # shape (R,2,2)
    mean = arr.mean(axis=0)                  # (2,2)
    se   = arr.std(axis=0, ddof=1) / np.sqrt(R)  # (2,2)
    mean_cm[name] = mean
    se_cm[name]   = se

annots = {}
for name in methods_dict:
    m = mean_cm[name]
    s = se_cm[name]
    annot = np.empty(m.shape, dtype=object)
    for i in (0,1):
        for j in (0,1):
            annot[i,j] = f"{m[i,j]:.3f}±{s[i,j]:.3f}"
    annots[name] = annot

for name in methods_dict:
    fig, ax = plt.subplots(figsize=(4,4))
    sns.heatmap(
        mean_cm[name],
        annot=annots[name],
        fmt="",
        cmap="Oranges",          # or another palette
        cbar=True,
        xticklabels=['Correct','Incorrect'],
        yticklabels=['Correct','Incorrect'],
        ax=ax
    )
    # draw orange border around entire matrix
    for spine in ax.spines.values():
        spine.set_edgecolor('orange')
        spine.set_linewidth(2)

    ax.set_xlabel('Predicted')
    ax.set_ylabel('True')
    #ax.set_title(f'{name}\nMean Confusion Rate ± s.e.')
    fig.tight_layout()
    # save to disk:
    fig.savefig(f'imdb_avg_cm_rates_se_{name}.png', dpi=300, bbox_inches='tight')
    plt.show()
    #plt.close(fig)


# ROC plot
fpr_grid = np.linspace(0,1,200)
plt.figure(figsize=(8,6))

handles = []
labels  = []
for name in methods_dict:
    # 4.1) build an array (R × len(fpr_grid)) of interpolated TPRs
    tprs = []
    for fpr, tpr in roc_curves[name]:
        tprs.append(np.interp(fpr_grid, fpr, tpr))
    arr = np.vstack(tprs)

    mean_tpr = arr.mean(axis=0)
    se_tpr   = arr.std(axis=0, ddof=1) / np.sqrt(R)

    # 4.2) get the mean AUC for legend
    mean_auc = np.mean(metrics[name]['auc'])

    # 4.3) plot line + shaded error band
    line, =plt.plot(fpr_grid, mean_tpr, color=color_map[name],label=None)
    band = plt.fill_between(
        fpr_grid,
        np.maximum(mean_tpr - se_tpr, 0),
        np.minimum(mean_tpr + se_tpr, 1),
        facecolor=line.get_color(),
        alpha=0.2,
        label=None
    )
    # one handle for the tuple (line, band)
    handles.append((line, band))
    labels.append(label_map[name])

diag, = plt.plot([0,1],[0,1],'k--', lw=1)
handles.append(diag)
labels.append('Chance')

plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
#plt.title('Mean Meta-ROC across replicates')
plt.grid(True)
plt.tight_layout()
# draw legend with HandlerTuple for the (line,band) tuples
plt.legend(handles, labels, handler_map={tuple: HandlerTuple(ndivide=None)}, loc='lower right')
plt.savefig('imdb_mean_meta_ROC_all_methods.png',dpi=300,bbox_inches='tight',transparent=False)
plt.show()


# --------------------------
# 11) Two-level estimator accuracy vs threshold (using p_keep = 1 - p_meta_abstain)
# --------------------------
plt.figure(figsize=(8,6))
acc_handles = []
acc_labels  = []
for name, curves in acc_curves.items():
    # stack into shape (R, 101)
    arr     = np.vstack(curves)
    mean_acc = arr.mean(axis=0)                      # mean over replicates
    se_acc   = arr.std(axis=0, ddof=1) / np.sqrt(R)   # standard error

    # 3b) plot mean curve
    line, = plt.plot(thr_grid, mean_acc, color=color_map[name], label=None)
    # 3c) shade ± s.e.
    band = plt.fill_between(
        thr_grid,
        np.maximum(mean_acc - se_acc, 0),
        np.minimum(mean_acc + se_acc, 1),
        facecolor=line.get_color(),
        alpha=0.2,
        label=None
    )
    # one handle for the tuple (line, band)
    acc_handles.append((line, band))
    acc_labels.append(label_map[name])

plt.xlabel('Threshold on $p_{keep}$ (respond)')
plt.ylabel('Two-level Accuracy')
#plt.title('Mean Accuracy vs. Threshold (± s.e.)')
plt.grid(True)
plt.tight_layout()
# draw legend: each tuple (curve+shade) as one entry
plt.legend(
    acc_handles,
    acc_labels,
    handler_map={tuple: HandlerTuple(ndivide=None)},
    loc='lower right'
)
plt.savefig('imdb_mean_accuracy_vs_pcorrect.png',dpi=300,bbox_inches='tight',transparent=False)
plt.show()




# 14) Base-model metrics on ID and OOD (cleaned)
# --------------------------
def compute_ece(probs, labels, n_bins=10):
    """Compute Expected Calibration Error (ECE)."""
    bins = np.linspace(0.0, 1.0, n_bins+1)
    ece = 0.0
    for i in range(n_bins):
        mask = (probs >= bins[i]) & (probs < bins[i+1])
        if np.any(mask):
            acc_bin = (labels[mask] == (probs[mask] >= 0.5)).mean()
            conf_bin = probs[mask].mean()
            ece += np.abs(acc_bin - conf_bin) * mask.mean()
    return ece
# getcontext().prec = 8

# Prepare a dict of lists to collect each replicate’s metrics
base_metrics_reps = {
    name: {
      'ID_acc':   [],
      'ID_nll':   [],
      'ID_brier': [],
      'ID_ece':   [],
      'H_tot_ood':[],
      'H_epi_ood':[]
    }
    for name in methods_dict
}

for r in range(R):
    # 3a) pick this replicate’s particles
    map_parts = [ all_map_parts[r] ]
    de_parts  = list(all_de_parts[r])
    
    methods = {'MAP': map_parts, 'DE': de_parts}

    def load_bmc(pref):
        out = {}
        key = 'psmc_single_x' if pref=='psmc' else 'hmc_particles'
        for P in select_P:
            arr=[]
            total = P if pref=='psmc' else N_particles * P
            for r in range(R):
                for i in range(total):
                    fn = (f'BayesianNN_IMDB_{pref}_SimpleMLP_MAP_d{D_DIM}_'
                        + (f'N{N_particles}_M{M_SMC}' if pref=='psmc' else 'N1_burnin'+str(BURNIN))
                        + f'_node{r*total+i+1}.mat')
                    if os.path.exists(fn):
                        mat=loadmat(fn); arr.append(mat[key])
            if arr: out[P]=np.vstack(arr)
        return out
    smc=load_bmc('psmc'); hmc=load_bmc('hmc')
    for P in select_P:
        if P in smc: methods[f'SMC_P{P}']=smc[P]
        if P in hmc: methods[f'HMC_P{P}']=hmc[P]

    # 3b) compute and append each metric
    for name, parts in methods.items():
        # — copy your existing code **verbatim** up through acc_id, nll_id, etc. —
        probs_id     = predict_particles(X_id_eval, parts, SimpleMLP)
        H_tot_id, H_epi_id, _, _, _, _, preds_id = compute_features(probs_id)
        mean_id      = probs_id.mean(axis=0)
        y_true_id    = y_id_eval
        acc_id       = (preds_id == y_true_id).mean()
        nll_id       = log_loss(y_true_id, mean_id, labels=[0,1])
        brier_id     = np.mean((mean_id[:,1] - y_true_id)**2)
        ece_id       = compute_ece(mean_id[:,1], y_true_id, n_bins=10)

        probs_ood    = predict_particles(ood_eval, parts, SimpleMLP)
        H_tot_ood, H_epi_ood, *_ = compute_features(probs_ood)

        # 3c) now **append** to our storage lists
        reps = base_metrics_reps[name]
        reps['ID_acc'].append(acc_id)
        reps['ID_nll'].append(nll_id)
        reps['ID_brier'].append(brier_id)
        reps['ID_ece'].append(ece_id)
        reps['H_tot_ood'].append(float(H_tot_ood.mean()))
        reps['H_epi_ood'].append(float(H_epi_ood.mean()))


print("\n=== Base-model metrics over replicates ===")
for name, stats in base_metrics_reps.items():
    print(f"\n>> {name}:")
    for metric, values in stats.items():
        arr  = np.array(values, dtype=float)            # shape (R,)
        mean = arr.mean()
        se   = arr.std(ddof=1) / np.sqrt(R)             # standard error
        print(f"   {metric:10s}: {mean:.8f} ± {se:.8f}")

# --- Generate LaTeX tables for both default and optimal thresholds ---
# helper to format mean±std
def mean_se(a):
    arr = np.array(a)
    return f"{arr.mean():.3f}±{arr.std(ddof=1)/np.sqrt(len(arr)):.3f}"

