"""
Use Pyro’s SMC-HMC parallel for Bayesian inference on IMDB data using SBERT embeddings.
Here we use a slightly more complex MLP (one hidden layer) as the forward model.
"""

from scipy.io import savemat, loadmat
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset

import os
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sentence_transformers import SentenceTransformer
from torchtext.datasets import IMDB
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, HMC
from torch.func import functional_call
import torch.multiprocessing as mp  # Use multiprocessing

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sigma_w = 0.1
sigma_b = 0.01

######################################
# --- SimpleMLP Definition ---
######################################
class SimpleMLP(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=128, num_classes=2):
        super().__init__()
        # One hidden layer MLP: fc1 -> ReLU -> fc2
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
        
        # Initialize weights and biases with a standard Gaussian (mean 0, std sigma)
        nn.init.normal_(self.fc1.weight, mean=0, std=sigma_w)
        nn.init.normal_(self.fc2.weight, mean=0, std=sigma_w)
        if self.fc1.bias is not None:
            nn.init.normal_(self.fc1.bias, mean=0, std=sigma_b)
        if self.fc2.bias is not None:
            nn.init.normal_(self.fc2.bias, mean=0, std=sigma_b)
    
    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

class LogReg(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=128, num_classes=2):
        super().__init__()
        self.fc = nn.Linear(input_dim, num_classes)
        nn.init.normal_(self.fc.weight, mean=0, std=sigma_w)
        if self.fc.bias is not None:
            nn.init.normal_(self.fc.bias, mean=0, std=sigma_b)
    
    def forward(self, x):
        x = self.fc(x)
        return x

######################################
# --- Helper Functions ---
######################################
def flatten_net(net):
    """Flatten all parameters of a network into a single 1D tensor."""
    return torch.cat([p.view(-1) for p in net.parameters()])

def unflatten_params(flat, net):
    """
    Reconstruct a parameter dictionary from the flattened tensor 'flat'
    using the parameter shapes of 'net'.
    """
    param_dict = {}
    pointer = 0
    for name, param in net.named_parameters():
        numel = param.numel()
        param_dict[name] = flat[pointer:pointer+numel].view(param.shape)
        pointer += numel
    return param_dict

def init_particle(net, device):
    """
    Initialize a flattened parameter vector for 'net' according to a
    standard Gaussian prior for all weights and biases.
    """
    param_dict = {}
    for name, param in net.named_parameters():
        if "bias" in name:
            param_dict[name] = torch.randn(param.shape, device=device) * sigma_b
        else:
            param_dict[name] = torch.randn(param.shape, device=device) * sigma_w
    flat_params = []
    for name, _ in net.named_parameters():
        flat_params.append(param_dict[name].view(-1))
    return torch.cat(flat_params)

def model_loss_func_ll(output, y_device, temp):
    """
    Compute the scaled cross-entropy loss.
    """
    crit = nn.CrossEntropyLoss(reduction='sum')
    return crit(output, y_device.long().view(-1)) * temp

def model_loss_func(params, x, y, temp, net):
    """
    Compute the potential energy (negative joint log probability) using a standard Gaussian
    prior (mean 0, sigma) for all parameters, along with a categorical likelihood.
    """
    keys = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"]
    log_prior = 0.0
    for k in keys:
        param_tensor = params[k]
        if "bias" in k:
            sigma = sigma_b
        else:
            sigma = sigma_w
        log_prior += (param_tensor**2).sum() / (2 * sigma**2)
    
    logits = functional_call(net, params, x)
    log_likelihood = model_loss_func_ll(logits, y, temp=temp)
    return (log_prior + log_likelihood)

def hmc_update_particle(net, x_train, y_train, params_init, num_samples, warmup_steps, step_size, num_steps, temp):
    """
    Runs an HMC chain using model_loss_func.
    Returns the final sample (flat parameter vector) and the acceptance probability.
    """
    def potential_fn(params_dict):
        params = unflatten_params(params_dict["params"], net)
        return model_loss_func(params, x_train, y_train, temp, net)
    
    hmc_kernel = HMC(
        potential_fn=potential_fn,
        step_size=step_size,
        num_steps=num_steps,
        adapt_step_size=False, adapt_mass_matrix=False,
        target_accept_prob=0.65
    )
    
    acceptance_probs = []
    
    def capture_diagnostics(kernel, params, stage, i):
        diag = kernel.logging()
        if "acc. prob" in diag:
            acceptance_probs.append(diag["acc. prob"])
    
    mcmc_run = MCMC(
        hmc_kernel,
        num_samples=num_samples,
        warmup_steps=warmup_steps,
        initial_params={"params": params_init},
        disable_progbar=True,
        hook_fn=capture_diagnostics,
    )
    
    mcmc_run.run()
    samples = mcmc_run.get_samples()["params"]
    acc_rate = acceptance_probs[-1] if acceptance_probs else 0.0
    
    return samples[-1], float(acc_rate)

# --- Parallel Particle Update Function ---
def update_particle(params, net, x_train, y_train, step_size, L, temp, M):
    """
    This function performs the HMC update for a single particle and computes the new loss.
    It is designed to be run in parallel.
    """
    updated_params, acc_rate = hmc_update_particle(
        net, x_train, y_train, params,
        num_samples=M, warmup_steps=0, step_size=step_size,
        num_steps=L, temp=temp
    )
    param_dict = unflatten_params(updated_params, net)
    output = functional_call(net, param_dict, x_train)
    loss_val = model_loss_func_ll(output, y_train, temp=1)
    return updated_params, (loss_val.item() if torch.is_tensor(loss_val) else loss_val), acc_rate

def compute_accuracy(net, particles, x_val, y_val):
    """
    Compute the ensemble accuracy on the validation data.
    The logits from each particle are averaged and then the argmax is taken.
    """
    with torch.no_grad():
        logits_list = []
        for params in particles:
            param_dict = unflatten_params(params, net)
            logits = functional_call(net, param_dict, x_val)
            logits_list.append(logits)
        ensemble_logits = torch.stack(logits_list).mean(dim=0)
        preds = ensemble_logits.argmax(dim=1)
        correct = (preds == y_val.view(-1)).sum().item()
        accuracy = correct / y_val.size(0)
    return accuracy

######################################
# --- SBERT Embedding + Caching Function ---
######################################
def create_sbert_embedded_imdb_dataset(
    model_name="all-mpnet-base-v2",
    train_cache_path="imdb_embeddings_train.pt",
    test_cache_path="imdb_embeddings_test.pt"
):
    """
    Loads the IMDB dataset and uses SBERT to embed each review.
    If cached embeddings exist, they are loaded from disk.
    
    Returns:
        X_train, y_train, X_test, y_test as PyTorch tensors.
    """
    if os.path.exists(train_cache_path) and os.path.exists(test_cache_path):
        print("Loading cached SBERT embeddings from disk...")
        X_train, y_train = torch.load(train_cache_path)
        X_test, y_test = torch.load(test_cache_path)
        return X_train, y_train, X_test, y_test

    print("Cached embeddings not found. Computing SBERT embeddings...")
    sbert = SentenceTransformer(model_name)
    sbert.eval()

    # Load IMDB dataset splits
    train_data = list(IMDB(split="train"))
    test_data  = list(IMDB(split="test"))
    
    # Map labels: some versions use 1 and 2, map them to 0 and 1
    label_map = {1: 0, 2: 1}

    X_train_list, y_train_list = [], []
    for (label, text) in train_data:
        label_int = label_map.get(label, label)
        emb = sbert.encode(text, convert_to_numpy=True)
        X_train_list.append(emb)
        y_train_list.append(label_int)

    X_test_list, y_test_list = [], []
    for (label, text) in test_data:
        label_int = label_map.get(label, label)
        emb = sbert.encode(text, convert_to_numpy=True)
        X_test_list.append(emb)
        y_test_list.append(label_int)

    X_train = torch.tensor(X_train_list, dtype=torch.float32)
    y_train = torch.tensor(y_train_list, dtype=torch.long)
    X_test  = torch.tensor(X_test_list, dtype=torch.float32)
    y_test  = torch.tensor(y_test_list, dtype=torch.long)

    torch.save((X_train, y_train), train_cache_path)
    torch.save((X_test, y_test), test_cache_path)
    print("SBERT embeddings computed and saved to disk.")

    return X_train, y_train, X_test, y_test

######################################
# --- SMC Sampler with HMC Mutations ---
######################################
def psmc_single(node, numwork, x_train, y_train, x_val, y_val, NetClass, N, step_size, L, M, trajectory):
    """
    Runs an SMC sampler over 1 replicate. Initializes N particles
    according to a standard Gaussian prior and gradually increases tempering from 0 to 1 while applying HMC mutations.
    
    Returns:
        times: execution time in one replicate.
        predictions: predictions arrays in one replicate.
        particles_flat: flattened particles in one replicate.
        Z_list: log normalizing constants in one replicate.
        K_list: adjustment constants in one replicate.
        count_all: mutation phase counts in one replicate.
        L_list: summation of number of leapfrog steps.
    """
    print("Node", node)
    torch.manual_seed(node)
    np.random.seed(node)
    
    # Instantiate the model.
    net = NetClass().to(device)
    net.eval()
    net.share_memory()  # Enable sharing the model across processes
    
    d = sum(p.numel() for p in net.parameters())
    num_val = x_val.size(0)
    num_classes = 2  # For SimpleMLP
    
    t_start = time.time()
    particles = []
    llike = []
    preds = []
        
    # Particle initialization.
    for _ in range(N):
        params_init = init_particle(net, device)
        particles.append(params_init)
        param_dict = unflatten_params(params_init, net)
        output = functional_call(net, param_dict, x_train)
        loss_val = model_loss_func_ll(output, y_train, temp=1)
        llike.append(loss_val.item() if torch.is_tensor(loss_val) else loss_val)
    llike = np.array(llike)
        
    tempcurr = 0.0
    ZZ = 0.0
    KK = 0.0
    count = 0
    Ls = 0
    step_size_cur = step_size
        
    # Tempering loop.
    while tempcurr < 1.0:
        t_start1 = time.time()

        temp_increment = 1.0 - tempcurr
        lwhat = -temp_increment * llike
        lmax = np.max(lwhat)
        w = np.exp(lwhat - lmax)
        w /= np.sum(w)
        ess = 1.0 / np.sum(w**2)
            
        # Adjust temperature increment until the ESS is acceptable.
        while ess < N / 2:
            temp_increment /= 2.0
            lwhat = -temp_increment * llike
            lmax = np.max(lwhat)
            w = np.exp(lwhat - lmax)
            w /= np.sum(w)
            ess = 1.0 / np.sum(w**2)
            
        proposed_temp = tempcurr + temp_increment
        if proposed_temp >= 1.0 and count == 0:
            proposed_temp = 0.5  # Ensure at least one HMC mutation.
        print("Current temperature:", tempcurr, ", dT=", temp_increment, "Step=", count)
        ZZ += np.log(np.mean(np.exp(lwhat - lmax)))
        KK += lmax
        Ls += L
            
        # Systematic resampling.
        cumulative_w = np.cumsum(w)
        new_particles = []
        new_llike = []
        positions = (np.arange(N) + np.random.uniform(0, 1)) / N
        i, j = 0, 0
        while i < N:
            if positions[i] < cumulative_w[j]:
                new_particles.append(particles[j])
                new_llike.append(llike[j])
                i += 1
            else:
                j += 1
        particles = new_particles
        llike = np.array(new_llike)
            
        # Save pre-mutation state.
        old_particles = particles.copy()
        old_llike = llike.copy()
            
        # --- Parallel Mutation Phase using torch.multiprocessing ---
        mutation_success = False
        while not mutation_success:
            with mp.Pool(processes=numwork) as pool:
                results = pool.starmap(
                    update_particle,
                    [(params, net, x_train, y_train, step_size_cur, L, proposed_temp, M)
                     for params in particles]
                )
            # Unpack results.
            updated_particles, updated_llike, acc_rates = zip(*results)
            overall_acc = np.mean(acc_rates)
            print("Overall acceptance for mutation phase:", overall_acc, "Stepsize=", step_size_cur, "L=", L)
            if overall_acc < 0.4 or overall_acc > 0.95:
                if overall_acc < 0.4:
                    step_size_cur *= 0.7
                else:
                    step_size_cur *= 1.1
                print("Unsatisfactory acceptance. Restarting mutation phase with step size", step_size_cur, "L=", L)
                # Revert to pre-mutation state.
                particles = old_particles.copy()
                llike = old_llike.copy()
            else:
                mutation_success = True
                if overall_acc < 0.6:
                    step_size_cur *= 0.7
                elif overall_acc > 0.8:
                    step_size_cur *= 1.1
                particles = list(updated_particles)
                llike = np.array(updated_llike)
            L = min(max(1, int(trajectory / step_size_cur)), 100)
        tempcurr = proposed_temp
        count += 1
        print(f'time for current temperature step = {time.time() - t_start1}')
        
    # Compute predictions on validation data.
    for params in particles:
        param_dict = unflatten_params(params, net)
        output = functional_call(net, param_dict, x_val)
        preds.append(output)
        
    accuracy = compute_accuracy(net, particles, x_val, y_val)
    print("Validation accuracy:", accuracy)
        
    particles_tensor = torch.stack(particles).detach().cpu()
    particles_flat = particles_tensor.numpy().reshape(N, d)
    preds_tensor = torch.stack(preds).detach().cpu()
    predictions = preds_tensor.numpy().reshape(N, x_val.size(0), num_classes)
        
    times = time.time() - t_start
    
    return times, predictions, particles_flat, ZZ, KK, count, Ls

if __name__ == '__main__':
    # Set the multiprocessing start method to spawn.
    mp.set_start_method('spawn', force=True)
    
    # Retrieve command-line arguments.
    node = int(sys.argv[1])
    numwork = int(sys.argv[2])
    
    # --- Data Loading ---
    train_cache = "imdb_embeddings_trainBig.pt"
    test_cache  = "imdb_embeddings_testBig.pt"
    X_train, y_train, X_test, y_test = create_sbert_embedded_imdb_dataset(
        train_cache_path=train_cache,
        test_cache_path=test_cache
    )
    x_train = X_train.to(device)
    y_train = y_train.to(device)
    x_val   = X_test.to(device)
    y_val   = y_test.to(device)
    
    print(x_train.shape[1])
    
    # --- Forward Model Definition ---
    net = SimpleMLP(input_dim=x_train.shape[1], hidden_dim=128, num_classes=2).to(device)
    net.eval()
    d = sum(p.numel() for p in net.parameters())
    print(f"The dimension of the parameters in SimpleMLP is {d}")
    
    # --- SMC Sampler Parameters ---
    R = 1       # Number of replicates
    N = 32      # Number of particles per replicate
    step_size = 0.008
    L = 1       # Number of leapfrog steps per HMC step
    M = 1       # Number of HMC steps per mutation 
    P = 1       # Number of parallel SMC runs
    trajectory = 0.0

    start_time = time.time()
    times, predictions, particles_flat, Z_list, K_list, count_all, L_list = psmc_single(
        node=node,
        numwork=numwork,
        x_train=x_train,
        y_train=y_train,
        x_val=x_val,
        y_val=y_val,
        NetClass=SimpleMLP,  # Using SimpleMLP
        N=N,
        step_size=step_size,
        L=L,
        M=M,
        trajectory=trajectory
    )
    psmc_single_time = np.array(times)
    psmc_single_pred = np.array(np.stack(predictions, axis=0))
    psmc_single_x = np.array(np.stack(particles_flat, axis=0))
    Z = np.array(Z_list)
    K = np.array(K_list)
    count = np.array(count_all)
    Lsum = np.array(L_list)

    total_time = time.time() - start_time
    print(f"Total execution time: {total_time:.2f} seconds")

    # --- Save the Results ---
    savemat(f'BayesianNN_IMDB_psmc_SimpleMLP_scaleGaussian_d{d}_N{N}_M{M}_node{node}.mat', {
        'psmc_single_time': psmc_single_time,
        'psmc_single_pred': psmc_single_pred,
        'psmc_single_x': psmc_single_x,
        'Z': Z,
        'K': K,
        'count': count,
        'Lsum': Lsum
    })
