#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Deep Ensembles (DE) for IMDB with R replicates (starting at 1).
For each replicate, an ensemble of models is trained and evaluated,
then metrics are averaged over replicates with standard errors computed.
In addition, the epoch count is recorded and its average (with SE) computed.
The DE framework is used for both in-domain (ID) and out-of-domain (OOD) evaluation,
as well as per-class (negative vs. positive) uncertainty analysis.
"""

import os, time, random, math, json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sentence_transformers import SentenceTransformer
import pyro
from sklearn.metrics import f1_score, average_precision_score
from scipy.io import savemat
import pandas as pd

pd.set_option('display.float_format', lambda x: f"{x:.8f}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
v = 0.025
sigma_w = np.sqrt(v)
sigma_b = np.sqrt(v)

class SimpleMLP(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=0, num_classes=2):
        super(SimpleMLP, self).__init__()
        if hidden_dim == 0:
            self.fc = nn.Linear(input_dim, num_classes)
            nn.init.normal_(self.fc.weight, mean=0, std=sigma_w)
            nn.init.normal_(self.fc.bias,   mean=0, std=sigma_b)
        else:
            self.fc1 = nn.Linear(input_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, num_classes)
            for l in (self.fc1, self.fc2):
                nn.init.normal_(l.weight, mean=0, std=sigma_w)
                nn.init.normal_(l.bias,   mean=0, std=sigma_b)
    def forward(self, x):
        if hasattr(self, 'fc'):
            return self.fc(x)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

def create_sbert_embedded_imdb_dataset(model_name="all-mpnet-base-v2",
                                       train_cache_path="imdb_embeddings_trainBig.pt",
                                       test_cache_path="imdb_embeddings_testBig.pt"):
    if os.path.exists(train_cache_path) and os.path.exists(test_cache_path):
        print("Loading cached SBERT embeddings...")
        X_train, y_train = torch.load(train_cache_path)
        X_test,  y_test  = torch.load(test_cache_path)
        return X_train, y_train, X_test, y_test

    print("Computing SBERT embeddings...")
    sbert = SentenceTransformer(model_name)
    from torchtext.datasets import IMDB
    train_data = list(IMDB(split="train"))
    test_data  = list(IMDB(split="test"))
    label_map = {"pos":1, "neg":0}

    X_train_list, y_train_list = [], []
    for lbl, txt in train_data:
        X_train_list.append(sbert.encode(txt, convert_to_numpy=True))
        y_train_list.append(label_map[lbl])
    X_test_list, y_test_list = [], []
    for lbl, txt in test_data:
        X_test_list.append(sbert.encode(txt, convert_to_numpy=True))
        y_test_list.append(label_map[lbl])

    X_train = torch.tensor(X_train_list, dtype=torch.float32)
    y_train = torch.tensor(y_train_list, dtype=torch.long)
    X_test  = torch.tensor(X_test_list,  dtype=torch.float32)
    y_test  = torch.tensor(y_test_list,  dtype=torch.long)
    torch.save((X_train, y_train), train_cache_path)
    torch.save((X_test,  y_test),  test_cache_path)
    return X_train, y_train, X_test, y_test

def evaluate_ensemble(ensemble_models, loader):
    all_probs, all_ep, all_labels = [], [], []
    with torch.no_grad():
        for inputs, labels in loader:
            batch_probs = [F.softmax(m(inputs), dim=1) for m in ensemble_models]
            stacked     = torch.stack(batch_probs, dim=0)
            ens_p       = stacked.mean(0)
            ent         = -torch.sum(ens_p * torch.log(ens_p + 1e-12), dim=1)
            epi         = ent - (-torch.sum(stacked * torch.log(stacked + 1e-12), dim=2)).mean(0)
            all_probs.append(ens_p)
            all_ep.append(epi)
            all_labels.append(labels)
    P    = torch.cat(all_probs, dim=0)
    E    = torch.cat(all_ep,    dim=0)
    Y    = torch.cat(all_labels,dim=0)
    preds= P.argmax(dim=1)
    acc  = (preds==Y).float().mean().item()
    f1   = f1_score(Y.cpu(), preds.cpu(), average='macro')
    auc  = average_precision_score(Y.cpu(), P[:,1].cpu())
    tot_ent = (-torch.sum(P * torch.log(P + 1e-12), dim=1)).mean().item()
    nll = -torch.log(P[torch.arange(len(Y)), Y] + 1e-12).sum().item()
    return acc, f1, auc, tot_ent, None, None, E.mean().item(), None, None, nll

def compute_per_class_uncertainties(ensemble_models, loader):
    all_probs, all_ep, all_labels = [], [], []
    with torch.no_grad():
        for inputs, labels in loader:
            batch_probs = [F.softmax(m(inputs), dim=1) for m in ensemble_models]
            stacked     = torch.stack(batch_probs, dim=0)
            ens_p       = stacked.mean(0)
            ent         = -torch.sum(ens_p * torch.log(ens_p + 1e-12), dim=1)
            epi         = ent - (-torch.sum(stacked * torch.log(stacked + 1e-12), dim=2)).mean(0)
            all_probs.append(ens_p)
            all_ep.append(epi)
            all_labels.append(labels)
    P = torch.cat(all_probs, dim=0)
    E = torch.cat(all_ep,    dim=0)
    Y = torch.cat(all_labels,dim=0)
    ent_all = -torch.sum(P * torch.log(P + 1e-12), dim=1)

    results = {}
    for cls in [0, 1]:
        mask = (Y == cls)
        results[f"tot_ent_in_cls{cls}"] = ent_all[mask].mean().item()
        results[f"epi_unc_in_cls{cls}"] = E[mask].mean().item()
    return results

def compute_total_and_epistemic_entropy_ensemble(ensemble_models, x):
    all_probs = []
    with torch.no_grad():
        for m in ensemble_models:
            m.eval()
            out   = m(x)
            probs = F.softmax(out, dim=1)
            all_probs.append(probs)
    stacked   = torch.stack(all_probs, dim=0)
    ens_p     = stacked.mean(0)
    total_ent = -torch.sum(ens_p * torch.log(ens_p + 1e-12), dim=1)
    model_ent = -torch.sum(stacked * torch.log(stacked + 1e-12), dim=2)
    epi_unc   = total_ent - model_ent.mean(0)
    return total_ent.mean().item(), epi_unc.mean().item()

def load_jsonl_first_n(fn, n=100):
    data = []
    with open(fn, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            if i >= n: break
            data.append(json.loads(line))
    return data

if __name__=='__main__':
    R = 5
    ensemble_size = 80

    X_train, y_train, X_test, y_test = create_sbert_embedded_imdb_dataset()
    train_loader = DataLoader(
        TensorDataset(X_train[:20000].to(device), y_train[:20000].to(device)),
        batch_size=64, shuffle=True
    )
    val_loader = DataLoader(
        TensorDataset(X_train[20000:].to(device), y_train[20000:].to(device)),
        batch_size=64, shuffle=False
    )
    test_loader = DataLoader(
        TensorDataset(X_test.to(device), y_test.to(device)),
        batch_size=64, shuffle=False
    )

    rep_epochs = []
    rep_acc    = []; rep_f1    = []; rep_auc    = []
    rep_tot_ent= []; rep_epi_unc= []
    rep_tot_cls0 = []; rep_epi_cls0 = []
    rep_tot_cls1 = []; rep_epi_cls1 = []
    splits = ["Reviews_selected","Meta_selected","Lipsum_generated","Reviews_full","Meta_full"]
    rep_ood_tot = {s:[] for s in splits}
    rep_ood_epi = {s:[] for s in splits}

    # Prepare OOD embeddings
    reviews_file = os.path.join("data","Appliances.jsonl")
    meta_file    = os.path.join("data","meta_Appliances.jsonl")
    rev_data     = load_jsonl_first_n(reviews_file, 100)
    met_data     = load_jsonl_first_n(meta_file,    100)
    ood_reviews_texts      = [d.get("text","") for d in rev_data]
    ood_meta_texts         = [" ".join([d.get("title",""), *d.get("features",[])]) for d in met_data]
    random.seed(42); np.random.seed(42)
    lorem_words = ("lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor "
                   "incididunt ut labore et dolore magna aliqua").split()
    ood_lipsum_texts       = [" ".join(random.choices(lorem_words, k=random.randint(1,10)))
                              for _ in range(100)]
    ood_full_reviews_texts = [json.dumps(d, ensure_ascii=False) for d in rev_data]
    ood_full_meta_texts    = [json.dumps(d, ensure_ascii=False) for d in met_data]

    sbert_ood = SentenceTransformer("all-mpnet-base-v2")
    emb_ood_reviews      = sbert_ood.encode(ood_reviews_texts,      convert_to_tensor=True).to(device)
    emb_ood_meta         = sbert_ood.encode(ood_meta_texts,         convert_to_tensor=True).to(device)
    emb_ood_lipsum       = sbert_ood.encode(ood_lipsum_texts,       convert_to_tensor=True).to(device)
    emb_ood_full_reviews = sbert_ood.encode(ood_full_reviews_texts, convert_to_tensor=True).to(device)
    emb_ood_full_meta    = sbert_ood.encode(ood_full_meta_texts,    convert_to_tensor=True).to(device)

    for r in range(1, R+1):
        print(f"\n===== Replicate {r} of {R} =====")
        torch.manual_seed(r); np.random.seed(r); random.seed(r); pyro.set_rng_seed(r)
        models, epochs_members = [], []

        for m in range(ensemble_size):
            seed = r*1000 + m
            print(f"  Training ensemble member {m+1}/{ensemble_size} (seed={seed})...")
            torch.manual_seed(seed); np.random.seed(seed); random.seed(seed); pyro.set_rng_seed(seed)

            model = SimpleMLP(input_dim=X_train.shape[1]).to(device)
            optimizer = optim.Adam(model.parameters(), lr=1e-3)
            crit = nn.CrossEntropyLoss()
            best_ma, no_imp, hist = float('inf'), 0, []

            for ep in range(1000):
                # training step with regularization
                model.train()
                run_loss = 0.0
                for inp, lab in train_loader:
                    optimizer.zero_grad()
                    out = model(inp)
                    ce = crit(out, lab)
                    reg = (torch.sum(model.fc.weight**2)/(2*sigma_w**2)
                         + torch.sum(model.fc.bias**2)/(2*sigma_b**2))
                    loss = ce + reg/len(train_loader.dataset)
                    loss.backward()
                    optimizer.step()
                    run_loss += loss.item()
                train_loss = run_loss / len(train_loader)

                # validation step with same regularization
                model.eval()
                vloss = 0.0
                with torch.no_grad():
                    for inp, lab in val_loader:
                        out = model(inp)
                        ce = crit(out, lab)
                        reg = (torch.sum(model.fc.weight**2)/(2*sigma_w**2)
                             + torch.sum(model.fc.bias**2)/(2*sigma_b**2))
                        vloss += (ce + reg/len(train_loader.dataset)).item()
                val_loss = vloss / len(val_loader)
                hist.append(val_loss)

                print(f"    Epoch {ep+1}: Train Loss = {train_loss:.8f}, Val Loss = {val_loss:.8f}")

                if ep >= 9:
                    mov = sum(hist[-10:]) / 10
                    if mov < best_ma:
                        best_ma, no_imp = mov, 0
                    else:
                        no_imp += 1
                    if no_imp >= 5:
                        print(f"    Early stopping at epoch {ep+1}")
                        break

            epochs_members.append(ep+1)
            models.append(model)

        avg_epochs = np.mean(epochs_members)
        rep_epochs.append(avg_epochs)
        print(f"  Replicate {r}: Average epochs = {avg_epochs:.2f}")

        # in-domain eval
        acc, f1, auc, tot_ent, _, _, epi, _, _, nll = evaluate_ensemble(models, test_loader)
        rep_acc.append(acc); rep_f1.append(f1); rep_auc.append(auc)
        rep_tot_ent.append(tot_ent); rep_epi_unc.append(epi)
        print(f"  In-Domain → Acc={acc:.4f}, F1={f1:.4f}, AUC-PR={auc:.4f}, TotEnt={tot_ent:.4f}, Epi={epi:.4f}, NLL={nll:.2f}")

        # per-class
        pc = compute_per_class_uncertainties(models, test_loader)
        rep_tot_cls0.append(pc['tot_ent_in_cls0']); rep_epi_cls0.append(pc['epi_unc_in_cls0'])
        rep_tot_cls1.append(pc['tot_ent_in_cls1']); rep_epi_cls1.append(pc['epi_unc_in_cls1'])
        print(f"    Class0 → TotEnt={pc['tot_ent_in_cls0']:.4f}, Epi={pc['epi_unc_in_cls0']:.4f}; "
              f"Class1 → TotEnt={pc['tot_ent_in_cls1']:.4f}, Epi={pc['epi_unc_in_cls1']:.4f}")

        # OOD
        for name, emb in zip(splits,
            [emb_ood_reviews, emb_ood_meta, emb_ood_lipsum, emb_ood_full_reviews, emb_ood_full_meta]):
            tot, epi = compute_total_and_epistemic_entropy_ensemble(models, emb)
            rep_ood_tot[name].append(tot)
            rep_ood_epi[name].append(epi)
            print(f"    OOD {name}: TotEnt={tot:.4f}, Epi={epi:.4f}")

    # aggregate & save
    metrics = {
        'replicate_epochs':             np.array(rep_epochs),
        'replicate_accuracy':           np.array(rep_acc),
        'replicate_f1_in':              np.array(rep_f1),
        'replicate_aucpr_in':           np.array(rep_auc),
        'replicate_total_entropy_in':   np.array(rep_tot_ent),
        'replicate_epistemic_in':       np.array(rep_epi_unc),
        'replicate_total_entropy_in_class0': np.array(rep_tot_cls0),
        'replicate_epistemic_in_class0':     np.array(rep_epi_cls0),
        'replicate_total_entropy_in_class1': np.array(rep_tot_cls1),
        'replicate_epistemic_in_class1':     np.array(rep_epi_cls1),
    }
    for s in splits:
        metrics[f'replicate_total_entropy_od_{s}'] = np.array(rep_ood_tot[s])
        metrics[f'replicate_epistemic_od_{s}']     = np.array(rep_ood_epi[s])

        # ── Compute mean ± SE over replicates ────────────────────────────────────
    def mean_se(lst):
        arr = np.array(lst)
        m   = arr.mean()
        s   = arr.std(ddof=1) / math.sqrt(len(arr))
        return m, s

    # In‐domain
    epochs_mean,     epochs_se     = mean_se(rep_epochs)
    acc_mean,        acc_se        = mean_se(rep_acc)
    f1_mean,         f1_se         = mean_se(rep_f1)
    auc_mean,        auc_se        = mean_se(rep_auc)
    tot_ent_mean,    tot_ent_se    = mean_se(rep_tot_ent)
    epi_ent_mean,    epi_ent_se    = mean_se(rep_epi_unc)

    # Per‐class
    tot0_mean, tot0_se = mean_se(rep_tot_cls0)
    epi0_mean, epi0_se = mean_se(rep_epi_cls0)
    tot1_mean, tot1_se = mean_se(rep_tot_cls1)
    epi1_mean, epi1_se = mean_se(rep_epi_cls1)

    # OOD splits
    ood_summary = {}
    for name in splits:
        t_mean, t_se = mean_se(rep_ood_tot[name])
        e_mean, e_se = mean_se(rep_ood_epi[name])
        ood_summary[name] = (t_mean, t_se, e_mean, e_se)

    # ── Print MAP‐style summary ───────────────────────────────────────────────
    print(f"\n===== Summary over {R} replicates =====")
    print(f"Avg. Ensemble Epochs:                   Mean = {epochs_mean:.2f}, SE = {epochs_se:.8f}")
    print(f"In-Domain Accuracy (%):               Mean = {acc_mean*100:.2f}, SE = {acc_se*100:.8f}")
    print(f"In-Domain F1 Score (macro):           Mean = {f1_mean:.8f}, SE = {f1_se:.8f}")
    print(f"In-Domain AUC-PR (macro):             Mean = {auc_mean:.8f}, SE = {auc_se:.8f}")
    print(f"In-Domain Total Entropy:              Mean = {tot_ent_mean:.8f}, SE = {tot_ent_se:.8f}")
    print(f"In-Domain Epistemic Entropy:          Mean = {epi_ent_mean:.8f}, SE = {epi_ent_se:.8f}")

    print("\nBreakdown for In-Domain Predictions:")
    print(f"  Class 0 – Total Entropy:             Mean = {tot0_mean:.8f}, SE = {tot0_se:.8f}")
    print(f"  Class 0 – Epistemic Entropy:         Mean = {epi0_mean:.8f}, SE = {epi0_se:.8f}")
    print(f"  Class 1 – Total Entropy:             Mean = {tot1_mean:.8f}, SE = {tot1_se:.8f}")
    print(f"  Class 1 – Epistemic Entropy:         Mean = {epi1_mean:.8f}, SE = {epi1_se:.8f}")

    print("\nOOD Results (Total & Epistemic Entropy per split):")
    for name,(tm, ts, em, es) in ood_summary.items():
        print(f"  {name:20s} – Total Entropy: Mean = {tm:.8f}, SE = {ts:.8f}; "
              f"Epistemic: Mean = {em:.8f}, SE = {es:.8f}")
    # ────────────────────────────────────────────────────────────────────────────

    savemat('imdb_de_metrics.mat', metrics)
    print("\nSaved all computed metrics to 'imdb_de_metrics.mat'.")

    
