from functools import partial, reduce

import jax
import pyt
from jax import jacrev, jit
from jax import numpy as np
from jax import random, vmap


def _init_layer(in_dim, hidden_dim, key, W_scale, B_scale, W_init, B_init):
    w_key, b_key = random.split(key)
    w = W_scale * W_init(w_key, (hidden_dim, in_dim))
    b = B_scale * B_init(b_key, (hidden_dim,))
    return w, b


def MLP_init(sizes, key, W_scale, B_scale, W_init=random.normal, B_init=random.normal):
    keys = random.split(key, num=len(sizes))
    return pyt.Params(
        [
            _init_layer(i, o, k, W_scale, B_scale, W_init, B_init)
            for i, o, k in zip(sizes, sizes[1:], keys)
        ]
    )


def _predict(params, x):
    x = x.reshape(-1)
    for w, b in params.params[:-1]:
        x = jax.nn.relu(np.dot(w, x) + b)
    w, b = params.params[-1]
    x = np.dot(w, x) + b
    return x


MLP_predict = jit(
    vmap(_predict, in_axes=(None, 0, None), out_axes=0), static_argnums=(2,)
)
jac_predict = jacrev(_predict)


@partial(jit, static_argnums=(3,))
def single_NTK(params, x, y, jac):
    j = jac(params, x)
    k = jac(params, y)
    return pyt.vdot(j, k)


batched_x_NTK = jit(
    vmap(single_NTK, in_axes=[None, 0, None, None]), static_argnums=(3,)
)
batched_NTK = jit(
    vmap(batched_x_NTK, in_axes=[None, None, 0, None]), static_argnums=(3,)
)


@partial(jit, static_argnums=(2,))
def NTK_fn(params, x, jac):
    return batched_NTK(params, x, x, jac)


@jit
def NTK(params, x):
    return NTK_fn(params, x, jac_predict)
