#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
pHMC for Bayesian inference on filtered CIFAR-10 using SimpleMLP with MAP prior incorporation.

The code first performs MAP training (logistic regression version: hidden_dim=0, 10 classes)
on filtered CIFAR-10 embeddings (obtained via a pretrained ResNet-50) and then runs HMC sampling
with a potential function that includes a Gaussian prior centered at the MAP parameters.
"""

############################################
#             Imports
############################################
import os
import sys
import time
import random
import numpy as np
from scipy.io import savemat

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset
from torch.func import functional_call
import torch.multiprocessing as mp

import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import HMC, MCMC

from sklearn.metrics import f1_score, average_precision_score
from sklearn.preprocessing import label_binarize

############################################
#       Device and Prior Settings
############################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Separate hyperparameters for MAP regularization and HMC prior
v = 0.2
sigma_w_map = np.sqrt(v)   # MAP weights
sigma_b_map = np.sqrt(v)    # MAP biases

s = 0.05*v
sigma_w_hmc = np.sqrt(s)   # HMC prior weights
sigma_b_hmc = np.sqrt(s)    # HMC prior biases
scale = 1
#print(f'scale is {scale} and sigma is {sigma_b_s} and sigma_v is {sigma_b}')

############################################
#       Filtered CIFAR-10 Dataset
############################################
class FilteredCIFAR10(Dataset):
    def __init__(self, root, train, transform, download, allowed_labels):
        self.dataset = datasets.CIFAR10(root=root, train=train, transform=transform, download=download)
        self.allowed_labels = allowed_labels
        self.data = [(img, label) for img, label in self.dataset if label in allowed_labels]
    def __getitem__(self, idx):
        return self.data[idx]
    def __len__(self):
        return len(self.data)

############################################
#  In-Domain Embedding Extraction using ResNet-50
############################################
def create_resnet50_embedded_cifar10_dataset(
    train_cache_path="cifar10_train_embeddings.pt",
    test_cache_path="cifar10_test_embeddings.pt",
    allowed_labels=list(range(10))
):
    if os.path.exists(train_cache_path) and os.path.exists(test_cache_path):
        print("Loading cached ResNet-50 embeddings for whole CIFAR-10...")
        X_train, y_train = torch.load(train_cache_path)
        X_test, y_test = torch.load(test_cache_path)
        return X_train, y_train, X_test, y_test

    print("Cached embeddings not found. Computing ResNet-50 embeddings for whole CIFAR-10...")
    transform = transforms.Compose([
         transforms.Resize(224),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
    ])
    train_dataset = FilteredCIFAR10(root='./data', train=True, transform=transform, download=True, allowed_labels=allowed_labels)
    test_dataset  = FilteredCIFAR10(root='./data', train=False, transform=transform, download=True, allowed_labels=allowed_labels)
    N_tr = len(train_dataset)
    N_val = len(test_dataset)
    train_dataset = Subset(train_dataset, list(range(N_tr)))
    test_dataset = Subset(test_dataset, list(range(N_val)))
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=2)
    test_loader  = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
    resnet50 = torchvision.models.resnet50(pretrained=True)
    # Remove the final fc layer.
    feature_extractor = nn.Sequential(*list(resnet50.children())[:-1])
    feature_extractor.eval()
    feature_extractor.to(device)
    X_train_list, y_train_list = [], []
    with torch.no_grad():
        for inputs, targets in train_loader:
            inputs = inputs.to(device)
            features = feature_extractor(inputs)
            features = features.view(features.size(0), -1)
            X_train_list.append(features.cpu())
            y_train_list.append(targets)
    X_train = torch.cat(X_train_list, dim=0)
    y_train = torch.cat(y_train_list, dim=0)
    X_test_list, y_test_list = [], []
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            features = feature_extractor(inputs)
            features = features.view(features.size(0), -1)
            X_test_list.append(features.cpu())
            y_test_list.append(targets)
    X_test = torch.cat(X_test_list, dim=0)
    y_test = torch.cat(y_test_list, dim=0)
    torch.save((X_train, y_train), train_cache_path)
    torch.save((X_test, y_test), test_cache_path)
    print("ResNet-50 embeddings computed and saved.")
    return X_train, y_train, X_test, y_test

############################################
#         SimpleMLP Model Definition
############################################
class SimpleMLP(nn.Module):
    """
    SimpleMLP for classification (logistic regression when hidden_dim==0) on 10 classes.
    """
    def __init__(self, input_dim=2048, hidden_dim=0, num_classes=10):
        super(SimpleMLP, self).__init__()
        if hidden_dim is None or hidden_dim == 0:
            self.fc = nn.Linear(input_dim, num_classes)
            nn.init.normal_(self.fc.weight, mean=0, std=sigma_w_map)
            if self.fc.bias is not None:
                nn.init.normal_(self.fc.bias, mean=0, std=sigma_b_map)
        else:
            self.fc1 = nn.Linear(input_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, num_classes)
            nn.init.normal_(self.fc1.weight, mean=0, std=sigma_w_map)
            nn.init.normal_(self.fc2.weight, mean=0, std=sigma_w_map)
            if self.fc1.bias is not None:
                nn.init.normal_(self.fc1.bias, mean=0, std=sigma_b_map)
            if self.fc2.bias is not None:
                nn.init.normal_(self.fc2.bias, mean=0, std=sigma_b_map)
    def forward(self, x):
        if hasattr(self, 'fc'):
            return self.fc(x)
        else:
            x = self.fc1(x)
            x = F.relu(x)
            return self.fc2(x)

############################################
#           Helper Functions
############################################
def flatten_net(net):
    return torch.cat([p.view(-1) for p in net.parameters()])

def unflatten_params(flat, net):
    param_dict = {}
    pointer = 0
    for name, param in net.named_parameters():
        numel = param.numel()
        param_dict[name] = flat[pointer:pointer+numel].view(param.shape)
        pointer += numel
    return param_dict


def init_particle(net, prior_params):
    """
    Initialize a flattened parameter vector for 'net'.
    - If prior_params is provided, draw from the Gaussian HMC prior centered at those MAP parameters.
    - Otherwise draw from a zero-mean Gaussian HMC prior.
    """
    param_dict = {}
    for name, param in net.named_parameters():
        # choose prior mean: MAP value or zero
        mean = prior_params[name].to(device)
        # choose sigma based on weight vs bias
        sigma = sigma_b_hmc if "bias" in name else sigma_w_hmc
        # sample around the mean
        param_dict[name] = mean + torch.randn(param.shape, device=device) * sigma

    # flatten into a single vector
    flat = [param_dict[name].view(-1) for name, _ in net.named_parameters()]
    return torch.cat(flat)


def model_loss_func_ll(output, y, temp):
    crit = nn.CrossEntropyLoss(reduction='sum')
    return crit(output, y.long().view(-1)) * temp

def model_loss_func(params, x, y, temp, net, prior_params):
    """
    Compute the potential energy (negative joint log probability) using a Gaussian prior
    centered at the MAP parameters (prior_params).
    """
    keys = ["fc.weight", "fc.bias"]
    log_prior = 0.0
    for k in keys:
        param_tensor = params[k]
        prior_tensor = prior_params[k]
        if "bias" in k:
            sigma = sigma_b_hmc
        else:
            sigma = sigma_w_hmc
        log_prior += ((param_tensor - prior_tensor)**2).sum() / (2 * sigma**2)
    logits = functional_call(net, params, x)
    log_likelihood = model_loss_func_ll(logits, y, temp)
    return log_prior + log_likelihood

def hmc_update_particle(net, x_train, y_train, params_init, num_samples, warmup_steps, step_size, num_steps, temp, prior_params):
    """
    Run an HMC chain using the potential defined by model_loss_func.
    Returns the final sample (flat parameter vector) and acceptance probability.
    """
    def potential_fn(params_dict):
        params = unflatten_params(params_dict["params"], net)
        return model_loss_func(params, x_train, y_train, temp, net, prior_params)
    
    hmc_kernel = HMC(
        potential_fn=potential_fn,
        step_size=step_size,
        num_steps=num_steps,
        adapt_step_size=False,
        adapt_mass_matrix=False,
        target_accept_prob=0.65
    )
    acceptance_probs = []
    def capture_diagnostics(kernel, params, stage, i):
        diag = kernel.logging()
        if "acc. prob" in diag:
            acceptance_probs.append(diag["acc. prob"])
    mcmc_run = MCMC(
        hmc_kernel,
        num_samples=num_samples,
        warmup_steps=warmup_steps,
        initial_params={"params": params_init},
        disable_progbar=True,
        hook_fn=capture_diagnostics,
    )
    mcmc_run.run()
    samples = mcmc_run.get_samples()["params"]
    acc_rate = acceptance_probs[-1] if acceptance_probs else 0.0
    return samples[-1], float(acc_rate)

def softmax_np(x):
    # Numerically stable softmax along the last axis.
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / np.sum(e_x, axis=-1, keepdims=True)

def compute_accuracy(net, particles, x_val, y_val):
    with torch.no_grad():
        logits_list = []
        for params in particles:
            param_dict = unflatten_params(params, net)
            logits = functional_call(net, param_dict, x_val)
            logits_list.append(logits)
        ensemble_logits = torch.stack(logits_list).mean(dim=0)
        preds = ensemble_logits.argmax(dim=1)
        correct = (preds == y_val.view(-1)).sum().item()
        return correct / y_val.size(0)

############################################
#    Main Execution: MAP Training then HMC Sampling and Analysis
############################################
if __name__ == '__main__':
    # Set multiprocessing spawn method and seed.
    mp.set_start_method('spawn', force=True)
    seed = int(sys.argv[1]) 

    print(f'scale is {scale} and sigma is {sigma_b_hmc} and sigma_v is {sigma_b_map}')

    ############################################
    # Load In-Domain Embeddings (whole CIFAR-10: classes 0–9)
    ############################################
    allowed_labels = list(range(10))
    train_cache = "cifar10_train_embeddings.pt"
    test_cache  = "cifar10_test_embeddings.pt"
    X_train, y_train, X_test, y_test = create_resnet50_embedded_cifar10_dataset(
         train_cache_path=train_cache,
         test_cache_path=test_cache,
         allowed_labels=allowed_labels
    )
    print(f"Embedding dimension: {X_train.shape[1]}")

    # split and split into 40 000 train / 10 000 early‐stop validation
    n_total = X_train.size(0)        # should be 50 000 for CIFAR-10
    print(f"Total CIFAR-10 train embeddings: {n_total}")

    # 2) Split
    n_train_es = 50000
    n_val_es   = n_total - n_train_es    # = 10 000

    X_train_subset = X_train[:n_train_es]
    y_train_subset = y_train[:n_train_es]
    X_val_subset   = X_train[n_train_es:]
    y_val_subset   = y_train[n_train_es:]

    # Create TensorDatasets and DataLoaders
    train_dataset = TensorDataset(X_train_subset, y_train_subset)
    val_dataset   = TensorDataset(X_val_subset, y_val_subset)
    batch_size = 128
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    pyro.set_rng_seed(seed)
    
    ############################################
    # MAP Training on In-Domain Data (SimpleMLP)
    ############################################
    model_mlp = SimpleMLP(input_dim=X_train_subset.shape[1], hidden_dim=0, num_classes=10).to(device)
    optimizer_mlp = optim.Adam(model_mlp.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    train_losses = []
    val_losses = []
    moving_avg_window = 10
    best_moving_avg = float('inf')
    patience = 5
    no_improve_count = 0
    max_epochs = 200

    t_map_start = time.time()
    for epoch in range(max_epochs):
        model_mlp.train()
        running_loss = 0.0
        for inputs, labels in train_loader:                   # iterate mini-batches
            #x_batch, y_batch = x_batch.to(device), y_batch.to(device)

            optimizer_mlp.zero_grad()
            outputs = model_mlp(inputs)
            ce_loss = criterion(outputs, labels)

            # MAP regularizer on this batch
            reg_loss = (torch.sum(model_mlp.fc.weight**2) / (2 * sigma_w_map**2) +
                        torch.sum(model_mlp.fc.bias**2)   / (2 * sigma_b_map**2))
            loss = ce_loss + reg_loss / len(train_loader.dataset)             # divide reg by total train size
            loss.backward()
            optimizer_mlp.step()
            running_loss += loss.item()
        train_loss = running_loss / len(train_loader)

        # model_mlp.eval()
        # val_running_loss = 0.0
        # with torch.no_grad():
        #     for inputs, labels in val_loader:
        #         outputs_val = model_mlp(inputs)
        #         ce_loss_val = criterion(outputs_val, labels)
        #         reg_loss_val = (torch.sum(model_mlp.fc.weight**2) / (2 * sigma_w_map**2) +
        #                         torch.sum(model_mlp.fc.bias**2)   / (2 * sigma_b_map**2))
        #         val_running_loss = (ce_loss_val + reg_loss_val / len(train_loader.dataset)).item()
        # val_loss = val_running_loss / len(val_loader)
        # val_losses.append(val_loss)

        # print(f"MAP Epoch {epoch+1:04d}: Train Loss = {loss.item():.4f}, Val Loss = {val_loss:.4f}")
        # if epoch >= moving_avg_window - 1:
        #     moving_avg = np.mean(val_losses[-moving_avg_window:])
        #     if moving_avg < best_moving_avg:
        #         best_moving_avg = moving_avg
        #         no_improve_count = 0
        #     else:
        #         no_improve_count += 1
        #     if no_improve_count >= patience:
        #         print(f"MAP: Early stopping at epoch {epoch+1}")
        #         break
    total_map_time = time.time() - t_map_start
    print(f"MAP Total execution time: {total_map_time:.2f} seconds")
    # Extract MAP prior parameters.
    prior_params = model_mlp.state_dict()
    print("Prior parameters extracted from MAP model.")

    for name in prior_params:
        prior_params[name] = scale * prior_params[name]

    ############################################
    # HMC Sampling with MAP Prior
    ############################################
    # HMC parameters (adjust as needed)
    step_size = 0.002
    trajectory = 0.0  # For adapting L
    L = min(max(1, int(trajectory / step_size)), 100)
    N_samples = 1     # Number of samples to collect (chain length)
    warmup_steps = 0
    burnin_all = 200
    thin_all = burnin_all
    basic = 2
    thin = int(thin_all / basic)
    burnin = thin
    temp = 1.0

    print("\nHMC Configuration:")
    print(f"  step_size = {step_size:.5f}")
    print(f"  L = {L}")
    print(f"  burnin = {burnin_all}, thin = {thin_all}, basic = {basic}")

    net = SimpleMLP(input_dim=X_train.shape[1], hidden_dim=0, num_classes=10).to(device)
    net.eval()
    d = sum(p.numel() for p in net.parameters())
    print(f"The dimension of the parameters in SimpleMLP is {d}")

    t_hmc_start = time.time()
    #params_init = init_particle(net)
    # initialize by sampling from the prior centered at the MAP fit
    params_init = init_particle(net, prior_params)
    params = params_init
    particles = []
    acc_rates = []
    Lsum = 0
    total_iterations = burnin + (N_samples - 1) * thin
    x_hmc, y_hmc = X_train.to(device), y_train.to(device)
    for i in range(total_iterations):
        updated_params, acc_rate = hmc_update_particle(
            net, x_hmc, y_hmc, params,
            num_samples=basic, warmup_steps=warmup_steps,
            step_size=step_size, num_steps=L, temp=temp,
            prior_params=prior_params
        )
        params = updated_params
        acc_rates.append(acc_rate)
        overall_acc = np.mean(acc_rates)
        print(f"Iteration {i+1:04d}: Avg. acceptance = {overall_acc:.4f}, step_size = {step_size:.5f}, L = {L}")
        if overall_acc < 0.6:
            step_size *= 0.7
        elif overall_acc > 0.8:
            step_size *= 1.1
        if i >= burnin - 1 and ((i - burnin + 1) % thin == 0):
            print(f"Collected sample at iteration {i+1}")
            particles.append(params)
        Lsum += L * basic
        L = min(max(1, int(trajectory / step_size)), 100)
    hmc_single_time = time.time() - t_hmc_start
    print(f"HMC Total execution time: {hmc_single_time:.2f} seconds")
    # evaluate HMC ensemble on the *test* (ID) set
    accuracy_hmc = compute_accuracy(net, particles, X_test, y_test)
    print(f"HMC Validation Accuracy: {accuracy_hmc*100:.2f}%")
    
    # Convert particles to NumPy array.
    particles_tensor = torch.stack(particles).detach().cpu()
    hmc_single_x = particles_tensor.numpy().reshape(len(particles), d)
    # rebuild preds on test set
    preds = []
    for particle in particles:
        param_dict = unflatten_params(particle, net)
        output = functional_call(net, param_dict, X_test)
        preds.append(output)
    preds_tensor = torch.stack(preds).detach().cpu()
    num_test = X_test.size(0)
    hmc_single_pred = preds_tensor.numpy().reshape(len(particles), num_test, 10)
    
    # Save HMC sampling results.
    savemat(f'BayesianNN_CIFAR_hmc_SimpleMLP_MAP_d{d}_thin{thin_all}_burnin{burnin_all}_node{seed}.mat', {
        'hmc_single_time': hmc_single_time,
        'hmc_single_pred': hmc_single_pred,
        'hmc_single_x': hmc_single_x,
        'Lsum': Lsum,
        'map_single_time': total_map_time,
    })
