# Compute and plot posterior-sampled losses for baseline vs equivariant CNNs and compare with McAllester PAC-Bayes bound.

import math
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import torch.nn.functional as F
from models import BaselineCNN
from models import EquivariantCNN
from dataset import RotatedMNISTDataset
from pacbayes_utils import (sweep_sigma_q_and_compute_curve, get_flat_params_from, set_flat_params_to, kl_diag_gaussians, mcallester_bound, estimate_gibbs_loss)


# -----------------
# Plotting function
# -----------------
def plot_posterior_error_overlay(
    model1, mu_flat1, sigma_q1, mu_p1,
    model2, mu_flat2, sigma_q2, mu_p2,
    sigma_p,
    train_loader, test_loader,
    device='cpu', S=200, delta=0.05,
    title=None, show=True, bins=60,
    figsize=(8,5), save_path="histogram_overlay.png",
    normalize=False
):
    """
    Plot posterior-sampled 0-1 losses for two models in the same axes.

    Parameters:
      - normalize: if True, plot density (area under hist = 1) instead of counts.

    Each model will have its histogram, mean (dashed) and McAllester bound (dash-dotted).
    Returns (losses_test1, losses_test2, kl1, kl2, bound1, bound2, complexity1, complexity2)
    """

    colors = ("C0", "C1")

    # --- compute quantities for model 1 ---
    kl1 = kl_diag_gaussians(mu_flat1, sigma_q1, mu_p1, sigma_p)
    print("Estimating empirical risk for the first model")
    mean_loss_train1, stderr_train1, losses_train1 = estimate_gibbs_loss(model1, mu_flat1, sigma_q1, train_loader, device=device, S=S)
    n = len(train_loader.dataset)
    bound1, complexity1 = mcallester_bound(mean_loss_train1, kl1, n, delta=delta)
    print("Estimating risk for the first model")
    mean_loss_test1, stderr_test1,losses_test1 = estimate_gibbs_loss(model1, mu_flat1, sigma_q1, test_loader, device=device, S=S)

    # --- compute quantities for model 2 ---
    kl2 = kl_diag_gaussians(mu_flat2, sigma_q2, mu_p2, sigma_p)
    print("Estimating empirical risk for the second model")
    mean_loss_train2, stderr_train2, losses_train2 = estimate_gibbs_loss(model2, mu_flat2, sigma_q2, train_loader, device=device, S=S)
    bound2, complexity2 = mcallester_bound(mean_loss_train2, kl2, n, delta=delta)
    print("Estimating risk for the second model")
    mean_loss_test2, stderr_test2,losses_test2 = estimate_gibbs_loss(model2, mu_flat2, sigma_q2, test_loader, device=device, S=S)

    # fixed x-limits for clarity (loss is in [0,1])
    xmin = 0
    xmax = 1

    # plotting
    fig, ax = plt.subplots(1, 1, figsize=figsize)

    # Calculate bin edges based on the combined data range
    # all_losses = np.concatenate([losses_test1, losses_test2])
    # bin_edges = np.linspace(all_losses.min(), all_losses.max(), bins + 1)

    # fixed [0, 1] range:
    bin_edges = np.linspace(0, 1, bins + 1)

    density = normalize
    # plot histograms with semi-transparency; specify colors so they differ
    ax.hist(losses_test1, bins=bin_edges, alpha=0.5, label=(title[0] if isinstance(title, (list, tuple)) else "Model 1"),
            density=density, color=colors[0])
    ax.hist(losses_test2, bins=bin_edges, alpha=0.5, label=(title[1] if isinstance(title, (list, tuple)) else "Model 2"),
            density=density, color=colors[1])

    # vertical lines: means and bounds, colored to match histograms
    ax.axvline(mean_loss_test1, color=colors[0], linestyle='--', linewidth=2,
               label=f"Mean test loss (baseline) = {mean_loss_test1:.4f}")
    ax.axvline(bound1, color=colors[0], linestyle='-.', linewidth=2, label=f"McAllester bound (baseline) = {bound1:.4f}")
    ax.axvline(mean_loss_test2, color=colors[1], linestyle='--', linewidth=2,
               label=f"Mean test loss (equivariant) = {mean_loss_test2:.4f}")
    ax.axvline(bound2, color=colors[1], linestyle='-.', linewidth=2, label=f"McAllester bound (equivariant) = {bound2:.4f}")

    ax.set_xlabel("Risk")
    ax.set_ylabel("Density" if density else "Count (posterior samples)")
    ax.set_xlim(xmin, xmax)
    ax.legend(loc='upper right')

    if isinstance(title, str):
        fig.suptitle(title)
    elif isinstance(title, (list, tuple)):
        fig.suptitle(" vs. ".join(title))

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])

    if show:
        plt.savefig(save_path)

    return losses_test1, losses_test2, kl1, kl2, bound1, bound2, complexity1, complexity2

if __name__ == "__main__":
    translated = False

    if not translated:
        path_mnist = "rotated_mnist"
    else:
        path_mnist = "rotated_translated_mnist"

    # load both models (state_dicts saved by training scripts)
    device = 'cpu'
    model_base = BaselineCNN().to(device)
    model_base.load_state_dict(torch.load(path_mnist + "/baseline_cnn.pt", map_location=device))
    model_eq = EquivariantCNN().to(device)
    model_eq.load_state_dict(torch.load(path_mnist + "/equivariant_cnn.pt", map_location=device))

    # dataset loaders for Monte-Carlo loss estimation
    train_ds = RotatedMNISTDataset(path_mnist + "/train.pt")
    train_loader = DataLoader(train_ds, batch_size=256, shuffle=False, num_workers=4)
    test_ds = RotatedMNISTDataset(path_mnist + "/test.pt")
    test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=4)

    # load prior means and set prior std
    mu_p_base = torch.load(path_mnist + "/prior_mu_baseline.pt")
    mu_p_eq = torch.load(path_mnist + "/prior_mu_equivariant.pt")
    sigma_p = 5e-2

    # choose posterior std candidates
    sigmas = np.logspace(-3, 1, 15)

    mu_q_base = get_flat_params_from(model_base)
    # find the best std for posterior (or use the given one)
    # results_base = sweep_sigma_q_and_compute_curve(model_base, mu_q_base, mu_p_base, sigma_p, train_loader, device=device, sigmas=sigmas, S=10, n=len(train_ds), delta=0.05)
    # best_result_base = min(results_base, key=lambda x: x['bound'])
    # sigma_q_base = best_result_base['sigma_q']
    # print(f"Best sigma_q for the baseline = {sigma_q_base}, with bound = {best_result_base['bound']}")
    sigma_q_base = 5e-2

    mu_q_eq = get_flat_params_from(model_eq)
    # find the best std for posterior (or use the given one)
    # results_eq = sweep_sigma_q_and_compute_curve(model_eq, mu_q_eq, mu_p_eq, sigma_p, train_loader, device=device, sigmas=sigmas, S=10, n=len(train_ds), delta=0.05)
    # best_result_eq = min(results_eq, key=lambda x: x['bound'])
    # sigma_q_eq = best_result_eq['sigma_q']
    # print(f"Best sigma_q for the baseline = {sigma_q_eq}, with bound = {best_result_eq['bound']}")
    sigma_q_eq = 5e-2

    S = 200  # posterior samples for histogram estimates

    # compute and save overlayed histograms + bounds
    losses1, losses2, kl1, kl2, bound1, bound2, complexity1, complexity2 = \
        plot_posterior_error_overlay(
            model_base, mu_q_base, sigma_q_base, mu_p_base,
            model_eq, mu_q_eq, sigma_q_eq, mu_p_eq,
            sigma_p,
            train_loader, test_loader,
            device=device, S=S, delta=0.05,
            title=("Baseline CNN", "Equivariant CNN"),
            bins=80,
            save_path=path_mnist+"/histogram_overlay.png"
        )

    print(f"Mean loss: {losses1.mean():.4f} ± {losses1.std(ddof=1) / math.sqrt(len(losses1)):.4f}")
    print(f"KL(Q||P) = {kl1:.6f}, complexity = {complexity1:.6f}, McAllester bound = {bound1:.6f}")
    print(f"Mean loss: {losses2.mean():.4f} ± {losses2.std(ddof=1)/math.sqrt(len(losses2)):.4f}")
    print(f"KL(Q||P) = {kl2:.6f}, complexity = {complexity2:.6f}, McAllester bound = {bound2:.6f}")
