"""
Use Pyro’s basic HMC for Bayesian inference on IMDB data using SBERT embeddings.
Here we use a slightly more complex MLP (one hidden layer) as the forward model.
"""

import os
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sentence_transformers import SentenceTransformer
from torchtext.datasets import IMDB
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from scipy.io import savemat, loadmat

import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, HMC

# For functional calls with given parameters
from torch.func import functional_call

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sigma_w = 0.1
sigma_b = 0.01

######################################
# --- SimpleMLP Definition ---
######################################
class SimpleMLP(nn.Module):
    def __init__(self, input_dim=384, hidden_dim=128, num_classes=2):
        super().__init__()
        # One hidden layer MLP: fc1 -> ReLU -> fc2
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
        
        # Initialize weights and biases with a standard Gaussian (mean 0, std \sigma)
        nn.init.normal_(self.fc1.weight, mean=0, std=sigma_w)
        nn.init.normal_(self.fc2.weight, mean=0, std=sigma_w)
        if self.fc1.bias is not None:
            nn.init.normal_(self.fc1.bias, mean=0, std=sigma_b)
        if self.fc2.bias is not None:
            nn.init.normal_(self.fc2.bias, mean=0, std=sigma_b)
    
    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x
    
class LogReg(nn.Module):
    def __init__(self, input_dim=384, hidden_dim=128, num_classes=2):
        super().__init__()
        # One hidden layer MLP: fc1 -> ReLU -> fc2
        self.fc = nn.Linear(input_dim, num_classes)

        # Initialize weights and biases with a standard Gaussian (mean 0, std \sigma)
        nn.init.normal_(self.fc.weight, mean=0, std=sigma_w)
        if self.fc.bias is not None:
            nn.init.normal_(self.fc.bias, mean=0, std=sigma_b)
    
    def forward(self, x):
        x = self.fc(x)
        return x

######################################
# --- Helper Functions ---
######################################
def flatten_net(net):
    """Flatten all parameters of a network into a single 1D tensor."""
    return torch.cat([p.view(-1) for p in net.parameters()])

def unflatten_params(flat, net):
    """
    Reconstruct a parameter dictionary from the flattened tensor 'flat'
    using the parameter shapes of 'net'.
    """
    param_dict = {}
    pointer = 0
    for name, param in net.named_parameters():
        numel = param.numel()
        param_dict[name] = flat[pointer:pointer+numel].view(param.shape)
        pointer += numel
    return param_dict

def init_particle(net, device):
    """
    Initialize a flattened parameter vector for 'net' according to a
    standard Gaussian prior for all weights and biases.
    """
    param_dict = {}
    for name, param in net.named_parameters():
        if "bias" in name:
            param_dict[name] = torch.randn(param.shape, device=device) * sigma_b
        else:
            param_dict[name] = torch.randn(param.shape, device=device) * sigma_w
    flat_params = []
    for name, _ in net.named_parameters():
        flat_params.append(param_dict[name].view(-1))
    return torch.cat(flat_params)

def model_loss_func_ll(output, y_device, temp):
    """
    Compute the scaled cross-entropy loss.
    """
    crit = nn.CrossEntropyLoss(reduction='sum')
    return crit(output, y_device.long().view(-1)) * temp

def model_loss_func(params, x, y, temp, net):
    """
    Compute the potential energy (negative joint log probability) using a standard Gaussian
    prior (mean 0, sigma=1.0) for all parameters, along with a categorical likelihood.
    """
    # List all parameter keys from the MLP (fc1 and fc2 layers)
    keys = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"]
    #keys = ["fc.weight", "fc.bias"]
    log_prior = 0.0
    for k in keys:
        param_tensor = params[k]
        if "bias" in k:
            sigma = sigma_b
        else:
            sigma = sigma_w
        log_prior += (param_tensor**2).sum() / (2 * sigma**2)
    
    # Compute the likelihood via the forward model
    logits = functional_call(net, params, x)
    log_likelihood = model_loss_func_ll(logits, y, temp=temp)
    return (log_prior + log_likelihood)

def hmc_update_particle(net, x_train, y_train, params_init, num_samples, warmup_steps, step_size, num_steps, temp):
    """
    Runs an HMC chain using model_loss_func.
    Returns the final sample (flat parameter vector) and the acceptance probability.
    """
    def potential_fn(params_dict):
        params = unflatten_params(params_dict["params"], net)
        return model_loss_func(params, x_train, y_train, temp, net)
    
    hmc_kernel = HMC(
        potential_fn=potential_fn,
        step_size=step_size,
        num_steps=num_steps,
        adapt_step_size=False, adapt_mass_matrix=False,
        target_accept_prob=0.65
    )
    
    acceptance_probs = []
    
    def capture_diagnostics(kernel, params, stage, i):
        diag = kernel.logging()
        if "acc. prob" in diag:
            acceptance_probs.append(diag["acc. prob"])
    
    mcmc_run = MCMC(
        hmc_kernel,
        num_samples=num_samples,
        warmup_steps=warmup_steps,
        initial_params={"params": params_init},
        disable_progbar=True,
        hook_fn=capture_diagnostics,
    )
    
    mcmc_run.run()
    samples = mcmc_run.get_samples()["params"]
    acc_rate = acceptance_probs[-1] if acceptance_probs else 0.0
    
    return samples[-1], float(acc_rate)

def compute_accuracy(net, particles, x_val, y_val):
    """
    Compute the ensemble accuracy on the validation data.
    The logits from each particle are averaged and then the argmax is taken.
    """
    with torch.no_grad():
        logits_list = []
        for params in particles:
            param_dict = unflatten_params(params, net)
            logits = functional_call(net, param_dict, x_val)
            logits_list.append(logits)
        ensemble_logits = torch.stack(logits_list).mean(dim=0)
        preds = ensemble_logits.argmax(dim=1)
        correct = (preds == y_val.view(-1)).sum().item()
        accuracy = correct / y_val.size(0)
    return accuracy

######################################
# --- SBERT Embedding + Caching Function ---
######################################
def create_sbert_embedded_imdb_dataset(
    model_name="all-mpnet-base-v2",
    train_cache_path="imdb_embeddings_train.pt",
    test_cache_path="imdb_embeddings_test.pt"
):
    """
    Loads the IMDB dataset and uses SBERT to embed each review.
    If cached embeddings exist, they are loaded from disk.
    
    Returns:
        X_train, y_train, X_test, y_test as PyTorch tensors.
    """
    if os.path.exists(train_cache_path) and os.path.exists(test_cache_path):
        print("Loading cached SBERT embeddings from disk...")
        X_train, y_train = torch.load(train_cache_path)
        X_test, y_test = torch.load(test_cache_path)
        return X_train, y_train, X_test, y_test

    print("Cached embeddings not found. Computing SBERT embeddings...")
    sbert = SentenceTransformer(model_name)
    sbert.eval()

    # Load IMDB dataset splits
    train_data = list(IMDB(split="train"))
    test_data  = list(IMDB(split="test"))
    
    # Map labels: some versions use 1 and 2, map them to 0 and 1
    label_map = {1: 0, 2: 1}

    X_train_list, y_train_list = [], []
    for (label, text) in train_data:
        label_int = label_map.get(label, label)
        emb = sbert.encode(text, convert_to_numpy=True)  # e.g., shape (384,)
        X_train_list.append(emb)
        y_train_list.append(label_int)

    X_test_list, y_test_list = [], []
    for (label, text) in test_data:
        label_int = label_map.get(label, label)
        emb = sbert.encode(text, convert_to_numpy=True)
        X_test_list.append(emb)
        y_test_list.append(label_int)

    # Convert lists to PyTorch tensors
    X_train = torch.tensor(X_train_list, dtype=torch.float32)
    y_train = torch.tensor(y_train_list, dtype=torch.long)
    X_test  = torch.tensor(X_test_list, dtype=torch.float32)
    y_test  = torch.tensor(y_test_list, dtype=torch.long)

    # Save embeddings for future runs
    torch.save((X_train, y_train), train_cache_path)
    torch.save((X_test, y_test), test_cache_path)
    print("SBERT embeddings computed and saved to disk.")

    return X_train, y_train, X_test, y_test

######################################
# --- Main Bayesian HMC Inference on IMDB ---
######################################
if __name__ == '__main__':
    # For reproducibility and node identification
    #if len(sys.argv) != 2:
    #    print("Usage: python {} <node>".format(sys.argv[0]))
    #    sys.exit(1)
    node = int(sys.argv[1])
    print("Node", node)
    torch.manual_seed(node)
    np.random.seed(node)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # --- Data Loading ---
    train_cache = "imdb_embeddings_trainBig.pt"
    test_cache  = "imdb_embeddings_testBig.pt"
    X_train, y_train, X_test, y_test = create_sbert_embedded_imdb_dataset(
        train_cache_path=train_cache,
        test_cache_path=test_cache
    )
    # Use the test set as validation for ensemble accuracy.
    x_train = X_train.to(device)
    y_train = y_train.to(device)
    x_val   = X_test.to(device)
    y_val   = y_test.to(device)
    
    # --- Forward Model Definition ---
    net = SimpleMLP(input_dim=x_train.shape[1], hidden_dim=128, num_classes=2).to(device)
    net.eval()
    d = sum(p.numel() for p in net.parameters())
    print(f"The dimension of the parameters in SimpleMLP is {d}")
    
    # --- HMC Sampler Parameters ---
    step_size = 0.005
    L = 1               # Number of leapfrog steps per HMC step
    N = 1              # Number of samples to collect
    warmup_steps = 0
    burnin_all = 45
    thin_all = burnin_all
    basic = 1          # Must divide burnin; basic is the number of samples per HMC update
    thin = int(burnin_all / basic)
    burnin = thin      # Match burnin to thin in this setting
    trajectory = 0.0
    temp = 1.0
    
    print("step_size =", step_size)
    print("L =", L)
    print("N =", N)
    print("warmup_steps =", warmup_steps)
    print("burnin =", burnin)
    print("thin =", thin)
    print("basic =", basic)
    print("trajectory =", trajectory)
    
    start_time = time.time()
    params_init = init_particle(net, device)
    params = params_init
    particles = []
    pred = []
    acc_rates = []
    Lsum = 0
    
    total_steps = burnin + (N - 1) * thin
    for i in range(total_steps):
        updated_params, acc_rate = hmc_update_particle(
            net, x_train, y_train, params,
            num_samples=basic, warmup_steps=warmup_steps,
            step_size=step_size, num_steps=L, temp=temp
        )
        params = updated_params
        acc_rates.append(acc_rate)
        overall_acc = np.mean(acc_rates)
        print(i+1, overall_acc, step_size, L)
        if overall_acc < 0.6:
            step_size *= 0.7
        elif overall_acc > 0.8:
            step_size *= 1.1
        if i >= burnin-1 and ((i - burnin + 1) % thin == 0):
            print("Collecting sample at step", i+1)
            particles.append(params)
        Lsum += L
        L = min(max(1, int(trajectory/step_size)), 100)
    
    # --- Accuracy Computation ---
    #def flatten_params(param_tensor):
    #    return param_tensor.flatten()
    
    #particles_flat = [flatten_params(sample) for sample in particles]
    #accuracy = compute_accuracy(net, particles, x_val, y_val)
    #print("Validation accuracy:", accuracy)
    
    #hmc_time = time.time() - start_time
    #print(f"Total execution time: {hmc_time:.2f} seconds")
    
    for params in particles:
        param_dict = unflatten_params(params, net)
        output = functional_call(net, param_dict, x_val)
        pred.append(output)
    print(len(pred))

    hmc_single_time = time.time() - start_time
    print(f"Total execution time: {hmc_single_time:.2f} seconds")

    hmc_single_pred_np = [p.detach().cpu().numpy() for p in pred]
    hmc_single_x_np = [p.detach().cpu().numpy() for p in particles]
    hmc_single_pred = np.array(hmc_single_pred_np, dtype=object)
    hmc_single_x = np.array(hmc_single_x_np, dtype=object)

    
    def flatten_params(param_tuple):
        flattened_params = []
        for p in param_tuple:
            if isinstance(p, torch.Tensor):
                flattened_params.append(p.flatten())
            elif isinstance(p, (float, int)):
                # Convert the float/int to a tensor and flatten (though a single element is already flat)
                flattened_params.append(torch.tensor([p], dtype=torch.float))
            else:
                raise ValueError(f"Unexpected parameter type: {type(p)}")
        return torch.cat(flattened_params)

    
    # Convert each particle (parameter sample) into a flat tensor.
    hmc_single_x_flat = [flatten_params(sample) for sample in hmc_single_x]
    
    # Now use the flattened parameters for accuracy computation.
    accuracy = compute_accuracy(net, hmc_single_x_flat, x_val, y_val)
    print("Validation accuracy:", accuracy)

    # --- Save the Results ---
    savemat(f'BayesianNN_IMDB_phmc_SimpleMLP_scaleGaussian_d{d}_N{N}_thin{thin_all}_burnin{burnin_all}_node{node}.mat', {
        'hmc_single_time': hmc_single_time,
        'hmc_single_pred': hmc_single_pred,
        'hmc_single_x': hmc_single_x,
        'Lsum': Lsum,
    })
