import jax.numpy as jnp
import jax.random as jrandom


#Function that generates a random point on the manifold

def generate_centroid_manifold(mu0_star, mu1_star, subkey0, subkey1, dim):
    # Generate mu0 that is orthogonal to mu1_star:
    mu0 = jrandom.normal(subkey0, shape=(dim,))
    # Remove any component in the mu1_star direction:
    mu0 -= (jnp.dot(mu0, mu1_star) / jnp.dot(mu1_star, mu1_star)) * mu1_star
    mu0 /= jnp.linalg.norm(mu0)

    # Generate mu1:
    mu1 = jrandom.normal(subkey1, shape=(dim,))
    # Form a matrix with columns: [mu0_star, mu0]
    A = jnp.stack([mu0_star, mu0], axis=1)  # shape: (dim, 2)
    # Project mu1 onto the span of [mu0_star, mu0]
    proj = A @ jnp.linalg.solve(A.T @ A, A.T @ mu1)
    # Remove the projection to make mu1 orthogonal to both mu0_star and mu0
    mu1 = mu1 - proj
    mu1 /= jnp.linalg.norm(mu1)

    return mu0, mu1
