#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Deep Ensembles (DE) for IMDB with R replicates (starting at 1).
For each replicate, an ensemble of models is trained and evaluated,
then metrics are averaged over replicates with standard errors computed.
In addition, the epoch count is recorded and its average (with SE) computed.
The DE framework is used for both in-domain (ID) and out-of-domain (OOD) evaluation,
as well as per-class (negative vs. positive) uncertainty analysis.
"""

import os, time, random, math, json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sentence_transformers import SentenceTransformer
import pyro
from sklearn.metrics import f1_score, average_precision_score
from scipy.io import savemat
import pandas as pd

pd.set_option('display.float_format', lambda x: f"{x:.8f}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
v = 0.025
sigma_w = np.sqrt(v)
sigma_b = np.sqrt(v)

class SimpleMLP(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=0, num_classes=2):
        super(SimpleMLP, self).__init__()
        if hidden_dim == 0:
            self.fc = nn.Linear(input_dim, num_classes)
            nn.init.normal_(self.fc.weight, mean=0, std=sigma_w)
            nn.init.normal_(self.fc.bias,   mean=0, std=sigma_b)
        else:
            self.fc1 = nn.Linear(input_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, num_classes)
            for l in (self.fc1, self.fc2):
                nn.init.normal_(l.weight, mean=0, std=sigma_w)
                nn.init.normal_(l.bias,   mean=0, std=sigma_b)
    def forward(self, x):
        if hasattr(self, 'fc'):
            return self.fc(x)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

def create_sbert_embedded_imdb_dataset(model_name="all-mpnet-base-v2",
                                       train_cache_path="imdb_embeddings_trainBig.pt",
                                       test_cache_path="imdb_embeddings_testBig.pt"):
    if os.path.exists(train_cache_path) and os.path.exists(test_cache_path):
        print("Loading cached SBERT embeddings...")
        X_train, y_train = torch.load(train_cache_path)
        X_test,  y_test  = torch.load(test_cache_path)
        return X_train, y_train, X_test, y_test

    print("Computing SBERT embeddings...")
    sbert = SentenceTransformer(model_name)
    from torchtext.datasets import IMDB
    train_data = list(IMDB(split="train"))
    test_data  = list(IMDB(split="test"))
    label_map = {"pos":1, "neg":0}

    X_train_list, y_train_list = [], []
    for lbl, txt in train_data:
        X_train_list.append(sbert.encode(txt, convert_to_numpy=True))
        y_train_list.append(label_map[lbl])
    X_test_list, y_test_list = [], []
    for lbl, txt in test_data:
        X_test_list.append(sbert.encode(txt, convert_to_numpy=True))
        y_test_list.append(label_map[lbl])

    X_train = torch.tensor(X_train_list, dtype=torch.float32)
    y_train = torch.tensor(y_train_list, dtype=torch.long)
    X_test  = torch.tensor(X_test_list,  dtype=torch.float32)
    y_test  = torch.tensor(y_test_list,  dtype=torch.long)
    torch.save((X_train, y_train), train_cache_path)
    torch.save((X_test,  y_test),  test_cache_path)
    return X_train, y_train, X_test, y_test

def evaluate_ensemble(ensemble_models, loader):
    all_probs, all_ep, all_labels = [], [], []
    with torch.no_grad():
        for inputs, labels in loader:
            batch_probs = [F.softmax(m(inputs), dim=1) for m in ensemble_models]
            stacked     = torch.stack(batch_probs, dim=0)
            ens_p       = stacked.mean(0)
            model_ent   = -torch.sum(stacked * torch.log(stacked + 1e-12), dim=2)
            total_ent   = -torch.sum(ens_p * torch.log(ens_p + 1e-12), dim=1)
            epi_unc     = total_ent - model_ent.mean(0)
            all_probs.append(ens_p)
            all_ep.append(epi_unc)
            all_labels.append(labels)

    P     = torch.cat(all_probs, dim=0)
    E     = torch.cat(all_ep,    dim=0)
    Y     = torch.cat(all_labels,dim=0)
    preds = P.argmax(dim=1)

    ent_all = -torch.sum(P * torch.log(P + 1e-12), dim=1)

    acc       = (preds==Y).float().mean().item()
    f1        = f1_score(Y.cpu(), preds.cpu(), average='macro')
    auc       = average_precision_score(Y.cpu(), P[:,1].cpu())
    tot_ent   = ent_all.mean().item()
    corr_tot  = ent_all[preds==Y].mean().item()    if (preds==Y).any() else 0.0
    inc_tot   = ent_all[preds!=Y].mean().item()    if (preds!=Y).any() else 0.0
    corr_epi  = E[preds==Y].mean().item()          if (preds==Y).any() else 0.0
    inc_epi   = E[preds!=Y].mean().item()          if (preds!=Y).any() else 0.0
    epi_mean  = E.mean().item()
    nll       = -torch.log(P[torch.arange(len(Y)), Y] + 1e-12).sum().item()

    return acc, f1, auc, tot_ent, corr_tot, corr_epi, inc_tot, inc_epi, epi_mean, nll

def compute_per_class_uncertainties(ensemble_models, loader):
    all_probs, all_ep, all_labels = [], [], []
    with torch.no_grad():
        for inputs, labels in loader:
            batch_probs = [F.softmax(m(inputs), dim=1) for m in ensemble_models]
            stacked     = torch.stack(batch_probs, dim=0)
            ens_p       = stacked.mean(0)
            ent         = -torch.sum(ens_p * torch.log(ens_p + 1e-12), dim=1)
            epi         = ent - (-torch.sum(stacked * torch.log(stacked + 1e-12), dim=2)).mean(0)
            all_probs.append(ens_p)
            all_ep.append(epi)
            all_labels.append(labels)
    P = torch.cat(all_probs, dim=0)
    E = torch.cat(all_ep,    dim=0)
    Y = torch.cat(all_labels,dim=0)
    ent_all = -torch.sum(P * torch.log(P + 1e-12), dim=1)

    results = {}
    for cls in [0, 1]:
        mask = (Y == cls)
        results[f"tot_ent_in_cls{cls}"] = ent_all[mask].mean().item()
        results[f"epi_unc_in_cls{cls}"] = E[mask].mean().item()
    return results

def compute_total_and_epistemic_entropy_ensemble(ensemble_models, x):
    all_probs = []
    with torch.no_grad():
        for m in ensemble_models:
            m.eval()
            out   = m(x)
            probs = F.softmax(out, dim=1)
            all_probs.append(probs)
    stacked   = torch.stack(all_probs, dim=0)
    ens_p     = stacked.mean(0)
    total_ent = -torch.sum(ens_p * torch.log(ens_p + 1e-12), dim=1)
    model_ent = -torch.sum(stacked * torch.log(stacked + 1e-12), dim=2)
    epi_unc   = total_ent - model_ent.mean(0)
    return total_ent.mean().item(), epi_unc.mean().item()

def load_jsonl_first_n(fn, n=100):
    data = []
    with open(fn, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            if i >= n: break
            data.append(json.loads(line))
    return data

if __name__=='__main__':
    R = 5
    ensemble_size = 10

    X_train, y_train, X_test, y_test = create_sbert_embedded_imdb_dataset()
    train_loader = DataLoader(TensorDataset(X_train[:20000].to(device), y_train[:20000].to(device)),
                              batch_size=64, shuffle=True)
    val_loader   = DataLoader(TensorDataset(X_train[20000:].to(device), y_train[20000:].to(device)),
                              batch_size=64, shuffle=False)
    test_loader  = DataLoader(TensorDataset(X_test.to(device), y_test.to(device)),
                              batch_size=64, shuffle=False)

    rep_epochs      = []
    rep_acc         = []; rep_f1         = []; rep_auc    = []
    rep_tot_ent     = []
    rep_corr_tot    = []; rep_corr_epi    = []
    rep_inc_tot     = []; rep_inc_epi     = []
    rep_epi_unc     = []; rep_nll      = []
    rep_tot_cls0    = []; rep_epi_cls0 = []
    rep_tot_cls1    = []; rep_epi_cls1 = []
    splits = ["reviews","meta","lipsum","full_reviews","full_meta"]
    rep_ood_tot = {s:[] for s in splits}
    rep_ood_epi = {s:[] for s in splits}

    # Prepare OOD embeddings
    reviews_file           = os.path.join("data","Appliances.jsonl")
    meta_file              = os.path.join("data","meta_Appliances.jsonl")
    rev_data               = load_jsonl_first_n(reviews_file, 100)
    met_data               = load_jsonl_first_n(meta_file,    100)
    ood_reviews_texts      = [d.get("text","") for d in rev_data]
    ood_meta_texts         = [" ".join([d.get("title",""), *d.get("features",[])]) for d in met_data]
    random.seed(42); np.random.seed(42)
    lorem_words            = ("lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor "
                             "incididunt ut labore et dolore magna aliqua").split()
    ood_lipsum_texts       = [" ".join(random.choices(lorem_words, k=random.randint(1,10)))
                              for _ in range(100)]
    ood_full_reviews_texts = [json.dumps(d, ensure_ascii=False) for d in rev_data]
    ood_full_meta_texts    = [json.dumps(d, ensure_ascii=False) for d in met_data]

    sbert_ood = SentenceTransformer("all-mpnet-base-v2")
    emb_ood_reviews      = sbert_ood.encode(ood_reviews_texts,      convert_to_tensor=True).to(device)
    emb_ood_meta         = sbert_ood.encode(ood_meta_texts,         convert_to_tensor=True).to(device)
    emb_ood_lipsum       = sbert_ood.encode(ood_lipsum_texts,       convert_to_tensor=True).to(device)
    emb_ood_full_reviews = sbert_ood.encode(ood_full_reviews_texts, convert_to_tensor=True).to(device)
    emb_ood_full_meta    = sbert_ood.encode(ood_full_meta_texts,    convert_to_tensor=True).to(device)

    for r in range(1, R+1):
        print(f"\n===== Replicate {r} of {R} =====")
        torch.manual_seed(r); np.random.seed(r); random.seed(r); pyro.set_rng_seed(r)
        models, epochs_members = [], []

        for m in range(ensemble_size):
            seed = r*1000 + m
            print(f"  Training ensemble member {m+1}/{ensemble_size} (seed={seed})...")
            torch.manual_seed(seed); np.random.seed(seed); random.seed(seed); pyro.set_rng_seed(seed)

            model = SimpleMLP(input_dim=X_train.shape[1]).to(device)
            optimizer = optim.Adam(model.parameters(), lr=1e-3)
            crit = nn.CrossEntropyLoss()
            best_ma, no_imp, hist = float('inf'), 0, []

            for ep in range(1000):
                model.train()
                run_loss = 0.0
                for inp, lab in train_loader:
                    optimizer.zero_grad()
                    out = model(inp)
                    ce = crit(out, lab)
                    reg = (torch.sum(model.fc.weight**2)/(2*sigma_w**2)
                         + torch.sum(model.fc.bias**2)/(2*sigma_b**2))
                    loss = ce + reg/len(train_loader.dataset)
                    loss.backward()
                    optimizer.step()
                    run_loss += loss.item()
                train_loss = run_loss / len(train_loader)

                model.eval()
                vloss = 0.0
                with torch.no_grad():
                    for inp, lab in val_loader:
                        out = model(inp)
                        ce = crit(out, lab)
                        reg = (torch.sum(model.fc.weight**2)/(2*sigma_w**2)
                             + torch.sum(model.fc.bias**2)/(2*sigma_b**2))
                        vloss += (ce + reg/len(train_loader.dataset)).item()
                val_loss = vloss / len(val_loader)
                hist.append(val_loss)

                print(f"    Epoch {ep+1}: Train Loss = {train_loss:.8f}, Val Loss = {val_loss:.8f}")

                if ep >= 9:
                    mov = sum(hist[-10:]) / 10
                    if mov < best_ma:
                        best_ma, no_imp = mov, 0
                    else:
                        no_imp += 1
                    if no_imp >= 5:
                        print(f"    Early stopping at epoch {ep+1}")
                        break

            epochs_members.append(ep+1)
            models.append(model)

        avg_epochs = np.mean(epochs_members)
        rep_epochs.append(avg_epochs)
        print(f"  Replicate {r}: Average epochs = {avg_epochs:.2f}")

        acc, f1, auc, tot_ent, corr_tot, corr_epi, inc_tot, inc_epi, epi, nll = evaluate_ensemble(models, test_loader)
        rep_acc.append(acc); rep_f1.append(f1); rep_auc.append(auc)
        rep_tot_ent.append(tot_ent)
        rep_corr_tot.append(corr_tot); rep_corr_epi.append(corr_epi)
        rep_inc_tot.append(inc_tot); rep_inc_epi.append(inc_epi)
        rep_epi_unc.append(epi); rep_nll.append(nll)
        print(f"  In-Domain → Acc={acc:.4f}, F1={f1:.4f}, AUC-PR={auc:.4f}")
        print(f"               TotEnt={tot_ent:.4f}, CorrTot={corr_tot:.4f}, CorrEpi={corr_epi:.4f}, IncTot={inc_tot:.4f}, IncEpi={inc_epi:.4f}, Epi={epi:.4f}, NLL={nll:.2f}")

        pc = compute_per_class_uncertainties(models, test_loader)
        rep_tot_cls0.append(pc['tot_ent_in_cls0']); rep_epi_cls0.append(pc['epi_unc_in_cls0'])
        rep_tot_cls1.append(pc['tot_ent_in_cls1']); rep_epi_cls1.append(pc['epi_unc_in_cls1'])
        print(f"    Class0 → TotEnt={pc['tot_ent_in_cls0']:.4f}, Epi={pc['epi_unc_in_cls0']:.4f}; "
              f"Class1 → TotEnt={pc['tot_ent_in_cls1']:.4f}, Epi={pc['epi_unc_in_cls1']:.4f}")

        for name, emb in zip(splits, [emb_ood_reviews, emb_ood_meta, emb_ood_lipsum, emb_ood_full_reviews, emb_ood_full_meta]):
            tot, epi_ood = compute_total_and_epistemic_entropy_ensemble(models, emb)
            rep_ood_tot[name].append(tot)
            rep_ood_epi[name].append(epi_ood)
            print(f"    OOD {name}: TotEnt={tot:.4f}, Epi={epi_ood:.4f}")

    # aggregate & save
    metrics = {
        'replicate_epochs':                  np.array(rep_epochs),
        'replicate_accuracy':                np.array(rep_acc),
        'replicate_f1_score':                np.array(rep_f1),
        'replicate_auc_pr':                  np.array(rep_auc),
        'replicate_total_entropy':           np.array(rep_tot_ent),
        'replicate_total_entropy_in_class0': np.array(rep_tot_cls0),
        'replicate_total_entropy_in_class1': np.array(rep_tot_cls1),
        'replicate_correct_total_entropy':   np.array(rep_corr_tot),
        'replicate_correct_epistemic_entropy': np.array(rep_corr_epi),
        'replicate_incorrect_total_entropy': np.array(rep_inc_tot),
        'replicate_incorrect_epistemic_entropy': np.array(rep_inc_epi),
        'replicate_epistemic_entropy':       np.array(rep_epi_unc),
        'replicate_epistemic_in_class0':     np.array(rep_epi_cls0),
        'replicate_epistemic_in_class1':     np.array(rep_epi_cls1),
        'replicate_nll':                     np.array(rep_nll),
    }
    for s in splits:
        metrics[f'replicate_total_entropy_od_{s}'] = np.array(rep_ood_tot[s])
        metrics[f'replicate_epistemic_od_{s}']     = np.array(rep_ood_epi[s])

    # helper to compute mean ± SE
    def mean_se(lst):
        arr = np.array(lst)
        return arr.mean(), arr.std(ddof=1)/math.sqrt(len(arr))

    # compute all summary stats
    epochs_mean, epochs_se = mean_se(rep_epochs)
    acc_mean, acc_se       = mean_se(rep_acc)
    f1_mean, f1_se         = mean_se(rep_f1)
    auc_mean, auc_se       = mean_se(rep_auc)
    tot_ent_mean, tot_ent_se = mean_se(rep_tot_ent)
    epi_mean, epi_se       = mean_se(rep_epi_unc)

    tot0_mean, tot0_se     = mean_se(rep_tot_cls0)
    epi0_mean, epi0_se     = mean_se(rep_epi_cls0)
    tot1_mean, tot1_se     = mean_se(rep_tot_cls1)
    epi1_mean, epi1_se     = mean_se(rep_epi_cls1)

    t_rev_mean, t_rev_se, = mean_se(rep_ood_tot['reviews'])
    e_rev_mean, e_rev_se  = mean_se(rep_ood_epi['reviews'])
    t_meta_mean, t_meta_se = mean_se(rep_ood_tot['meta'])
    e_meta_mean, e_meta_se = mean_se(rep_ood_epi['meta'])
    t_lipsum_mean, t_lipsum_se = mean_se(rep_ood_tot['lipsum'])
    e_lipsum_mean, e_lipsum_se = mean_se(rep_ood_epi['lipsum'])
    t_full_rev_mean, t_full_rev_se = mean_se(rep_ood_tot['full_reviews'])
    e_full_rev_mean, e_full_rev_se = mean_se(rep_ood_epi['full_reviews'])
    t_full_meta_mean, t_full_meta_se = mean_se(rep_ood_tot['full_meta'])
    e_full_meta_mean, e_full_meta_se = mean_se(rep_ood_epi['full_meta'])

    # print MAP-style summary
    print(f"\n===== Summary over {R} replicates =====")
    print(f"Avg. Ensemble Epochs:                   Mean = {epochs_mean:.2f}, SE = {epochs_se:.8f}")
    print(f"In-Domain Accuracy (%):               Mean = {acc_mean*100:.2f}, SE = {acc_se*100:.8f}")
    print(f"In-Domain F1 Score (macro):           Mean = {f1_mean:.8f}, SE = {f1_se:.8f}")
    print(f"In-Domain AUC-PR (macro):             Mean = {auc_mean:.8f}, SE = {auc_se:.8f}")
    print(f"In-Domain Total Entropy:              Mean = {tot_ent_mean:.8f}, SE = {tot_ent_se:.8f}")
    print(f"In-Domain Epistemic Entropy:          Mean = {epi_mean:.8f}, SE = {epi_se:.8f}")

    print("\nBreakdown for In-Domain Predictions:")
    print(f"  Class 0 – Total Entropy:             Mean = {tot0_mean:.8f}, SE = {tot0_se:.8f}")
    print(f"  Class 0 – Epistemic Entropy:         Mean = {epi0_mean:.8f}, SE = {epi0_se:.8f}")
    print(f"  Class 1 – Total Entropy:             Mean = {tot1_mean:.8f}, SE = {tot1_se:.8f}")
    print(f"  Class 1 – Epistemic Entropy:         Mean = {epi1_mean:.8f}, SE = {epi1_se:.8f}")

    print("\nOOD Results (Total & Epistemic Entropy per split):")
    print(f"  reviews     – Total Entropy: Mean = {t_rev_mean:.8f}, SE = {t_rev_se:.8f}; Epistemic: Mean = {e_rev_mean:.8f}, SE = {e_rev_se:.8f}")
    print(f"  meta        – Total Entropy: Mean = {t_meta_mean:.8f}, SE = {t_meta_se:.8f}; Epistemic: Mean = {e_meta_mean:.8f}, SE = {e_meta_se:.8f}")
    print(f"  lipsum     – Total Entropy: Mean = {t_lipsum_mean:.8f}, SE = {t_lipsum_se:.8f}; Epistemic: Mean = {e_lipsum_mean:.8f}, SE = {e_lipsum_se:.8f}")
    print(f"  full_reviews         – Total Entropy: Mean = {t_full_rev_mean:.8f}, SE = {t_full_rev_se:.8f}; Epistemic: Mean = {e_full_rev_mean:.8f}, SE = {e_full_rev_se:.8f}")
    print(f"  full_meta            – Total Entropy: Mean = {t_full_meta_mean:.8f}, SE = {t_full_meta_se:.8f}; Epistemic: Mean = {e_full_meta_mean:.8f}, SE = {e_full_meta_se:.8f}")

    savemat('imdb_de_metrics.mat', metrics)
    print("\nSaved all computed metrics to 'imdb_de_metrics.mat'.")
