import jax
import jax.numpy as jnp

def spsa_aml_iid(
    key,
    sim_fn,                     
    obs,                        
    theta0,
    iterations,
    n_sim,
    a,
    c,
    A = None,
    alpha = 1.0,
    gamma = 0.16666,
    bandwidth_ma = 0.9
):
    d         = len(theta0)
    theta     = theta0.astype(float)
    A         = int(0.1 * iterations) if A is None else A
    h_prev    = None

    sim_fn_jit                   = jax.jit(sim_fn, static_argnames=("n_sim",))
    kde_logdensity_jit           = jax.jit(kde_logdensity_modified)
    kde_logdensity_batch         = jax.jit(
        jax.vmap(
            kde_logdensity_jit,
            in_axes=(0, None, None)
        )
    )
    silverman_bandwidth_jit      = jax.jit(silverman_bandwidth)

    for k in range(1, iterations + 1):
        a_k = a / ((k + A) ** alpha)
        c_k = c / (k ** gamma)

        key, key_delta, key_p, key_m = jax.random.split(key, 4)
        delta   = jax.random.choice(key_delta, jnp.array([-1.0, 1.0]), shape=(d,))
        t_plus  = theta + c_k * delta
        t_minus = theta - c_k * delta

        sims_plus  = sim_fn_jit(key_p, t_plus,  n_sim)  
        sims_minus = sim_fn_jit(key_m, t_minus, n_sim)  

        h_est = silverman_bandwidth_jit(
            jnp.vstack([sims_plus, sims_minus])
        )
        h = h_est if h_prev is None else bandwidth_ma * h_prev + (1.0 - bandwidth_ma) * h_est
        h_prev = h

        ll_plus  = kde_logdensity_batch(obs, sims_plus,  h)  
        ll_minus = kde_logdensity_batch(obs, sims_minus, h)  

        diff   = jnp.sum(ll_plus - ll_minus)                  
        g_hat  = delta * diff / (2.0 * c_k)                   

        theta  = theta + a_k * g_hat

    return theta

def silverman_bandwidth(samples):
    n, d = samples.shape
    sigma = jnp.std(samples, axis=0, ddof=1)
    factor = (4.0 / (d + 2.0)) ** (1.0 / (d + 4.0)) * n ** (-1.0 / (d + 4.0))
    return factor * sigma + 1e-8        

    
def kde_logdensity_modified(
    x,
    samples,
    h,
    tiny = 1e-300,
):
    inv_h_sq = 1.0 / (h * h)                         

    diff = samples - x                               
    d2   = jnp.sum((diff ** 2) * inv_h_sq, axis=1)   
    kernels = jnp.exp(-0.5 * jnp.where(d2 < 1.0, d2, jnp.sqrt(d2)))

    ksum_mean = kernels.mean()

    def _fallback(_):
        idx = jnp.argmin(jnp.sum(diff ** 2, axis=1)) 
        return kernels[idx]

    ksum = jax.lax.cond(ksum_mean < tiny, _fallback, lambda _: ksum_mean, operand=None)
    ksum = jnp.clip(ksum, a_min=tiny)                

    return jnp.log(ksum) - jnp.sum(jnp.log(h))