#!/usr/bin/env python3
import os
# Set environment variables to help avoid segmentation faults on macOS.
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

import sys
import time
import random
import json
import numpy as np
from scipy.io import loadmat, savemat
from scipy.special import softmax
from sklearn.metrics import accuracy_score, f1_score, average_precision_score, log_loss
import torch
import torch.nn as nn
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer

# Global constant for numerical stability.
EPS = 1e-12

#####################################
# Model and Utility Functions
#####################################

class SimpleMLP(nn.Module):
    """
    A simple classifier (logistic regression if hidden_dim == 0 or None).
    """
    def __init__(self, input_dim=768, hidden_dim=0, num_classes=2):
        super(SimpleMLP, self).__init__()
        if hidden_dim is None or hidden_dim == 0:
            self.fc = nn.Linear(input_dim, num_classes)
            nn.init.normal_(self.fc.weight, mean=0, std=0.5)
            if self.fc.bias is not None:
                nn.init.normal_(self.fc.bias, mean=0, std=0.5)
        else:
            self.fc1 = nn.Linear(input_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, num_classes)
            nn.init.normal_(self.fc1.weight, mean=0, std=0.5)
            nn.init.normal_(self.fc2.weight, mean=0, std=0.5)
            if self.fc1.bias is not None:
                nn.init.normal_(self.fc1.bias, mean=0, std=0.5)
            if self.fc2.bias is not None:
                nn.init.normal_(self.fc2.bias, mean=0, std=0.5)

    def forward(self, x):
        if hasattr(self, 'fc'):
            return self.fc(x)
        else:
            x = self.fc1(x)
            x = F.relu(x)
            return self.fc2(x)

def compute_ece(probs, labels, n_bins=10):
    """Compute Expected Calibration Error (ECE)."""
    bins = np.linspace(0.0, 1.0, n_bins+1)
    ece = 0.0
    for i in range(n_bins):
        mask = (probs >= bins[i]) & (probs < bins[i+1])
        if np.any(mask):
            acc_bin = (labels[mask] == (probs[mask] >= 0.5)).mean()
            conf_bin = probs[mask].mean()
            ece += np.abs(acc_bin - conf_bin) * mask.mean()
    return ece

def unflatten_params(flat, net):
    """
    Reconstructs a parameter dictionary from a flattened tensor using the parameter shapes from net.named_parameters().
    """
    param_dict = {}
    pointer = 0
    for name, param in net.named_parameters():
        numel = param.numel()
        param_dict[name] = flat[pointer:pointer+numel].view(param.shape)
        pointer += numel
    return param_dict

def compute_indomain_metrics(softmax_samples, y_true):
    """
    Given softmax predictions (shape: [n_files, n_examples, num_classes])
    for the in-domain dataset, compute:
      - Aggregated probability (mean over files)
      - Accuracy, F1 score, AUC-PR, and Negative Log Likelihood (NLL)
      - Total predictive entropy and epistemic uncertainty.
      - Additionally, computes total and epistemic entropies separately for correct and incorrect predictions.
    """
    # Average predictions over files.
    p_avg = np.mean(softmax_samples, axis=0)  # shape: (n_examples, num_classes)
    preds = np.argmax(p_avg, axis=1)
    
    acc = accuracy_score(y_true, preds)
    f1 = f1_score(y_true, preds)
    auc_pr = average_precision_score(y_true, p_avg[:, 1])
    nll = log_loss(y_true, p_avg, normalize=False)
    
    brier     = np.mean((p_avg[:,1] - y_true)**2)
    ece       = compute_ece(p_avg[:,1], y_true, n_bins=10)
    
    # Compute total entropy per test example.
    total_entropies = -np.sum(p_avg * np.log(p_avg + EPS), axis=1)  # (n_examples,)
    total_entropy = total_entropies.mean()
    
    # Compute individual entropies from each file's prediction.
    sample_entropies = np.array([
        -np.sum(probs * np.log(probs + EPS), axis=1) for probs in softmax_samples
    ])  # shape: (n_files, n_examples)
    avg_sample_entropies = sample_entropies.mean(axis=0)  # shape: (n_examples,)
    avg_entropy = avg_sample_entropies.mean()
    
    # Epistemic uncertainty: difference between total entropy and average entropy over particles.
    epistemic_uncertainty = total_entropy - avg_entropy

    # Partition samples into correct and incorrect predictions.
    correct_mask = (preds == y_true)
    incorrect_mask = ~correct_mask

    # Compute average total entropy for each partition.
    total_entropy_correct = total_entropies[correct_mask].mean() if correct_mask.any() else np.nan
    total_entropy_incorrect = total_entropies[incorrect_mask].mean() if incorrect_mask.any() else np.nan
    
    # Compute epistemic uncertainty per sample, then average over partitions.
    epistemic_uncertainties = total_entropies - avg_sample_entropies
    epistemic_uncertainty_correct = epistemic_uncertainties[correct_mask].mean() if correct_mask.any() else np.nan
    epistemic_uncertainty_incorrect = epistemic_uncertainties[incorrect_mask].mean() if incorrect_mask.any() else np.nan
    
    return {
        "aggregated_probability": p_avg,
        "accuracy": acc,
        "f1": f1,
        "auc_pr": auc_pr,
        "nll": nll,
        "brier": brier,
        "ece": ece,
        "total_entropy": total_entropy,
        "epistemic_uncertainty": epistemic_uncertainty,
        "total_entropy_correct": total_entropy_correct,
        "total_entropy_incorrect": total_entropy_incorrect,
        "epistemic_uncertainty_correct": epistemic_uncertainty_correct,
        "epistemic_uncertainty_incorrect": epistemic_uncertainty_incorrect
    }

def aggregate_replicate(file_indices, results_dir, file_pattern):
    """
    Loads result files for file_indices (stored in results_dir). Each file must contain:
      - 'hmc_pred': logits,
      - 'hmc_particles': flattened network parameters,
      - 'Lsum': cumulative “epoch” count.
    Returns:
      - softmax_array: (n_files, n_examples, num_classes)
      - avg_Lsum: average Lsum across files
      - particles_array: stacked particles (n_files, num_params)
    """
    softmax_list = []
    Lsum_list = []
    particles_list = []
    for i in file_indices:
        fname = os.path.join(results_dir, file_pattern.format(i=i))
        if not os.path.exists(fname):
            print(f"Warning: File {fname} not found.")
            continue
        data = loadmat(fname)
        pred = data["hmc_pred"]
        if pred.ndim == 3 and pred.shape[0] == 1:
            pred = np.squeeze(pred, axis=0)
        probs = softmax(pred, axis=1)
        softmax_list.append(probs)
        Lsum_list.append(float(data["Lsum"]))
        particle = np.squeeze(data["hmc_particles"])
        particles_list.append(particle)
    if not softmax_list:
        raise ValueError("No files loaded; please check file pattern and indices.")
    softmax_array = np.stack(softmax_list, axis=0)
    avg_Lsum = np.mean(Lsum_list)
    particles_array = np.stack(particles_list, axis=0)
    return softmax_array, avg_Lsum, particles_array

def compute_ood_entropies(particles, model, x_ood, device):
    """
    Evaluates the model (using each flattened particle) on OOD data x_ood.
    Returns a tuple of:
      - total predictive entropy (aggregated over test examples)
      - epistemic uncertainty, computed as the difference between the total entropy and the average of 
        particle-level entropies.
    """
    softmax_preds = []
    try:
        from torch.func import functional_call
    except ImportError:
        from torch.nn.utils.stateless import functional_call
    for particle in particles:
        particle_tensor = torch.tensor(particle, dtype=torch.float32, device=device).view(-1)
        param_dict = unflatten_params(particle_tensor, model)
        logits = functional_call(model, param_dict, x_ood)
        softmax_preds.append(torch.softmax(logits, dim=1))
    softmax_preds = torch.stack(softmax_preds, dim=0)  # (n_particles, n_examples, num_classes)
    
    agg_prob = softmax_preds.mean(dim=0)  # (n_examples, num_classes)
    total_entropies = -torch.sum(agg_prob * torch.log(agg_prob + EPS), dim=1)
    total_entropy = total_entropies.mean().item()
    
    sample_entropies = -torch.sum(softmax_preds * torch.log(softmax_preds + EPS), dim=2)
    avg_sample_entropies = sample_entropies.mean(dim=0)
    avg_entropy = avg_sample_entropies.mean().item()
    
    epistemic_uncertainty = total_entropy - avg_entropy
    
    return total_entropy, epistemic_uncertainty

############################################
# OOD Embedding Functions (using SBERT)
############################################

def load_jsonl_first_n(filename, n=100):
    data = []
    with open(filename, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            if i >= n:
                break
            data.append(json.loads(line))
    return data

def compute_ood_embeddings(device, n_samples=100):
    """
    Computes OOD embeddings for:
      - Reviews and Meta data (loaded from JSONL files in the "data" folder)
      - Lipsum texts (randomly generated)
      - Full JSON strings for Reviews and Meta data.
    Returns a dictionary mapping dataset names to Torch tensors.
    """
    reviews_file = os.path.join("data", "Appliances.jsonl")
    meta_file = os.path.join("data", "meta_Appliances.jsonl")
    reviews_data = load_jsonl_first_n(reviews_file, n=n_samples)
    meta_data = load_jsonl_first_n(meta_file, n=n_samples)
    
    ood_reviews_texts = [entry.get("text", "").strip() for entry in reviews_data]
    ood_meta_texts = []
    for entry in meta_data:
        title = entry.get("title", "")
        features = " ".join(entry.get("features", []))
        text = (title + " " + features).strip()
        ood_meta_texts.append(text)

    random.seed(42)
    np.random.seed(42)
    lorem_words = ("lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor incididunt "
                   "ut labore et dolore magna aliqua").split()
    ood_lipsum_texts = []
    for _ in range(n_samples):
        L = random.randint(1, 10)
        ood_lipsum_texts.append(" ".join(random.choices(lorem_words, k=L)))
    ood_full_reviews_texts = [json.dumps(entry, ensure_ascii=False) for entry in reviews_data]
    ood_full_meta_texts = [json.dumps(entry, ensure_ascii=False) for entry in meta_data]
    
    sbert = SentenceTransformer("all-mpnet-base-v2")
    sbert.eval()
    emb_ood_reviews = sbert.encode(ood_reviews_texts, convert_to_tensor=True).to(device)
    emb_ood_meta = sbert.encode(ood_meta_texts, convert_to_tensor=True).to(device)
    emb_ood_lipsum = sbert.encode(ood_lipsum_texts, convert_to_tensor=True).to(device)
    emb_ood_full_reviews = sbert.encode(ood_full_reviews_texts, convert_to_tensor=True).to(device)
    emb_ood_full_meta = sbert.encode(ood_full_meta_texts, convert_to_tensor=True).to(device)
    
    return {
        "ood_reviews": emb_ood_reviews,
        "ood_meta": emb_ood_meta,
        "ood_lipsum": emb_ood_lipsum,
        "ood_full_reviews": emb_ood_full_reviews,
        "ood_full_meta": emb_ood_full_meta
    }

############################################
# Main Aggregation and Evaluation (Serial)
############################################

if __name__ == "__main__":
    # Hard-coded parameters.
    results_dir = "."                      
    file_pattern = "BayesianNN_IMDB_hmc_SimpleMLP_MAP_d1538_N1_burnin25_node{i}.mat"
    test_data = "imdb_embeddings_testBig.pt"  
    output = "phmc_metrics.mat"            
    n_ood = 100                          

    R = 5       
    N = 10
    processor_list = [1, 2, 4, 8]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if not os.path.exists(test_data):
        raise FileNotFoundError(f"Test data file {test_data} not found.")
    test_data_loaded = torch.load(test_data, map_location="cpu")
    X_test, y_test = test_data_loaded  
    y_test_np = y_test.numpy().flatten()

    input_dim = X_test.shape[1]
    model = SimpleMLP(input_dim=input_dim, hidden_dim=0, num_classes=2).to(device)
    model.eval()

    print("Computing OOD embeddings using SBERT (all-mpnet-base-v2)...")
    ood_embeddings = compute_ood_embeddings(device, n_samples=n_ood)

    all_results = {}

    for P_use in processor_list:
        print(f"\nProcessing results for P_use = {P_use}")
        files_per_replicate = N * P_use
        replicate_metrics = []  # List to hold replicate-level metrics.
        # — new: per-class (0 & 1) in-domain entropy lists
        per_class_total_entropy_list = []
        per_class_epistemic_uncertainty_list = []
        # Lists for summarizing OOD entropies.
        ood_reviews_total_list = []
        ood_reviews_epi_list = []
        ood_meta_total_list = []
        ood_meta_epi_list = []
        ood_lipsum_total_list = []
        ood_lipsum_epi_list = []
        ood_full_reviews_total_list = []
        ood_full_reviews_epi_list = []
        ood_full_meta_total_list = []
        ood_full_meta_epi_list = []

        avg_Lsum_replicates = []

        for r in range(R):
            start_idx = r * files_per_replicate + 1
            end_idx = (r + 1) * files_per_replicate
            file_indices = list(range(start_idx, end_idx + 1))
            print(f"  Replicate {r+1}/{R}: processing files {start_idx} to {end_idx}")

            softmax_array, avg_Lsum, particles_array = aggregate_replicate(file_indices, results_dir, file_pattern)
            avg_Lsum_replicates.append(avg_Lsum)
            indomain = compute_indomain_metrics(softmax_array, y_test_np)
            indomain["avg_Lsum"] = avg_Lsum

            # — new: compute per-class breakdown for total & epistemic entropy
            p_avg = indomain["aggregated_probability"]
            total_entropies = -np.sum(p_avg * np.log(p_avg + EPS), axis=1)
            sample_entropies = np.array([
                -np.sum(probs * np.log(probs + EPS), axis=1)
                for probs in softmax_array
            ])
            avg_sample_entropies = sample_entropies.mean(axis=0)
            epistemic_uncertainties = total_entropies - avg_sample_entropies

            class_totals = []
            class_epis   = []
            for c in [0, 1]:
                mask = (y_test_np == c)
                class_totals.append(total_entropies[mask].mean() if mask.any() else np.nan)
                class_epis.append(epistemic_uncertainties[mask].mean() if mask.any() else np.nan)
            per_class_total_entropy_list.append(class_totals)
            per_class_epistemic_uncertainty_list.append(class_epis)

            # For each OOD dataset, compute both total and epistemic uncertainties.
            ood_reviews_total, ood_reviews_epi = compute_ood_entropies(particles_array, model, ood_embeddings["ood_reviews"], device)
            ood_meta_total, ood_meta_epi       = compute_ood_entropies(particles_array, model, ood_embeddings["ood_meta"], device)
            ood_lipsum_total, ood_lipsum_epi   = compute_ood_entropies(particles_array, model, ood_embeddings["ood_lipsum"], device)
            ood_full_reviews_total, ood_full_reviews_epi = compute_ood_entropies(particles_array, model, ood_embeddings["ood_full_reviews"], device)
            ood_full_meta_total, ood_full_meta_epi       = compute_ood_entropies(particles_array, model, ood_embeddings["ood_full_meta"], device)

            ood_reviews_total_list.append(ood_reviews_total)
            ood_reviews_epi_list.append(ood_reviews_epi)
            ood_meta_total_list.append(ood_meta_total)
            ood_meta_epi_list.append(ood_meta_epi)
            ood_lipsum_total_list.append(ood_lipsum_total)
            ood_lipsum_epi_list.append(ood_lipsum_epi)
            ood_full_reviews_total_list.append(ood_full_reviews_total)
            ood_full_reviews_epi_list.append(ood_full_reviews_epi)
            ood_full_meta_total_list.append(ood_full_meta_total)
            ood_full_meta_epi_list.append(ood_full_meta_epi)

            replicate_metrics.append({
                "indomain": indomain,
                "ood_reviews_total_entropy": ood_reviews_total,
                "ood_reviews_epistemic_uncertainty": ood_reviews_epi,
                "ood_meta_total_entropy": ood_meta_total,
                "ood_meta_epistemic_uncertainty": ood_meta_epi,
                "ood_lipsum_total_entropy": ood_lipsum_total,
                "ood_lipsum_epistemic_uncertainty": ood_lipsum_epi,
                "ood_full_reviews_total_entropy": ood_full_reviews_total,
                "ood_full_reviews_epistemic_uncertainty": ood_full_reviews_epi,
                "ood_full_meta_total_entropy": ood_full_meta_total,
                "ood_full_meta_epistemic_uncertainty": ood_full_meta_epi
            })

            print(f"    In-domain Accuracy: {indomain['accuracy']:.8f}, Avg Lsum: {avg_Lsum:.8f}")

        def mean_and_se(lst):
            n = len(lst)
            mean_val = np.mean(lst)
            se_val = np.std(lst, ddof=1) / np.sqrt(n) if n > 1 else 0.0
            return mean_val, se_val

        # Compute summary statistics for averaged Lsum over replicates.
        mean_Lsum = np.mean(avg_Lsum_replicates)
        se_Lsum = np.std(avg_Lsum_replicates, ddof=1) / np.sqrt(len(avg_Lsum_replicates))
        print(f"  P_use={P_use} Aggregated Lsum: mean = {mean_Lsum:.8f}, SE = {se_Lsum:.8f}")

        # Summarize in-domain metrics.
        indomain_metric_names = ["accuracy", "f1", "auc_pr", "nll", "brier", "ece", "total_entropy", "epistemic_uncertainty",
                                   "total_entropy_correct", "total_entropy_incorrect",
                                   "epistemic_uncertainty_correct", "epistemic_uncertainty_incorrect",
                                   "avg_Lsum"]
        indomain_summary = {}
        for key in indomain_metric_names:
            values = np.array([rep["indomain"][key] for rep in replicate_metrics])
            indomain_summary[key + "_mean"] = float(np.mean(values))
            indomain_summary[key + "_se"] = float(np.std(values, ddof=1) / np.sqrt(len(values)))
            print(f"  P_use={P_use} In-domain {key}: mean = {indomain_summary[key + '_mean']:.8f}, SE = {indomain_summary[key + '_se']:.8f}")

        # Summarize OOD entropies.
        ood_summary = {}
        for (name, values) in [
            ("ood_reviews_total_entropy", ood_reviews_total_list),
            ("ood_reviews_epistemic_uncertainty", ood_reviews_epi_list),
            ("ood_meta_total_entropy", ood_meta_total_list),
            ("ood_meta_epistemic_uncertainty", ood_meta_epi_list),
            ("ood_lipsum_total_entropy", ood_lipsum_total_list),
            ("ood_lipsum_epistemic_uncertainty", ood_lipsum_epi_list),
            ("ood_full_reviews_total_entropy", ood_full_reviews_total_list),
            ("ood_full_reviews_epistemic_uncertainty", ood_full_reviews_epi_list),
            ("ood_full_meta_total_entropy", ood_full_meta_total_list),
            ("ood_full_meta_epistemic_uncertainty", ood_full_meta_epi_list)
        ]:
            mean_val, se_val = mean_and_se(values)
            ood_summary[name + "_mean"] = mean_val
            ood_summary[name + "_se"] = se_val
            print(f"  P_use={P_use} {name}: mean = {mean_val:.8f}, SE = {se_val:.8f}")

        # ----------------------------
        # Save aggregated results to .mat
        # ----------------------------
        # map in-domain
        results = {
            'mean_acc':        indomain_summary['accuracy_mean'],
            'stderr_acc':      indomain_summary['accuracy_se'],
            'mean_f1_id':      indomain_summary['f1_mean'],
            'stderr_f1_id':    indomain_summary['f1_se'],
            'mean_aucpr_id':   indomain_summary['auc_pr_mean'],
            'stderr_aucpr_id': indomain_summary['auc_pr_se'],
            'mean_nll':        indomain_summary['nll_mean'],
            'stderr_nll':      indomain_summary['nll_se'],
            'mean_brier':        indomain_summary['brier_mean'],
            'stderr_brier':      indomain_summary['brier_se'],
            'mean_ece':        indomain_summary['ece_mean'],
            'stderr_ece':      indomain_summary['ece_se'],
            'mean_tot_ent':    indomain_summary['total_entropy_mean'],
            'stderr_tot_ent':  indomain_summary['total_entropy_se'],
            'mean_epi':        indomain_summary['epistemic_uncertainty_mean'],
            'stderr_epi':      indomain_summary['epistemic_uncertainty_se'],
            'mean_Lsum':       indomain_summary['avg_Lsum_mean'],
            'stderr_Lsum':     indomain_summary['avg_Lsum_se'],

            # correct vs incorrect
            'mean_tot_corr':   indomain_summary['total_entropy_correct_mean'],
            'se_tot_corr':     indomain_summary['total_entropy_correct_se'],
            'mean_epi_corr':   indomain_summary['epistemic_uncertainty_correct_mean'],
            'se_epi_corr':     indomain_summary['epistemic_uncertainty_correct_se'],
            'mean_tot_inc':    indomain_summary['total_entropy_incorrect_mean'],
            'se_tot_inc':      indomain_summary['total_entropy_incorrect_se'],
            'mean_epi_inc':    indomain_summary['epistemic_uncertainty_incorrect_mean'],
            'se_epi_inc':      indomain_summary['epistemic_uncertainty_incorrect_se'],
        }

        # OOD per-category
        ood_keys = [
            'reviews',
            'meta',
            'lipsum',
            'full_reviews',
            'full_meta',
        ]
        for key in ood_keys:
            tot_mean = ood_summary.get(f'ood_{key}_total_entropy_mean',   np.nan)
            tot_se   = ood_summary.get(f'ood_{key}_total_entropy_se',     np.nan)
            epi_mean = ood_summary.get(f'ood_{key}_epistemic_uncertainty_mean', np.nan)
            epi_se   = ood_summary.get(f'ood_{key}_epistemic_uncertainty_se',   np.nan)

            results[f'ood_{key}_mean_tot'] = tot_mean
            results[f'ood_{key}_se_tot']   = tot_se
            results[f'ood_{key}_mean_epi'] = epi_mean
            results[f'ood_{key}_se_epi']   = epi_se

        # — new: aggregate per-class in-domain entropies over replicates
        class_tot_arr = np.vstack(per_class_total_entropy_list)
        class_epi_arr = np.vstack(per_class_epistemic_uncertainty_list)
        mean_class_tot = np.nanmean(class_tot_arr, axis=0)
        se_class_tot   = np.nanstd(class_tot_arr, ddof=1, axis=0) / np.sqrt(class_tot_arr.shape[0])
        mean_class_epi = np.nanmean(class_epi_arr, axis=0)
        se_class_epi   = np.nanstd(class_epi_arr, ddof=1, axis=0) / np.sqrt(class_epi_arr.shape[0])
        results['mean_class_tot_ent']   = mean_class_tot
        results['stderr_class_tot_ent'] = se_class_tot
        results['mean_class_epi_ent']   = mean_class_epi
        results['stderr_class_epi_ent'] = se_class_epi

        # after computing mean_class_tot, se_class_tot, mean_class_epi, se_class_epi:

        print(f"P_use={P_use}  Pre-class total entropy   → Class 0: {mean_class_tot[0]:.8f} ± {se_class_tot[0]:.8f}, "
            f"Class 1: {mean_class_tot[1]:.8f} ± {se_class_tot[1]:.8f}")
        print(f"P_use={P_use}  Pre-class epistemic unc. → Class 0: {mean_class_epi[0]:.8f} ± {se_class_epi[0]:.8f}, "
            f"Class 1: {mean_class_epi[1]:.8f} ± {se_class_epi[1]:.8f}")

        savemat(f'phmc_aggregated_results_P{P_use}.mat', results)
        print(f"--> Saved phmc_aggregated_results_P{P_use}.mat")
