import jax
import jax.numpy as jnp
import jax.random as jrandom


def generate_points_mixture_jax_3(L, sigma, dim, samples, z0,z1, z2, key):
    # Split key for randomness
    samples = int(samples)
    key, subkey1, subkey2, subkey3, subkey4 = jrandom.split(key, 5)

    choices = jrandom.randint(subkey1, shape=(samples, L), minval=0, maxval=3)


    if sigma == 0:
        # For sigma=0, return Dirac mixture samples: points are exactly at z0 or z1
        # Broadcast the centroids to the shape (samples, L, dim)
        #points_z0 = jnp.broadcast_to(z0, (samples, L, dim))
        #points_z1 = jnp.broadcast_to(z1, (samples, L, dim))
        #points_z2 = jnp.broadcast_to(z2, (samples, L, dim))
        Zs = jnp.stack([z0, z1, z2], axis=0)  
        points = Zs[choices]    
        #points = jnp.take(Zs, choices, axis=0)
    else:
        # Define covariance matrix for both distributions
        I = jnp.eye(dim)
        cov = sigma**2 * I

        # Generate normal samples for both distributions
        noise_z0 = jrandom.multivariate_normal(subkey2, mean=z0, cov=cov, shape=(samples, L))
        noise_z1 = jrandom.multivariate_normal(subkey3, mean=z1, cov=cov, shape=(samples, L))
        noise_z2 = jrandom.multivariate_normal(subkey4, mean=z2, cov=cov, shape=(samples, L))
    
        noise_stack = jnp.stack([noise_z0, noise_z1, noise_z2], axis=2)

        # pick along axis=2 using `choices`
        # expand choices to (..., 1, 1) so take_along_axis can select the right dim block
        idx = choices[..., None, None]
        points = jnp.take_along_axis(noise_stack, idx, axis=2)
        points = points.squeeze(axis=2)
        # Remove the extra dimensions introduced by take_along_axis

    return points  # Returns an array of shape (samples, L, dim)

key= jrandom.PRNGKey(1)
dim=6
d2=jnp.floor(dim/2).astype(int)
samples=2
L=3
logits = jnp.array([0.0, 0.0, 0.0])
key, sk0, sk1,sk2 = jrandom.split(key,4)
z0 = jnp.zeros(dim)
z0 = z0.at[dim - 1].set(1)
z1 = jnp.zeros(dim)
z1 = z1.at[0].set(1)
z2 = jnp.zeros(dim)
z2 = z2.at[d2].set(1)

#a= generate_points_mixture_jax_3(L, 0.01, dim, samples, z0, z1,z2,key)
#print(a.shape)
