#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
HMC Inference on IMDB using SBERT embeddings and a SimpleMLP,
with a MAP prior (extracted from MAP training) used in the potential.
This code re‐uses the same model and dataset configuration as the SMC version.
"""

import os
import sys
import time
import random
import numpy as np
from scipy.io import savemat
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, HMC
from torch.func import functional_call

from sentence_transformers import SentenceTransformer
#import torchtext; torchtext.disable_torchtext_deprecation_warning()
import torchtext
from torchtext.datasets import IMDB

# Set device and global prior standard deviations.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Separate hyperparameters for MAP regularization and HMC prior
v = 0.025
sigma_w = np.sqrt(v)    # Std for weights (MAP)
sigma_b = np.sqrt(v)    # Std for biases  (MAP)

s = 0.1*v
sigma_w_s = np.sqrt(s)   # HMC prior weights
sigma_b_s = np.sqrt(s)    # HMC prior biases
scale = 1
#print(f'scale is {scale} and sigma is {sigma_b_s} and sigma_v is {sigma_b}')

######################################
# --- SimpleMLP Definition ---
######################################
class SimpleMLP(nn.Module):
    """
    A simple classifier that is either logistic regression or a single-hidden-layer MLP.
    Set hidden_dim=0 (or None) for logistic regression.
    """
    def __init__(self, input_dim=768, hidden_dim=0, num_classes=2):
        super(SimpleMLP, self).__init__()
        if hidden_dim is None or hidden_dim == 0:
            # Logistic regression: parameters will be named "fc.weight" and "fc.bias"
            self.fc = nn.Linear(input_dim, num_classes)
            nn.init.normal_(self.fc.weight, mean=0, std=sigma_w)
            if self.fc.bias is not None:
                nn.init.normal_(self.fc.bias, mean=0, std=sigma_b)
        else:
            # Single-hidden-layer MLP: parameters will be named "fc1.weight", "fc1.bias",
            # "fc2.weight", and "fc2.bias".
            self.fc1 = nn.Linear(input_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, num_classes)
            nn.init.normal_(self.fc1.weight, mean=0, std=sigma_w)
            nn.init.normal_(self.fc2.weight, mean=0, std=sigma_w)
            if self.fc1.bias is not None:
                nn.init.normal_(self.fc1.bias, mean=0, std=sigma_b)
            if self.fc2.bias is not None:
                nn.init.normal_(self.fc2.bias, mean=0, std=sigma_b)

    def forward(self, x):
        if hasattr(self, 'fc'):
            # Logistic regression branch.
            return self.fc(x)
        else:
            # MLP branch.
            x = self.fc1(x)
            x = F.relu(x)
            return self.fc2(x)

######################################
# --- Helper Functions ---
######################################
def unflatten_params(flat, net):
    """
    Reconstruct a parameter dictionary from the flattened tensor `flat`
    using the parameter shapes of `net`.
    """
    param_dict = {}
    pointer = 0
    for name, param in net.named_parameters():
        numel = param.numel()
        param_dict[name] = flat[pointer:pointer+numel].view(param.shape)
        pointer += numel
    return param_dict

# def init_particle(net, device):
#     """
#     Initialize a flattened parameter vector for `net` according to the Gaussian prior.
#     For biases use sigma_b and for weights sigma_w.
#     """
#     param_dict = {}
#     for name, param in net.named_parameters():
#         if "bias" in name:
#             param_dict[name] = torch.randn(param.shape, device=device) * sigma_b
#         else:
#             param_dict[name] = torch.randn(param.shape, device=device) * sigma_w
#     flat_params = []
#     for name, _ in net.named_parameters():
#         flat_params.append(param_dict[name].view(-1))
#     return torch.cat(flat_params)
def init_particle(net, device, prior_params):
    """
    Initialize a flattened parameter vector for `net` according to the SMC Gaussian prior
    centered at the MAP prior_params.
    """
    param_dict = {}
    for name, param in net.named_parameters():
        # MAP mean from prior_params, SMC std sigma_w_s / sigma_b_s
        mean = prior_params[name].to(device)
        std = sigma_b_s if "bias" in name else sigma_w_s
        param_dict[name] = mean + torch.randn(param.shape, device=device) * std
    flat_params = [param_dict[name].view(-1) for name, _ in net.named_parameters()]
    return torch.cat(flat_params)

def flatten_state_params(state, net):
    """
    Flatten the state-dict parameters (e.g. MAP parameters) into a single vector.
    """
    flat_params = []
    for name, param in net.named_parameters():
        flat_params.append(state[name].view(-1))
    return torch.cat(flat_params)

def model_loss_func_ll(output, y, temp):
    """
    Compute the scaled cross-entropy loss.
    """
    crit = nn.CrossEntropyLoss(reduction='sum')
    return crit(output, y.long().view(-1)) * temp

def model_loss_func(params, x, y, temp, net, prior_params):
    """
    Compute the potential energy (negative joint log probability) using a Gaussian prior
    centered at the MAP parameters (`prior_params`) for the SimpleMLP.
    Only the parameters of the logistic regression branch ("fc.weight", "fc.bias") are used.
    For HMC inference, we use the HMC-specific prior standard deviations (sigma_b_s and sigma_w_s).
    """
    keys = ["fc.weight", "fc.bias"]
    log_prior = 0.0
    for k in keys:
        param_tensor = params[k]
        prior_tensor = prior_params[k]
        # Use the HMC standard deviations for computing the prior energy.
        sigma = sigma_b_s if "bias" in k else sigma_w_s
        log_prior += ((param_tensor - prior_tensor)**2).sum() / (2 * sigma**2)
    logits = functional_call(net, params, x)
    log_likelihood = model_loss_func_ll(logits, y, temp)
    return log_prior + log_likelihood

def hmc_update_particle(net, x_train, y_train, params_init, num_samples, warmup_steps, step_size, num_steps, temp, prior_params):
    """
    Runs an HMC chain using model_loss_func (with the HMC prior based on sigma_b_s and sigma_w_s).
    Returns the final sample (flattened parameter vector) and the acceptance probability.
    """
    def potential_fn(params_dict):
        params = unflatten_params(params_dict["params"], net)
        return model_loss_func(params, x_train, y_train, temp, net, prior_params)
    
    hmc_kernel = HMC(
        potential_fn=potential_fn,
        step_size=step_size,
        num_steps=num_steps,
        adapt_step_size=False, adapt_mass_matrix=False,
        target_accept_prob=0.65
    )
    
    acceptance_probs = []
    def capture_diagnostics(kernel, params, stage, i):
        diag = kernel.logging()
        if "acc. prob" in diag:
            acceptance_probs.append(diag["acc. prob"])
    
    mcmc_run = MCMC(
        hmc_kernel,
        num_samples=num_samples,
        warmup_steps=warmup_steps,
        initial_params={"params": params_init},
        disable_progbar=True,
        hook_fn=capture_diagnostics,
    )
    mcmc_run.run()
    samples = mcmc_run.get_samples()["params"]
    acc_rate = acceptance_probs[-1] if acceptance_probs else 0.0
    return samples[-1], float(acc_rate)

def compute_accuracy(net, particles, x_val, y_val):
    """
    Compute the ensemble accuracy on validation data.
    Averages the logits over all particles and then computes the argmax.
    """
    with torch.no_grad():
        logits_list = []
        for params in particles:
            param_dict = unflatten_params(params, net)
            logits = functional_call(net, param_dict, x_val)
            logits_list.append(logits)
        ensemble_logits = torch.stack(logits_list).mean(dim=0)
        preds = ensemble_logits.argmax(dim=1)
        correct = (preds == y_val.view(-1)).sum().item()
        accuracy = correct / y_val.size(0)
    return accuracy

######################################
# --- SBERT Embedding & Caching for IMDB ---
######################################
def create_sbert_embedded_imdb_dataset(
    model_name="all-mpnet-base-v2",
    train_cache_path="imdb_embeddings_trainBig.pt",
    test_cache_path="imdb_embeddings_testBig.pt"
):
    """
    Loads the IMDB dataset and uses SBERT to embed each review.
    If cached embeddings exist, they are loaded from disk.
    
    Returns:
        X_train, y_train, X_test, y_test as PyTorch tensors.
    """
    if os.path.exists(train_cache_path) and os.path.exists(test_cache_path):
        print("Loading cached SBERT embeddings from disk...")
        X_train, y_train = torch.load(train_cache_path)
        X_test, y_test = torch.load(test_cache_path)
        return X_train, y_train, X_test, y_test

    print("Cached embeddings not found. Computing SBERT embeddings...")
    sbert = SentenceTransformer(model_name)
    sbert.eval()

    # Load IMDB dataset splits.
    train_data = list(IMDB(split="train"))
    test_data  = list(IMDB(split="test"))
    
    # Map labels: some versions use (1,2); here we map to 0 and 1.
    label_map = {1: 0, 2: 1}
    X_train_list, y_train_list = [], []
    for (label, text) in train_data:
        label_int = label_map.get(label, label)
        emb = sbert.encode(text, convert_to_numpy=True)
        X_train_list.append(emb)
        y_train_list.append(label_int)
    X_test_list, y_test_list = [], []
    for (label, text) in test_data:
        label_int = label_map.get(label, label)
        emb = sbert.encode(text, convert_to_numpy=True)
        X_test_list.append(emb)
        y_test_list.append(label_int)
    
    X_train = torch.tensor(X_train_list, dtype=torch.float32)
    y_train = torch.tensor(y_train_list, dtype=torch.long)
    X_test  = torch.tensor(X_test_list, dtype=torch.float32)
    y_test  = torch.tensor(y_test_list, dtype=torch.long)
    
    torch.save((X_train, y_train), train_cache_path)
    torch.save((X_test, y_test), test_cache_path)
    print("SBERT embeddings computed and saved to disk.")
    
    return X_train, y_train, X_test, y_test

######################################
# --- Main: MAP Training then HMC Inference ---
######################################
if __name__ == '__main__':
    # Set seeds for reproducibility.
    node = int(sys.argv[1]) # You can change or pass this as a command-line argument.

    print(f'scale is {scale} and sigma is {sigma_b_s} and sigma_v is {sigma_b}')
    
    # --- Data Loading and SBERT Embedding ---
    train_cache = "imdb_embeddings_trainBig.pt"
    test_cache  = "imdb_embeddings_testBig.pt"
    X_train, y_train, X_test, y_test = create_sbert_embedded_imdb_dataset(
        train_cache_path=train_cache,
        test_cache_path=test_cache
    )

    # MAP pre-training split: first 20k train, next 5k early-stop val
    x_map_train = X_train[:20000].to(device)
    y_map_train = y_train[:20000].to(device)
    x_map_val   = X_train[20000:25000].to(device)
    y_map_val   = y_train[20000:25000].to(device)

    # SMC uses full train set; test split unchanged
    x_train = X_train.to(device)
    y_train = y_train.to(device)
    x_val   = X_test.to(device)
    y_val   = y_test.to(device)

    print(f"The number of train data is {len(y_train)}")
    print(f"The number of validation data is {len(y_val)}")

    torch.manual_seed(node)
    np.random.seed(node)
    random.seed(node)
    pyro.set_rng_seed(node)
    
    # --- MAP (Deterministic) Training of SimpleMLP ---
    # When setting input_dim for MAP training:
    input_dim = x_map_train.shape[1]
    model_mlp = SimpleMLP(input_dim=input_dim, hidden_dim=0, num_classes=2).to(device)
    optimizer = optim.Adam(model_mlp.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    net = model_mlp
    net.eval()
    d = sum(p.numel() for p in net.parameters())
    print(f"The dimension of the parameters in SimpleMLP is {d}")
    
    model_mlp.train()
    train_losses = []
    val_losses = []
    moving_avg_window = 10
    best_moving_avg = float('inf')
    patience = 5
    no_improve_count = 0
    start_time = time.time()
    
    # DataLoaders for MAP pre-training
    train_dataset = torch.utils.data.TensorDataset(x_map_train, y_map_train)
    val_dataset   = torch.utils.data.TensorDataset(x_map_val,   y_map_val)
    train_loader  = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader    = DataLoader(val_dataset,   batch_size=64, shuffle=False)
    
    for epoch in range(1000):
        model_mlp.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model_mlp(inputs)
            ce_loss = criterion(outputs, labels)
            # Regularization term (MAP prior with zero mean and std sigma_w/sigma_b).
            reg_fc = (torch.sum(model_mlp.fc.weight**2) / (2 * sigma_w**2) +
                      torch.sum(model_mlp.fc.bias**2) / (2 * sigma_b**2))
            reg_loss = reg_fc
            loss = ce_loss + reg_loss / len(train_dataset)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        train_loss = running_loss / len(train_loader)
        train_losses.append(train_loss)
        
        model_mlp.eval()
        val_running_loss = 0.0
        with torch.no_grad():
            for inputs, labels in val_loader:
                outputs = model_mlp(inputs)
                ce_loss = criterion(outputs, labels)
                reg_fc = (torch.sum(model_mlp.fc.weight**2) / (2 * sigma_w**2) +
                           torch.sum(model_mlp.fc.bias**2) / (2 * sigma_b**2))
                reg_loss = reg_fc
                loss = ce_loss + reg_loss / len(train_dataset)
                val_running_loss += loss.item()
        val_loss = val_running_loss / len(val_loader)
        val_losses.append(val_loss)
        
        print(f"MAP Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        if epoch >= moving_avg_window - 1:
            moving_avg = sum(val_losses[-moving_avg_window:]) / moving_avg_window
            if moving_avg < best_moving_avg:
                best_moving_avg = moving_avg
                no_improve_count = 0
            else:
                no_improve_count += 1
            if no_improve_count >= patience:
                print("MAP: Early stopping due to no improvement.")
                break
    
    total_map_time = time.time() - start_time
    print(f"MAP Total execution time: {total_map_time:.2f} seconds")
    
    # --- Extract Prior Parameters from the MAP Model ---
    prior_params = model_mlp.state_dict()
    print("Prior parameters extracted from the MAP trained SimpleMLP.")

    for name in prior_params:
        prior_params[name] = scale * prior_params[name]
    
    ############################################
    # --- HMC Inference using MAP Prior ---
    ############################################

    # HMC Sampler Parameters.
    step_size = 0.018
    trajectory = 0.0   # Trajectory scaling constant (can be tuned)
    L = min(max(1, int(trajectory/step_size)), 100)               # Number of leapfrog steps per HMC update
    N_samples = 1      # Number of posterior samples to collect (after burnin)
    warmup_steps = 0
    burnin_all = 25
    thin_all = burnin_all  # thinning period (set to burnin_all for a serial chain)
    basic = 1              # basic step count (should divide thin)
    thin = int(thin_all / basic)
    burnin = thin         # number of burnin iterations (set equal to thin)
    temp = 1.0          # Likelihood temperature
    
    # Initialize HMC chain starting from the MAP solution.
    # Flatten the MAP parameters in the same order as in unflatten_params.
    #params = flatten_state_params(prior_params, model_mlp)
    params = init_particle(model_mlp, device, prior_params)
    particles = []
    acc_rates = []
    Lsum = 0
    
    total_steps = burnin + (N_samples - 1) * thin
    print("Starting HMC sampling...")
    start_time = time.time()
    for i in range(total_steps):
        updated_params, acc_rate = hmc_update_particle(
            model_mlp, x_train, y_train, params,
            num_samples=1, warmup_steps=warmup_steps,
            step_size=step_size, num_steps=L, temp=temp,
            prior_params=prior_params
        )
        params = updated_params
        acc_rates.append(acc_rate)
        overall_acc = np.mean(acc_rates)
        print(f"Step {i+1}: Acc rate = {overall_acc:.3f}, step_size = {step_size:.5f}, L = {L}")
        # Adapt step size based on acceptance rate.
        if overall_acc < 0.6:
            step_size *= 0.7
        elif overall_acc > 0.8:
            step_size *= 1.1
        if i >= burnin - 1 and ((i - burnin + 1) % thin == 0):
            print("Collecting sample at step", i+1)
            particles.append(params)
        Lsum += L * basic
        L = min(max(1, int(trajectory/step_size)), 100)
    
    hmc_total_time = time.time() - start_time
    print(f"HMC Total execution time: {hmc_total_time:.2f} seconds")
    
    # --- Accuracy Computation ---
    # Compute ensemble accuracy on validation data.
    # (Each sample is a flat vector; we use unflatten_params to get the dictionary.)
    accuracy = compute_accuracy(model_mlp, particles, x_val, y_val)
    print("HMC Validation Accuracy (ensemble):", accuracy)
    print("HMC epochs:", Lsum)
    
    # Prepare predictions (for each sample, run the forward model on x_val)
    preds = []
    for params in particles:
        param_dict = unflatten_params(params, model_mlp)
        output = functional_call(model_mlp, param_dict, x_val)
        preds.append(output)
    preds_tensor = torch.stack(preds).detach().cpu()
    
    # Save results.
    savemat(f'BayesianNN_IMDB_hmc_SimpleMLP_MAP_d{sum(p.numel() for p in model_mlp.parameters())}_N{N_samples}_burnin{burnin_all}_node{node}.mat', {
        'hmc_total_time': hmc_total_time,
        'hmc_pred': preds_tensor.numpy(),
        'hmc_particles': np.stack([p.cpu().numpy() for p in particles]),
        'Lsum': Lsum,
    })
