import chex
import optax
import jax.numpy as jnp
import jax


def uniform_sphere_random(key, shape, dtype=jnp.float32):
    angles = jax.random.normal(key, shape, dtype=dtype)
    norm = jnp.linalg.norm(angles.flatten(), ord=2)
    return jnp.where(norm == 0, jnp.zeros_like(angles), angles / norm)


def apply_perturbed_normalized_gd(
        learning_rate: float,
        seed: int
) -> optax.GradientTransformation:
    lr = jnp.ones([], dtype=jnp.float32) * learning_rate
    train_freq = jnp.power(lr, -0.1).astype(jnp.int32)
    radius = jnp.power(lr, 100).clip(2 ** (-124)) # to prevent floating point error

    def init_fn(params):
        del params
        return optax.AddNoiseState(count=jnp.zeros([], jnp.int32), rng_key=jax.random.PRNGKey(seed))

    def update_fn(updates, state, params=None):
        del params
        num_vars = len(jax.tree_util.tree_leaves(updates))
        treedef = jax.tree_util.tree_structure(updates)
        count_inc = optax._src.numerics.safe_int32_increment(state.count)
        all_keys = jax.random.split(state.rng_key, num=num_vars + 1)
        r = jax.random.uniform(state.rng_key, shape=[], minval=0., maxval=radius)
        noise = jax.tree_util.tree_map(
            lambda g, k: uniform_sphere_random(k, shape=g.shape, dtype=g.dtype),
            updates, jax.tree_util.tree_unflatten(treedef, all_keys[1:]))
        trigger = jnp.squeeze(jnp.mod(state.count, train_freq) == 0)
        chex.assert_shape(trigger, ())

        def add_noise_fn(t, n):
            return jax.lax.select(trigger, t + r * n, t)

        g_norm = optax.global_norm(updates)
        updates = jax.tree_map(lambda g: g / g_norm, updates)
        updates = jax.tree_util.tree_map(add_noise_fn, updates, noise)
        return updates, optax.AddNoiseState(count=count_inc, rng_key=all_keys[0])

    return optax.GradientTransformation(init_fn, update_fn)
