import numpy as np
from scipy.special import expit

def online_update_pl(theta, ranking, X, H_pl, eta, B=1):
    H_tilde = H_pl.copy()
    K = len(ranking)
    for i in range(K - 1):
        indices = ranking[i:]
        X_sub = X[indices]
        utilities = X_sub @ theta
        exp_utils = np.exp(utilities)
        Z = np.sum(exp_utils)
        p = exp_utils / Z
        weighted_outer = np.einsum('i,ij,ik->jk', p, X_sub, X_sub)
        mean_vec = np.sum(p[:, None] * X_sub, axis=0)
        H_tilde += eta * (weighted_outer - np.outer(mean_vec, mean_vec))
        grad = np.sum(p[:, None] * X_sub, axis=0) - X[ranking[i]]
        theta -= eta * np.linalg.pinv(H_tilde) @ grad
    norm = np.linalg.norm(theta)
    return theta if norm <= B else theta * (B / norm)

def online_update_rb(theta, pairs, X, H_rb, eta, B=1):
    H_tilde = H_rb.copy()
    for i, j in pairs:
        diff = X[i] - X[j]
        sig = expit(diff @ theta)
        grad = -(1 - sig) * diff
        sig_dot = sig * (1 - sig)
        H_tilde += eta * sig_dot * np.outer(diff, diff)
        theta -= eta * np.linalg.pinv(H_tilde) @ grad
    norm = np.linalg.norm(theta)
    return theta if norm <= B else theta * (B / norm)

def pl_hessian(theta, X, ranking):
    d = len(theta)
    H = np.zeros((d, d))
    K = len(ranking)
    for i in range(K - 1):
        indices = ranking[i:]
        X_sub = X[indices]
        utilities = X_sub @ theta
        exp_utils = np.exp(utilities)
        Z = np.sum(exp_utils)
        p = exp_utils / Z
        weighted_outer = np.einsum('i,ij,ik->jk', p, X_sub, X_sub)
        mean_vec = np.sum(p[:, None] * X_sub, axis=0)
        H += weighted_outer - np.outer(mean_vec, mean_vec)
    return H

def rb_hessian(theta, X, pairs):
    d = len(theta)
    H = np.zeros((d, d))
    for i, j in pairs:
        diff = X[i] - X[j]
        z = diff @ theta
        sig = expit(z)
        sig_dot = sig * (1 - sig)
        H += sig_dot * np.outer(diff, diff)
    return H

def generate_biased_contexts(N, d, noise_scale=0.1):
    v = np.random.randn(d)
    v /= np.linalg.norm(v)

    X = np.zeros((N, d))
    for i in range(N):
        noise = np.random.randn(d)
        parallel = (noise @ v) * v
        noise -= parallel  
        x_i = v + noise_scale * noise
        x_i /= np.linalg.norm(x_i)  
        X[i] = x_i
    return X

def generate_mostly_orthogonal_X(N, d, theta_star, noise_scale=0.01):
    X = np.zeros((N, d))
    for i in range(N - d):
        while True:
            rand_vec = np.random.randn(d)
            parallel = (rand_vec @ theta_star) * theta_star
            x_i = rand_vec - parallel
            norm_x = np.linalg.norm(x_i)
            if norm_x > 1e-15:
                break
        x_i /= norm_x
        # Add small Gaussian noise and renormalize
        noise = np.random.randn(d) * noise_scale
        x_i_noisy = x_i + noise
        x_i_noisy /= np.linalg.norm(x_i_noisy)
        X[i] = x_i_noisy

    for i in range(N - d, N):
        rand_vec = np.random.randn(d)
        rand_vec /= np.linalg.norm(rand_vec)
        X[i] = rand_vec

    return X

def get_contexts(env, N, d, theta_star, C=None, noise_scale=0.1, seed=12345):
    np.random.seed(seed)

    if env == "fixed":
        X = np.random.randn(N, d)
        X_set = [X / np.linalg.norm(X, axis=1, keepdims=True)]
        
    elif env == "hard_instance":
        if theta_star is None:
            theta_star = np.random.randn(d)
            theta_star /= np.linalg.norm(theta_star)
        X_set = [
            X / np.linalg.norm(X, axis=1, keepdims=True)
            for X in [generate_mostly_orthogonal_X(N, d, theta_star) for _ in range(C)]
        ]

    elif env == "varying":
        X_set = [
            X / np.linalg.norm(X, axis=1, keepdims=True)
            for X in [np.random.randn(N, d) for _ in range(C)]
        ]

    elif env == "skewed_but_varying":
        X_set = [generate_biased_contexts(N, d, noise_scale) for _ in range(C)]

    else:
        raise ValueError(f"Unknown environment type: {env}")

    return X_set