import numpy as np
from scipy.stats import wasserstein_distance
from scipy.integrate import quad
import matplotlib.pyplot as plt
import torch
import ot


def get_avg_wasserstein(
    model,
    fk_sampler,
    st_sampler,
    base_width,
    beta,
    batch_size=4096,
):
    model.eval()
    torch.no_grad().__enter__()

    # Get samples target.
    samples_target = metropolis_hastings_nd(
        log_p=log_p_tilde,
        x0=np.array([1.5]),
        beta=beta,
        base_width=base_width,
        n_samples=batch_size,
    ).squeeze()

    # The number of runs we will average over
    n_samples = 50 
    ws = []
    W = 0
    a = np.ones(len(samples_target)) / len(samples_target)

    for i in range(n_samples):
        _, x_fk = fk_sampler.sample()
        C = ot.dist(x_fk.cpu().numpy(), samples_target[:,None], metric='euclidean')**2
        W2_sq = ot.emd2(a, a, C)  # this is now correct
        W2 = np.sqrt(W2_sq)

        W += W2/n_samples
    print("FK_wasserstein", W, base_width)
    ws.append(W)

    W = 0
    n_samples = 50 
    for i in range(n_samples):
        _, x_st = st_sampler.sample(
            n_iterations=100,
            batch_size=batch_size,
        )

        C = ot.dist(x_st.cpu().numpy(), samples_target[:,None], metric='euclidean')**2
        W2_sq = ot.emd2(a, a, C)  # this is now correct
        W2 = np.sqrt(W2_sq)

        W += W2/n_samples
    print("SPT_wasserstein", W)
    ws.append(W)

    # Return averaged wasserstein distances
    return ws[0], ws[1]


def marginal_prob_std(t, sigma):
    """
    Returns the standard deviation of our perturbation.
    CUDA-safe and autograd-safe.
    """
    # Ensure tensor, preserve device & dtype
    if not torch.is_tensor(t):
        t = torch.tensor(t)
    device = t.device
    dtype = t.dtype

    sigma_t = sigma ** (2 * t)
    return torch.sqrt((sigma_t - 1.0) / (2 * np.log(sigma))).to(
        device=device, dtype=dtype
    )


def diff_coeff(t, sigma):
    """
    Diffusion coefficient g(t)
    """
    if not torch.is_tensor(t):
        t = torch.tensor(t)
    return sigma**t


def p_base(x, sigma_target=0.2):
    sigma2 = sigma_target**2
    # first dimension
    if type(x) is not float:
        xd = x[..., 0]
        if x.shape[-1] > 1:
            other_sq = np.sum(x[..., 1:] ** 2, axis=-1)
        else:
            other_sq = 0.0
    else:
        xd = x
        other_sq = 0

    # two Gaussians along first dimension, other dims isotropic
    gauss1 = 0.9 * np.exp(-((xd + 1.5) ** 2 + other_sq) / (2 * sigma2))
    gauss2 = 0.1 * np.exp(-((xd - 1.5) ** 2 + other_sq) / (2 * sigma2))

    return gauss1 + gauss2


def R(x, base_width):
    if type(x) is not float:
        xd = x[..., 0]
        if x.shape[-1] > 1:
            other_sq = np.sum(x[..., 1:] ** 2, axis=-1)
        else:
            other_sq = 0.0
    else:
        xd = x
        other_sq = 0

    # global peak (low-mass Gaussian)
    global_peak = 2.0 * np.exp(-((xd - 1.5) ** 2 + other_sq) / base_width)
    # local peak (high-mass Gaussian)
    local_peak = 1.5 * np.exp(-((xd + 1.5) ** 2 + other_sq) / 3.0)

    return global_peak + local_peak


def log_p_tilde(
    x,
    beta,
    base_width,
    reward_fn,
    sigma_target=0.2,
):
    return np.log(p_base(x, sigma_target)) + beta * reward_fn(x, base_width)


def unnormalized_pdf(x, beta, base_width):
    return p_base(x) * np.exp(beta * R(x, base_width))


def normalization_constant(beta, base_width, L=10.0):
    """
    Integrates over [-L, L], which is more than enough here
    """
    spike = 1.5
    eps = base_width  # width of zoom window — tune this

    f = lambda x: unnormalized_pdf(x, beta, base_width)

    I1, _ = quad(f, -L, spike - eps)
    I2, _ = quad(f, spike - eps, spike + eps, limit=200000)  # zoom region
    I3, _ = quad(f, spike + eps, L)

    return I1 + I2 + I3


def metropolis_hastings_nd(
    log_p,
    x0,
    beta,
    n_samples,
    base_width,
    proposal_std_small=0.05,
    proposal_std_large=2.0,
    burn_in=10000,
    reward_fn=R,
):
    x = np.array(x0, dtype=float)
    D = x.shape[0]
    samples = []

    log_px = log_p(x, beta, base_width, reward_fn=reward_fn)
    accept_count = 0

    for i in range(n_samples + burn_in):
        # Small vs. large jump
        if np.random.rand() < 0.9:
            proposal = x + proposal_std_small * np.random.randn(D)
        else:
            proposal = x + proposal_std_large * np.random.randn(D)

        log_p_prop = log_p(proposal, beta, base_width, reward_fn=reward_fn)

        # Metropolis acceptance
        if np.log(np.random.rand()) < (log_p_prop - log_px):
            x = proposal
            log_px = log_p_prop
            accept_count += 1

        if i >= burn_in:
            samples.append(x.copy())

    print("Acceptance rate:", accept_count / (n_samples + burn_in))
    return np.array(samples)
