from neural_tangents import stax

def get_nn_ntk(network_size, sigmaweights, sigmabiases):
  layers = []
  for size in network_size[1:-1]:
    layers += [
      stax.Dense(size, W_std=sigmaweights, b_std=sigmabiases, parameterization="ntk"),
      stax.Relu()
    ]
  layers += [stax.Dense(network_size[-1])]

  init_fn, apply_fn, _ = stax.serial(*layers)

  return init_fn, apply_fn