#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Feb 24 10:43:57 2025

Adapted for IMDB: SMC with MAP prior using SBERT embeddings and a SimpleMLP.
"""

import os
import sys
import time
import random
import numpy as np
from scipy.io import savemat
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.multiprocessing as mp  # For parallel mutations
from torch.func import functional_call

import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, HMC

from sentence_transformers import SentenceTransformer
#import torchtext; torchtext.disable_torchtext_deprecation_warning()
import torchtext
from torchtext.datasets import IMDB

# Set device and global prior standard deviations.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Separate hyperparameters for MAP regularization and HMC prior
v = 0.025
sigma_w = np.sqrt(v)    # Std for weights (MAP)
sigma_b = np.sqrt(v)    # Std for biases  (MAP)

s = 0.1*v
sigma_w_s = np.sqrt(s)   # HMC prior weights
sigma_b_s = np.sqrt(s)    # HMC prior biases
scale = 1
#print(f'scale is {scale} and sigma is {sigma_b_s} and sigma_v is {sigma_b}')

######################################
# --- SimpleMLP Definition ---
######################################
class SimpleMLP(nn.Module):
    """
    A simple classifier that is either logistic regression or a single-hidden-layer MLP.
    Set hidden_dim=0 (or None) for logistic regression.
    """
    def __init__(self, input_dim=768, hidden_dim=0, num_classes=2):
        super(SimpleMLP, self).__init__()
        if hidden_dim is None or hidden_dim == 0:
            # Logistic regression: parameters will be named "fc.weight" and "fc.bias"
            self.fc = nn.Linear(input_dim, num_classes)
            nn.init.normal_(self.fc.weight, mean=0, std=sigma_w)
            if self.fc.bias is not None:
                nn.init.normal_(self.fc.bias, mean=0, std=sigma_b)
        else:
            # Single-hidden-layer MLP: parameters will be named "fc1.weight", "fc1.bias",
            # "fc2.weight", and "fc2.bias", matching your helper functions.
            self.fc1 = nn.Linear(input_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, num_classes)
            nn.init.normal_(self.fc1.weight, mean=0, std=sigma_w)
            nn.init.normal_(self.fc2.weight, mean=0, std=sigma_w)
            if self.fc1.bias is not None:
                nn.init.normal_(self.fc1.bias, mean=0, std=sigma_b)
            if self.fc2.bias is not None:
                nn.init.normal_(self.fc2.bias, mean=0, std=sigma_b)

    def forward(self, x):
        if hasattr(self, 'fc'):
            # Logistic regression branch.
            return self.fc(x)
        else:
            # MLP branch.
            x = self.fc1(x)
            x = F.relu(x)
            return self.fc2(x)

######################################
# --- Helper Functions ---
######################################
def unflatten_params(flat, net):
    """
    Reconstruct a parameter dictionary from the flattened tensor `flat`
    using the parameter shapes of `net`.
    """
    param_dict = {}
    pointer = 0
    for name, param in net.named_parameters():
        numel = param.numel()
        param_dict[name] = flat[pointer:pointer+numel].view(param.shape)
        pointer += numel
    return param_dict

# def init_particle(net, device):
#     """
#     Initialize a flattened parameter vector for `net` according to the Gaussian prior.
#     For biases use sigma_b and for weights sigma_w.
#     """
#     param_dict = {}
#     for name, param in net.named_parameters():
#         if "bias" in name:
#             param_dict[name] = torch.randn(param.shape, device=device) * sigma_b
#         else:
#             param_dict[name] = torch.randn(param.shape, device=device) * sigma_w
#     flat_params = []
#     for name, _ in net.named_parameters():
#         flat_params.append(param_dict[name].view(-1))
#     return torch.cat(flat_params)
def init_particle(net, device, prior_params):
    """
    Initialize a flattened parameter vector for `net` according to the SMC Gaussian prior
    centered at the MAP prior_params.
    """
    param_dict = {}
    for name, param in net.named_parameters():
        # MAP mean from prior_params, SMC std sigma_w_s / sigma_b_s
        mean = prior_params[name].to(device)
        std = sigma_b_s if "bias" in name else sigma_w_s
        param_dict[name] = mean + torch.randn(param.shape, device=device) * std
    flat_params = [param_dict[name].view(-1) for name, _ in net.named_parameters()]
    return torch.cat(flat_params)


def model_loss_func_ll(output, y, temp):
    """
    Compute the scaled cross-entropy loss.
    """
    crit = nn.CrossEntropyLoss(reduction='sum')
    return crit(output, y.long().view(-1)) * temp

def model_loss_func(params, x, y, temp, net, prior_params):
    """
    Compute the potential energy (negative joint log probability) with a Gaussian prior
    centered at the MAP estimates (prior_params) for the SimpleMLP.

    NOTE: For MAP training, the regularization uses sigma_w and sigma_b.
          For SMC inference, we use the SMC prior with sigma_w_s and sigma_b_s.
    """
    # For the SMC mutations we use the logistic regression branch.
    keys = ["fc.weight", "fc.bias"]
    log_prior = 0.0
    for k in keys:
        param_tensor = params[k]
        prior_tensor = prior_params[k]
        # Use the SMC-specific prior standard deviations here.
        sigma = sigma_b_s if "bias" in k else sigma_w_s
        log_prior += ((param_tensor - prior_tensor)**2).sum() / (2 * sigma**2)
    logits = functional_call(net, params, x)
    log_likelihood = model_loss_func_ll(logits, y, temp)
    return log_prior + log_likelihood

def hmc_update_particle(net, x_train, y_train, params_init, num_samples, warmup_steps, step_size, num_steps, temp, prior_params):
    """
    Runs an HMC chain using model_loss_func with the SMC prior.
    Returns the final sample (flattened parameter vector) and the acceptance probability.
    """
    def potential_fn(params_dict):
        params = unflatten_params(params_dict["params"], net)
        return model_loss_func(params, x_train, y_train, temp, net, prior_params)
    
    hmc_kernel = HMC(
        potential_fn=potential_fn,
        step_size=step_size,
        num_steps=num_steps,
        adapt_step_size=False, adapt_mass_matrix=False,
        target_accept_prob=0.65
    )
    
    acceptance_probs = []
    def capture_diagnostics(kernel, params, stage, i):
        diag = kernel.logging()
        if "acc. prob" in diag:
            acceptance_probs.append(diag["acc. prob"])
    
    mcmc_run = MCMC(
        hmc_kernel,
        num_samples=num_samples,
        warmup_steps=warmup_steps,
        initial_params={"params": params_init},
        disable_progbar=True,
        hook_fn=capture_diagnostics,
    )
    mcmc_run.run()
    samples = mcmc_run.get_samples()["params"]
    acc_rate = acceptance_probs[-1] if acceptance_probs else 0.0
    return samples[-1], float(acc_rate)

def update_particle(params, net, x_train, y_train, step_size, L, temp, M, prior_params, node):
    """
    Performs the HMC update for a single particle and computes the new loss.
    Designed to be run in parallel.
    """
    torch.manual_seed(node)
    np.random.seed(node)
    random.seed(node)
    pyro.set_rng_seed(node)
    updated_params, acc_rate = hmc_update_particle(
        net, x_train, y_train, params,
        num_samples=M, warmup_steps=0, step_size=step_size,
        num_steps=L, temp=temp, prior_params=prior_params
    )
    param_dict = unflatten_params(updated_params, net)
    output = functional_call(net, param_dict, x_train)
    loss_val = model_loss_func_ll(output, y_train, temp=1)
    return updated_params, (loss_val.item() if torch.is_tensor(loss_val) else loss_val), acc_rate

def compute_accuracy(net, particles, x_val, y_val):
    """
    Compute the ensemble accuracy on validation data.
    """
    with torch.no_grad():
        logits_list = []
        for params in particles:
            param_dict = unflatten_params(params, net)
            logits = functional_call(net, param_dict, x_val)
            logits_list.append(logits)
        ensemble_logits = torch.stack(logits_list).mean(dim=0)
        preds = ensemble_logits.argmax(dim=1)
        correct = (preds == y_val.view(-1)).sum().item()
        accuracy = correct / y_val.size(0)
    return accuracy

######################################
# --- SBERT Embedding & Caching for IMDB ---
######################################
def create_sbert_embedded_imdb_dataset(
    model_name="all-mpnet-base-v2",
    #train_cache_path="imdb_embeddings_train.pt",
    #test_cache_path="imdb_embeddings_test.pt"
    train_cache_path = "imdb_embeddings_trainBig.pt",
    test_cache_path  = "imdb_embeddings_testBig.pt"
):
    """
    Loads the IMDB dataset and uses SBERT to embed each review.
    If cached embeddings exist, they are loaded from disk.
    
    Returns:
        X_train, y_train, X_test, y_test as PyTorch tensors.
    """
    if os.path.exists(train_cache_path) and os.path.exists(test_cache_path):
        print("Loading cached SBERT embeddings from disk...")
        X_train, y_train = torch.load(train_cache_path)
        X_test, y_test = torch.load(test_cache_path)
        return X_train, y_train, X_test, y_test

    print("Cached embeddings not found. Computing SBERT embeddings...")
    sbert = SentenceTransformer(model_name)
    sbert.eval()

    # Load IMDB dataset splits.
    train_data = list(IMDB(split="train"))
    test_data  = list(IMDB(split="test"))
    
    # Map labels: some versions use (1,2); here we map to 0 and 1.
    label_map = {1: 0, 2: 1}
    X_train_list, y_train_list = [], []
    for (label, text) in train_data:
        label_int = label_map.get(label, label)
        emb = sbert.encode(text, convert_to_numpy=True)
        X_train_list.append(emb)
        y_train_list.append(label_int)
    X_test_list, y_test_list = [], []
    for (label, text) in test_data:
        label_int = label_map.get(label, label)
        emb = sbert.encode(text, convert_to_numpy=True)
        X_test_list.append(emb)
        y_test_list.append(label_int)
    
    X_train = torch.tensor(X_train_list, dtype=torch.float32)
    y_train = torch.tensor(y_train_list, dtype=torch.long)
    X_test  = torch.tensor(X_test_list, dtype=torch.float32)
    y_test  = torch.tensor(y_test_list, dtype=torch.long)
    
    torch.save((X_train, y_train), train_cache_path)
    torch.save((X_test, y_test), test_cache_path)
    print("SBERT embeddings computed and saved to disk.")
    
    return X_train, y_train, X_test, y_test

######################################
# --- SMC Sampler with HMC Mutations ---
######################################
def psmc_single(node, numwork, x_train, y_train, x_val, y_val, NetClass, N, step_size, L, M, trajectory, prior_params):
    """
    Runs an SMC sampler (single replicate) with N particles.
    Initializes particles from the Gaussian prior and then performs tempering with HMC mutations.
    The Gaussian prior is centered at the MAP estimates (prior_params) for the SimpleMLP.
    
    NOTE: The initialization and MAP training use sigma_w and sigma_b,
          while the mutations (i.e. the SMC prior) use sigma_w_s and sigma_b_s.
    """
    print("Node", node)
    torch.manual_seed(node)
    np.random.seed(node)
    random.seed(node)
    pyro.set_rng_seed(node)
    
    net = NetClass().to(device)
    net.eval()
    net.share_memory()  # Enable sharing across processes
    
    d = sum(p.numel() for p in net.parameters())
    num_val = x_val.size(0)
    num_classes = 2  # For SimpleMLP
    
    t_start = time.time()
    particles = []
    llike = []
    preds = []
    
    # Particle initialization.
    for _ in range(N):
        #params_init = init_particle(net, device)
        params_init = init_particle(net, device, prior_params)
        particles.append(params_init)
        param_dict = unflatten_params(params_init, net)
        output = functional_call(net, param_dict, x_train)
        loss_val = model_loss_func_ll(output, y_train, temp=1)
        llike.append(loss_val.item() if torch.is_tensor(loss_val) else loss_val)
    llike = np.array(llike)
    
    tempcurr = 0.0
    ZZ = 0.0
    KK = 0.0
    count = 0
    Ls = 0
    step_size_cur = step_size
    
    # Tempering loop.
    while tempcurr < 1.0:
        temp_increment = 1.0 - tempcurr
        lwhat = -temp_increment * llike
        lmax = np.max(lwhat)
        w = np.exp(lwhat - lmax)
        w /= np.sum(w)
        ess = 1.0 / np.sum(w**2)
        while ess < N / 2:
            temp_increment /= 2.0
            lwhat = -temp_increment * llike
            lmax = np.max(lwhat)
            w = np.exp(lwhat - lmax)
            w /= np.sum(w)
            ess = 1.0 / np.sum(w**2)
        proposed_temp = tempcurr + temp_increment
        if proposed_temp >= 1.0 and count == 0:
            proposed_temp = 0.5
        print("Current temperature:", tempcurr, ", dT=", temp_increment, "Step=", count)
        ZZ += np.log(np.mean(np.exp(lwhat - lmax)))
        KK += lmax
        Ls += L * M
        
        # Systematic resampling.
        cumulative_w = np.cumsum(w)
        new_particles = []
        new_llike = []
        positions = (np.arange(N) + np.random.uniform(0, 1)) / N
        i, j = 0, 0
        while i < N:
            if positions[i] < cumulative_w[j]:
                new_particles.append(particles[j])
                new_llike.append(llike[j])
                i += 1
            else:
                j += 1
        particles = new_particles
        llike = np.array(new_llike)
        
        old_particles = particles.copy()
        old_llike = llike.copy()
        
        # --- Parallel Mutation Phase ---
        mutation_success = False
        while not mutation_success:
            mutation_nodes = [random.randint(0, 2**10 - 1) for _ in range(len(particles))]
            with mp.Pool(processes=numwork) as pool:
                results = pool.starmap(
                    update_particle,
                    [(params, net, x_train, y_train, step_size_cur, L, proposed_temp, M, prior_params, node)
                     for params, node in zip(particles, mutation_nodes)]
                )
            updated_particles, updated_llike, acc_rates = zip(*results)
            overall_acc = np.mean(acc_rates)
            print("Overall acceptance for mutation phase:", overall_acc, "Stepsize=", step_size_cur, "L=", L)
            if overall_acc < 0.4 or overall_acc > 0.95:
                if overall_acc < 0.4:
                    step_size_cur *= 0.7
                else:
                    step_size_cur *= 1.1
                print("Unsatisfactory acceptance. Restarting mutation phase with step size", step_size_cur, "L=", L)
                particles = old_particles.copy()
                llike = old_llike.copy()
            else:
                mutation_success = True
                if overall_acc < 0.6:
                    step_size_cur *= 0.7
                elif overall_acc > 0.8:
                    step_size_cur *= 1.1
                particles = list(updated_particles)
                llike = np.array(updated_llike)
            L = min(max(1, int(trajectory / step_size_cur)), 100)
        tempcurr = proposed_temp
        count += 1
        
    # Compute predictions on validation data.
    for params in particles:
        param_dict = unflatten_params(params, net)
        output = functional_call(net, param_dict, x_val)
        preds.append(output)
        
    accuracy = compute_accuracy(net, particles, x_val, y_val)
    print("Validation accuracy:", accuracy)
    
    particles_tensor = torch.stack(particles).detach().cpu()
    particles_flat = particles_tensor.numpy().reshape(N, d)
    preds_tensor = torch.stack(preds).detach().cpu()
    predictions = preds_tensor.numpy().reshape(N, x_val.size(0), num_classes)
    
    t_elapsed = time.time() - t_start
    return t_elapsed, predictions, particles_flat, ZZ, KK, count, Ls

######################################
# --- Main: MAP Training then SMC Inference ---
######################################
if __name__ == '__main__':
    # Set multiprocessing start method.
    mp.set_start_method('spawn', force=True)

    print(f'scale is {scale} and sigma is {sigma_b_s} and sigma_v is {sigma_b}')
    
    # Retrieve command-line arguments.
    # Example usage: python script.py <node> <numwork>
    node = int(sys.argv[1])
    numwork = int(sys.argv[2])
    
    # --- Data Loading and SBERT Embedding ---
    train_cache = "imdb_embeddings_trainBig.pt"
    test_cache  = "imdb_embeddings_testBig.pt"
    # After loading SBERT embeddings:
    X_train, y_train, X_test, y_test = create_sbert_embedded_imdb_dataset(
        train_cache_path=train_cache,
        test_cache_path=test_cache
    )

    # MAP pre-training split: first 20k train, next 5k early-stop val
    x_map_train = X_train[:20000].to(device)
    y_map_train = y_train[:20000].to(device)
    x_map_val   = X_train[20000:25000].to(device)
    y_map_val   = y_train[20000:25000].to(device)

    # SMC uses full train set; test split unchanged
    x_train = X_train.to(device)
    y_train = y_train.to(device)
    x_val   = X_test.to(device)
    y_val   = y_test.to(device)

    print(f"The number of train data is {len(y_train)}")
    print(f"The number of validation data is {len(y_val)}")

    torch.manual_seed(node)
    np.random.seed(node)
    random.seed(node)
    pyro.set_rng_seed(node)
    
    # --- MAP (Deterministic) Training of SimpleMLP ---
    # When setting input_dim for MAP training:
    input_dim = x_map_train.shape[1]
    model_mlp = SimpleMLP(input_dim=input_dim, hidden_dim=0, num_classes=2).to(device)
    optimizer = optim.Adam(model_mlp.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    net = model_mlp
    net.eval()
    d = sum(p.numel() for p in net.parameters())
    print(f"The dimension of the parameters in SimpleMLP is {d}")
    
    model_mlp.train()
    train_losses = []
    val_losses = []
    moving_avg_window = 10
    best_moving_avg = float('inf')
    patience = 5
    no_improve_count = 0
    start_time = time.time()
    
    # DataLoaders for MAP pre-training
    train_dataset = torch.utils.data.TensorDataset(x_map_train, y_map_train)
    val_dataset   = torch.utils.data.TensorDataset(x_map_val,   y_map_val)
    train_loader  = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader    = DataLoader(val_dataset,   batch_size=64, shuffle=False)
    
    for epoch in range(1000):
        model_mlp.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model_mlp(inputs)
            ce_loss = criterion(outputs, labels)
            # Regularization term (MAP prior with zero mean and std sigma_w/sigma_b).
            reg_fc = (torch.sum(model_mlp.fc.weight**2) / (2 * sigma_w**2) +
                      torch.sum(model_mlp.fc.bias**2) / (2 * sigma_b**2))
            reg_loss = reg_fc
            loss = ce_loss + reg_loss / len(train_dataset)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        train_loss = running_loss / len(train_loader)
        train_losses.append(train_loss)
        
        model_mlp.eval()
        val_running_loss = 0.0
        with torch.no_grad():
            for inputs, labels in val_loader:
                outputs = model_mlp(inputs)
                ce_loss = criterion(outputs, labels)
                reg_fc = (torch.sum(model_mlp.fc.weight**2) / (2 * sigma_w**2) +
                           torch.sum(model_mlp.fc.bias**2) / (2 * sigma_b**2))
                reg_loss = reg_fc
                loss = ce_loss + reg_loss / len(train_dataset)
                val_running_loss += loss.item()
        val_loss = val_running_loss / len(val_loader)
        val_losses.append(val_loss)
        
        print(f"MAP Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        if epoch >= moving_avg_window - 1:
            moving_avg = sum(val_losses[-moving_avg_window:]) / moving_avg_window
            if moving_avg < best_moving_avg:
                best_moving_avg = moving_avg
                no_improve_count = 0
            else:
                no_improve_count += 1
            if no_improve_count >= patience:
                print("MAP: Early stopping due to no improvement.")
                break
    
    total_map_time = time.time() - start_time
    print(f"MAP Total execution time: {total_map_time:.2f} seconds")
    
    # --- Extract Prior Parameters from the MAP Model ---
    prior_params = model_mlp.state_dict()
    print("Prior parameters extracted from the MAP trained SimpleMLP.")

    for name in prior_params:
        prior_params[name] = scale * prior_params[name]
    
    ############################################
    # --- SMC Inference using MAP Prior ---
    ############################################
    # SMC Sampler Parameters.
    N = 10       # Number of particles
    step_size = 0.022
    L = 1        # Number of leapfrog steps per HMC update
    M = 1       # Number of HMC samples per mutation
    trajectory = 0.0  # Set trajectory scaling constant (can be tuned)
    
    start_time = time.time()
    (smc_time, predictions, particles_flat, Z, K, count, Lsum) = psmc_single(
        node=node,
        numwork=numwork,
        x_train=x_train,
        y_train=y_train,
        x_val=x_val,
        y_val=y_val,
        NetClass=SimpleMLP,
        N=N,
        step_size=step_size,
        L=L,
        M=M,
        trajectory=trajectory,
        prior_params=prior_params
    )
    total_smc_time = time.time() - start_time
    print(f"SMC Total execution time: {total_smc_time:.2f} seconds")
    acc = compute_accuracy(SimpleMLP(input_dim=input_dim, hidden_dim=0, num_classes=2).to(device),
                           [torch.tensor(p, device=device) for p in particles_flat],
                           x_val, y_val)
    print("SMC Validation Accuracy:", acc)
    print("SMC epochs:", Lsum)
    
    # Save results.
    savemat(f'BayesianNN_IMDB_psmc_SimpleMLP_MAP_d{sum(p.numel() for p in model_mlp.parameters())}_N{N}_M{M}_node{node}.mat', {
        'psmc_single_time': smc_time,
        'psmc_single_pred': predictions,
        'psmc_single_x': particles_flat,
        'Z': Z,
        'K': K,
        'count': count,
        'Lsum': Lsum
    })
