#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MAP Inference on IMDB with OOD Uncertainty Evaluation

This script performs two main tasks:
  1. IN-DOMAIN: Runs MAP training for IMDB using SBERT embeddings and a SimpleMLP.
     It runs R = 5 replicates (seeds 1,…,5), prints intermediate training/validation losses,
     and computes the following metrics (averaged over samples):
         - Accuracy
         - F1 Score (macro)
         - AUC-PR (using the probability for the positive class)
         - Average Total Entropy
         - Average Entropy for Correct predictions
         - Average Entropy for Incorrect predictions
         - Negative Log-Likelihood (NLL)
  2. OOD EVALUATION: Computes OOD embeddings (from two JSONL files and generated lipsum texts)
     using the same SBERT model (all-mpnet-base-v2) so that the input dimensions match.
     For each replicate the trained in-domain model is applied to these OOD embeddings to compute
     the average predictive entropy (only total entropy is computed for OOD).

Finally, the in-domain and OOD results (averaged over replicates) are saved into a MATLAB .mat file.
"""

import os, time, random, math, json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sentence_transformers import SentenceTransformer
from sklearn.metrics import f1_score, average_precision_score, accuracy_score
from scipy.io import savemat
import pandas as pd

pd.options.display.float_format = '{:.8f}'.format

# Device & prior variance
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
v = 0.025
sigma_w = np.sqrt(v)
sigma_b = np.sqrt(v)

class SimpleMLP(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=0, num_classes=2):
        super(SimpleMLP, self).__init__()
        if hidden_dim == 0 or hidden_dim is None:
            self.fc = nn.Linear(input_dim, num_classes)
            nn.init.normal_(self.fc.weight, 0, sigma_w)
            nn.init.normal_(self.fc.bias,   0, sigma_b)
        else:
            self.fc1 = nn.Linear(input_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, num_classes)
            for l in (self.fc1, self.fc2):
                nn.init.normal_(l.weight, 0, sigma_w)
                nn.init.normal_(l.bias,   0, sigma_b)
    def forward(self, x):
        if hasattr(self, 'fc'):
            return self.fc(x)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

def create_sbert_embedded_imdb_dataset(
    model_name="all-mpnet-base-v2",
    train_cache_path="imdb_embeddings_trainBig.pt",
    test_cache_path="imdb_embeddings_testBig.pt"
):
    if os.path.exists(train_cache_path) and os.path.exists(test_cache_path):
        print("Loading cached SBERT embeddings from disk for IMDB...")
        X_train_full, y_train_full = torch.load(train_cache_path)
        X_test, y_test             = torch.load(test_cache_path)
        return X_train_full, y_train_full, X_test, y_test

    print("Cached embeddings not found. Computing SBERT embeddings for IMDB...")
    sbert = SentenceTransformer(model_name); sbert.eval()
    from torchtext.datasets import IMDB
    train_data = list(IMDB(split="train"))
    test_data  = list(IMDB(split="test"))
    label_map = {"pos":1, "neg":0}

    X_tr, y_tr = [], []
    for lbl, txt in train_data:
        X_tr.append(sbert.encode(txt, convert_to_numpy=True))
        y_tr.append(label_map[lbl])
    X_te, y_te = [], []
    for lbl, txt in test_data:
        X_te.append(sbert.encode(txt, convert_to_numpy=True))
        y_te.append(label_map[lbl])

    X_train_full = torch.tensor(X_tr, dtype=torch.float32)
    y_train_full = torch.tensor(y_tr, dtype=torch.long)
    X_test       = torch.tensor(X_te, dtype=torch.float32)
    y_test       = torch.tensor(y_te, dtype=torch.long)
    torch.save((X_train_full, y_train_full), train_cache_path)
    torch.save((X_test,        y_test),        test_cache_path)
    print("SBERT embeddings for IMDB computed and saved to disk.")
    return X_train_full, y_train_full, X_test, y_test

def evaluate_model(model, loader):
    model.eval()
    all_out, all_lbl = [], []
    with torch.no_grad():
        for x,y in loader:
            o = model(x)
            all_out.append(o)
            all_lbl.append(y)
    O = torch.cat(all_out)
    L = torch.cat(all_lbl)
    P = F.softmax(O, dim=1)
    pred = P.argmax(1)
    acc  = (pred==L).float().mean().item()
    f1v  = f1_score(L.cpu(), pred.cpu(), average="macro")
    aucv = average_precision_score(L.cpu(), P[:,1].cpu())
    ent = -torch.sum(P*torch.log(P+1e-12), dim=1)
    tot_ent  = ent.mean().item()
    corr_ent = ent[pred==L].mean().item() if (pred==L).any() else 0
    inc_ent  = ent[pred!=L].mean().item() if (pred!=L).any() else 0
    nll_val  = F.cross_entropy(O, L, reduction="sum").item()
    return acc, f1v, aucv, tot_ent, corr_ent, inc_ent, nll_val

# ── MISSING HELPER ─────────────────────────────────────────────────────────────
def compute_total_entropy(model, x):
    model.eval()
    with torch.no_grad():
        out = model(x)
        p   = F.softmax(out, dim=1)
        e   = -torch.sum(p * torch.log(p + 1e-12), dim=1)
    return e.mean().item()
# ────────────────────────────────────────────────────────────────────────────────

print("Computing OOD embeddings using SBERT (all-mpnet-base-v2)...")
def load_jsonl_first_n(fn, n=100):
    lst=[]
    with open(fn, encoding="utf-8") as f:
        for i,line in enumerate(f):
            if i>=n: break
            lst.append(json.loads(line))
    return lst

reviews_file = os.path.join("data","Appliances.jsonl")
meta_file    = os.path.join("data","meta_Appliances.jsonl")
rev_data = load_jsonl_first_n(reviews_file, 100)
met_data= load_jsonl_first_n(meta_file,    100)

ood_reviews_texts = [d.get("text","") for d in rev_data]
ood_meta_texts    = [" ".join([d.get("title",""), *d.get("features",[])]) for d in met_data]

random.seed(42); np.random.seed(42)
lorem_words = ("lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor "
               "incididunt ut labore et dolore magna aliqua").split()
ood_lipsum_texts = [" ".join(random.choices(lorem_words, k=random.randint(1,10))) for _ in range(100)]

ood_full_reviews_texts = [json.dumps(d, ensure_ascii=False) for d in rev_data]
ood_full_meta_texts    = [json.dumps(d, ensure_ascii=False) for d in met_data]

sbert_ood = SentenceTransformer("all-mpnet-base-v2")
emb_ood_reviews      = sbert_ood.encode(ood_reviews_texts,      convert_to_tensor=True).to(device)
emb_ood_meta         = sbert_ood.encode(ood_meta_texts,         convert_to_tensor=True).to(device)
emb_ood_lipsum       = sbert_ood.encode(ood_lipsum_texts,       convert_to_tensor=True).to(device)
emb_ood_full_reviews = sbert_ood.encode(ood_full_reviews_texts, convert_to_tensor=True).to(device)
emb_ood_full_meta    = sbert_ood.encode(ood_full_meta_texts,    convert_to_tensor=True).to(device)

if __name__ == '__main__':
    print("==== In-Domain IMDB MAP Training and Evaluation ====")

    # ── SLICE FROM THE FULL TRAIN ONLY ONCE ────────────────────────────────────────
    X_train_full, y_train_full, X_test, y_test = create_sbert_embedded_imdb_dataset(
        train_cache_path="imdb_embeddings_trainBig.pt",
        test_cache_path ="imdb_embeddings_testBig.pt"
    )
    X_train = X_train_full[:20000]
    y_train = y_train_full[:20000]
    X_val   = X_train_full[20000:]
    y_val   = y_train_full[20000:]
    # ────────────────────────────────────────────────────────────────────────────────

    x_train   = X_train.to(device)
    y_train   = y_train.to(device)
    x_val     = X_val.to(device)
    y_val     = y_val.to(device)
    x_test_id = X_test.to(device)
    y_test_id = y_test.to(device)

    print(f"Number of training examples:   {len(y_train)}")
    print(f"Number of validation examples: {len(y_val)}")
    print(f"Number of test examples (ID):  {len(y_test_id)}")

    train_loader = DataLoader(TensorDataset(x_train,   y_train),   batch_size=64, shuffle=True)
    val_loader   = DataLoader(TensorDataset(x_val,     y_val),     batch_size=64, shuffle=False)
    test_loader  = DataLoader(TensorDataset(x_test_id, y_test_id), batch_size=64, shuffle=False)

    # containers
    accuracies = []; f1_scores = []; auc_prs = []
    total_entropies = []; correct_entropies = []; incorrect_entropies = []; nlls = []
    training_times  = []; epoch_counts = []

    ood_reviews_entropies      = []
    ood_meta_entropies         = []
    ood_lipsum_entropies       = []
    ood_full_reviews_entropies = []
    ood_full_meta_entropies    = []

    # ── per-class replicate containers ─────────────────────────────────────────────
    replicate_acc_class0       = []; replicate_acc_class1       = []
    replicate_f1_class0        = []; replicate_f1_class1        = []
    replicate_aucpr_class0     = []; replicate_aucpr_class1     = []
    replicate_tot_ent_class0   = []; replicate_tot_ent_class1   = []
    # ────────────────────────────────────────────────────────────────────────────────

    R = 5
    for seed in range(1, R+1):
        print(f"\nStarting replicate with seed = {seed}")
        torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)

        model     = SimpleMLP(input_dim=x_train.shape[1], hidden_dim=0, num_classes=2).to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()

        moving_avg_window = 10
        best_moving_avg   = float('inf')
        no_improve_count  = 0
        start_time        = time.time()
        val_losses        = []

        for epoch in range(1000):
            model.train()
            running_loss = 0.0
            for inputs, labels in train_loader:
                optimizer.zero_grad()
                outputs = model(inputs)
                ce_loss = criterion(outputs, labels)
                reg_loss = (torch.sum(model.fc.weight**2) / (2*sigma_w**2) +
                            torch.sum(model.fc.bias**2)   / (2*sigma_b**2))
                loss = ce_loss + reg_loss / len(train_loader.dataset)
                loss.backward(); optimizer.step()
                running_loss += loss.item()
            train_loss = running_loss / len(train_loader)

            model.eval()
            val_running_loss = 0.0
            with torch.no_grad():
                for inputs, labels in val_loader:
                    outputs = model(inputs)
                    ce_loss = criterion(outputs, labels)
                    reg_loss = (torch.sum(model.fc.weight**2) / (2*sigma_w**2) +
                                torch.sum(model.fc.bias**2)   / (2*sigma_b**2))
                    val_running_loss += (ce_loss + reg_loss / len(train_loader.dataset)).item()
            val_loss = val_running_loss / len(val_loader)
            val_losses.append(val_loss)
            print(f"Seed {seed} - Epoch {epoch+1}: Train Loss = {train_loss:.8f}, Val Loss = {val_loss:.8f}")

            if epoch >= moving_avg_window-1:
                mov = sum(val_losses[-moving_avg_window:]) / moving_avg_window
                if mov < best_moving_avg:
                    best_moving_avg = mov
                    no_improve_count = 0
                else:
                    no_improve_count += 1
                if no_improve_count >= 5:
                    print(f"Early stopping at epoch {epoch+1} for seed {seed}.")
                    break

        epochs_completed = epoch + 1
        epoch_counts.append(epochs_completed)
        training_times.append(time.time() - start_time)
        print(f"Replicate seed {seed} completed {epochs_completed} epochs.")

        acc, f1v, aucv, tent, cent, icent, nll = evaluate_model(model, test_loader)
        print(f"  Accuracy      = {acc*100:.4f}%")
        print(f"  F1 Score      = {f1v:.8f}")
        print(f"  AUC-PR        = {aucv:.8f}")
        print(f"  Total Entropy = {tent:.8f}")
        print(f"  Correct Ent   = {cent:.8f}")
        print(f"  Incorrect Ent = {icent:.8f}")
        print(f"  NLL           = {nll:.8f}")

        accuracies.append(acc)
        f1_scores.append(f1v)
        auc_prs.append(aucv)
        total_entropies.append(tent)
        correct_entropies.append(cent)
        incorrect_entropies.append(icent)
        nlls.append(nll)

        # ── per-class computation ────────────────────────────────────────────────────
        model.eval()
        probs_list, labels_list = [], []
        with torch.no_grad():
            for inputs, labels in test_loader:
                out = model(inputs)
                p   = F.softmax(out, dim=1)
                probs_list.append(p)
                labels_list.append(labels)
        P = torch.cat(probs_list, dim=0)
        L = torch.cat(labels_list, dim=0)
        preds = P.argmax(dim=1)
        ent_all = -torch.sum(P * torch.log(P + 1e-12), dim=1)

        mask0 = (L == 0)
        acc0  = (preds[mask0]==L[mask0]).float().mean().item()
        f10   = f1_score(L[mask0].cpu(), preds[mask0].cpu(), average='binary', pos_label=0)
        auc0  = average_precision_score(L.cpu(), P[:,0].cpu())
        te0   = ent_all[mask0].mean().item()

        mask1 = (L == 1)
        acc1  = (preds[mask1]==L[mask1]).float().mean().item()
        f11   = f1_score(L[mask1].cpu(), preds[mask1].cpu(), average='binary', pos_label=1)
        auc1  = average_precision_score(L.cpu(), P[:,1].cpu())
        te1   = ent_all[mask1].mean().item()

        replicate_acc_class0.     append(acc0)
        replicate_acc_class1.     append(acc1)
        replicate_f1_class0.      append(f10)
        replicate_f1_class1.      append(f11)
        replicate_aucpr_class0.   append(auc0)
        replicate_aucpr_class1.   append(auc1)
        replicate_tot_ent_class0. append(te0)
        replicate_tot_ent_class1. append(te1)
        # ─────────────────────────────────────────────────────────────────────────────

        # ── OOD (5 splits) ─────────────────────────────────────────────────────────
        ood_reviews_entropies.    append(compute_total_entropy(model, emb_ood_reviews))
        ood_meta_entropies.       append(compute_total_entropy(model, emb_ood_meta))
        ood_lipsum_entropies.     append(compute_total_entropy(model, emb_ood_lipsum))
        ood_full_reviews_entropies.append(compute_total_entropy(model, emb_ood_full_reviews))
        ood_full_meta_entropies.   append(compute_total_entropy(model, emb_ood_full_meta))
        # ─────────────────────────────────────────────────────────────────────────────

    # aggregate & save (unchanged) ...
    def mean_se(lst):
        arr = np.array(lst)
        m   = arr.mean()
        s   = arr.std(ddof=1)/math.sqrt(len(arr))
        return m, s

    mean_acc, se_acc       = mean_se(accuracies)
    mean_f1, se_f1         = mean_se(f1_scores)
    mean_auc, se_auc       = mean_se(auc_prs)
    mean_tent, se_tent     = mean_se(total_entropies)
    mean_cent, se_cent     = mean_se(correct_entropies)
    mean_icent, se_icent   = mean_se(incorrect_entropies)
    mean_nll, se_nll       = mean_se(nlls)
    mean_epochs, se_epochs = mean_se(epoch_counts)
    mean_time, se_time     = mean_se(training_times)

    ood_means = {}
    for name, lst in [
        ("reviews",      ood_reviews_entropies),
        ("meta",         ood_meta_entropies),
        ("lipsum",       ood_lipsum_entropies),
        ("full_reviews", ood_full_reviews_entropies),
        ("full_meta",    ood_full_meta_entropies),
    ]:
        ood_means[name] = mean_se(lst)

    # ── per-class aggregated metrics ─────────────────────────────────────────────
    mean_acc0, se_acc0 = mean_se(replicate_acc_class0)
    mean_acc1, se_acc1 = mean_se(replicate_acc_class1)
    mean_f10, se_f10   = mean_se(replicate_f1_class0)
    mean_f11, se_f11   = mean_se(replicate_f1_class1)
    mean_auc0, se_auc0 = mean_se(replicate_aucpr_class0)
    mean_auc1, se_auc1 = mean_se(replicate_aucpr_class1)
    mean_te0, se_te0   = mean_se(replicate_tot_ent_class0)
    mean_te1, se_te1   = mean_se(replicate_tot_ent_class1)
    # ─────────────────────────────────────────────────────────────────────────────

    all_metrics = {
        "R": R,
        "replicate_epochs":      np.array(epoch_counts),
        "replicate_accuracy":    np.array(accuracies),
        "replicate_f1_score":    np.array(f1_scores),
        "replicate_auc_pr":      np.array(auc_prs),
        "replicate_total_entropy":      np.array(total_entropies),
        "replicate_correct_entropy":    np.array(correct_entropies),
        "replicate_incorrect_entropy":  np.array(incorrect_entropies),
        "replicate_nll":                 np.array(nlls),
        # OOD replicate arrays
        "replicate_total_entropy_od_reviews":      np.array(ood_reviews_entropies),
        "replicate_total_entropy_od_meta":         np.array(ood_meta_entropies),
        "replicate_total_entropy_od_lipsum":       np.array(ood_lipsum_entropies),
        "replicate_total_entropy_od_full_reviews": np.array(ood_full_reviews_entropies),
        "replicate_total_entropy_od_full_meta":    np.array(ood_full_meta_entropies),
        # summary stats
        "accuracy_mean": mean_acc,  "accuracy_se": se_acc,
        "f1_mean":       mean_f1,   "f1_se":       se_f1,
        "auc_pr_mean":   mean_auc,  "auc_pr_se":   se_auc,
        "total_entropy_mean": mean_tent, "total_entropy_se": se_tent,
        "correct_entropy_mean": mean_cent, "correct_entropy_se": se_cent,
        "incorrect_entropy_mean": mean_icent, "incorrect_entropy_se": se_icent,
        "nll_mean": mean_nll, "nll_se": se_nll,
        "epochs_mean": mean_epochs, "epochs_se": se_epochs,
        "time_mean": mean_time, "time_se": se_time,
        # OOD summary
        **{f"ood_{k}_mean": ood_means[k][0] for k in ood_means},
        **{f"ood_{k}_se":   ood_means[k][1] for k in ood_means},
        # per-class arrays can also be added here if desired
    }

        # ── add per‐class summary stats to the saved metrics ───────────────────────
    all_metrics.update({
        # negative class (0)
        'accuracy_class0_mean':       mean_acc0,
        'accuracy_class0_stderr':     se_acc0,
        'f1_score_class0_mean':       mean_f10,
        'f1_score_class0_stderr':     se_f10,
        'aucpr_class0_mean':          mean_auc0,
        'aucpr_class0_stderr':        se_auc0,
        'total_entropy_class0_mean':  mean_te0,
        'total_entropy_class0_stderr':se_te0,
        # positive class (1)
        'accuracy_class1_mean':       mean_acc1,
        'accuracy_class1_stderr':     se_acc1,
        'f1_score_class1_mean':       mean_f11,
        'f1_score_class1_stderr':     se_f11,
        'aucpr_class1_mean':          mean_auc1,
        'aucpr_class1_stderr':        se_auc1,
        'total_entropy_class1_mean':  mean_te1,
        'total_entropy_class1_stderr':se_te1,
    })
    # ────────────────────────────────────────────────────────────────────────────

    savemat("imdb_map_metrics.mat", all_metrics)

    # print overall summary
    print(f"\n===== Summary over {R} replicates =====")
    print(f"MAP Epochs: Mean = {mean_epochs:.2f}, SE = {se_epochs:.8f}")
    print(f"MAP Test Accuracy: Mean = {mean_acc*100:.2f}%, SE = {se_acc*100:.8f}%")
    print(f"In-Domain NLL: Mean = {mean_nll:.8f}, SE = {se_nll:.8f}")
    print(f"In-Domain Total Entropy: Mean = {mean_tent:.8f}, SE = {se_tent:.8f}")
    print(f"In-Domain Correct Entropy: Mean = {mean_cent:.8f}, SE = {se_cent:.8f}")
    print(f"In-Domain Incorrect Entropy: Mean = {mean_icent:.8f}, SE = {se_icent:.8f}")
    print(f"In-Domain F1 Score: Mean = {mean_f1:.8f}, SE = {se_f1:.8f}")
    print(f"In-Domain AUC-PR: Mean = {mean_auc:.8f}, SE = {se_auc:.8f}")
    print("\nOOD Total Entropies:")
    for name,(m,s) in ood_means.items():
        print(f"  {name:12s}: Mean = {m:.8f}, SE = {s:.8f}")

    # print per-class summary
    print(f"\nPer-class In-Domain Results over {R} replicates:")
    print(f"  Class 0 (Negative) -  Acc: Mean = {mean_acc0:.4f}, SE = {se_acc0:.8f}; "
          f"F1: Mean = {mean_f10:.4f}, SE = {se_f10:.8f}; "
          f"AUC-PR: Mean = {mean_auc0:.4f}, SE = {se_auc0:.8f}; "
          f"Total Entropy: Mean = {mean_te0:.8f}, SE = {se_te0:.8f}")
    print(f"  Class 1 (Positive) -  Acc: Mean = {mean_acc1:.4f}, SE = {se_acc1:.8f}; "
          f"F1: Mean = {mean_f11:.4f}, SE = {se_f11:.8f}; "
          f"AUC-PR: Mean = {mean_auc1:.4f}, SE = {se_auc1:.8f}; "
          f"Total Entropy: Mean = {mean_te1:.8f}, SE = {se_te1:.8f}")
