#!/usr/bin/env python3
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset
import torch.optim as optim
from scipy.io import savemat
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import HMC, MCMC
from torch.func import functional_call
import sys

# For F1 and AUC-PR metrics.
from sklearn.metrics import f1_score, average_precision_score
from sklearn.preprocessing import label_binarize

############################################
# Set device and global prior standard deviations (N(0,v) initialization)
############################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sigma_b    = np.sqrt(0.1)
sigma_conv = np.sqrt(0.1)
sigma_fc   = np.sqrt(0.1)

############################################
# Tuning variances (s) for the N(MAP, s) prior used in HMC
############################################
sigma_b_s    = np.sqrt(0.01)
sigma_conv_s = np.sqrt(0.01)
sigma_fc_s   = np.sqrt(0.01)

s_variances = {
    'sigma_b':    sigma_b_s,
    'sigma_conv': sigma_conv_s,
    'sigma_fc':   sigma_fc_s
}

############################################
# Filtered MNIST Dataset (only digits 0-7)
############################################
class FilteredDataset(Dataset):
    def __init__(self, dataset, allowed_labels):
        self.data = [(img, label) for img, label in dataset if label in allowed_labels]
    def __getitem__(self, idx):
        return self.data[idx]
    def __len__(self):
        return len(self.data)

############################################
# Simple CNN Architecture (Filtered MNIST: 8 classes)
############################################
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv = nn.Conv2d(1, 4, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc   = nn.Linear(4 * 14 * 14, 8)

        nn.init.normal_(self.conv.weight, mean=0, std=sigma_conv)
        if self.conv.bias is not None:
            nn.init.normal_(self.conv.bias, mean=0, std=sigma_b)
        nn.init.normal_(self.fc.weight, mean=0, std=sigma_fc)
        if self.fc.bias is not None:
            nn.init.normal_(self.fc.bias, mean=0, std=sigma_b)

    def forward(self, x):
        x = F.relu(self.conv(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

############################################
# Helper Functions (unchanged)
############################################
def unflatten_params(flat, net):
    param_dict = {}
    pointer = 0
    for name, param in net.named_parameters():
        numel = param.numel()
        param_dict[name] = flat[pointer:pointer+numel].view(param.shape)
        pointer += numel
    return param_dict

def init_particle_from_prior(prior_params, s_variances, device):
    flat_list = []
    for name, param in prior_params.items():
        if name in ["conv.bias", "fc.bias"]:
            sigma = s_variances['sigma_b']
        elif name == "conv.weight":
            sigma = s_variances['sigma_conv']
        elif name == "fc.weight":
            sigma = s_variances['sigma_fc']
        else:
            sigma = 1.0
        mu   = param.to(device)
        draw = torch.randn_like(mu) * sigma + mu
        flat_list.append(draw.view(-1))
    return torch.cat(flat_list)

def model_loss_func_ll(output, y, temp):
    crit = nn.CrossEntropyLoss(reduction='sum')
    return crit(output, y.long().view(-1)) * temp

def model_loss_func(params, x, y, temp, net, prior_params, s_variances):
    keys = ["conv.weight", "conv.bias", "fc.weight", "fc.bias"]
    log_prior = 0.0
    for k in keys:
        param_tensor = params[k]
        prior_tensor = prior_params[k]
        if k in ["conv.bias", "fc.bias"]:
            sigma = s_variances['sigma_b']
        elif k == "conv.weight":
            sigma = s_variances['sigma_conv']
        elif k == "fc.weight":
            sigma = s_variances['sigma_fc']
        log_prior += ((param_tensor - prior_tensor)**2).sum() / (2 * sigma**2)
    logits = functional_call(net, params, x)
    log_likelihood = model_loss_func_ll(logits, y, temp)
    return log_prior + log_likelihood

def compute_accuracy(net, particles, x_val, y_val):
    with torch.no_grad():
        probs_list = []
        for params in particles:
            param_dict = unflatten_params(params, net)
            logits = functional_call(net, param_dict, x_val)
            probs  = F.softmax(logits, dim=1)
            probs_list.append(probs)
        ensemble_probs = torch.stack(probs_list).mean(dim=0)
        preds = ensemble_probs.argmax(dim=1)
        correct = (preds == y_val.view(-1)).sum().item()
        return correct / y_val.size(0)

def softmax_np(x):
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / np.sum(e_x, axis=-1, keepdims=True)

def hmc_update_particle(net, x_train, y_train, params_init,
                        num_samples, warmup_steps, step_size,
                        num_steps, temp, prior_params, s_variances):
    def potential_fn(params_dict):
        params = unflatten_params(params_dict["params"], net)
        return model_loss_func(params, x_train, y_train, temp, net,
                               prior_params, s_variances)

    hmc_kernel = HMC(
        potential_fn=potential_fn,
        step_size=step_size,
        num_steps=num_steps,
        adapt_step_size=False,
        adapt_mass_matrix=False,
        target_accept_prob=0.65
    )
    acceptance_probs = []

    def capture_diagnostics(kernel, params, stage, i):
        diag = kernel.logging()
        if "acc. prob" in diag:
            acceptance_probs.append(diag["acc. prob"])

    mcmc_run = MCMC(
        hmc_kernel,
        num_samples=num_samples,
        warmup_steps=warmup_steps,
        initial_params={"params": params_init},
        disable_progbar=True,
        hook_fn=capture_diagnostics,
    )
    mcmc_run.run()
    samples = mcmc_run.get_samples()["params"]
    acc_rate = acceptance_probs[-1]
    return samples[-1], float(acc_rate)

############################################
# Main Execution: MAP Training then HMC Sampling, Metrics, and Saving
############################################
if __name__ == '__main__':
    # Set fixed seeds for reproducibility.
    seed = int(sys.argv[1])

    ############################################
    # Data Loading and Filtering (Filtered MNIST: digits 0–7)
    ############################################
    N_total_train = 1200
    N_tr          = 1000
    N_val         = 200

    transform = transforms.Compose([transforms.ToTensor()])
    full_train_dataset = torchvision.datasets.MNIST(
        root='./data', train=True, download=True, transform=transform)
    full_test_dataset  = torchvision.datasets.MNIST(
        root='./data', train=False, download=True, transform=transform)

    allowed_labels      = list(range(8))
    filtered_train_pool = FilteredDataset(full_train_dataset, allowed_labels)
    filtered_test_pool  = FilteredDataset(full_test_dataset,  allowed_labels)

    # MAP splits on first 2000 of training pool
    filtered_train_total = Subset(filtered_train_pool, list(range(N_total_train)))
    map_train_dataset    = Subset(filtered_train_total, list(range(N_tr)))
    map_val_dataset      = Subset(
                              filtered_train_total,
                              list(range(N_tr, N_tr + N_val))
                          )

    train_loader = DataLoader(map_train_dataset, batch_size=64, shuffle=True)
    val_loader   = DataLoader(map_val_dataset,   batch_size=64, shuffle=False)

    # Full loader over all 2000 for HMC
    train_loader_full = DataLoader(
        filtered_train_total,
        batch_size=len(filtered_train_total),
        shuffle=False
    )
    x_train, y_train = next(iter(train_loader_full))
    x_train, y_train = x_train.to(device), y_train.to(device)

    ############################################
    # Prepare ID-Test Set (first 1000 of filtered test pool)
    ############################################
    test_id_dataset  = Subset(filtered_test_pool, list(range(1000)))
    test_loader_full = DataLoader(
        test_id_dataset,
        batch_size=len(test_id_dataset),
        shuffle=False
    )
    x_test, y_test = next(iter(test_loader_full))
    x_test, y_test = x_test.to(device), y_test.to(device)

    # Seeds
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    pyro.set_rng_seed(seed)

    ############################################
    # MAP Training (early stopping on map_val_dataset)
    ############################################
    model_cnn = SimpleCNN().to(device)
    optimizer_cnn = optim.Adam(model_cnn.parameters(), lr=0.001)
    criterion_cnn = nn.CrossEntropyLoss()

    model_cnn.train()
    d = sum(p.numel() for p in model_cnn.parameters())
    print(f"The dimension of the parameters in SimpleCNN is {d}")

    train_losses = []
    val_losses = []
    moving_avg_window = 10
    best_moving_avg = float('inf')
    patience = 5
    no_improve_count = 0

    start_time = time.time()

    for epoch in range(1000):
        running_loss = 0.0
        model_cnn.train()
        for imgs, labs in train_loader:
            optimizer_cnn.zero_grad()
            outputs = model_cnn(imgs.to(device))
            ce_loss = criterion_cnn(outputs, labs.to(device))
            reg_c = torch.sum(model_cnn.conv.weight**2)/(2*sigma_conv**2) + \
                    torch.sum(model_cnn.conv.bias**2)/(2*sigma_b**2)
            reg_f = torch.sum(model_cnn.fc.weight**2)/(2*sigma_fc**2) + \
                    torch.sum(model_cnn.fc.bias**2)/(2*sigma_b**2)
            loss = ce_loss + (reg_c + reg_f)/len(map_train_dataset)
            loss.backward()
            optimizer_cnn.step()
            running_loss += loss.item()
        train_loss = running_loss / len(train_loader)
        train_losses.append(train_loss)

        # validation
        val_running = 0.0
        model_cnn.eval()
        with torch.no_grad():
            for imgs, labs in val_loader:
                outputs = model_cnn(imgs.to(device))
                ce_loss = criterion_cnn(outputs, labs.to(device))
                reg_c = torch.sum(model_cnn.conv.weight**2)/(2*sigma_conv**2) + \
                        torch.sum(model_cnn.conv.bias**2)/(2*sigma_b**2)
                reg_f = torch.sum(model_cnn.fc.weight**2)/(2*sigma_fc**2) + \
                        torch.sum(model_cnn.fc.bias**2)/(2*sigma_b**2)
                loss = ce_loss + (reg_c + reg_f)/len(map_train_dataset)
                val_running += loss.item()
        val_loss = val_running / len(val_loader)
        val_losses.append(val_loss)

        print(f"MAP Epoch {epoch+1}: Train Loss = {train_loss:.8f}, Val Loss = {val_loss:.8f}")

        if epoch >= moving_avg_window - 1:
            moving_avg = sum(val_losses[-moving_avg_window:]) / moving_avg_window
            if moving_avg < best_moving_avg:
                best_moving_avg = moving_avg
                no_improve_count = 0
            else:
                no_improve_count += 1
            if no_improve_count >= patience:
                print("MAP: Early stopping due to no improvement.")
                break

    total_map_time = time.time() - start_time
    print(f"MAP Total execution time: {total_map_time:.2f} seconds")

    # Extract the MAP model parameters to serve as the prior.
    prior_params = model_cnn.state_dict()
    print("Prior parameters extracted from the trained MAP model.")

    ############################################
    # HMC Sampling Using MAP prior
    ############################################
    step_size = 0.012
    trajectory = 0.0
    L = min(max(1, int(trajectory/step_size)), 100)
    N_samples    = 1
    warmup_steps = 0
    burnin_all   = 160
    thin_all     = burnin_all
    basic        = 10
    thin         = int(thin_all/basic)
    burnin       = thin

    net = SimpleCNN().to(device)
    net.eval()
    params = init_particle_from_prior(prior_params, s_variances, device)

    particles = []
    acc_rates = []
    Lsum = 0

    start_hmc = time.time()
    for i in range(burnin + (N_samples-1)*thin):
        params, acc = hmc_update_particle(
            net, x_train, y_train, params,
            num_samples=basic, warmup_steps=warmup_steps,
            step_size=step_size, num_steps=L, temp=1,
            prior_params=prior_params, s_variances=s_variances
        )
        acc_rates.append(acc)
        avg_acc = np.mean(acc_rates)
        print(f"Iteration {i+1}: Avg. Acc. = {avg_acc:.4f}, step_size = {step_size:.5f}, L = {L}")
        if avg_acc < 0.6:
            step_size *= 0.7
        elif avg_acc > 0.8:
            step_size *= 1.1

        if i >= burnin-1 and ((i - burnin + 1) % thin == 0):
            particles.append(params)
        Lsum += L * basic
        L = min(max(1, int(trajectory/step_size)), 100)
    hmc_single_time = time.time() - start_hmc

    # Convert particles to numpy
    d = sum(p.numel() for p in net.parameters())
    particles_tensor = torch.stack(particles).detach().cpu()
    hmc_single_x = particles_tensor.numpy().reshape(len(particles), d)

    # Compute hmc_single_pred on ID-test
    preds = []
    for p in particles:
        pdict = unflatten_params(p, net)
        out = functional_call(net, pdict, x_test)
        preds.append(out)
    preds_tensor = torch.stack(preds).detach().cpu().numpy()
    hmc_single_pred = preds_tensor

    # Compute ID-test accuracy
    accuracy = compute_accuracy(net, particles, x_test, y_test)
    print("ID-Test accuracy (ensemble):", accuracy)

    # Save HMC results with original variable names
    savemat(
        f'BayesianNN_MNIST_hmc_results_d{d}_train{N_tr}_val{N_val}_thin{thin_all}_burnin{burnin_all}_node{seed}.mat',
        {
            'hmc_single_time': hmc_single_time,
            'hmc_single_pred': hmc_single_pred,
            'hmc_single_x':    hmc_single_x,
            'Lsum':            Lsum
        }
    )
    print("HMC results saved.")

    ############################################
    # Compute Metrics (ID-Test and OOD)
    ############################################
    # In-Domain (ID-test)
    ensemble_probs = softmax_np(hmc_single_pred).mean(axis=0)  # (1000,8)
    y_test_np = y_test.cpu().numpy().flatten()
    nlls = -np.log(ensemble_probs[np.arange(len(y_test_np)), y_test_np] + 1e-12)
    avg_nll_in = np.sum(nlls)
    total_ent = -np.sum(ensemble_probs * np.log(ensemble_probs + 1e-12), axis=1)
    avg_total_entropy_in = np.mean(total_ent)
    part_probs = softmax_np(hmc_single_pred)
    part_ent = -np.sum(part_probs * np.log(part_probs + 1e-12), axis=2)
    avg_part_ent = np.mean(part_ent, axis=0)
    avg_epistemic_in = np.mean(total_ent - avg_part_ent)
    f1_in = f1_score(y_test_np, ensemble_probs.argmax(axis=1), average='macro', zero_division=0)
    y_bin = label_binarize(y_test_np, classes=np.arange(8))
    aucpr_in = average_precision_score(y_bin, ensemble_probs, average='macro')

    # Out-Of-Domain (digits 8 & 9)
    od_idxs = []
    c8 = c9 = 0
    for idx, (img, lbl) in enumerate(full_test_dataset):
        if lbl == 8 and c8 < 100:
            od_idxs.append(idx); c8 += 1
        elif lbl == 9 and c9 < 100:
            od_idxs.append(idx); c9 += 1
        if c8 == 100 and c9 == 100:
            break
    od_dataset = Subset(full_test_dataset, od_idxs)
    od_loader  = DataLoader(od_dataset, batch_size=len(od_dataset), shuffle=False)
    od_imgs, od_lbls = next(iter(od_loader))
    od_imgs = od_imgs.to(device)
    preds_od = []
    for p in particles:
        pdict = unflatten_params(p, net)
        preds_od.append(functional_call(net, pdict, od_imgs))
    preds_od = torch.stack(preds_od).detach().cpu().numpy()
    ens_od = softmax_np(preds_od).mean(axis=0)
    total_ent_od = -np.sum(ens_od * np.log(ens_od + 1e-12), axis=1)
    avg_total_entropy_od = np.mean(total_ent_od)
    part_od = softmax_np(preds_od)
    part_ent_od = -np.sum(part_od * np.log(part_od + 1e-12), axis=2)
    avg_part_ent_od = np.mean(part_ent_od, axis=0)
    avg_epistemic_od = np.mean(total_ent_od - avg_part_ent_od)
    f1_od = f1_score(od_lbls.numpy(), ens_od.argmax(axis=1), average='macro', zero_division=0)
    aucpr_od = 0.0

    print("\nID-Test metrics:")
    print(f"Average NLL: {avg_nll_in:.4f}")
    print(f"Average Total Entropy: {avg_total_entropy_in:.4f}")
    print(f"Average Epistemic: {avg_epistemic_in:.8f}")
    print(f"F1 Score (macro): {f1_in:.4f}")
    print(f"AUC-PR (macro): {aucpr_in:.4f}")

    print("\nOOD metrics:")
    print(f"Average Total Entropy: {avg_total_entropy_od:.4f}")
    print(f"Average Epistemic: {avg_epistemic_od:.8f}")
    print(f"F1 Score (macro): {f1_od:.4f}")
    print(f"AUC-PR (macro): {aucpr_od:.4f}")

    # Save metrics with original variable names
    metrics = {
        'avg_nll_in':           avg_nll_in,
        'avg_total_entropy_in': avg_total_entropy_in,
        'avg_epistemic_in':     avg_epistemic_in,
        'f1_in':                f1_in,
        'aucpr_in':             aucpr_in,
        'avg_total_entropy_od': avg_total_entropy_od,
        'avg_epistemic_od':     avg_epistemic_od,
        'f1_od':                f1_od,
        'aucpr_od':             aucpr_od,
        'HMC_Validation_Accuracy': accuracy,
        'total_hmc_time':       hmc_single_time,
        'Lsum':                 Lsum,
        'num_particles':        len(particles)
    }
    savemat(
        f'BayesianNN_MNIST_hmc_metrics_d{d}_train{N_tr}_val{N_val}_thin{thin_all}_burnin{burnin_all}_node{seed}.mat',
        metrics
    )
    print("Metrics saved.")
