import numpy as np
import matplotlib.pyplot as plt

# Set random seed for reproducibility.
np.random.seed(35)

# Parameters
M = 6
n = 5       # number of vectors mu_i
d = 3       # dimensionality of each mu_i
lr = 0.01   # learning rate
iterations = 20000  # number of gradient descent iterations
mu = np.array([M] + [0] * (d - 1)) + 0.0000001 * np.random.randn(n, d)

def compute_batch_gradient(mu, X):
    """
    mu:        (n, d)
    X:         (batch_size, d)
    returns:   grad_batch: (n, d)    = sum over samples of grad(mu; x)
               V_batch:    (batch_size, d)
    """
    n, d = mu.shape
    N = X.shape[0]  # batch_size
    # 1) compute diff: (N, n, d)
    diff = X[:, None, :] - mu[None, :, :]

    # 2) squared distances: (N, n)
    sq_dists = np.sum(diff**2, axis=2)

    # 3) stable weights phi: (N, n)
    tmp = -sq_dists/2.0
    tmp = tmp - tmp.mean(axis=1, keepdims=True)
    exps = np.exp(tmp)
    phi = exps / exps.sum(axis=1, keepdims=True)

    # 4) per sample V: (N, d)
    V_batch = phi @ mu

    # 5) common scalars
    dot_mu_mu = mu @ mu.T        # (n,n)
    dot_V_mu = V_batch @ mu.T    # (N, n)
    dot_V_V  = np.sum(V_batch**2, axis=1, keepdims=True)  # (N,1)
    
    # 6) Now build each term in batch:
    # term1
    term1 = phi[:, :, None] * V_batch[:, None, :]
    # term2 
    A = phi[:, :, None] * mu[None, :, :]    # (N,n,d)
    term2 = phi[:, :, None] * ( dot_mu_mu[None,:,:] @ A )  # (N,n,d)

    # term3
    term3 = -2 * phi[:, :, None] * (dot_V_mu[:,:,None] * V_batch[:,None,:])

    # term4
    S = np.sum(phi[:,:,None] * dot_V_mu[:,:,None] * mu[None,:,:], axis=1)  # (N,d)
    term4 = -2 * phi[:,:,None] * S[:,None,:]

    # term5
    term5 = 3 * phi[:,:,None] * dot_V_V[:,:,None] * V_batch[:,None,:]

    # sum over terms:
    grad_samples = term1 + term2 + term3 + term4 + term5  # (N,n,d)

    # finally sum over the N samples:
    grad_batch = grad_samples.sum(axis=0) / N             # (n,d)

    return grad_batch, V_batch

# Run gradient descent.
batch_size = 20000
loss_history = []
mu_norm_history = []

mu_history = []

for iter in range(iterations):
    x_batch = np.random.randn(batch_size, d)
    grad_batch = np.zeros_like(mu)  # shape (n, d)
    grad_avg, V_batch = compute_batch_gradient(mu, x_batch)
    loss = np.linalg.norm(V_batch, axis=1, ord=2).mean()
    
    # Update mu: simple gradient descent update.
    mu = mu - lr * grad_avg

    # Optionally, print progress every 200 iterations.
    if iter % 200 == 0:
        grad_norm = np.linalg.norm(grad_avg)
        print(f"Iteration {iter}: gradient norm = {grad_norm:.4f}")
        print("Current mu parameters:")
        print(mu)
        print("-"*40)
        print(loss)
        print(1/loss)
        
    mu_norms = np.linalg.norm(mu, axis=1)
    mu_norm_history.append(mu_norms)
        
    #V-batch is the loss
    if iter % 200 == 0:
        print(f"Iteration {iter:4d}  Loss: {loss}")
    loss_history.append(loss)
    mu_history.append(mu.copy())
    
loss_array = np.array(loss_history, dtype=float)

log_loss = np.log(loss_array)
iterations = np.arange(1, len(loss_array) + 1)
log_iter = np.log(iterations)

plt.figure(figsize=(8, 4))
plt.plot(log_iter, log_loss)
plt.xlabel(r"$\log(\mathrm{Iteration})$", fontsize=16)
plt.ylabel(r"$\log(\mathcal{L})$", fontsize=16)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.savefig(f"{M}-ddpm_loglog_loss.png")
plt.show()

mu_norm_history = np.stack(mu_norm_history, axis=0)
import matplotlib.pyplot as plt

plt.figure(figsize=(8,5))
for i in range(mu_norm_history.shape[1]):
    plt.plot(mu_norm_history[:, i], label=f"$\\|\\mu_{i}\\|$")
    
plt.xlabel("Iteration", fontsize=16)
plt.ylabel("$\\|\\mu_i\\|$", fontsize=16)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.legend(ncol=2, fontsize="small")
plt.grid(True)
plt.tight_layout()
plt.savefig(f"{M}-mu_norms.png")
plt.show()
