# variational_pacbayes.py
import math
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.stateless import functional_call  # requires torch >= 2.0
import matplotlib.pyplot as plt

# keep your model imports
from models import BaselineCNN
from dataset import RotatedMNISTDataset
from pacbayes_utils import sweep_sigma_q_and_compute_curve  # still usable after VI

# -------------------------
# Utilities: flatten/set params, KL
# -------------------------
def get_flat_params_from(model, device='cpu', dtype=torch.float32):
    params = []
    shapes = []
    names = []
    for name, p in model.named_parameters():
        shapes.append(p.shape)
        names.append(name)
        params.append(p.detach().cpu().view(-1).to(dtype))
    flat = torch.cat(params).clone().detach()
    return flat.to(device), shapes, names

def unflatten_flat_to_tensors(flat, shapes, names, device='cpu'):
    """
    Return an OrderedDict-like mapping {name: tensor_with_shape} from a flattened vector.
    Tensors are on `device`. No .requires_grad here; caller can create tensors requiring_grad.
    """
    mapping = {}
    pointer = 0
    flat = flat.to(device)
    for shape, name in zip(shapes, names):
        numel = int(torch.tensor(shape).prod())
        chunk = flat[pointer: pointer + numel]
        mapping[name] = chunk.view(*shape)
        pointer += numel
    assert pointer == flat.numel()
    return mapping

def kl_diag_gaussians_torch(mu_q, log_sigma_q, mu_p, sigma_p):
    """
    KL(Q||P) for diagonal Gaussians in torch; elementwise sums.
    mu_q: tensor (d,)
    log_sigma_q: tensor (d,)  (log std)
    mu_p: tensor (d,) or scalar  (prior mean)
    sigma_p: scalar or tensor (prior std)
    returns scalar KL (sum over dims)
    """
    sigma_q = torch.exp(log_sigma_q)
    var_q = sigma_q**2
    var_p = (sigma_p**2) if torch.is_tensor(sigma_p) == False else sigma_p**2
    # make shapes broadcastable
    mu_p = mu_p.to(mu_q.device) if not torch.is_tensor(mu_p) else mu_p
    sigma_p = torch.as_tensor(sigma_p, dtype=mu_q.dtype, device=mu_q.device)

    # elementwise KL terms: log(sigma_p/sigma_q) + (var_q + (mu_q-mu_p)^2) / (2 var_p) - 1/2
    term1 = torch.log(sigma_p) - log_sigma_q         # log(sigma_p/sigma_q)
    term2 = 0.5 * (var_q + (mu_q - mu_p)**2) / (sigma_p**2)
    kl_elements = term1 + term2 - 0.5
    return kl_elements.sum()

# negative log likelihood per example for classification (cross-entropy)
def batch_negative_log_likelihood(model, param_dict, xb, yb):
    """
    model: nn.Module
    param_dict: mapping name->tensor
    xb: inputs (batch)
    yb: labels (batch)
    returns mean negative log-likelihood (averaged over batch)
    """
    logits = functional_call(model, param_dict, (xb,))  # functional call expects tuple args
    loss = nn.CrossEntropyLoss(reduction='mean')
    return loss(logits, yb)

# -------------------------
# Variational Inference training
# -------------------------
def vi_train(model, train_loader, mu_p_flat, sigma_p, device='cpu',
             epochs=10, lr=1e-3, S=1, init_log_sigma=-6.0, verbose=True):
    """
    Train mu and log_sigma (diagonal Gaussian posterior) to maximize ELBO (via minimizing -ELBO).
    - model: nn.Module (the architecture)
    - train_loader: DataLoader for training set
    - mu_p_flat: prior mean flattened (torch tensor)
    - sigma_p: prior std (scalar)
    - S: number of MC samples per minibatch (we'll use S samples to estimate expectation)
    Returns: learned mu_flat (torch.Tensor), log_sigma (torch.Tensor)
    """
    device = device
    model = model.to(device)
    # get shapes and names for mapping back
    init_flat, shapes, names = get_flat_params_from(model, device=device)
    d = init_flat.numel()
    # initialize variational parameters: mu (start at current weights), log_sigma (init small)
    mu = nn.Parameter(init_flat.clone().to(device))
    log_sigma = nn.Parameter(torch.full((d,), fill_value=init_log_sigma, device=device, dtype=mu.dtype))

    optimizer = optim.Adam([mu, log_sigma], lr=lr)

    n = len(train_loader.dataset)
    if verbose:
        print(f"VI training: {d} params, dataset size n={n}, epochs={epochs}, S={S}, lr={lr}")

    for epoch in range(epochs):
        running_loss = 0.0
        for xb, yb in train_loader:
            xb = xb.to(device)
            yb = yb.to(device)
            batch_size = xb.shape[0]

            # Monte Carlo estimate of expected negative log-likelihood under q(θ)
            # We'll average S samples
            mc_nll = 0.0
            for s in range(S):
                eps = torch.randn_like(mu, device=device)
                theta_sample = mu + torch.exp(log_sigma) * eps  # sample theta (reparam trick)
                param_dict = unflatten_flat_to_tensors(theta_sample, shapes, names, device=device)
                # compute mean NLL on this batch
                nll = batch_negative_log_likelihood(model, param_dict, xb, yb)
                mc_nll += nll
            mc_nll = mc_nll / float(S)  # mean over samples

            # scale to approximate full dataset negative log-likelihood
            scaled_nll = (n / batch_size) * mc_nll

            # KL(q||p)
            mu_p = mu_p_flat.to(device)
            kl = kl_diag_gaussians_torch(mu, log_sigma, mu_p, sigma_p)

            # variational objective to minimize = scaled_nll + kl
            loss = scaled_nll + kl

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += float(loss.item()) * batch_size

        epoch_loss = running_loss / n
        if verbose:
            # report diagnostics: mean sigma and KL
            mean_sigma = float(torch.exp(log_sigma).mean().item())
            current_kl = float(kl.detach().cpu().item())
            print(f"Epoch {epoch+1}/{epochs}  loss={epoch_loss:.4f}  mean_sigma={mean_sigma:.4e}  KL={current_kl:.4e}")

    # return learned variational params (detached)
    return mu.detach().cpu(), torch.exp(log_sigma.detach()).cpu(), shapes, names


# -------------------------
# Example usage (main)
# -------------------------
if __name__ == "__main__":
    device = 'cpu'  # or 'cuda'
    model = BaselineCNN().to(device)
    model.load_state_dict(torch.load("rotated_mnist/baseline_cnn.pt", map_location=device))  # your initial weights

    # dataset
    train_ds = RotatedMNISTDataset("rotated_mnist/train.pt")
    train_loader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=4)

    # prior
    mu_p = torch.load("rotated_mnist/prior_mu_baseline.pt").flatten()  # make sure flattened
    sigma_p = 1.0

    # run VI
    mu_q_flat, sigma_q_vec, shapes, names = vi_train(
        model, train_loader,
        mu_p_flat=mu_p,
        sigma_p=sigma_p,
        device=device,
        epochs=25,
        lr=1e-3,
        S=5,               # you can increase S for lower-variance gradient estimates
        init_log_sigma=-6.0,
        verbose=True
    )

    # Now mu_q_flat is learned posterior mean (1D tensor), sigma_q_vec is per-param learned std (1D tensor)
    # You can now run sweep or plotting using the learned mu:
    # e.g. use isotropic posterior with scalar sigma_q = sigma_q_vec.mean().item()
    learned_sigma_scalar = float(sigma_q_vec.mean().item())
    print("Learned avg sigma:", learned_sigma_scalar)

    # if you want to use your sweep function (it expects mu_flat as numpy or torch), call it:
    # (you may need to adjust sweep_sigma_q_and_compute_curve to accept mu_flat and sigma_q vector)
    # Example call with scalar sigma:
    # from pacbayes_utils import sweep_sigma_q_and_compute_curve  # your existing function
    # mu_flat_for_sweep = mu_q_flat.numpy()
    # sigmas = np.logspace(-6, 0, 20)
    # results = sweep_sigma_q_and_compute_curve(model, mu_flat_for_sweep, sigma_p, train_loader,
    #                                          device=device, sigmas=sigmas, S=200, n=len(train_ds), delta=0.05)

    # or use plot_posterior_error_distribution by passing the mu and a chosen sigma_q
    # (you will need to adapt your plot function to accept either vector sigma_q or scalar)
    losses, kl, bound, complexity = plot_posterior_error_distribution(
        model, mu_q_flat, sigma_q_vec, mu_p, sigma_p, train_loader, device=device, S=S, delta=0.05,
        title="Posterior error distribution"
    )

    print(f"Mean Gibbs loss: {losses.mean():.4f} ± {losses.std(ddof=1) / math.sqrt(len(losses)):.4f}")
    print(f"KL(Q||P) = {kl:.6f}, complexity = {complexity:.6f}, McAllester bound = {bound:.6f}")
