#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
PSMC Analysis with In-Domain and OOD Evaluation for BayesianNN_IMDB_psmc_SimpleMLP_MAP.

This script assumes that:
  - Replicate files (each containing 32 particles) are stored in the current directory.
    Their names follow the pattern: 
      BayesianNN_IMDB_psmc_SimpleMLP_MAP_d*_N32_M5_node*.mat
  - The in-domain test embeddings are stored in "imdb_embeddings_testBig.pt",
    containing a tuple (X_test, y_test).
  - OOD evaluation is performed on several datasets derived from JSONL files (in folder "data")
    and on Lipsum texts via SBERT.
  
For each processor configuration P in [1, 2, 4, 8], R = 5 replicates are assumed.
Each replicate is formed by grouping P files (so that each replicate provides 32*P predictions
and particles). For each replicate the script computes in-domain metrics:
  - Accuracy, F1, AUC-PR, Negative Log Loss (NLL),
  - Total predictive entropy and epistemic uncertainty (as total entropy minus average per-particle entropy),
  - Additionally, the total and epistemic entropies are computed separately for correct and incorrect predictions.
and also computes the OOD predictive entropy on several OOD datasets.
Aggregated statistics (mean and standard error, with eight-decimal precision) are printed and saved
to "psmc_aggregated_results_P{P}.mat" for each P.
"""

# ------------------ Environment Setup ------------------
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

import torch
torch.set_num_threads(1)

# ------------------ Standard Imports ------------------
import glob
import json
import random
import numpy as np
import scipy.io as sio
from scipy.special import softmax
from sklearn.metrics import accuracy_score, f1_score, average_precision_score, log_loss
import torch.nn as nn
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from scipy.io import savemat

# Numerical stability constant
EPS = 1e-12

############################################
# Model and Utility Functions
############################################

class SimpleMLP(nn.Module):
    """
    A simple classifier (logistic regression when hidden_dim==0 or None).
    """
    def __init__(self, input_dim=768, hidden_dim=0, num_classes=2):
        super(SimpleMLP, self).__init__()
        if hidden_dim is None or hidden_dim == 0:
            self.fc = nn.Linear(input_dim, num_classes)
            nn.init.normal_(self.fc.weight, mean=0, std=0.5)
            if self.fc.bias is not None:
                nn.init.normal_(self.fc.bias, mean=0, std=0.5)
        else:
            self.fc1 = nn.Linear(input_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, num_classes)
            nn.init.normal_(self.fc1.weight, mean=0, std=0.5)
            nn.init.normal_(self.fc2.weight, mean=0, std=0.5)
            if self.fc1.bias is not None:
                nn.init.normal_(self.fc1.bias, mean=0, std=0.5)
            if self.fc2.bias is not None:
                nn.init.normal_(self.fc2.bias, mean=0, std=0.5)

    def forward(self, x):
        if hasattr(self, 'fc'):
            return self.fc(x)
        else:
            x = self.fc1(x)
            x = F.relu(x)
            return self.fc2(x)

def unflatten_params(flat, net):
    """
    Reconstruct a parameter dictionary from the flattened tensor 'flat'
    using the parameter shapes in 'net'.
    """
    param_dict = {}
    pointer = 0
    for name, param in net.named_parameters():
        numel = param.numel()
        param_dict[name] = flat[pointer:pointer+numel].view(param.shape)
        pointer += numel
    return param_dict

def compute_ece(probs, labels, n_bins=10):
    """Compute Expected Calibration Error (ECE)."""
    bins = np.linspace(0.0, 1.0, n_bins+1)
    ece = 0.0
    for i in range(n_bins):
        mask = (probs >= bins[i]) & (probs < bins[i+1])
        if np.any(mask):
            acc_bin = (labels[mask] == (probs[mask] >= 0.5)).mean()
            conf_bin = probs[mask].mean()
            ece += np.abs(acc_bin - conf_bin) * mask.mean()
    return ece

############################################
# In-Domain and OOD Metrics Functions
############################################

def compute_indomain_metrics(softmax_samples, y_true):
    """
    Given softmax predictions (shape: [n_samples, n_examples, num_classes]),
    compute in-domain metrics.
    """
    p_avg = np.mean(softmax_samples, axis=0)
    preds = np.argmax(p_avg, axis=1)
    acc = accuracy_score(y_true, preds)
    f1 = f1_score(y_true, preds)
    auc_pr = average_precision_score(y_true, p_avg[:, 1])
    nll = log_loss(y_true, p_avg, normalize=False)

    brier     = np.mean((p_avg[:,1] - y_true)**2)
    ece       = compute_ece(p_avg[:,1], y_true, n_bins=10)
    
    total_entropy_vec = -np.sum(p_avg * np.log(p_avg + EPS), axis=1)
    total_entropy = total_entropy_vec.mean()
    
    sample_entropies = np.array([
        -np.sum(probs * np.log(probs + EPS), axis=1) for probs in softmax_samples
    ])
    avg_sample_entropy = sample_entropies.mean(axis=0)
    epistemic_vec = total_entropy_vec - avg_sample_entropy
    epistemic_uncertainty = epistemic_vec.mean()
    
    correct_mask = (preds == y_true)
    if np.any(correct_mask):
        te_corr = total_entropy_vec[correct_mask].mean()
        epi_corr = epistemic_vec[correct_mask].mean()
    else:
        te_corr = np.nan
        epi_corr = np.nan
    if np.any(~correct_mask):
        te_inc = total_entropy_vec[~correct_mask].mean()
        epi_inc = epistemic_vec[~correct_mask].mean()
    else:
        te_inc = np.nan
        epi_inc = np.nan
        
    return {
        "accuracy": acc,
        "f1": f1,
        "auc_pr": auc_pr,
        "nll": nll,
        "brier": brier,
        "ece": ece,
        "total_entropy": total_entropy,
        "epistemic_uncertainty": epistemic_uncertainty,
        "total_entropy_correct": te_corr,
        "epistemic_uncertainty_correct": epi_corr,
        "total_entropy_incorrect": te_inc,
        "epistemic_uncertainty_incorrect": epi_inc
    }

def aggregate_replicate_psmc(file_list):
    """
    Load PSMC replicate files (each with 32 particles).
    Returns:
      - softmax_array: (n_files*32, n_examples, num_classes)
      - (mean_Lsum, se_Lsum)
      - particles_array: (n_files*32, num_params)
    """
    softmax_list = []
    Lsum_list = []
    particles_list = []
    for f in file_list:
        data = sio.loadmat(f)
        preds = data["psmc_single_pred"]            # (32, n_examples, num_classes)
        softmax_list.append(softmax(preds, axis=-1))
        Lsum_list.append(float(data["Lsum"]))
        particles_list.append(data["psmc_single_x"]) # (32, num_params)
    soft_arr = np.concatenate(softmax_list, axis=0)
    parts   = np.concatenate(particles_list, axis=0)
    Ls = np.array(Lsum_list)
    mean_Ls = np.mean(Ls)
    se_Ls   = np.std(Ls, ddof=1)/np.sqrt(len(Ls)) if len(Ls)>1 else 0.0
    return soft_arr, (mean_Ls, se_Ls), parts

def compute_ood_entropies(particles, model, x_ood, device):
    """
    Computes total predictive entropy and epistemic uncertainty on x_ood.
    """
    try:
        from torch.func import functional_call
    except ImportError:
        from torch.nn.utils.stateless import functional_call
    soft_list = []
    for p in particles:
        pt = torch.tensor(p, dtype=torch.float32, device=device).view(-1)
        params = unflatten_params(pt, model)
        logits = functional_call(model, params, x_ood)
        soft_list.append(torch.softmax(logits, dim=1))
    soft_stack = torch.stack(soft_list, dim=0)
    agg = soft_stack.mean(dim=0)
    total_ent = -torch.sum(agg * torch.log(agg + EPS), dim=1).mean().item()
    part_ent  = -torch.sum(soft_stack * torch.log(soft_stack + EPS), dim=2)
    epistemic = total_ent - part_ent.mean().item()
    return total_ent, epistemic

############################################
# OOD Embedding Functions (using SBERT)
############################################

def load_jsonl_first_n(filename, n=100):
    data=[]
    with open(filename, "r", encoding="utf-8") as f:
        for i,line in enumerate(f):
            if i>=n: break
            data.append(json.loads(line))
    return data

def compute_ood_embeddings(device, n_samples=100):
    reviews = load_jsonl_first_n("data/Appliances.jsonl", n_samples)
    meta    = load_jsonl_first_n("data/meta_Appliances.jsonl", n_samples)
    texts_rev  = [e.get("text","") for e in reviews]
    texts_meta = [(e.get("title","")+" "+ " ".join(e.get("features",[]))).strip() for e in meta]
    random.seed(42); np.random.seed(42)
    lorem = ("lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua").split()
    texts_lip  = [" ".join(random.choices(lorem, k=random.randint(1,10))) for _ in range(n_samples)]
    full_rev   = [json.dumps(e, ensure_ascii=False) for e in reviews]
    full_meta  = [json.dumps(e, ensure_ascii=False) for e in meta]
    sbert = SentenceTransformer("all-mpnet-base-v2"); sbert.eval()
    encode = lambda lst: sbert.encode(lst, convert_to_tensor=True).to(device)
    return {
        "ood_reviews":      encode(texts_rev),
        "ood_meta":         encode(texts_meta),
        "ood_lipsum":       encode(texts_lip),
        "ood_full_reviews": encode(full_rev),
        "ood_full_meta":    encode(full_meta),
    }

############################################
# Main Execution
############################################

if __name__ == "__main__":
    RESULTS_DIR      = "."
    TEST_FILE        = "imdb_embeddings_testBig.pt"
    file_pattern     = os.path.join(RESULTS_DIR,
                            "BayesianNN_IMDB_psmc_SimpleMLP_MAP_d*_N10_M1_node*.mat")
    replicate_files_all = sorted(glob.glob(file_pattern))
    if not replicate_files_all:
        raise ValueError("No replicate .mat files found; check naming pattern.")
    print(f"Found {len(replicate_files_all)} replicate files.")

    R = 5
    processor_list = [1, 2, 4, 8]

    # Load in-domain
    if not os.path.exists(TEST_FILE):
        raise FileNotFoundError(f"Test file '{TEST_FILE}' not found.")
    X_test, y_test = torch.load(TEST_FILE, map_location="cpu")
    y_test_np = y_test.numpy().flatten()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model  = SimpleMLP(input_dim=X_test.shape[1], hidden_dim=0, num_classes=2).to(device)
    model.eval()

    print("Computing OOD embeddings using SBERT...")
    ood_embeddings = compute_ood_embeddings(device, n_samples=100)

    for P in processor_list:
        print(f"\nProcessing P = {P}")
        files_per_replicate = P
        replicate_results = []

        # added lists for per-class
        per_class_total_entropy_list = []
        per_class_epistemic_uncertainty_list = []

        ood_reviews_total_list = []
        ood_reviews_epistemic_list = []
        ood_meta_total_list = []
        ood_meta_epistemic_list = []
        ood_lipsum_total_list = []
        ood_lipsum_epistemic_list = []
        ood_full_reviews_total_list = []
        ood_full_reviews_epistemic_list = []
        ood_full_meta_total_list = []
        ood_full_meta_epistemic_list = []

        for r in range(R):
            group_files = replicate_files_all[
                r*files_per_replicate : (r+1)*files_per_replicate
            ]
            if len(group_files) < files_per_replicate:
                print(f"Skipping replicate {r+1} for P={P} (insufficient files).")
                continue

            soft_arr, (avg_Lsum, se_Lsum), parts = aggregate_replicate_psmc(group_files)
            indomain = compute_indomain_metrics(soft_arr, y_test_np)
            indomain["avg_Lsum"] = avg_Lsum
            indomain["se_Lsum"]  = se_Lsum

            # — added: compute per-class breakdown
            p_avg = np.mean(soft_arr, axis=0)
            total_ent_vec = -np.sum(p_avg * np.log(p_avg + EPS), axis=1)
            sample_entropies = np.array([
                -np.sum(probs * np.log(probs + EPS), axis=1)
                for probs in soft_arr
            ])
            avg_sample_entropy = sample_entropies.mean(axis=0)
            epistemic_vec = total_ent_vec - avg_sample_entropy

            mask0 = (y_test_np == 0)
            mask1 = (y_test_np == 1)
            class0_tot = total_ent_vec[mask0].mean() if mask0.any() else np.nan
            class1_tot = total_ent_vec[mask1].mean() if mask1.any() else np.nan
            class0_epi = epistemic_vec[mask0].mean() if mask0.any() else np.nan
            class1_epi = epistemic_vec[mask1].mean() if mask1.any() else np.nan
            per_class_total_entropy_list.append([class0_tot, class1_tot])
            per_class_epistemic_uncertainty_list.append([class0_epi, class1_epi])

            # Original OOD computations (unchanged)
            ood_reviews_total, ood_reviews_epi = compute_ood_entropies(
                parts, model, ood_embeddings["ood_reviews"], device
            )
            ood_meta_total, ood_meta_epi = compute_ood_entropies(
                parts, model, ood_embeddings["ood_meta"], device
            )
            ood_lipsum_total, ood_lipsum_epi = compute_ood_entropies(
                parts, model, ood_embeddings["ood_lipsum"], device
            )
            ood_full_reviews_total, ood_full_reviews_epi = compute_ood_entropies(
                parts, model, ood_embeddings["ood_full_reviews"], device
            )
            ood_full_meta_total, ood_full_meta_epi = compute_ood_entropies(
                parts, model, ood_embeddings["ood_full_meta"], device
            )

            ood_reviews_total_list.append(ood_reviews_total)
            ood_reviews_epistemic_list.append(ood_reviews_epi)
            ood_meta_total_list.append(ood_meta_total)
            ood_meta_epistemic_list.append(ood_meta_epi)
            ood_lipsum_total_list.append(ood_lipsum_total)
            ood_lipsum_epistemic_list.append(ood_lipsum_epi)
            ood_full_reviews_total_list.append(ood_full_reviews_total)
            ood_full_reviews_epistemic_list.append(ood_full_reviews_epi)
            ood_full_meta_total_list.append(ood_full_meta_total)
            ood_full_meta_epistemic_list.append(ood_full_meta_epi)

            replicate_results.append({
                "indomain": indomain,
                "ood_reviews_total_entropy": ood_reviews_total,
                "ood_reviews_epistemic_uncertainty": ood_reviews_epi,
                "ood_meta_total_entropy": ood_meta_total,
                "ood_meta_epistemic_uncertainty": ood_meta_epi,
                "ood_lipsum_total_entropy": ood_lipsum_total,
                "ood_lipsum_epistemic_uncertainty": ood_lipsum_epi,
                "ood_full_reviews_total_entropy": ood_full_reviews_total,
                "ood_full_reviews_epistemic_uncertainty": ood_full_reviews_epi,
                "ood_full_meta_total_entropy": ood_full_meta_total,
                "ood_full_meta_epistemic_uncertainty": ood_full_meta_epi
            })

            print(f"  Replicate {r+1}/{R}: In-domain Accuracy = {indomain['accuracy']:.8f}, "
                  f"Avg Lsum = {avg_Lsum:.8f}, SE Lsum = {se_Lsum:.8f}")

        def mean_se(lst):
            arr = np.array(lst, dtype=float)
            m = np.nanmean(arr)
            s = np.nanstd(arr, ddof=1) / np.sqrt(arr.shape[0]) if arr.shape[0] > 1 else 0.0
            return m, s

        # Summarize in-domain metrics (unchanged)
        indomain_metric_names = [
            "accuracy","f1","auc_pr","nll", "brier", "ece",
            "total_entropy","epistemic_uncertainty",
            "total_entropy_correct","epistemic_uncertainty_correct",
            "total_entropy_incorrect","epistemic_uncertainty_incorrect",
            "avg_Lsum","se_Lsum"
        ]
        indomain_summary = {}
        for key in indomain_metric_names:
            vals = [rep["indomain"][key] for rep in replicate_results]
            m, s = mean_se(vals)
            indomain_summary[key + "_mean"] = m
            indomain_summary[key + "_se"]   = s
            print(f"  P={P} In-domain {key}: mean = {m:.8f}, SE = {s:.8f}")

        # Summarize OOD metrics (unchanged)
        ood_summary = {}
        for name, lst in [
            ("ood_reviews_total_entropy", ood_reviews_total_list),
            ("ood_reviews_epistemic_uncertainty", ood_reviews_epistemic_list),
            ("ood_meta_total_entropy", ood_meta_total_list),
            ("ood_meta_epistemic_uncertainty", ood_meta_epistemic_list),
            ("ood_lipsum_total_entropy", ood_lipsum_total_list),
            ("ood_lipsum_epistemic_uncertainty", ood_lipsum_epistemic_list),
            ("ood_full_reviews_total_entropy", ood_full_reviews_total_list),
            ("ood_full_reviews_epistemic_uncertainty", ood_full_reviews_epistemic_list),
            ("ood_full_meta_total_entropy", ood_full_meta_total_list),
            ("ood_full_meta_epistemic_uncertainty", ood_full_meta_epistemic_list)
        ]:
            m, s = mean_se(lst)
            ood_summary[name + "_mean"] = m
            ood_summary[name + "_se"]   = s
            print(f"  P={P} {name}: mean = {m:.8f}, SE = {s:.8f}")

        # Build results dict (unchanged)
        results = {
            'mean_acc':        indomain_summary['accuracy_mean'],
            'stderr_acc':      indomain_summary['accuracy_se'],
            'mean_f1_id':      indomain_summary['f1_mean'],
            'stderr_f1_id':    indomain_summary['f1_se'],
            'mean_aucpr_id':   indomain_summary['auc_pr_mean'],
            'stderr_aucpr_id': indomain_summary['auc_pr_se'],
            'mean_nll':        indomain_summary['nll_mean'],
            'stderr_nll':      indomain_summary['nll_se'],
            'mean_tot_ent':    indomain_summary['total_entropy_mean'],
            'stderr_tot_ent':  indomain_summary['total_entropy_se'],
            'mean_epi':        indomain_summary['epistemic_uncertainty_mean'],
            'stderr_epi':      indomain_summary['epistemic_uncertainty_se'],
            'mean_Lsum':       indomain_summary['avg_Lsum_mean'],
            'stderr_Lsum':     indomain_summary['se_Lsum_se'],
            'mean_brier':        indomain_summary['brier_mean'],
            'stderr_brier':      indomain_summary['brier_se'],
            'mean_ece':        indomain_summary['ece_mean'],
            'stderr_ece':      indomain_summary['ece_se'],

            'mean_tot_corr':   indomain_summary['total_entropy_correct_mean'],
            'se_tot_corr':     indomain_summary['total_entropy_correct_se'],
            'mean_epi_corr':   indomain_summary['epistemic_uncertainty_correct_mean'],
            'se_epi_corr':     indomain_summary['epistemic_uncertainty_correct_se'],
            'mean_tot_inc':    indomain_summary['total_entropy_incorrect_mean'],
            'se_tot_inc':      indomain_summary['total_entropy_incorrect_se'],
            'mean_epi_inc':    indomain_summary['epistemic_uncertainty_incorrect_mean'],
            'se_epi_inc':      indomain_summary['epistemic_uncertainty_incorrect_se'],
        }

        # OOD per-category into results (unchanged)
        for key in ["reviews","meta","lipsum","full_reviews","full_meta"]:
            results[f'ood_{key}_mean_tot'] = ood_summary[f'ood_{key}_total_entropy_mean']
            results[f'ood_{key}_se_tot']   = ood_summary[f'ood_{key}_total_entropy_se']
            results[f'ood_{key}_mean_epi'] = ood_summary[f'ood_{key}_epistemic_uncertainty_mean']
            results[f'ood_{key}_se_epi']   = ood_summary[f'ood_{key}_epistemic_uncertainty_se']

        # — added: per-class summary and print
        if per_class_total_entropy_list:
            arr_tot = np.vstack(per_class_total_entropy_list)
            arr_epi = np.vstack(per_class_epistemic_uncertainty_list)
            mean_class_tot = np.nanmean(arr_tot, axis=0)
            se_class_tot   = np.nanstd(arr_tot, ddof=1, axis=0)/np.sqrt(arr_tot.shape[0])
            mean_class_epi = np.nanmean(arr_epi, axis=0)
            se_class_epi   = np.nanstd(arr_epi, ddof=1, axis=0)/np.sqrt(arr_epi.shape[0])
        else:
            mean_class_tot = [np.nan, np.nan]
            se_class_tot   = [np.nan, np.nan]
            mean_class_epi = [np.nan, np.nan]
            se_class_epi   = [np.nan, np.nan]

        # Print per-class results
        print(f"  P={P} Pre-class total entropy    → Class 0: {mean_class_tot[0]:.8f} ± {se_class_tot[0]:.8f}, "
              f"Class 1: {mean_class_tot[1]:.8f} ± {se_class_tot[1]:.8f}")
        print(f"  P={P} Pre-class epistemic unc.  → Class 0: {mean_class_epi[0]:.8f} ± {se_class_epi[0]:.8f}, "
              f"Class 1: {mean_class_epi[1]:.8f} ± {se_class_epi[1]:.8f}")

        # Save to .mat
        results['mean_class_tot_ent']   = mean_class_tot
        results['stderr_class_tot_ent'] = se_class_tot
        results['mean_class_epi_ent']   = mean_class_epi
        results['stderr_class_epi_ent'] = se_class_epi

        savemat(f'psmc_aggregated_results_P{P}.mat', results)
        print(f"--> Saved psmc_aggregated_results_P{P}.mat")
