import jax
import jax.numpy as jnp
import equinox as eqx

@eqx.filter_jit
def gan_simulator(key, 
                  mean_params, 
                  n_sim, 
                  generator, 
                  state,
                  latent_dim,
                  sigma=1e-1):
    z_batch = jax.random.normal(key, (n_sim, latent_dim)) * sigma + mean_params
    z_batch = z_batch.reshape(n_sim, latent_dim,1,1)
    vmapped_gen = jax.vmap(generator, 
                           axis_name="batch", 
                           in_axes=(0, None), 
                           out_axes=(0, None))
    sims, _ = vmapped_gen(z_batch, state)
    return sims.reshape(n_sim, 16*16)

def mvt_norm_simulator(
    key,
    mean_params,
    n_sim,
    cov = None
    ):
    if cov is None:
        cov = jnp.identity(mean_params.shape[0])
    key, subkey = jax.random.split(key)
    sims = jax.random.multivariate_normal(subkey,
                                          mean=mean_params,
                                          cov=cov,
                                          shape=(n_sim,))
    return sims


def univ_norm_simulator(
    key,
    mean_params,
    n_sim,
    cov = None
    ):
    if cov is None:
        cov = 1.0 
    key, subkey = jax.random.split(key)
    sims = jax.random.normal(subkey, shape=(n_sim,)) * jnp.sqrt(cov) + mean_params
    return sims.reshape(-1,1)

def shifted_exp_simulator(key, theta, n_sim, rate=1.0):
    shift_param = theta
    U = jax.random.uniform(key, shape=(n_sim,))
    Y = -jnp.log(U) / rate  
    return (shift_param + Y).reshape(n_sim, 1)