"""
Use CIFAR-10 (with ResNet-50 embeddings) and SimpleMLP with 10 classes.
"""

# All imports at the top
from scipy.io import savemat, loadmat
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.func import functional_call
#import concurrent.futures  # (no longer used in the mutation phase)
import torchvision.datasets as datasets
from torchvision import models
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, HMC
import torch.multiprocessing as mp  # Use torch multiprocessing for parallel processing

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Updated sigma values:
sigma_w = 0.01
sigma_b = 0.001

######################################
# CIFAR-10 Embedding Function using ResNet-50
######################################
def create_resnet50_embedded_cifar10_dataset(
    train_cache_path="cifar10_train_embeddings.pt",
    test_cache_path="cifar10_test_embeddings.pt"
):
    if os.path.exists(train_cache_path) and os.path.exists(test_cache_path):
        print("Loading cached ResNet-50 embeddings for CIFAR-10...")
        X_train, y_train = torch.load(train_cache_path)
        X_test, y_test = torch.load(test_cache_path)
        return X_train, y_train, X_test, y_test

    print("Cached embeddings not found. Computing ResNet-50 embeddings for CIFAR-10...")

    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset  = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=2)
    test_loader  = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

    resnet50 = models.resnet50(pretrained=True)
    # Remove the final fc layer: keep up to the avgpool layer.
    feature_extractor = nn.Sequential(*list(resnet50.children())[:-1])
    feature_extractor.eval()
    feature_extractor.to(device)

    X_train_list, y_train_list = [], []
    with torch.no_grad():
        for inputs, targets in train_loader:
            inputs = inputs.to(device)
            features = feature_extractor(inputs)  # shape: (batch, 2048, 1, 1)
            features = features.view(features.size(0), -1)  # flatten to (batch, 2048)
            X_train_list.append(features.cpu())
            y_train_list.append(targets)
    X_train = torch.cat(X_train_list, dim=0)
    y_train = torch.cat(y_train_list, dim=0)

    X_test_list, y_test_list = [], []
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            features = feature_extractor(inputs)
            features = features.view(features.size(0), -1)
            X_test_list.append(features.cpu())
            y_test_list.append(targets)
    X_test = torch.cat(X_test_list, dim=0)
    y_test = torch.cat(y_test_list, dim=0)

    torch.save((X_train, y_train), train_cache_path)
    torch.save((X_test, y_test), test_cache_path)
    print("ResNet-50 embeddings computed and saved.")

    return X_train, y_train, X_test, y_test

######################################
# SimpleMLP Definition
######################################
class SimpleMLP(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=128, num_classes=10):
        super().__init__()
        # One hidden layer MLP: fc1 -> ReLU -> fc2
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
        
        # Initialize weights and biases with updated sigma values.
        nn.init.normal_(self.fc1.weight, mean=0, std=sigma_w)
        nn.init.normal_(self.fc2.weight, mean=0, std=sigma_w)
        if self.fc1.bias is not None:
            nn.init.normal_(self.fc1.bias, mean=0, std=sigma_b)
        if self.fc2.bias is not None:
            nn.init.normal_(self.fc2.bias, mean=0, std=sigma_b)
    
    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

######################################
# Helper Functions
######################################
def flatten_net(net):
    """Flatten all parameters of a network into a single 1D tensor."""
    return torch.cat([p.view(-1) for p in net.parameters()])

def unflatten_params(flat, net):
    """
    Reconstruct a parameter dictionary from the flattened tensor 'flat'
    using the parameter shapes of 'net'.
    """
    param_dict = {}
    pointer = 0
    for name, param in net.named_parameters():
        numel = param.numel()
        param_dict[name] = flat[pointer:pointer+numel].view(param.shape)
        pointer += numel
    return param_dict

def init_particle(net):
    """
    Initialize a flattened parameter vector for 'net' according to a
    standard Gaussian prior for all weights and biases using the updated sigma values.
    """
    param_dict = {}
    for name, param in net.named_parameters():
        if "bias" in name:
            param_dict[name] = torch.randn(param.shape, device=device) * sigma_b
        else:
            param_dict[name] = torch.randn(param.shape, device=device) * sigma_w
    flat_params = []
    for name, _ in net.named_parameters():
        flat_params.append(param_dict[name].view(-1))
    return torch.cat(flat_params)

def model_loss_func_ll(output, y_device, temp):
    """
    Compute the scaled cross-entropy loss.
    """
    crit = nn.CrossEntropyLoss(reduction='sum')
    return crit(output, y_device.long().view(-1)) * temp

def model_loss_func(params, x, y, temp, net):
    """
    Compute the potential energy (negative joint log probability) using a Gaussian
    prior for all parameters with sigma_w for weights and sigma_b for biases,
    along with a categorical likelihood.
    """
    keys = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"]
    log_prior = 0.0
    for k in keys:
        param_tensor = params[k]
        if "bias" in k:
            sigma = sigma_b
        else:
            sigma = sigma_w
        log_prior += (param_tensor**2).sum() / (2 * sigma**2)
    logits = functional_call(net, params, x)
    log_likelihood = model_loss_func_ll(logits, y, temp=temp)
    return (log_prior + log_likelihood)

def hmc_update_particle(net, x_train, y_train, params_init, num_samples, warmup_steps, step_size, num_steps, temp):
    """
    Runs an HMC chain using model_loss_func.
    Returns the final sample (flat parameter vector) and the acceptance probability.
    """
    def potential_fn(params_dict):
        params = unflatten_params(params_dict["params"], net)
        return model_loss_func(params, x_train, y_train, temp, net)
    
    hmc_kernel = HMC(
        potential_fn=potential_fn,
        step_size=step_size,
        num_steps=num_steps,
        adapt_step_size=False, adapt_mass_matrix=False,
        target_accept_prob=0.65
    )
    
    acceptance_probs = []
    
    def capture_diagnostics(kernel, params, stage, i):
        diag = kernel.logging()
        if "acc. prob" in diag:
            acceptance_probs.append(diag["acc. prob"])
    
    mcmc_run = MCMC(
        hmc_kernel,
        num_samples=num_samples,
        warmup_steps=warmup_steps,
        initial_params={"params": params_init},
        disable_progbar=True,
        hook_fn=capture_diagnostics,
    )
    
    mcmc_run.run()
    samples = mcmc_run.get_samples()["params"]
    acc_rate = acceptance_probs[-1] if acceptance_probs else 0.0
    
    return samples[-1], float(acc_rate)

def update_particle(params, net, x_train, y_train, step_size, L, temp, M):
    """
    This function performs the HMC update for a single particle and computes the new loss.
    It is designed to be run in parallel.
    """
    updated_params, acc_rate = hmc_update_particle(
        net, x_train, y_train, params,
        num_samples=M, warmup_steps=0, step_size=step_size,
        num_steps=L, temp=temp
    )
    param_dict = unflatten_params(updated_params, net)
    output = functional_call(net, param_dict, x_train)
    loss_val = model_loss_func_ll(output, y_train, temp=1)
    return updated_params, (loss_val.item() if torch.is_tensor(loss_val) else loss_val), acc_rate

def compute_accuracy(net, particles, x_val, y_val):
    """
    Compute the ensemble accuracy on the validation data.
    The logits from each particle are averaged and then the argmax is taken.
    """
    with torch.no_grad():
        logits_list = []
        for params in particles:
            param_dict = unflatten_params(params, net)
            logits = functional_call(net, param_dict, x_val)
            logits_list.append(logits)
        ensemble_logits = torch.stack(logits_list).mean(dim=0)
        preds = ensemble_logits.argmax(dim=1)
        correct = (preds == y_val.view(-1)).sum().item()
        accuracy = correct / y_val.size(0)
    return accuracy

######################################
# SMC Sampler with HMC Mutations
######################################
def psmc_single(node, numwork, x_train, y_train, x_val, y_val, NetClass, N, step_size, L, M, trajectory):
    print("Node", node)
    torch.manual_seed(node)
    np.random.seed(node)
    
    net = NetClass().to(device)
    net.eval()
    
    d = sum(p.numel() for p in net.parameters())
    num_val = x_val.size(0)
    num_classes = 10  # For CIFAR-10
    
    t_start = time.time()
    particles = []
    llike = []
    preds = []
        
    # Particle initialization.
    for _ in range(N):
        params_init = init_particle(net)
        particles.append(params_init)
        param_dict = unflatten_params(params_init, net)
        output = functional_call(net, param_dict, x_train)
        loss_val = model_loss_func_ll(output, y_train, temp=1)
        llike.append(loss_val.item() if torch.is_tensor(loss_val) else loss_val)
    llike = np.array(llike)
        
    tempcurr = 0.0
    ZZ = 0.0
    KK = 0.0
    count = 0
    Ls = 0
    step_size_cur = step_size
        
    # Tempering loop.
    while tempcurr < 1.0:
        t_start1 = time.time()

        temp_increment = 1.0 - tempcurr
        lwhat = -temp_increment * llike
        lmax = np.max(lwhat)
        w = np.exp(lwhat - lmax)
        w /= np.sum(w)
        ess = 1.0 / np.sum(w**2)
            
        while ess < N / 2:
            temp_increment /= 2.0
            lwhat = -temp_increment * llike
            lmax = np.max(lwhat)
            w = np.exp(lwhat - lmax)
            w /= np.sum(w)
            ess = 1.0 / np.sum(w**2)
            
        proposed_temp = tempcurr + temp_increment
        if proposed_temp >= 1.0 and count == 0:
            proposed_temp = 0.5
        print("Current temperature:", tempcurr, ", dT=", temp_increment, "Step=", count)
        ZZ += np.log(np.mean(np.exp(lwhat - lmax)))
        KK += lmax
        Ls += L
            
        cumulative_w = np.cumsum(w)
        new_particles = []
        new_llike = []
        positions = (np.arange(N) + np.random.uniform(0, 1)) / N
        i, j = 0, 0
        while i < N:
            if positions[i] < cumulative_w[j]:
                new_particles.append(particles[j])
                new_llike.append(llike[j])
                i += 1
            else:
                j += 1
        particles = new_particles
        llike = np.array(new_llike)
            
        old_particles = particles.copy()
        old_llike = llike.copy()
            
        mutation_success = False
        while not mutation_success:
            # --- Parallel Mutation Phase using torch.multiprocessing ---
            with mp.Pool(processes=numwork) as pool:
                results = pool.starmap(
                    update_particle,
                    [(params, net, x_train, y_train, step_size_cur, L, proposed_temp, M)
                     for params in particles]
                )
            updated_particles, updated_llike, acc_rates = zip(*results)
            overall_acc = np.mean(acc_rates)
            print("Overall acceptance for mutation phase:", overall_acc, "Stepsize=", step_size_cur, "L=", L)
            if overall_acc < 0.4 or overall_acc > 0.95:
                if overall_acc < 0.4:
                    step_size_cur *= 0.7
                else:
                    step_size_cur *= 1.1
                print("Unsatisfactory acceptance. Restarting mutation phase with step size", step_size_cur, "L=", L)
                particles = old_particles.copy()
                llike = old_llike.copy()
            else:
                mutation_success = True
                if overall_acc < 0.6:
                    step_size_cur *= 0.7
                elif overall_acc > 0.8:
                    step_size_cur *= 1.1
                particles = list(updated_particles)
                llike = np.array(updated_llike)
            L = min(max(1, int(trajectory / step_size_cur)), 100)
        tempcurr = proposed_temp
        count += 1
        print(f'time={time.time() - t_start1}')
        
    for params in particles:
        param_dict = unflatten_params(params, net)
        output = functional_call(net, param_dict, x_val)
        preds.append(output)
        
    accuracy = compute_accuracy(net, particles, x_val, y_val)
    print("Validation accuracy:", accuracy)
        
    particles_tensor = torch.stack(particles).detach().cpu()
    particles_flat = particles_tensor.numpy().reshape(N, d)
    preds_tensor = torch.stack(preds).detach().cpu()
    predictions = preds_tensor.numpy().reshape(N, num_val, num_classes)
        
    total_time = time.time() - t_start
    
    return total_time, predictions, particles_flat, ZZ, KK, count, Ls

######################################
# Main Execution
######################################
if __name__ == '__main__':
	# Set the multiprocessing start method to spawn.
    mp.set_start_method('spawn', force=True)
    if len(sys.argv) < 3:
        print("Usage: python {} <node> <numwork>".format(sys.argv[0]))
        sys.exit(1)
    node = int(sys.argv[1])
    numwork = int(sys.argv[2])

    # Load CIFAR-10 embeddings (computed via ResNet-50)
    train_cache = "cifar10_train_embeddings.pt"
    test_cache  = "cifar10_test_embeddings.pt"
    X_train, y_train, X_test, y_test = create_resnet50_embedded_cifar10_dataset(
        train_cache_path=train_cache,
        test_cache_path=test_cache
    )
    x_train = X_train.to(device)
    y_train = y_train.to(device)
    x_val   = X_test.to(device)
    y_val   = y_test.to(device)

    print("Embedding dimension:", x_train.shape[1])
    
    # Instantiate SimpleMLP with input_dim matching CIFAR-10 embeddings and 10 output classes.
    net = SimpleMLP(input_dim=x_train.shape[1], hidden_dim=128, num_classes=10).to(device)
    net.eval()
    d = sum(p.numel() for p in net.parameters())
    print(f"The dimension of the parameters in SimpleMLP is {d}")

    # SMC Sampler Parameters
    N = 32       # Number of particles per replicate
    step_size = 0.0006
    L = 1        # Number of leapfrog steps per HMC step
    M = 5        # Number of HMC steps per mutation
    trajectory = 0.0

    start_time = time.time()
    total_time, predictions, particles_flat, ZZ, KK, count, Ls = psmc_single(
        node=node,
        numwork=numwork,
        x_train=x_train,
        y_train=y_train,
        x_val=x_val,
        y_val=y_val,
        NetClass=SimpleMLP,
        N=N,
        step_size=step_size,
        L=L,
        M=M,
        trajectory=trajectory
    )
    psmc_single_time = np.array(total_time)
    psmc_single_pred = np.array(np.stack(predictions, axis=0))
    psmc_single_x = np.array(np.stack(particles_flat, axis=0))
    ZZ = np.array(ZZ)
    KK = np.array(KK)
    count = np.array(count)
    Lsum = np.array(Ls)

    total_exec_time = time.time() - start_time
    print(f"Total execution time: {total_exec_time:.2f} seconds")

    savemat(f'BayesianNN_CIFAR_psmc_SimpleMLP_scaleGaussian_d{d}_N{N}_M{M}_node{node}.mat', {
        'psmc_single_time': psmc_single_time,
        'psmc_single_pred': psmc_single_pred,
        'psmc_single_x': psmc_single_x,
        'Z': ZZ,
        'K': KK,
        'count': count,
        'Lsum': Lsum
    })
