from jax import numpy as jnp
from flax import linen as nn
from flax.linen import initializers as nni

def muP(
        d: int,
        width: list,
        depth: int
):
    
    std = jnp.sqrt(2/width)
    sigma = nn.relu

    class muP(nn.Module):
        @nn.compact
        def __call__(self, x):
            for i in range(depth-1):
                x = nn.Dense(width, kernel_init=nni.normal(std), use_bias=True, name=f'linear_{i}')(x)
                mult = jnp.sqrt(width/d) if i == 0 else 1.
                x = sigma(mult*x)
            x = nn.Dense(1, kernel_init=nni.normal(std), use_bias=False, name="head")(x)
            return jnp.sqrt(1/width)*x 
        
    return muP()