
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import HMC, MCMC
import torchvision
import torchvision.transforms as transforms
#from torch.nn.utils.stateless import functional_call
from torch.func import functional_call
from scipy.io import savemat, loadmat
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import sys

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Simple CNN Definition ---
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # One convolution: input channels=1, output channels=4, kernel=3x3 with padding=1
        self.conv = nn.Conv2d(1, 4, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        # One fully connected layer: input image 28x28 -> after pooling becomes 14x14; output 10 classes.
        self.fc = nn.Linear(4 * 14 * 14, 10)
        
        nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')
        if self.conv.bias is not None:
            nn.init.normal_(self.conv.bias, mean=0, std=0.1)
    
        nn.init.kaiming_normal_(self.fc.weight, mode='fan_out', nonlinearity='relu')
        if self.fc.bias is not None:
            nn.init.normal_(self.fc.bias, mean=0, std=0.1)
    
    def forward(self, x):
        x = torch.relu(self.conv(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# --- Helper Functions ---
def flatten_net(net):
    """Flatten all parameters of a network into a single 1D tensor."""
    return torch.cat([p.view(-1) for p in net.parameters()])

def unflatten_params(flat, net):
    """
    Reconstruct a parameter dictionary from the flattened tensor 'flat'
    using the parameter shapes of 'net'.
    """
    param_dict = {}
    pointer = 0
    for name, param in net.named_parameters():
        numel = param.numel()
        param_dict[name] = flat[pointer:pointer+numel].view(param.shape)
        pointer += numel
    return param_dict

def init_particle(net):
    """
    Initialize a flattened parameter vector for 'net' according to the
    Kaiming normal prior for weights and Normal(0, 0.1) for biases.
    """
    param_dict = {}
    for name, param in net.named_parameters():
        if "bias" in name:
            param_dict[name] = torch.randn(param.shape, device=device) * 0.1
        else:
            if len(param.shape) == 4:  # convolutional weights
                fan_in = param.shape[1] * param.shape[2] * param.shape[3]
            elif len(param.shape) == 2:  # fully connected weights
                fan_in = param.shape[1]
            else:
                fan_in = param.numel()
            sigma = np.sqrt(2.0 / fan_in)
            param_dict[name] = torch.randn(param.shape, device=device) * sigma
    flat_params = []
    for name, _ in net.named_parameters():
        flat_params.append(param_dict[name].view(-1))
    return torch.cat(flat_params)

def model_loss_func_ll(output, y_device, temp):
    """
    Compute the scaled cross-entropy loss.
    """
    crit = nn.CrossEntropyLoss(reduction='sum')
    return crit(output, y_device.long().view(-1)) * temp

def model_loss_func(params, x, y, temp, net):
    """
    Compute the potential energy (negative joint log probability) using a Kaiming normal prior
    for weights and Normal(0, 0.1) for biases, along with a categorical likelihood.
    """
    keys = ["conv.weight", "conv.bias", "fc.weight", "fc.bias"]
    log_prior = 0.0
    for k in keys:
        param_tensor = params[k]
        if "bias" in k:
            sigma = 0.1
        else:
            if len(param_tensor.shape) == 4:
                fan_in = param_tensor.shape[1] * param_tensor.shape[2] * param_tensor.shape[3]
            elif len(param_tensor.shape) == 2:
                fan_in = param_tensor.shape[1]
            else:
                fan_in = param_tensor.numel()
            sigma = np.sqrt(2.0 / fan_in)
        log_prior += (param_tensor**2).sum()/2/sigma**2
        # dist.Normal(0.0, sigma).log_prob(param_tensor).sum()
    
    # Use the network via functional_call.
    logits = functional_call(net, params, x)
    log_likelihood = model_loss_func_ll(logits, y, temp=temp)
    return (log_prior + log_likelihood)

def hmc_update_particle(net, x_train, y_train, params_init, num_samples, warmup_steps, step_size, num_steps, temp):
    """
    Runs an HMC chain using model_loss_func.
    Returns the final sample (flat parameter vector) and the acceptance probability.
    """
    def potential_fn(params_dict):
        params = unflatten_params(params_dict["params"], net)
        return model_loss_func(params, x_train, y_train, temp, net)
    
    hmc_kernel = HMC(
        potential_fn=potential_fn,
        step_size=step_size,
        num_steps=num_steps,
        adapt_step_size=False, adapt_mass_matrix=False,
        target_accept_prob=0.65
    )
    
    acceptance_probs = []
    
    def capture_diagnostics(kernel, params, stage, i):
        diag = kernel.logging()
        if "acc. prob" in diag:
            acceptance_probs.append(diag["acc. prob"])
    
    mcmc_run = MCMC(
        hmc_kernel,
        num_samples=num_samples,
        warmup_steps=warmup_steps,
        initial_params={"params": params_init},
        disable_progbar = True,
        hook_fn=capture_diagnostics,
    )
    
    mcmc_run.run()
    samples = mcmc_run.get_samples()["params"]
    acc_rate = acceptance_probs[-1]
    
    return samples[-1], float(acc_rate)


def compute_accuracy(net, particles, x_val, y_val):
    """
    Compute the ensemble accuracy on the validation data.
    The logits from each particle are averaged and then the argmax is taken.
    """
    with torch.no_grad():
        logits_list = []
        for params in particles:
            param_dict = unflatten_params(params, net)
            logits = functional_call(net, param_dict, x_val)
            logits_list.append(logits)
        ensemble_logits = torch.stack(logits_list).mean(dim=0)
        preds = ensemble_logits.argmax(dim=1)
        correct = (preds == y_val.view(-1)).sum().item()
        accuracy = correct / y_val.size(0)
    return accuracy

if __name__ == '__main__':

    node = int(sys.argv[1])
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # --- Data Loading using torchvision.datasets.MNIST ---
    transform = transforms.Compose([transforms.ToTensor()])
    full_train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    full_val_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    N_tr = 1000  # Number of training samples to use.
    N_val = 1000  # Number of validation samples to use.
    
    train_indices = list(range(len(full_train_dataset)))
    train_subset = Subset(full_train_dataset, train_indices[:N_tr])
    
    val_indices = list(range(len(full_val_dataset)))
    val_subset = Subset(full_val_dataset, val_indices[:N_val])
    
    train_loader_full = DataLoader(train_subset, batch_size=len(train_subset), shuffle=False)
    x_train, y_train = next(iter(train_loader_full))
    x_train, y_train = x_train.to(device), y_train.to(device)
    
    val_loader_full = DataLoader(val_subset, batch_size=len(val_subset), shuffle=False)
    x_val, y_val = next(iter(val_loader_full))
    x_val, y_val = x_val.to(device), y_val.to(device)

    net = SimpleCNN().to(device)
    net.eval()
    d = sum(p.numel() for p in net.parameters())
    print(f"The dimension of the parameters in the CobbNN is {d}")

    # --- HMC Sampler Parameters ---
    step_size = 0.005
    L = 1      # Number of leapfrog steps per HMC step
    N = 1      # number of samples per chain
    warmup_steps = 0
    burnin = 40
    thin = burnin # thin=1 is with N=1, which is the parallel setting; thin=burnin is with N!=1, which is the serial setting
    basic = 1 # must divide thin
    thin = int(thin/basic) # thin thin
    burnin = thin # match
    trajectory=0.0

    # single node HMC
    print("Node", node)
    torch.manual_seed(node)
    np.random.seed(node)

    # Print all parameters
    print("step_size =", step_size)
    print("L =", L)
    print("N =", N)
    print("warmup_steps =", warmup_steps)
    print("burnin =", burnin)
    print("thin =", thin)
    print("basic =", basic)
    print("trajectory =", trajectory)

    start_time = time.time()    
    params_init = init_particle(net)
    params = params_init
    particles = []
    pred = []
    acc_rates = []
    Lsum = 0
    
    for i in range(burnin+(N-1)*thin):
        updated_params, acc_rate = hmc_update_particle(
            net, x_train, y_train, params,
            num_samples=basic, warmup_steps=warmup_steps, step_size=step_size,
            num_steps=L, temp=1
        )
        params = updated_params
        acc_rates.append(acc_rate)
        overall_acc = np.mean(acc_rates)
        print(i+1, overall_acc, step_size, L)
        # if i < burnin and ((i+1)%20 == 0):
        if overall_acc < 0.6:
            step_size *= 0.7
        elif overall_acc > 0.8:
            step_size *= 1.1
        if i >= burnin-1 and ((i-burnin+1)%thin==0):
            #print(i+1)
            particles.append(params)
        Lsum += L
        L = min(max(1, int(trajectory/step_size)),100)
    
    for params in particles:
        param_dict = unflatten_params(params, net)
        output = functional_call(net, param_dict, x_val)
        pred.append(output)

    hmc_single_time = time.time() - start_time
    print(f"Total execution time: {hmc_single_time:.2f} seconds")

    hmc_single_pred_np = [p.detach().cpu().numpy() for p in pred]
    hmc_single_x_np = [p.detach().cpu().numpy() for p in particles]
    hmc_single_pred = np.array(hmc_single_pred_np, dtype=object)
    hmc_single_x = np.array(hmc_single_x_np, dtype=object)

    
    def flatten_params(param_tuple):
        flattened_params = []
        for p in param_tuple:
            if isinstance(p, torch.Tensor):
                flattened_params.append(p.flatten())
            elif isinstance(p, (float, int)):
                # Convert the float/int to a tensor and flatten (though a single element is already flat)
                flattened_params.append(torch.tensor([p], dtype=torch.float))
            else:
                raise ValueError(f"Unexpected parameter type: {type(p)}")
        return torch.cat(flattened_params)

    
    # Convert each particle (parameter sample) into a flat tensor.
    hmc_single_x_flat = [flatten_params(sample) for sample in hmc_single_x]
    
    # Now use the flattened parameters for accuracy computation.
    accuracy = compute_accuracy(net, hmc_single_x_flat, x_val, y_val)
    print("Validation accuracy:", accuracy)


    # --- Save the Results ---
    savemat(f'BayesianNN_MNIST_hmc_SimpleNN_kaiming_d{d}_train{N_tr}_val{N_val}_N{N}_thin{thin}_burnin{burnin}_node{node}.mat', {
        'hmc_single_time': hmc_single_time,
        'hmc_single_pred': hmc_single_pred,
        'hmc_single_x': hmc_single_x,
        'Lsum': Lsum,
    })
