import jax.numpy as jnp

def make_boolean_vector(N, first=True, ratio=0.3):
    vec = jnp.full((N,), False, dtype=bool)
    k = int(jnp.ceil(N * ratio))

    if first:
        vec = vec.at[:k].set(True)
    else:
        vec = vec.at[-k:].set(True)
    return vec

def get_conditional_mask(theta_o, x_o, first=True, ratio=0.3):
    theta_len = theta_o.shape[0]
    x_len = x_o.shape[0]
    x_mask = make_boolean_vector(x_len, first=first, ratio=ratio)
    theta_mask = jnp.full((theta_len,), False, dtype=bool)
    condition_mask = jnp.concatenate((theta_mask, x_mask), axis=0)
    return condition_mask
