#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
pSMC with MAP Prior on Whole CIFAR-10 (classes 0–9)
using ResNet-50 embeddings and SimpleMLP.

This script first trains a MAP model on CIFAR-10 embeddings (classes 0–9)
to extract a prior, then runs SMC inference (using HMC mutations) with that MAP prior.
After SMC inference, it computes comprehensive analysis metrics for in-domain
and out-of-domain data.
"""

############################################
#                Imports
############################################
import os
import sys
import time
import random
import numpy as np
from scipy.io import savemat

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset
from torch.func import functional_call
import torch.multiprocessing as mp

import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import HMC, MCMC
from torchvision import models

############################################
#       Device and Prior Settings
############################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Separate hyperparameters for MAP regularization and HMC prior
v = 0.2
sigma_w_map = np.sqrt(v)   # MAP weights
sigma_b_map = np.sqrt(v)    # MAP biases

s = 0.05*v
sigma_w_smc = np.sqrt(s)   # HMC prior weights
sigma_b_smc = np.sqrt(s)    # HMC prior biases
scale = 1

############################################
#       Whole CIFAR-10 Dataset
############################################
class FilteredCIFAR10(Dataset):
    def __init__(self, root, train, transform, download, allowed_labels):
        self.dataset = datasets.CIFAR10(root=root, train=train, transform=transform, download=download)
        self.allowed_labels = allowed_labels
        self.data = [(img, label) for img, label in self.dataset if label in allowed_labels]
    def __getitem__(self, idx):
        return self.data[idx]
    def __len__(self):
        return len(self.data)

############################################
#   In-Domain Embedding Extraction using ResNet-50
############################################
def create_resnet50_embedded_cifar10_dataset(
    train_cache_path="cifar10_train_embeddings.pt",
    test_cache_path="cifar10_test_embeddings.pt",
    allowed_labels=list(range(10))
):
    if os.path.exists(train_cache_path) and os.path.exists(test_cache_path):
        print("Loading cached ResNet-50 embeddings for whole CIFAR-10...")
        X_train, y_train = torch.load(train_cache_path)
        X_test, y_test = torch.load(test_cache_path)
        return X_train, y_train, X_test, y_test

    print("Cached embeddings not found. Computing ResNet-50 embeddings for whole CIFAR-10...")
    transform = transforms.Compose([
         transforms.Resize(224),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
    ])
    train_dataset = FilteredCIFAR10(root='./data', train=True, transform=transform, download=True, allowed_labels=allowed_labels)
    test_dataset  = FilteredCIFAR10(root='./data', train=False, transform=transform, download=True, allowed_labels=allowed_labels)
    N_tr = len(train_dataset)
    N_val = len(test_dataset)
    train_dataset = Subset(train_dataset, list(range(N_tr)))
    test_dataset = Subset(test_dataset, list(range(N_val)))
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=2)
    test_loader  = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
    resnet50 = models.resnet50(pretrained=True)
    # Remove final fc layer to get 2048-dim features.
    feature_extractor = nn.Sequential(*list(resnet50.children())[:-1])
    feature_extractor.eval()
    feature_extractor.to(device)
    X_train_list, y_train_list = [], []
    with torch.no_grad():
        for inputs, targets in train_loader:
            inputs = inputs.to(device)
            features = feature_extractor(inputs)
            features = features.view(features.size(0), -1)
            X_train_list.append(features.cpu())
            y_train_list.append(targets)
    X_train = torch.cat(X_train_list, dim=0)
    y_train = torch.cat(y_train_list, dim=0)
    X_test_list, y_test_list = [], []
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            features = feature_extractor(inputs)
            features = features.view(features.size(0), -1)
            X_test_list.append(features.cpu())
            y_test_list.append(targets)
    X_test = torch.cat(X_test_list, dim=0)
    y_test = torch.cat(y_test_list, dim=0)
    torch.save((X_train, y_train), train_cache_path)
    torch.save((X_test, y_test), test_cache_path)
    print("ResNet-50 embeddings computed and saved.")
    return X_train, y_train, X_test, y_test


############################################
#         SimpleMLP Model Definition
############################################
class SimpleMLP(nn.Module):
    """
    SimpleMLP for classification over 10 classes.
    When hidden_dim==0 (or None), the model is logistic regression.
    """
    def __init__(self, input_dim=2048, hidden_dim=0, num_classes=10):
        super(SimpleMLP, self).__init__()
        if hidden_dim is None or hidden_dim == 0:
            self.fc = nn.Linear(input_dim, num_classes)
            nn.init.normal_(self.fc.weight, mean=0, std=sigma_w_map)
            if self.fc.bias is not None:
                nn.init.normal_(self.fc.bias, mean=0, std=sigma_b_map)
        else:
            self.fc1 = nn.Linear(input_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, num_classes)
            nn.init.normal_(self.fc1.weight, mean=0, std=sigma_w_map)
            nn.init.normal_(self.fc2.weight, mean=0, std=sigma_w_map)
            if self.fc1.bias is not None:
                nn.init.normal_(self.fc1.bias, mean=0, std=sigma_b_map)
            if self.fc2.bias is not None:
                nn.init.normal_(self.fc2.bias, mean=0, std=sigma_b_map)
    def forward(self, x):
        if hasattr(self, 'fc'):
            return self.fc(x)
        else:
            x = self.fc1(x)
            x = F.relu(x)
            return self.fc2(x)

############################################
#           Helper Functions
############################################
def flatten_net(net):
    return torch.cat([p.view(-1) for p in net.parameters()])

def unflatten_params(flat, net):
    param_dict = {}
    pointer = 0
    for name, param in net.named_parameters():
        numel = param.numel()
        param_dict[name] = flat[pointer:pointer+numel].view(param.shape)
        pointer += numel
    return param_dict

def init_particle(net, prior_params):
    """Initialize a flattened parameter vector from the SMC prior centered at the MAP prior_params."""
    param_dict = {}
    for name, param in net.named_parameters():
        # get the MAP prior mean for this parameter
        mean = prior_params[name].to(device)
        if "bias" in name:
            std = sigma_b_smc
        else:
            std = sigma_w_smc
        param_dict[name] = mean + torch.randn(param.shape, device=device) * std
    flat_params = [param_dict[name].view(-1) for name, _ in net.named_parameters()]
    return torch.cat(flat_params)

def model_loss_func_ll(output, y, temp):
    crit = nn.CrossEntropyLoss(reduction='sum')
    return crit(output, y.long().view(-1)) * temp

def model_loss_func(params, x, y, temp, net, prior_params):
    """
    Negative log joint probability (potential energy) using a Gaussian prior
    centered at the MAP parameters (prior_params).
    """
    keys = ["fc.weight", "fc.bias"]
    log_prior = 0.0
    for k in keys:
        param_tensor = params[k]
        prior_tensor = prior_params[k]
        if "bias" in k:
            sigma = sigma_b_smc
        else:
            sigma = sigma_w_smc
        log_prior += ((param_tensor - prior_tensor)**2).sum() / (2 * sigma**2)
    logits = functional_call(net, params, x)
    log_likelihood = model_loss_func_ll(logits, y, temp)
    return log_prior + log_likelihood

############################################
#    HMC & SMC Functions with Prior Incorporation
############################################
def hmc_update_particle(net, x_train, y_train, params_init, num_samples, warmup_steps, step_size, num_steps, temp, prior_params):
    def potential_fn(params_dict):
        params = unflatten_params(params_dict["params"], net)
        return model_loss_func(params, x_train, y_train, temp, net, prior_params)
    
    hmc_kernel = HMC(
        potential_fn=potential_fn,
        step_size=step_size,
        num_steps=num_steps,
        adapt_step_size=False,
        adapt_mass_matrix=False,
        target_accept_prob=0.65
    )
    acceptance_probs = []
    def capture_diagnostics(kernel, params, stage, i):
        diag = kernel.logging()
        if "acc. prob" in diag:
            acceptance_probs.append(diag["acc. prob"])
    mcmc_run = MCMC(
        hmc_kernel,
        num_samples=num_samples,
        warmup_steps=warmup_steps,
        initial_params={"params": params_init},
        disable_progbar=True,
        hook_fn=capture_diagnostics,
    )
    mcmc_run.run()
    samples = mcmc_run.get_samples()["params"]
    acc_rate = acceptance_probs[-1] if acceptance_probs else 0.0
    return samples[-1], float(acc_rate)

def update_particle(params, net, x_train, y_train, step_size, L, temp, M, prior_params, node):
    torch.manual_seed(node)
    np.random.seed(node)
    random.seed(node)
    pyro.set_rng_seed(node)
    updated_params, acc_rate = hmc_update_particle(
        net, x_train, y_train, params,
        num_samples=M, warmup_steps=0, step_size=step_size,
        num_steps=L, temp=temp, prior_params=prior_params
    )
    param_dict = unflatten_params(updated_params, net)
    output = functional_call(net, param_dict, x_train)
    loss_val = model_loss_func_ll(output, y_train, temp=1)
    return updated_params, (loss_val.item() if torch.is_tensor(loss_val) else loss_val), acc_rate

def compute_accuracy(net, particles, x_val, y_val):
    with torch.no_grad():
        logits_list = []
        for params in particles:
            param_dict = unflatten_params(params, net)
            logits = functional_call(net, param_dict, x_val)
            logits_list.append(logits)
        ensemble_logits = torch.stack(logits_list).mean(dim=0)
        preds = ensemble_logits.argmax(dim=1)
        correct = (preds == y_val.view(-1)).sum().item()
        return correct / y_val.size(0)

def psmc_single(node, numwork, x_train, y_train, x_val, y_val, NetClass, N, step_size, L, M, trajectory, prior_params):
    print(f"\nNode: {node}")
    torch.manual_seed(node)
    np.random.seed(node)
    random.seed(node)
    pyro.set_rng_seed(node)
    
    net = NetClass().to(device)
    net.eval()
    
    d = sum(p.numel() for p in net.parameters())
    num_val = x_val.size(0)
    num_classes = 10  # in-domain
    
    t_start = time.time()
    particles = []
    llike = []
    preds = []
    
    for _ in range(N):
        params_init = init_particle(net, prior_params)
        particles.append(params_init)
        param_dict = unflatten_params(params_init, net)
        output = functional_call(net, param_dict, x_train)
        loss_val = model_loss_func_ll(output, y_train, temp=1)
        llike.append(loss_val.item() if torch.is_tensor(loss_val) else loss_val)
    llike = np.array(llike)

    tempcurr = 0.0
    ZZ = 0.0
    KK = 0.0
    count = 0
    Ls = 0
    step_size_cur = step_size
    
    while tempcurr < 1.0:
        temp_increment = 1.0 - tempcurr
        lwhat = -temp_increment * llike
        lmax = np.max(lwhat)
        w = np.exp(lwhat - lmax)
        w /= np.sum(w)
        ess = 1.0 / np.sum(w**2)
        
        while ess < N / 2:
            temp_increment /= 2.0
            lwhat = -temp_increment * llike
            lmax = np.max(lwhat)
            w = np.exp(lwhat - lmax)
            w /= np.sum(w)
            ess = 1.0 / np.sum(w**2)
        
        proposed_temp = tempcurr + temp_increment
        if proposed_temp >= 1.0 and count == 0:
            proposed_temp = 0.5
        print(f"Current temperature: {tempcurr:.4f}, dT = {temp_increment:.4f}, Step = {count}")
        ZZ += np.log(np.mean(np.exp(lwhat - lmax)))
        KK += lmax
        Ls += L * M
        
        # Systematic resampling.
        cumulative_w = np.cumsum(w)
        new_particles = []
        new_llike = []
        positions = (np.arange(N) + np.random.uniform(0, 1)) / N
        i, j = 0, 0
        while i < N:
            if positions[i] < cumulative_w[j]:
                new_particles.append(particles[j])
                new_llike.append(llike[j])
                i += 1
            else:
                j += 1
        particles = new_particles
        llike = np.array(new_llike)
        
        old_particles = particles.copy()
        old_llike = llike.copy()
        
        mutation_success = False
        while not mutation_success:
            mutation_nodes = [random.randint(0, 2**10 - 1) for _ in range(len(particles))]
            with mp.Pool(processes=numwork) as pool:
                results = pool.starmap(
                    update_particle,
                    [(params, net, x_train, y_train, step_size_cur, L, proposed_temp, M, prior_params, node)
                     for params, node in zip(particles, mutation_nodes)]
                )
            updated_particles, updated_llike, acc_rates = zip(*results)
            overall_acc = np.mean(acc_rates)
            print(f"Overall acceptance for mutation phase: {overall_acc:.4f}, Stepsize = {step_size_cur:.6f}, L = {L}")
            if overall_acc < 0.4 or overall_acc > 0.95:
                if overall_acc < 0.4:
                    step_size_cur *= 0.7
                else:
                    step_size_cur *= 1.1
                print(f"Unsatisfactory acceptance. Restarting mutation phase with step size {step_size_cur:.6f}, L = {L}")
                particles = old_particles.copy()
                llike = old_llike.copy()
            else:
                mutation_success = True
                if overall_acc < 0.6:
                    step_size_cur *= 0.7
                elif overall_acc > 0.8:
                    step_size_cur *= 1.1
                particles = list(updated_particles)
                llike = np.array(updated_llike)
            L = min(max(1, int(trajectory / step_size_cur)), 100)
        tempcurr = proposed_temp
        count += 1
    
    # Compute predictions on validation data.
    for params in particles:
        param_dict = unflatten_params(params, net)
        output = functional_call(net, param_dict, x_val)
        preds.append(output)
    
    accuracy = compute_accuracy(net, particles, x_val, y_val)
    print(f"\nValidation accuracy: {accuracy*100:.2f}%")
    hmc_time = time.time() - t_start
    print(f"HMC/SMC Total execution time: {hmc_time:.2f} seconds")
    
    particles_tensor = torch.stack(particles).detach().cpu()
    particles_flat = particles_tensor.numpy().reshape(N, d)
    preds_tensor = torch.stack(preds).detach().cpu()
    predictions = preds_tensor.numpy().reshape(N, x_val.size(0), num_classes)
    
    return hmc_time, predictions, particles_flat, ZZ, KK, count, Ls

############################################
#              Main Execution
############################################
if __name__ == '__main__':
    # Set multiprocessing start method.
    mp.set_start_method('spawn', force=True)
    
    # (For example purposes, we set node and numwork manually.)
    node = int(sys.argv[1])
    numwork = int(sys.argv[2])

    print(f'scale is {scale} and sigma is {sigma_b_smc} and sigma_v is {sigma_b_map}')
    
    ############################################
    # Load In-Domain Embeddings (Classes 0-9)
    ############################################
    allowed_labels = list(range(10))
    train_cache = "cifar10_train_embeddings.pt"
    test_cache  = "cifar10_test_embeddings.pt"
    X_train, y_train, X_test, y_test = create_resnet50_embedded_cifar10_dataset(
         train_cache_path=train_cache,
         test_cache_path=test_cache,
         allowed_labels=allowed_labels
    )
    print(f"Embedding dimension: {X_train.shape[1]}")

    # split and split into 40 000 train / 10 000 early‐stop validation
    n_total = X_train.size(0)        # should be 50 000 for CIFAR-10
    print(f"Total CIFAR-10 train embeddings: {n_total}")

    # 2) Split
    n_train_es = 50000
    n_val_es   = n_total - n_train_es    # = 10 000

    X_train_subset = X_train[:n_train_es]
    y_train_subset = y_train[:n_train_es]
    X_val_subset   = X_train[n_train_es:]
    y_val_subset   = y_train[n_train_es:]

    # Create TensorDatasets and DataLoaders
    train_dataset = TensorDataset(X_train_subset, y_train_subset)
    val_dataset   = TensorDataset(X_val_subset, y_val_subset)
    batch_size = 128
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    torch.manual_seed(node)
    np.random.seed(node)
    random.seed(node)
    pyro.set_rng_seed(node)
    
    ############################################
    # MAP Training on In-Domain Data (SimpleMLP)
    ############################################
    model_mlp = SimpleMLP(input_dim=X_train_subset.shape[1], hidden_dim=0, num_classes=10).to(device)
    optimizer_mlp = optim.Adam(model_mlp.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    train_losses = []
    val_losses = []
    moving_avg_window = 10
    best_moving_avg = float('inf')
    patience = 5
    no_improve_count = 0
    max_epochs = 200

    start_time = time.time()
    for epoch in range(max_epochs):
        model_mlp.train()
        running_loss = 0.0
        for inputs, labels in train_loader:                   # iterate mini-batches
            #x_batch, y_batch = x_batch.to(device), y_batch.to(device)

            optimizer_mlp.zero_grad()
            outputs = model_mlp(inputs)
            ce_loss = criterion(outputs, labels)

            # MAP regularizer on this batch
            reg_loss = (torch.sum(model_mlp.fc.weight**2) / (2 * sigma_w_map**2) +
                        torch.sum(model_mlp.fc.bias**2)   / (2 * sigma_b_map**2))
            loss = ce_loss + reg_loss / len(train_loader.dataset)             # divide reg by total train size
            loss.backward()
            optimizer_mlp.step()
            running_loss += loss.item()
        train_loss = running_loss / len(train_loader)

        # model_mlp.eval()
        # val_running_loss = 0.0
        # with torch.no_grad():
        #     for inputs, labels in val_loader:
        #         outputs_val = model_mlp(inputs)
        #         ce_loss_val = criterion(outputs_val, labels)
        #         reg_loss_val = (torch.sum(model_mlp.fc.weight**2) / (2 * sigma_w_map**2) +
        #                         torch.sum(model_mlp.fc.bias**2)   / (2 * sigma_b_map**2))
        #         val_running_loss = (ce_loss_val + reg_loss_val / len(train_loader.dataset)).item()
        # val_loss = val_running_loss / len(val_loader)
        # val_losses.append(val_loss)

        # print(f"MAP Epoch {epoch+1:04d}: Train Loss = {loss.item():.4f}, Val Loss = {val_loss:.4f}")
        # if epoch >= moving_avg_window - 1:
        #     moving_avg = np.mean(val_losses[-moving_avg_window:])
        #     if moving_avg < best_moving_avg:
        #         best_moving_avg = moving_avg
        #         no_improve_count = 0
        #     else:
        #         no_improve_count += 1
        #     if no_improve_count >= patience:
        #         print(f"MAP: Early stopping at epoch {epoch+1}")
        #         break
    total_map_time = time.time() - start_time
    print(f"\nMAP Total execution time: {total_map_time:.2f} seconds")
    
    # Extract MAP parameters to serve as prior for SMC.
    prior_params = model_mlp.state_dict()
    print("Prior parameters extracted from trained MAP model.")

    for name in prior_params:
        prior_params[name] = scale * prior_params[name]
    
    ############################################
    # Run SMC Inference with MAP Prior
    ############################################
    N = 10         # number of particles
    step_size = 0.07
    trajectory = 0.0
    L = min(max(1, int(trajectory / step_size)), 100)
    M = 4        # number of HMC samples per mutation
    
    start_time = time.time()
    t_elapsed, predictions, particles_flat, ZZ, KK, count, Ls = psmc_single(
        node=node,
        numwork=numwork,
        x_train=X_train,
        y_train=y_train,
        x_val=X_test,
        y_val=y_test,
        NetClass=SimpleMLP,
        N=N,
        step_size=step_size,
        L=L,
        M=M,
        trajectory=trajectory,
        prior_params=prior_params
    )
    total_smc_time = time.time() - start_time
    print(f"SMC Total execution time: {total_smc_time:.2f} seconds")
    
    # Compute SMC validation accuracy.
    smc_accuracy = compute_accuracy(
        (lambda: SimpleMLP(input_dim=X_train.shape[1], hidden_dim=0, num_classes=10))().to(device),
        [torch.tensor(x) for x in particles_flat],
        X_test,
        y_test
    )
    print(f"SMC Validation Accuracy: {smc_accuracy*100:.2f}%")
    print(f"SMC Lsum: {Ls}")
    
    ############################################
    # Save SMC Results and Analysis Metrics
    ############################################
    savemat(f'BayesianNN_CIFAR_MAP_psmc_SimpleMLP_N{N}_M{M}_node{node}.mat', {
        'psmc_single_time': t_elapsed,
        'psmc_single_pred': predictions,
        'psmc_single_x': particles_flat,
        'Z': ZZ,
        'K': KK,
        'count': count,
        'Lsum': Ls,
        'map_single_time': total_map_time,
        'smc_validation_accuracy': smc_accuracy
    })
