import jax
import math


def dense_init(key, shape, dtype):
    stdv = 1.0 / math.sqrt(shape[0])
    weight = jax.nn.initializers.uniform(stdv)(key, shape, dtype)
    weight = 2 * weight - stdv
    return weight
