import jax
import jax.numpy as jnp
import jax.random as jrandom

#z0,z1 two vectors of dimensiom dim
def generate_points_mixture_jax(L, sigma, dim, samples, z0,z1, key):
    # Split key for randomness
    samples = int(samples)
    key, subkey1, subkey2, subkey3 = jrandom.split(key, 4)

    # Generate `samples * L` random binary choices (0 or 1)
    choices = jrandom.bernoulli(subkey1, p=0.5, shape=(samples, L))  # Boolean array

    if sigma == 0:
        # For sigma=0, return Dirac mixture samples: points are exactly at z0 or z1
        # Broadcast the centroids to the shape (samples, L, dim)
        points_z0 = jnp.broadcast_to(z0, (samples, L, dim))
        points_z1 = jnp.broadcast_to(z1, (samples, L, dim))
        points = jnp.where(choices[..., None], points_z1, points_z0)
    else:
        # Define covariance matrix for both distributions
        I = jnp.eye(dim)
        cov = sigma**2 * I

        # Generate normal samples for both distributions
        noise_z0 = jrandom.multivariate_normal(subkey2, mean=z0, cov=cov, shape=(samples, L))
        noise_z1 = jrandom.multivariate_normal(subkey3, mean=z1, cov=cov, shape=(samples, L))
        # Select points based on choices
        points = jnp.where(choices[..., None], noise_z1, noise_z0)

    return points  # Returns an array of shape (samples, L, dim)