import torch

# --------------------
# Online Updates
# --------------------
def online_update_pl(theta, ranking, X, H_pl, eta, B=1.0):
    H_tilde = H_pl.clone()
    K = len(ranking)

    for i in range(K - 1):
        indices = ranking[i:]
        X_sub = X[indices]  # shape: (K-i, d)

        utilities = X_sub @ theta
        utilities = torch.clamp(utilities, min=-20.0, max=20.0)  # for numerical stability

        exp_utils = torch.exp(utilities)
        Z = exp_utils.sum()
        p = exp_utils / Z  # shape: (K-i,)

        # weighted outer product
        weighted_outer = (p[:, None] * X_sub).T @ X_sub  # (d, d)
        mean_vec = torch.sum(p[:, None] * X_sub, dim=0)  # (d,)

        H_tilde += eta * (weighted_outer - torch.outer(mean_vec, mean_vec))

        grad = torch.sum(p[:, None] * X_sub, dim=0) - X[ranking[i]]  # shape: (d,)
        theta -= eta * torch.linalg.solve(H_tilde, grad) 

    # project onto L2 ball of radius B
    norm = torch.norm(theta, p=2)
    if norm > B:
        theta = theta * (B / norm)
    return theta

def online_update_rb(theta, pairs, X, H_rb, eta, B=1):
    H_tilde = H_rb.clone()
    for i, j in pairs:
        diff = X[i] - X[j]
        z = diff @ theta
        sig = torch.sigmoid(z)
        grad = -(1 - sig) * diff
        sig_dot = sig * (1 - sig)
        H_tilde += eta * sig_dot * torch.outer(diff, diff)
        delta = eta * torch.linalg.pinv(H_tilde) @ grad
        theta = theta - delta
    norm = torch.norm(theta)
    return theta if norm <= B else theta * (B / norm)

# --------------------
# Hessians
# --------------------
def pl_hessian(theta, X, ranking):
    d = theta.shape[0]
    H = torch.zeros(d, d, dtype=theta.dtype, device=theta.device)
    K = len(ranking)

    for i in range(K - 1):
        indices = ranking[i:]
        X_sub = X[indices]

        utilities = X_sub @ theta
        utilities = torch.clamp(utilities, min=-20.0, max=20.0)  # prevent exp overflow

        exp_utils = torch.exp(utilities)
        p = exp_utils / exp_utils.sum()

        weighted_outer = (p[:, None] * X_sub).T @ X_sub
        mean_vec = torch.sum(p[:, None] * X_sub, dim=0)

        H += weighted_outer - torch.outer(mean_vec, mean_vec)

    return H

def rb_hessian(theta, X, pairs):
    d = theta.shape[0]
    H = torch.zeros((d, d), dtype=X.dtype, device=X.device)
    for i, j in pairs:
        diff = X[i] - X[j]
        z = diff @ theta
        sig = torch.sigmoid(z)
        sig_dot = sig * (1 - sig)
        H += sig_dot * torch.outer(diff, diff)
    return H