
import numpy as np
import time
import torch
import torchvision
import sys
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import HMC, MCMC
from torch.func import functional_call
from scipy.io import savemat, loadmat
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import torch.multiprocessing as mp

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Simple CNN Definition ---
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # One convolution: input channels=1, output channels=4, kernel=3x3 with padding=1
        self.conv = nn.Conv2d(1, 4, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        # One fully connected layer: input image 28x28 -> after pooling becomes 14x14; output 10 classes.
        self.fc = nn.Linear(4 * 14 * 14, 10)
        
        nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')
        if self.conv.bias is not None:
            nn.init.normal_(self.conv.bias, mean=0, std=0.1)
    
        nn.init.kaiming_normal_(self.fc.weight, mode='fan_out', nonlinearity='relu')
        if self.fc.bias is not None:
            nn.init.normal_(self.fc.bias, mean=0, std=0.1)
    
    def forward(self, x):
        x = torch.relu(self.conv(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# --- Helper Functions ---
def flatten_net(net):
    """Flatten all parameters of a network into a single 1D tensor."""
    return torch.cat([p.view(-1) for p in net.parameters()])

def unflatten_params(flat, net):
    """
    Reconstruct a parameter dictionary from the flattened tensor 'flat'
    using the parameter shapes of 'net'.
    """
    param_dict = {}
    pointer = 0
    for name, param in net.named_parameters():
        numel = param.numel()
        param_dict[name] = flat[pointer:pointer+numel].view(param.shape)
        pointer += numel
    return param_dict

def init_particle(net):
    """
    Initialize a flattened parameter vector for 'net' according to the
    Kaiming normal prior for weights and Normal(0, 0.1) for biases.
    """
    param_dict = {}
    for name, param in net.named_parameters():
        if "bias" in name:
            param_dict[name] = torch.randn(param.shape, device=device) * 0.1
        else:
            if len(param.shape) == 4:  # convolutional weights
                fan_in = param.shape[1] * param.shape[2] * param.shape[3]
            elif len(param.shape) == 2:  # fully connected weights
                fan_in = param.shape[1]
            else:
                fan_in = param.numel()
            sigma = np.sqrt(2.0 / fan_in)
            param_dict[name] = torch.randn(param.shape, device=device) * sigma
    flat_params = []
    for name, _ in net.named_parameters():
        flat_params.append(param_dict[name].view(-1))
    return torch.cat(flat_params)

def model_loss_func_ll(output, y_device, temp):
    """
    Compute the scaled cross-entropy loss.
    """
    crit = nn.CrossEntropyLoss(reduction='sum')
    return crit(output, y_device.long().view(-1)) * temp

def model_loss_func(params, x, y, temp, net):
    """
    Compute the potential energy (negative joint log probability) using a Kaiming normal prior
    for weights and Normal(0, 0.1) for biases, along with a categorical likelihood.
    """
    keys = ["conv.weight", "conv.bias", "fc.weight", "fc.bias"]
    log_prior = 0.0
    for k in keys:
        param_tensor = params[k]
        if "bias" in k:
            sigma = 0.1
        else:
            if len(param_tensor.shape) == 4:
                fan_in = param_tensor.shape[1] * param_tensor.shape[2] * param_tensor.shape[3]
            elif len(param_tensor.shape) == 2:
                fan_in = param_tensor.shape[1]
            else:
                fan_in = param_tensor.numel()
            sigma = np.sqrt(2.0 / fan_in)
        log_prior += (param_tensor**2).sum() / (2 * sigma**2)
    
    # Use the network via functional_call.
    logits = functional_call(net, params, x)
    log_likelihood = model_loss_func_ll(logits, y, temp=temp)
    return log_prior + log_likelihood

def hmc_update_particle(net, x_train, y_train, params_init, num_samples, warmup_steps, step_size, num_steps, temp):
    """
    Runs an HMC chain using model_loss_func.
    Returns the final sample (flat parameter vector) and the acceptance probability.
    """
    def potential_fn(params_dict):
        params = unflatten_params(params_dict["params"], net)
        return model_loss_func(params, x_train, y_train, temp, net)
    
    hmc_kernel = HMC(
        potential_fn=potential_fn,
        step_size=step_size,
        num_steps=num_steps,
        adapt_step_size=False, adapt_mass_matrix=False,
        target_accept_prob=0.65
    )
    
    acceptance_probs = []
    
    def capture_diagnostics(kernel, params, stage, i):
        diag = kernel.logging()
        if "acc. prob" in diag:
            acceptance_probs.append(diag["acc. prob"])
    
    mcmc_run = MCMC(
        hmc_kernel,
        num_samples=num_samples,
        warmup_steps=warmup_steps,
        initial_params={"params": params_init},
        disable_progbar=True,
        hook_fn=capture_diagnostics,
    )
    
    mcmc_run.run()
    samples = mcmc_run.get_samples()["params"]
    acc_rate = acceptance_probs[-1]
    
    return samples[-1], float(acc_rate)

def compute_accuracy(net, particles, x_val, y_val):
    """
    Compute the ensemble accuracy on the validation data.
    The logits from each particle are averaged and then the argmax is taken.
    """
    with torch.no_grad():
        logits_list = []
        for params in particles:
            param_dict = unflatten_params(params, net)
            logits = functional_call(net, param_dict, x_val)
            logits_list.append(logits)
        ensemble_logits = torch.stack(logits_list).mean(dim=0)
        preds = ensemble_logits.argmax(dim=1)
        correct = (preds == y_val.view(-1)).sum().item()
        accuracy = correct / y_val.size(0)
    return accuracy

# --- Parallel Particle Update Function ---
def update_particle(params, net, x_train, y_train, step_size, L, temp, M):
    """
    This function performs the HMC update for a single particle and computes the new loss.
    It is designed to be run in parallel.
    """
    updated_params, acc_rate = hmc_update_particle(
        net, x_train, y_train, params,
        num_samples=M, warmup_steps=0, step_size=step_size,
        num_steps=L, temp=temp
    )
    param_dict = unflatten_params(updated_params, net)
    output = functional_call(net, param_dict, x_train)
    loss_val = model_loss_func_ll(output, y_train, temp=1)
    return updated_params, (loss_val.item() if torch.is_tensor(loss_val) else loss_val), acc_rate

# --- SMC Sampler with HMC Mutations ---
def psmc_single(node, numwork, x_train, y_train, x_val, y_val, NetClass, N, step_size, L, M, trajectory):
    """
    Runs an SMC sampler over 1 replicate. Initializes N particles
    according to a Kaiming prior and gradually increases tempering from 0 to 1 while applying HMC mutations.
    
    Returns:
        times: execution time in one replicate.
        predictions: predictions arrays in one replicate.
        particles_flat: flattened particles in one replicate.
        Z_list: log normalizing constants in one replicate.
        K_list: adjustment constants in one replicate.
        count_all: mutation phase counts in one replicate.
        L_list: summation of number of leapfrog steps.
    """
    print("Node", node)
    torch.manual_seed(node)
    np.random.seed(node)
    
    # Instantiate the model.
    net = NetClass().to(device)
    net.eval()
    net.share_memory()  # Allow sharing between processes
    
    # Determine total number of parameters and validation info.
    d = sum(p.numel() for p in net.parameters())
    num_val = x_val.size(0)
    num_classes = 10  # For SimpleCNN
    
    t_start = time.time()
    particles = []
    llike = []
    preds = []
        
    # Particle initialization.
    for _ in range(N):
        params_init = init_particle(net)
        particles.append(params_init)
        param_dict = unflatten_params(params_init, net)
        output = functional_call(net, param_dict, x_train)
        loss_val = model_loss_func_ll(output, y_train, temp=1)
        llike.append(loss_val.item() if torch.is_tensor(loss_val) else loss_val)
    llike = np.array(llike)
        
    tempcurr = 0.0
    ZZ = 0.0
    KK = 0.0
    count = 0
    Ls = 0
    step_size_cur = step_size
        
    # Tempering loop.
    while tempcurr < 1.0:
        t_start1 = time.time()

        temp_increment = 1.0 - tempcurr
        lwhat = -temp_increment * llike
        lmax = np.max(lwhat)
        w = np.exp(lwhat - lmax)
        w /= np.sum(w)
        ess = 1.0 / np.sum(w**2)
            
        # Adjust temperature increment until the ESS is acceptable.
        while ess < N / 2:
            temp_increment /= 2.0
            lwhat = -temp_increment * llike
            lmax = np.max(lwhat)
            w = np.exp(lwhat - lmax)
            w /= np.sum(w)
            ess = 1.0 / np.sum(w**2)
            
        proposed_temp = tempcurr + temp_increment
        if proposed_temp >= 1.0 and count == 0:
            proposed_temp = 0.5  # Ensure at least one HMC mutation.
        print("Current temperature:", tempcurr, ", dT=", temp_increment, "Step=", count)
        ZZ += np.log(np.mean(np.exp(lwhat - lmax)))
        KK += lmax
        Ls += L
            
        # Systematic resampling.
        cumulative_w = np.cumsum(w)
        new_particles = []
        new_llike = []
        positions = (np.arange(N) + np.random.uniform(0, 1)) / N
        i, j = 0, 0
        while i < N:
            if positions[i] < cumulative_w[j]:
                new_particles.append(particles[j])
                new_llike.append(llike[j])
                i += 1
            else:
                j += 1
        particles = new_particles
        llike = np.array(new_llike)
            
        # Save pre-mutation state.
        old_particles = particles.copy()
        old_llike = llike.copy()
            
        # --- Parallel Mutation Phase using torch.multiprocessing ---
        mutation_success = False
        while not mutation_success:
            with mp.Pool(processes=numwork) as pool:
                results = pool.starmap(
                    update_particle,
                    [(params, net, x_train, y_train, step_size_cur, L, proposed_temp, M)
                     for params in particles]
                )
            # Unpack results.
            updated_particles, updated_llike, acc_rates = zip(*results)
            overall_acc = np.mean(acc_rates)
            print("Overall acceptance for mutation phase:", overall_acc, "Stepsize=", step_size_cur, "L=", L)
            if overall_acc < 0.4 or overall_acc > 0.95:
                if overall_acc < 0.4:
                    step_size_cur *= 0.7
                else:
                    step_size_cur *= 1.1
                print("Unsatisfactory acceptance. Restarting mutation phase with step size", step_size_cur, "L=", L)
                # Revert to pre-mutation state.
                particles = old_particles.copy()
                llike = old_llike.copy()
            else:
                mutation_success = True
                if overall_acc < 0.6:
                    step_size_cur *= 0.7
                elif overall_acc > 0.8:
                    step_size_cur *= 1.1
                particles = list(updated_particles)
                llike = np.array(updated_llike)
            L = min(max(1, int(trajectory / step_size_cur)), 100)
        tempcurr = proposed_temp
        count += 1
        print(f'time for current temperature step = {time.time() - t_start1}')
        
    # Compute predictions on validation data.
    for params in particles:
        param_dict = unflatten_params(params, net)
        output = functional_call(net, param_dict, x_val)
        preds.append(output)
        
    accuracy = compute_accuracy(net, particles, x_val, y_val)
    print("Validation accuracy:", accuracy)
        
    particles_tensor = torch.stack(particles).detach().cpu()
    particles_flat = particles_tensor.numpy().reshape(N, d)
    preds_tensor = torch.stack(preds).detach().cpu()
    predictions = preds_tensor.numpy().reshape(N, x_val.size(0), num_classes)
        
    times_elapsed = time.time() - t_start
    
    return times_elapsed, predictions, particles_flat, ZZ, KK, count, Ls

if __name__ == '__main__':
    # Set the multiprocessing start method to spawn.
    mp.set_start_method('spawn', force=True)
    
    # Retrieve command-line arguments.
    node = int(sys.argv[1])
    numwork = int(sys.argv[2])
    
    # --- Data Loading using torchvision.datasets.MNIST ---
    transform = transforms.Compose([transforms.ToTensor()])
    full_train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    full_val_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    N_tr = 1000  # Number of training samples to use.
    N_val = 1000  # Number of validation samples to use.
    
    train_indices = list(range(len(full_train_dataset)))
    train_subset = Subset(full_train_dataset, train_indices[:N_tr])
    
    val_indices = list(range(len(full_val_dataset)))
    val_subset = Subset(full_val_dataset, val_indices[:N_val])
    
    train_loader_full = DataLoader(train_subset, batch_size=len(train_subset), shuffle=False)
    x_train, y_train = next(iter(train_loader_full))
    x_train, y_train = x_train.to(device), y_train.to(device)
    
    val_loader_full = DataLoader(val_subset, batch_size=len(val_subset), shuffle=False)
    x_val, y_val = next(iter(val_loader_full))
    x_val, y_val = x_val.to(device), y_val.to(device)
    
    net = SimpleCNN().to(device)
    net.eval()
    d = sum(p.numel() for p in net.parameters())
    print(f"The dimension of the parameters in the SimpleCNN is {d}")
    
    # --- SMC Sampler Parameters ---
    R = 1       # Number of replicates
    N = 32      # Number of particles per replicate
    step_size = 0.02
    L = 1       # Number of leapfrog steps per HMC step
    M = 1      # Number of HMC steps per mutation 
    P = 1       # Number of parallel SMC runs
    trajectory = 0.0

    # Print all parameters
    print("step_size =", step_size)
    print("L =", L)
    print("N =", N)
    print("M =", M)
    print("trajectory =", trajectory)
    
    start_time = time.time()
    times_elapsed, predictions, particles_flat, Z, K, count, Lsum = psmc_single(
        node=node,
        numwork=numwork,
        x_train=x_train,
        y_train=y_train,
        x_val=x_val,
        y_val=y_val,
        NetClass=SimpleCNN,
        N=N,
        step_size=step_size,
        L=L,
        M=M,
        trajectory=trajectory
    )
    psmc_single_time = np.array(times_elapsed)
    psmc_single_pred = np.array(np.stack(predictions, axis=0))
    psmc_single_x = np.array(np.stack(particles_flat, axis=0))
    
    total_time = time.time() - start_time
    print(f"Total execution time: {total_time:.2f} seconds")
    
    # --- Save the Results ---
    savemat(f'BayesianNN_MNIST_psmc_SimpleNN_kaiming_d{d}_train{N_tr}_val{N_val}_N{N}_M{M}_node{node}.mat', {
        'psmc_single_time': psmc_single_time,
        'psmc_single_pred': psmc_single_pred,
        'psmc_single_x': psmc_single_x,
        'Z': Z,
        'K': K,
        'count': count,
        'Lsum': Lsum
    })
