# TODO delete additional defs and comments/renew comments
"""
Variational diagonal-Gaussian posterior + training routine for PAC-Bayes-style objective.

Usage sketch:
    1. Train a point-estimate model on the PRIOR split to get prior mean mu_P (optional).
    2. Train your model on TRAIN split to get an initial MAP estimate mu_map (optional initialization).
    3. Create a base model instance (same architecture), wrap it with VariationalWrapper:
         from models import BaselineCNN
         base_model = BaselineCNN()
         var = VariationalWrapper(base_model, init_mu=mu_init, init_rho=-6.0)
    4. Train with train_variational(...) using train_loader, providing prior_mu and sigma_p.
    5. After training, evaluate Gibbs 0-1 loss and compute true McAllester bound using utilities.

Notes:
 - The objective minimized is: hat_Rn_Q + (lambda / n) * KL(Q||P)
 - After training you should compute hat_Rn_Q (0-1 Gibbs loss) with many samples and compute McAllester bound.
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda import device
from torch.utils.data import DataLoader
import numpy as np

# ---------------------------
# Helpers: flatten/shape utilities
# ---------------------------
def shapes_and_sizes_from_model(model):
    shapes = []
    sizes = []
    for p in model.parameters():
        s = tuple(p.size())
        shapes.append(s)
        sizes.append(int(p.numel()))
    return shapes, sizes

def flatten_tensors(tensor_list):
    return torch.cat([t.detach().view(-1) for t in tensor_list])

def unflatten_vector_to_tensors(vec, shapes, device=None):
    """
    Given a flat vector and list of shapes, yield a list of tensors with those shapes,
    consuming entries from vec in order.
    """
    tensors = []
    pointer = 0
    for s in shapes:
        num = int(np.prod(s))
        chunk = vec[pointer:pointer + num]
        tensors.append(chunk.view(*s).to(device))
        pointer += num
    assert pointer == vec.numel()
    return tensors

# ---------------------------
# Variational wrapper
# ---------------------------
class VariationalWrapper(nn.Module):
    """
    Wrap a 'base_model' architecture and represent each parameter as a Gaussian
    with learnable mean mu and learnable rho (we map rho -> sigma via softplus).
    You can optionally initialize mu from a provided flat vector (init_mu).
    """
    def __init__(self, base_model, init_mu=None, init_rho=-6.0, device='cpu'):
        """
        base_model: an instance of the architecture (e.g., BaselineCNN())
        init_mu: 1D torch tensor of same length as total params (optional). If None, use base_model init.
        init_rho: initial value for rho (scalar or tensor broadcastable).
        """
        super().__init__()
        self.base_model = base_model  # used only for architecture and forward code via temporary parameter setting
        self.device = device

        # record shapes/sizes
        self.shapes, self.sizes = shapes_and_sizes_from_model(base_model)
        self.D = int(sum(self.sizes))

        # Initialize mu: from init_mu (1D) or from base model's initial parameters
        if init_mu is None:
            mu_init = flatten_tensors([p for p in base_model.parameters()]).float()
        else:
            mu_init = init_mu.detach().clone().float()
            assert mu_init.numel() == self.D, "init_mu size mismatch"

        # register mu and rho as parameters (flattened)
        # we keep them as 1D Parameters and unflatten when loading weights into base_model for forward
        self.mu = nn.Parameter(mu_init.to(device))             # mean vector
        rho_init = torch.ones_like(self.mu) * float(init_rho)
        self.rho = nn.Parameter(rho_init.to(device))           # rho -> sigma via softplus

    def sigma(self):
        # use softplus to ensure positive std; add small eps for numerical stability
        return F.softplus(self.rho) + 1e-8

    def sample_flat(self, S=1):
        """
        Sample S flat weight vectors from the diagonal Gaussian posterior Q.
        Returns tensor of shape [S, D].
        """
        sigma = self.sigma()
        # shape [S, D]
        eps = torch.randn((S, self.D), device=self.mu.device, dtype=self.mu.dtype)
        samples = self.mu.unsqueeze(0) + eps * sigma.unsqueeze(0)
        return samples  # [S, D]

    def load_flat_to_base(self, flat_vec):
        """
        Given a 1D tensor flat_vec of length D, write its contents into base_model.parameters()
        in the same order as shapes recorded.
        """
        tensors = unflatten_vector_to_tensors(flat_vec, self.shapes, device=self.device)
        # copy into base_model parameters
        pointer = 0
        for p, t in zip(self.base_model.parameters(), tensors):
            with torch.no_grad():
                p.data.copy_(t)

    def forward_with_sample(self, x, flat_vec):
        """
        Helper: loads flat_vec into base_model and does forward(x). Returns base_model(x).
        """
        self.load_flat_to_base(flat_vec)
        return self.base_model(x)

    def kl_diag_with_prior(self, prior_mu=None, sigma_p=1.0):
        """
        Compute KL(Q||P) between diagonal Gaussians where
        Q = N(mu, diag(sigma_q^2)), P = N(prior_mu, sigma_p^2 I).
        prior_mu: 1D tensor of length D (if None, assumed zero)
        sigma_p: scalar std
        Returns scalar (float tensor).
        """
        mu_q = self.mu
        sigma_q = self.sigma()
        if prior_mu is None:
            mu_p = torch.zeros_like(mu_q, device=mu_q.device)
        else:
            mu_p = prior_mu.to(mu_q.device).float()
            assert mu_p.numel() == mu_q.numel()
        var_q = sigma_q ** 2
        var_p = float(sigma_p ** 2)
        # per-dim KL: log(sigma_p/sigma_q) + (var_q + (mu_q-mu_p)^2)/(2 var_p) - 1/2
        term1 = torch.log((sigma_p) / sigma_q)
        term2 = (var_q + (mu_q - mu_p) ** 2) / (2.0 * var_p)
        per = term1 + term2 - 0.5
        return per.sum()

# ---------------------------
# Training routine
# ---------------------------
def train_variational(
    variational_model: VariationalWrapper,
    train_loader: DataLoader,
    epochs=20,
    lr=1e-3,
    sigma_p=1.0,
    prior_mu=None,
    lambda_coef=1.0,
    S=1,
    device='cpu',
    log_every=100,
    normalize_ce_by_logC=True
):
    """
    Train variational posterior by minimizing:
        L = E_{w ~ Q}[CE(w)]_empirical  +  (lambda_coef / n) * KL(Q||P)
    where CE(w) is cross-entropy (surrogate). We use Monte-Carlo sampling with S samples per minibatch.

    Arguments:
      variational_model: VariationalWrapper (holds base_model and variational params)
      train_loader: DataLoader for training set (must be training split)
      prior_mu: 1D tensor prior mean (or None -> zero)
      lambda_coef: weight on KL term (controls regularization strength)
      S: number of posterior samples per minibatch (often 1 is OK)
    """
    device = device
    vm = variational_model
    vm.to(device)
    optimizer = torch.optim.Adam([vm.mu, vm.rho], lr=lr)

    n_train = len(train_loader.dataset)
    logC = math.log(train_loader.dataset.labels.max().item() + 1) if normalize_ce_by_logC else 1.0
    # careful: above assumes dataset.labels exists; if not, computeC differently:
    try:
        C = int(torch.max(train_loader.dataset.labels).item() + 1)
        logC = math.log(C)
    except Exception:
        # fallback: use 10 for MNIST
        logC = math.log(10)

    for epoch in range(1, epochs + 1):
        vm.train()
        total_loss = 0.0
        total_batches = 0
        for batch_idx, (imgs, labels, angles) in enumerate(train_loader):
            imgs = imgs.to(device)
            labels = labels.to(device)

            # Monte Carlo estimate of expected CE: sample S weight vectors and average CE
            sampled_ws = vm.sample_flat(S=S)  # [S, D]
            batch_ce = 0.0
            for s in range(S):
                flat_w = sampled_ws[s]
                # load into base model and forward
                preds = vm.forward_with_sample(imgs, flat_w)
                ce = F.cross_entropy(preds, labels, reduction='mean')  # standard CE
                if normalize_ce_by_logC:
                    ce = ce / logC
                batch_ce = batch_ce + ce
            batch_ce = batch_ce / float(S)

            # KL term (full KL scaled by lambda_coef / n)
            kl = vm.kl_diag_with_prior(prior_mu=prior_mu, sigma_p=sigma_p)
            kl_term = (lambda_coef / float(n_train)) * kl

            loss = batch_ce + kl_term

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += float(loss.item())
            total_batches += 1

            if (batch_idx + 1) % log_every == 0:
                print(f"Epoch {epoch} Batch {batch_idx+1}: loss={total_loss/total_batches:.6f} batch_ce={batch_ce.item():.6f} kl={kl.item():.3e}")

        avg_loss = total_loss / max(1, total_batches)
        # logging per epoch
        kl_val = vm.kl_diag_with_prior(prior_mu=prior_mu, sigma_p=sigma_p).item()
        print(f"[Epoch {epoch}] avg_loss={avg_loss:.6f} KL={kl_val:.3e}")
        # if epoch % 5 == 0:
        #    debug_variational_stats(vm, prior_mu, sigma_p)

    return vm

# ---------------------------
# Evaluation helpers: estimate Gibbs 0-1 loss and compute McAllester bound
# ---------------------------
def estimate_gibbs_error_01(variational_model: VariationalWrapper, data_loader: DataLoader, sigma_p=None,
                            prior_mu=None, S=200, device='cpu'):
    """
    Estimate empirical Gibbs 0-1 loss: average 0-1 loss over data and posterior samples.
    Returns hat_Rn_Q (float) and array of per-sample losses (length S if S posterior samples used).
    """
    vm = variational_model
    vm.to(device)
    mu_flat = vm.mu.detach().to(device)

    losses = []
    for s in range(S):
        eps = torch.randn_like(mu_flat, device=device) * 1.0
        sigma_q = vm.sigma().detach()
        w_sample = mu_flat + eps * sigma_q
        # evaluate w_sample on whole dataset
        incorrect = 0
        total = 0
        for imgs, labels, angles in data_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            preds = vm.forward_with_sample(imgs, w_sample)
            preds = preds.argmax(dim=1)
            incorrect += (preds != labels).sum().item()
            total += labels.size(0)
        losses.append(incorrect / total)
    losses = np.array(losses)
    return float(losses.mean()), losses

def kl_qp_total(variational_model: VariationalWrapper, prior_mu=None, sigma_p=1.0):
    return float(variational_model.kl_diag_with_prior(prior_mu=prior_mu, sigma_p=sigma_p).detach().cpu().numpy())

def mcallester_bound_from_estimates(hat_Rn_Q, kl, n, delta=0.05):
    term = (kl + math.log(n / delta)) / (2.0 * max(1, (n - 1)))
    complexity = math.sqrt(max(0.0, term))
    return hat_Rn_Q + complexity, complexity



def debug_variational_stats(vm: VariationalWrapper, prior_mu=None, sigma_p=1.0):
    mu = vm.mu.detach()
    sigma = vm.sigma().detach()
    D = mu.numel()
    mu_norm2 = float((mu**2).sum().cpu().item())
    mu_mean = float(mu.mean().cpu().item())
    sigma_mean = float(sigma.mean().cpu().item())
    sigma_min = float(sigma.min().cpu().item())
    sigma_max = float(sigma.max().cpu().item())

    # compute KL components per-dim
    if prior_mu is None:
        mu_p = torch.zeros_like(mu)
    else:
        mu_p = prior_mu.to(mu.device).float()
    var_q = sigma**2
    var_p = float(sigma_p**2)
    term1 = torch.log(sigma_p / sigma)
    term2 = (var_q + (mu - mu_p)**2) / (2.0 * var_p) - 0.5
    per_dim = (term1 + term2).detach().cpu().numpy()
    top_idx = np.argsort(-per_dim)[:10]

    print("D =", D,
          "mu_norm2=", mu_norm2,
          "mu_mean=", mu_mean,
          "sigma_mean=", sigma_mean,
          "sigma_min=", sigma_min,
          "sigma_max=", sigma_max,
          "KL_total=", float(per_dim.sum()))
    print("Top 10 per-dim KL contributors (values, idx):")
    for idx in top_idx:
        print(idx, per_dim[idx])




# ---------------------------
# Example usage outline (to run in your script / notebook):
# ---------------------------
if __name__ == "__main__":
    # Example (not executed by default):
    from models import BaselineCNN
    from generate_rotated_mnist import RotatedMNISTDataset
    from torch.utils.data import DataLoader

    # device = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = 'cpu'

    # 1) Prepare data
    train_ds = RotatedMNISTDataset("rotated_mnist/train.pt")
    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4)

    # 2) Optionally load prior mean (trained on prior split)
    prior_mu = torch.load("rotated_mnist/prior_mu_baseline.pt")  # 1D tensor of length D
    sigma_p = 1.0

    # 3) Build base model and variational wrapper
    base_model = BaselineCNN()
    # Optionally initialize mu to point estimate (e.g., MAP on train)
    # mu_init = get_flat_params_from_map_model(...)
    var = VariationalWrapper(base_model, init_mu=None, init_rho=-1.0, device=device)
    # debug_variational_stats(var, prior_mu, sigma_p)

    # 4) Train variational posterior
    var = train_variational(var, train_loader, epochs=40, lr=1e-4, sigma_p=sigma_p, prior_mu=prior_mu,
                            lambda_coef=1e-4, S=1, device=device, log_every=200)

    # 5) Evaluate Gibbs 0-1 loss and McAllester bound
    hat_Rn_Q, losses = estimate_gibbs_error_01(var, train_loader, prior_mu=prior_mu, S=200, device=device)
    kl_val = kl_qp_total(var, prior_mu=prior_mu, sigma_p=sigma_p)
    bound, complexity = mcallester_bound_from_estimates(hat_Rn_Q, kl_val, n=len(train_ds), delta=0.05)
    print("hat_Rn_Q:", hat_Rn_Q, "KL:", kl_val, "bound:", bound)
