"""
Use CIFAR-10 data and the SimpleMLP model (from the transfer-learning code) 
This code uses Pyro’s basic HMC for Bayesian inference.
The SimpleMLP has one hidden layer with input_dim=2048 and 10 output classes.
"""

# ===== All imports placed at the top =====
import os
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
#import matplotlib.pyplot as plt
from scipy.io import savemat, loadmat
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
from torchvision import models
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, HMC
from torch.func import functional_call
import torch.multiprocessing as mp

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Updated sigma settings
sigma_w = 0.01
sigma_b = 0.001

# ===== CIFAR-10 Embedding Function using ResNet-50 =====
def create_resnet50_embedded_cifar10_dataset(train_cache_path="cifar10_train_embeddings.pt",
                                             test_cache_path="cifar10_test_embeddings.pt"):
    """
    Loads CIFAR-10 and uses a pre-trained ResNet-50 to extract the final pooled 2048-dimensional features.
    These features and labels are cached to disk.
    """
    if os.path.exists(train_cache_path) and os.path.exists(test_cache_path):
        print("Loading cached ResNet-50 embeddings for CIFAR-10...")
        X_train, y_train = torch.load(train_cache_path)
        X_test, y_test = torch.load(test_cache_path)
        return X_train, y_train, X_test, y_test

    print("Cached embeddings not found. Computing ResNet-50 embeddings for CIFAR-10...")
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset  = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=2)
    test_loader  = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

    resnet50 = models.resnet50(pretrained=True)
    # Remove the final fc layer: keep up to the avgpool layer.
    feature_extractor = nn.Sequential(*list(resnet50.children())[:-1])
    feature_extractor.eval()
    feature_extractor.to(device)

    X_train_list, y_train_list = [], []
    with torch.no_grad():
        for inputs, targets in train_loader:
            inputs = inputs.to(device)
            features = feature_extractor(inputs)
            features = features.view(features.size(0), -1)  # shape: (batch, 2048)
            X_train_list.append(features.cpu())
            y_train_list.append(targets)
    X_train = torch.cat(X_train_list, dim=0)
    y_train = torch.cat(y_train_list, dim=0)

    X_test_list, y_test_list = [], []
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            features = feature_extractor(inputs)
            features = features.view(features.size(0), -1)
            X_test_list.append(features.cpu())
            y_test_list.append(targets)
    X_test = torch.cat(X_test_list, dim=0)
    y_test = torch.cat(y_test_list, dim=0)

    torch.save((X_train, y_train), train_cache_path)
    torch.save((X_test, y_test), test_cache_path)
    print("ResNet-50 embeddings computed and saved.")
    return X_train, y_train, X_test, y_test

# ===== SimpleMLP Definition =====
class SimpleMLP(nn.Module):
    #def __init__(self, input_dim=2048, hidden_dim=128, num_classes=10):
    def __init__(self, input_dim=2048, hidden_dim=128, num_classes=10):
        super().__init__()
        # One hidden layer MLP: fc1 -> ReLU -> fc2
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
        nn.init.normal_(self.fc1.weight, mean=0, std=sigma_w)
        nn.init.normal_(self.fc2.weight, mean=0, std=sigma_w)
        if self.fc1.bias is not None:
            nn.init.normal_(self.fc1.bias, mean=0, std=sigma_b)
        if self.fc2.bias is not None:
            nn.init.normal_(self.fc2.bias, mean=0, std=sigma_b)
    
    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

# ===== LogReg Definition (if needed) =====
class LogReg(nn.Module):
    #def __init__(self, input_dim=2048, hidden_dim=128, num_classes=10):
    def __init__(self, input_dim=2048, hidden_dim=128, num_classes=10):
        super().__init__()
        self.fc = nn.Linear(input_dim, num_classes)
        nn.init.normal_(self.fc.weight, mean=0, std=sigma_w)
        if self.fc.bias is not None:
            nn.init.normal_(self.fc.bias, mean=0, std=sigma_b)
    
    def forward(self, x):
        x = self.fc(x)
        return x

# ===== Helper Functions =====
def flatten_net(net):
    """Flatten all parameters of a network into a single 1D tensor."""
    return torch.cat([p.view(-1) for p in net.parameters()])

def unflatten_params(flat, net):
    """
    Reconstruct a parameter dictionary from the flattened tensor 'flat'
    using the parameter shapes of 'net'.
    """
    param_dict = {}
    pointer = 0
    for name, param in net.named_parameters():
        numel = param.numel()
        param_dict[name] = flat[pointer:pointer+numel].view(param.shape)
        pointer += numel
    return param_dict

def init_particle(net, device):
    """
    Initialize a flattened parameter vector for 'net' according to a Gaussian prior
    using sigma_w for weights and sigma_b for biases.
    """
    param_dict = {}
    for name, param in net.named_parameters():
        if "bias" in name:
            param_dict[name] = torch.randn(param.shape, device=device) * sigma_b
        else:
            param_dict[name] = torch.randn(param.shape, device=device) * sigma_w
    flat_params = []
    for name, _ in net.named_parameters():
        flat_params.append(param_dict[name].view(-1))
    return torch.cat(flat_params)

def model_loss_func_ll(output, y_device, temp):
    """
    Compute the scaled cross-entropy loss.
    """
    crit = nn.CrossEntropyLoss(reduction='sum')
    return crit(output, y_device.long().view(-1)) * temp

def model_loss_func(params, x, y, temp, net):
    """
    Compute the potential energy (negative joint log probability) using a Gaussian prior.
    Uses sigma_w for weights and sigma_b for biases.
    """
    keys = ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"]
    log_prior = 0.0
    for k in keys:
        param_tensor = params[k]
        if "bias" in k:
            sigma = sigma_b
        else:
            sigma = sigma_w
        log_prior += (param_tensor**2).sum() / (2 * sigma**2)
    logits = functional_call(net, params, x)
    log_likelihood = model_loss_func_ll(logits, y, temp=temp)
    return log_prior + log_likelihood

def hmc_update_particle(net, x_train, y_train, params_init, num_samples, warmup_steps, step_size, num_steps, temp):
    """
    Runs an HMC chain using model_loss_func.
    Returns the final sample (flat parameter vector) and the acceptance probability.
    """
    def potential_fn(params_dict):
        params = unflatten_params(params_dict["params"], net)
        return model_loss_func(params, x_train, y_train, temp, net)
    
    hmc_kernel = HMC(
        potential_fn=potential_fn,
        step_size=step_size,
        num_steps=num_steps,
        adapt_step_size=False,
        adapt_mass_matrix=False,
        target_accept_prob=0.65
    )
    
    acceptance_probs = []
    def capture_diagnostics(kernel, params, stage, i):
        diag = kernel.logging()
        if "acc. prob" in diag:
            acceptance_probs.append(diag["acc. prob"])
    
    mcmc_run = MCMC(
        hmc_kernel,
        num_samples=num_samples,
        warmup_steps=warmup_steps,
        initial_params={"params": params_init},
        disable_progbar=True,
        hook_fn=capture_diagnostics,
    )
    mcmc_run.run()
    samples = mcmc_run.get_samples()["params"]
    acc_rate = acceptance_probs[-1] if acceptance_probs else 0.0
    return samples[-1], float(acc_rate)


def compute_accuracy(net, particles, x_val, y_val):
    """
    Compute ensemble accuracy on validation data by averaging logits from each particle.
    """
    with torch.no_grad():
        logits_list = []
        for params in particles:
            param_dict = unflatten_params(params, net)
            logits = functional_call(net, param_dict, x_val)
            logits_list.append(logits)
        ensemble_logits = torch.stack(logits_list).mean(dim=0)
        preds = ensemble_logits.argmax(dim=1)
        correct = (preds == y_val.view(-1)).sum().item()
        accuracy = correct / y_val.size(0)
    return accuracy


# ===== Main Bayesian HMC Inference using CIFAR-10 and SimpleMLP =====
if __name__ == '__main__':

    node = int(sys.argv[1])
    print("Node", node)
    torch.manual_seed(node)
    np.random.seed(node)
    
    # Data Loading: use CIFAR-10 embeddings computed via ResNet-50.
    train_cache = "cifar10_train_embeddings.pt"
    test_cache  = "cifar10_test_embeddings.pt"
    X_train, y_train, X_test, y_test = create_resnet50_embedded_cifar10_dataset(
        train_cache_path=train_cache, test_cache_path=test_cache
    )
    x_train = X_train.to(device)
    y_train = y_train.to(device)
    x_val   = X_test.to(device)
    y_val   = y_test.to(device)
    
    print("Embedding dimension:", x_train.shape[1])
    
    # Forward Model Definition: instantiate SimpleMLP with 2048-dim input and 10 classes.
    net = SimpleMLP(input_dim=x_train.shape[1], hidden_dim=128, num_classes=10).to(device)
    net.eval()
    d = sum(p.numel() for p in net.parameters())
    print(f"The dimension of the parameters in SimpleMLP is {d}")
    
    # --- HMC Sampler Parameters ---
    step_size = 0.0005
    L = 1               # Number of leapfrog steps per HMC step
    N = 1              # Number of samples to collect
    warmup_steps = 0
    burnin_all = 550
    thin_all = burnin_all
    basic = 5          # Must divide burnin; basic is the number of samples per HMC update
    thin = int(burnin_all / basic)
    burnin = thin      # Match burnin to thin in this setting
    trajectory = 0.0
    temp = 1.0
    
    print("step_size =", step_size)
    print("L =", L)
    print("N =", N)
    print("warmup_steps =", warmup_steps)
    print("burnin =", burnin)
    print("thin =", thin)
    print("basic =", basic)
    print("trajectory =", trajectory)
    
    start_time = time.time()
    params_init = init_particle(net, device)
    params = params_init
    particles = []
    pred = []
    acc_rates = []
    Lsum = 0
    
    total_steps = burnin + (N - 1) * thin
    for i in range(total_steps):
        updated_params, acc_rate = hmc_update_particle(
            net, x_train, y_train, params,
            num_samples=basic, warmup_steps=warmup_steps,
            step_size=step_size, num_steps=L, temp=temp
        )
        params = updated_params
        acc_rates.append(acc_rate)
        overall_acc = np.mean(acc_rates)
        print(i+1, overall_acc, step_size, L)
        if overall_acc < 0.6:
            step_size *= 0.7
        elif overall_acc > 0.8:
            step_size *= 1.1
        if i >= burnin-1 and ((i - burnin + 1) % thin == 0):
            print("Collecting sample at step", i+1)
            particles.append(params)
        Lsum += L
        L = min(max(1, int(trajectory/step_size)), 100)
    
    for params in particles:
        param_dict = unflatten_params(params, net)
        output = functional_call(net, param_dict, x_val)
        pred.append(output)
    print(len(pred))

    hmc_single_time = time.time() - start_time
    print(f"Total execution time: {hmc_single_time:.2f} seconds")

    hmc_single_pred_np = [p.detach().cpu().numpy() for p in pred]
    hmc_single_x_np = [p.detach().cpu().numpy() for p in particles]
    hmc_single_pred = np.array(hmc_single_pred_np, dtype=object)
    hmc_single_x = np.array(hmc_single_x_np, dtype=object)

    
    def flatten_params(param_tuple):
        flattened_params = []
        for p in param_tuple:
            if isinstance(p, torch.Tensor):
                flattened_params.append(p.flatten())
            elif isinstance(p, (float, int)):
                # Convert the float/int to a tensor and flatten (though a single element is already flat)
                flattened_params.append(torch.tensor([p], dtype=torch.float))
            else:
                raise ValueError(f"Unexpected parameter type: {type(p)}")
        return torch.cat(flattened_params)

    
    # Convert each particle (parameter sample) into a flat tensor.
    hmc_single_x_flat = [flatten_params(sample) for sample in hmc_single_x]
    
    # Now use the flattened parameters for accuracy computation.
    accuracy = compute_accuracy(net, hmc_single_x_flat, x_val, y_val)
    print("Validation accuracy:", accuracy)
    
    savemat(f'BayesianNN_CIFAR_phmc_SimpleMLP_scaleGaussian_d{d}_N{N}_thin{thin_all}_burnin{burnin_all}_node{node}.mat', {
        'hmc_single_time': hmc_single_time,
        'hmc_single_pred': hmc_single_pred,
        'hmc_single_x': hmc_single_x,
        'Lsum': Lsum
    })
