import math
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import torch.nn.functional as F
from models import BaselineCNN
from models import EquivariantCNN
from dataset import RotatedMNISTDataset
from pacbayes_utils import sweep_sigma_q_and_compute_curve

# -------------------------
# Utilities: flatten/set params, KL, predict
# -------------------------
def get_flat_params_from(model):
    params = []
    for p in model.parameters():
        params.append(p.detach().view(-1))
    return torch.cat(params).clone().detach()

def set_flat_params_to(model, flat_params):
    pointer = 0
    for p in model.parameters():
        numel = p.numel()
        new_vals = flat_params[pointer: pointer + numel].view_as(p).to(p.device)
        with torch.no_grad():
            p.copy_(new_vals)
        pointer += numel
    assert pointer == flat_params.numel()

def kl_diag_gaussians(mu_q, sigma_q, mu_p, sigma_p):
    """
    KL(Q || P) for diagonal Gaussians:
      Q ~ N(mu_q, diag(sigma_q^2))
      P ~ N(mu_p, diag(sigma_p^2))
    """
    var_q = sigma_q**2
    var_p = sigma_p**2

    term1 = np.log(var_p / var_q)
    term2 = var_q / var_p
    term3 = (mu_p - mu_q)**2 / var_p

    kl = 0.5 * (term1 + term2 + term3 - 1).sum(dim=-1)
    return kl

def predict_with_weights(model, flat_w, inputs, device='cpu'):
    set_flat_params_to(model, flat_w)
    model.eval()
    with torch.no_grad():
        outputs = model(inputs.to(device))
        preds = outputs.argmax(dim=1)
    return preds.cpu()

# Estimate per-sample Gibbs loss (returns array of S losses)
def estimate_gibbs_losses_array(model, mu_flat, sigma_q, data_loader, device='cpu', S=200):
    model = model.to(device)
    mu_flat = mu_flat.to(device)
    losses = []
    for s in range(S):
        if s % 50 == 0:
            print(s)
        eps = torch.randn_like(mu_flat) * sigma_q
        w_sample = mu_flat + eps
        total = 0
        incorrect = 0
        for imgs, labels in data_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            preds = predict_with_weights(model, w_sample, imgs, device=device).to(device)
            incorrect += (preds != labels).sum().item()
            total += labels.size(0)
        losses.append(incorrect / total)
    return np.array(losses)

def mcallester_complexity(kl, n, delta=0.05):
    term = (kl + math.log(n / delta)) / (2.0 * n - 1)
    return math.sqrt(max(0.0, term))

# -------------------------
# Plotting function (main)
# -------------------------
def plot_posterior_error_distribution(model, mu_flat, sigma_q, mu_p, sigma_p, train_loader, test_loader,
                                      device='cpu', S=200, delta=0.05,
                                      title=None, show=True, bins=30):
    """
    Computes posterior-sampled errors (S samples) and plots histogram:
      - histogram of empirical Gibbs 0-1 loss across posterior samples
      - dashed vertical: mean Gibbs loss
      - dash-dotted vertical: McAllester bound (mean + complexity)
    Returns (losses_array, kl, bound, complexity)
    """
    # compute KL
    kl = kl_diag_gaussians(mu_flat, sigma_q, mu_p, sigma_p)
    # compute array of losses (one per posterior sample, for empirical risk)
    losses = estimate_gibbs_losses_array(model, mu_flat, sigma_q, train_loader, device=device, S=S)
    mean_loss = float(losses.mean())
    # stderr = float(losses.std(ddof=1) / math.sqrt(len(losses)))
    n = len(train_loader.dataset)
    complexity = mcallester_complexity(kl, n, delta=delta)
    bound = mean_loss + complexity

    # compute array of losses to approximate the true risk
    losses_test = estimate_gibbs_losses_array(model, mu_flat, sigma_q, test_loader, device=device, S=S)
    mean_loss_test = float(losses.mean())

    # plot
    plt.figure(figsize=(8,5))
    plt.hist(losses_test, bins=bins, alpha=0.7)
    plt.axvline(mean_loss_test, linestyle='--', linewidth=2, label=f"Mean test loss = {mean_loss_test:.4f}")
    plt.axvline(bound, linestyle='-.', linewidth=2, label=f"McAllester bound = {bound:.4f}")
    plt.xlabel("Empirical Gibbs 0-1 loss")
    plt.ylabel("Count (posterior samples)")
    plt.title(title or "Posterior error distribution and McAllester PAC-Bayes bound")
    plt.legend(loc='upper right')

    # annotation: KL, complexity and hyperparams
    txt = f"KL={kl:.2e}\ncomplexity={complexity:.4f}\nn={n}, delta={delta}\nσ_q={sigma_q}, σ_p={sigma_p}\nS={S}"
    plt.gcf().text(0.02, 0.02, txt, fontsize=9, va='center')

    if show:
        # plt.show()
        plt.savefig("histogram.png")


    return losses, kl, bound, complexity


def plot_posterior_error_overlay(
    model1, mu_flat1, sigma_q1, mu_p1,
    model2, mu_flat2, sigma_q2, mu_p2,
    sigma_p,
    train_loader, test_loader,
    device='cpu', S=200, delta=0.05,
    title=None, show=True, bins=30,
    figsize=(8,5), save_path="histogram_overlay.png",
    normalize=True
):
    """
    Plot posterior-sampled Gibbs 0-1 losses for two models in the same axes.

    Parameters:
      - normalize: if True, plot density (area under hist = 1) instead of counts.

    Each model will have its histogram, mean (dashed) and McAllester bound (dash-dotted).
    Returns (losses_test1, losses_test2, kl1, kl2, bound1, bound2, complexity1, complexity2)
    """

    colors = ("C0", "C1")

    # --- compute quantities for model 1 ---
    kl1 = kl_diag_gaussians(mu_flat1, sigma_q1, mu_p1, sigma_p)
    losses_train1 = estimate_gibbs_losses_array(model1, mu_flat1, sigma_q1, train_loader, device=device, S=S)
    mean_loss_train1 = float(losses_train1.mean())
    n = len(train_loader.dataset)
    complexity1 = mcallester_complexity(kl1, n, delta=delta)
    bound1 = mean_loss_train1 + complexity1
    losses_test1 = estimate_gibbs_losses_array(model1, mu_flat1, sigma_q1, test_loader, device=device, S=S)
    mean_loss_test1 = float(losses_test1.mean())

    losses_test1_prior = estimate_gibbs_losses_array(model1, mu_p1, sigma_p, test_loader, device=device, S=S)
    mean_loss_test1_prior = float(losses_test1_prior.mean())

    # --- compute quantities for model 2 ---
    kl2 = kl_diag_gaussians(mu_flat2, sigma_q2, mu_p2, sigma_p)
    losses_train2 = estimate_gibbs_losses_array(model2, mu_flat2, sigma_q2, train_loader, device=device, S=S)
    mean_loss_train2 = float(losses_train2.mean())
    complexity2 = mcallester_complexity(kl2, n, delta=delta)
    bound2 = mean_loss_train2 + complexity2
    losses_test2 = estimate_gibbs_losses_array(model2, mu_flat2, sigma_q2, test_loader, device=device, S=S)
    mean_loss_test2 = float(losses_test2.mean())

    losses_test2_prior = estimate_gibbs_losses_array(model2, mu_p2, sigma_p, test_loader, device=device, S=S)
    mean_loss_test2_prior = float(losses_test2_prior.mean())

    # compute combined x-limits
    # all_losses = np.concatenate([np.asarray(losses_test1), np.asarray(losses_test2)])
    # xmin, xmax = float(all_losses.min()), float(all_losses.max())
    # xpad = 0.02 * (xmax - xmin) if xmax > xmin else 0.01
    # xmin -= xpad
    # xmax += xpad
    xmin = 0
    xmax = 1

    # plotting
    fig, ax = plt.subplots(1, 1, figsize=figsize)

    density = normalize
    # plot histograms with semi-transparency; specify colors so they differ
    ax.hist(losses_test1, bins=bins, alpha=0.5, label=(title[0] if isinstance(title, (list, tuple)) else "Model 1"),
            density=density, color=colors[0])
    ax.hist(losses_test2, bins=bins, alpha=0.5, label=(title[1] if isinstance(title, (list, tuple)) else "Model 2"),
            density=density, color=colors[1])

    # vertical lines: means and bounds, colored to match histograms
    ax.axvline(mean_loss_test1, color=colors[0], linestyle='--', linewidth=2,
               label=f"Mean test loss (baseline) = {mean_loss_test1:.4f}")
    ax.axvline(mean_loss_test1_prior, color=colors[0], linestyle=':', linewidth=2,
               label=f"Mean test loss (baseline prior) = {mean_loss_test1_prior:.4f}")
    ax.axvline(bound1, color=colors[0], linestyle='-.', linewidth=2, label=f"McAllester bound (baseline) = {bound1:.4f}")
    ax.axvline(mean_loss_test2, color=colors[1], linestyle='--', linewidth=2,
               label=f"Mean test loss (equivariant) = {mean_loss_test2:.4f}")
    ax.axvline(mean_loss_test2_prior, color=colors[1], linestyle=':', linewidth=2,
               label=f"Mean test loss (equivariant prior) = {mean_loss_test2_prior:.4f}")
    ax.axvline(bound2, color=colors[1], linestyle='-.', linewidth=2, label=f"McAllester bound (equivariant) = {bound2:.4f}")

    ax.set_xlabel("Empirical Gibbs 0-1 loss")
    ax.set_ylabel("Density" if density else "Count (posterior samples)")
    ax.set_xlim(xmin, xmax)
    ax.legend(loc='upper right')

    # annotation boxes placed below the x-axis on left and right halves respectively
    txt1 = (
        f"Baseline:", f"KL={kl1:.2e}", f"complexity={complexity1:.4f}"
    )
    txt2 = (
        f"Equivariant:", f"KL={kl2:.2e}", f"complexity={complexity2:.4f}"
    )
    # position annotations at y=-0.18 in axes fraction coordinates (below axis)
    # ax.text(0.02, -0.18, txt1, fontsize=9, va='top', ha='left', transform=ax.transAxes)
    # ax.text(0.98, -0.18, txt2, fontsize=9, va='top', ha='right', transform=ax.transAxes)

    if isinstance(title, str):
        fig.suptitle(title)
    elif isinstance(title, (list, tuple)):
        fig.suptitle(" vs. ".join(title))

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])

    if show:
        plt.savefig(save_path)

    return losses_test1, losses_test2, kl1, kl2, bound1, bound2, complexity1, complexity2

if __name__ == "__main__":
    translated = True

    if not translated:
        path_mnist = "rotated_mnist"
    else:
        path_mnist = "rotated_translated_mnist"

    # load both models
    device = 'cpu'
    model_base = BaselineCNN().to(device)
    model_base.load_state_dict(torch.load(path_mnist + "/baseline_cnn.pt", map_location=device))
    model_eq = EquivariantCNN().to(device)
    model_eq.load_state_dict(torch.load(path_mnist + "/equivariant_cnn.pt", map_location=device))

    # dataset
    train_ds = RotatedMNISTDataset(path_mnist + "/train.pt")
    train_loader = DataLoader(train_ds, batch_size=256, shuffle=False, num_workers=4)
    test_ds = RotatedMNISTDataset(path_mnist + "/test.pt")
    test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=4)

    # prior
    mu_p_base = torch.load(path_mnist + "/prior_mu_baseline.pt")
    mu_p_eq = torch.load(path_mnist + "/prior_mu_equivariant.pt")
    sigma_p = 5e-2

    # posterior
    sigmas = np.logspace(-3, 1, 15)

    mu_q_base = get_flat_params_from(model_base)
    # results_base = sweep_sigma_q_and_compute_curve(model_base, mu_q_base, mu_p_base, sigma_p, train_loader, device=device, sigmas=sigmas, S=10, n=len(train_ds), delta=0.05)
    # best_result_base = min(results_base, key=lambda x: x['bound'])
    # sigma_q_base = best_result_base['sigma_q']
    # print(f"Best sigma_q for the baseline = {sigma_q_base}, with bound = {best_result_base['bound']}")
    sigma_q_base = 5e-2

    mu_q_eq = get_flat_params_from(model_eq)
    # results_eq = sweep_sigma_q_and_compute_curve(model_eq, mu_q_eq, mu_p_eq, sigma_p, train_loader, device=device, sigmas=sigmas, S=10, n=len(train_ds), delta=0.05)
    # best_result_eq = min(results_eq, key=lambda x: x['bound'])
    # sigma_q_eq = best_result_eq['sigma_q']
    # print(f"Best sigma_q for the baseline = {sigma_q_eq}, with bound = {best_result_eq['bound']}")
    sigma_q_eq = 5e-2

    S = 100  #  samples
    #losses, kl, bound, complexity = plot_posterior_error_distribution(
    #    model_base, mu_q_base, sigma_q_base, mu_p_base, sigma_p, train_loader, test_loader, device=device, S=S, delta=0.05,
    #    title="Posterior error distribution"
    #)


    losses1, losses2, kl1, kl2, bound1, bound2, complexity1, complexity2 = \
        plot_posterior_error_overlay(
            model_base, mu_q_base, sigma_q_base, mu_p_base,
            model_eq, mu_q_eq, sigma_q_eq, mu_p_eq,
            sigma_p,
            train_loader, test_loader,
            device=device, S=S, delta=0.05,
            title=("Baseline CNN", "Equivariant CNN"),
            bins=30,
            save_path=path_mnist+"/histogram_overlay_prior.png"
        )

    # print(f"Mean Gibbs loss: {losses.mean():.4f} ± {losses.std(ddof=1)/math.sqrt(len(losses)):.4f}")
    # print(f"KL(Q||P) = {kl:.6f}, complexity = {complexity:.6f}, McAllester bound = {bound:.6f}")
    print(f"Mean Gibbs loss: {losses1.mean():.4f} ± {losses1.std(ddof=1) / math.sqrt(len(losses1)):.4f}")
    print(f"KL(Q||P) = {kl1:.6f}, complexity = {complexity1:.6f}, McAllester bound = {bound1:.6f}")
    print(f"Mean Gibbs loss: {losses2.mean():.4f} ± {losses2.std(ddof=1)/math.sqrt(len(losses2)):.4f}")
    print(f"KL(Q||P) = {kl2:.6f}, complexity = {complexity2:.6f}, McAllester bound = {bound2:.6f}")
