import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import numpy as np
import jax.numpy as jnp
from jax.experimental.ode import odeint
from jax import vmap
import jax
from jax.nn.initializers import glorot_normal, normal, orthogonal


def lstm_forward(params, inputs, c_state, h_state):

    z = (
        jnp.dot(inputs, params["input_kernel"])
        + jnp.dot(h_state, params["recurrent_kernel"])
        + params["bias"]
    )
    i, ig, fg, og = jnp.split(z, 4, axis=-1)

    input_activation = jax.nn.tanh(i)
    input_gate = jax.nn.sigmoid(ig)
    forget_gate = jax.nn.sigmoid(fg + params["forget_gate_bias"])
    output_gate = jax.nn.sigmoid(og)

    new_cell = c_state * forget_gate + input_activation * input_gate
    output_state = jax.nn.tanh(new_cell) * output_gate

    return new_cell, output_state


def init_lstm_params(rng, in_size, cell_size):
    rng, r1, r2, r3 = jax.random.split(rng, 4)
    params = {
        "input_kernel": glorot_normal()(r1, (in_size, 4 * cell_size)),
        "recurrent_kernel": normal()(r2, (cell_size, 4 * cell_size)),
        "bias": jax.nn.initializers.zeros(r3, (4 * cell_size,)),
        "forget_gate_bias": 5,
    }
    return rng, params


def init_ode_rnn_params(rng, in_size, state_size, hidden_layers):
    layers = list(hidden_layers)  # deep copy
    layers.append(state_size)  # layer layer is state size
    layers.insert(0, state_size + in_size)  # first layer accepts state + input

    rng, *w_rng = jax.random.split(rng, num=len(layers))
    rng, *b_rng = jax.random.split(rng, num=len(layers))
    params = [
        (
            glorot_normal()(w_rng[i], (layers[i], layers[i + 1])),  # weights
            jax.nn.initializers.zeros(b_rng[i], (layers[i + 1],)),  # bias
        )
        for i in range(len(layers) - 1)
    ]
    return rng, params


def nn_forward(x, t, params, input):
    x = jnp.concatenate([x, input], axis=-1)
    for i, (w, b) in enumerate(params):
        x = jnp.dot(x, w) + b
        if i < len(params) - 1:
            x = jnp.tanh(x)
    return x


in_size = 8
state_size = 64
hidden_layers = [64, 64]
batch_size = 128


def ode_rnn_fn(params, inputs, state):
    def apply_fn(hidden, inp):
        t = jnp.array([0.0, 1.0])
        _, carry = odeint(nn_forward, hidden, t, params["ode_rnn"], inp)
        return carry, carry

    carry, preds = jax.lax.scan(apply_fn, state, inputs)
    return carry


def mm_rnn_fn(params, inputs, c_state, h_state):
    def apply_fn(hidden, inp):
        t = jnp.array([0.0, 1.0])
        c_hidden, h_hidden = lstm_forward(params["lstm"], inp, hidden[0], hidden[1])
        _, h_hidden = odeint(nn_forward, h_hidden, t, params["ode_rnn"], inp)
        return (c_hidden, h_hidden), h_hidden

    carry, preds = jax.lax.scan(apply_fn, (c_state, h_state), inputs)
    return carry[0]


ode_jac_fn = jax.jacrev(ode_rnn_fn, 2)
batched_ode_jac = vmap(ode_jac_fn, in_axes=(None, 0, 0))

batched_ode_jac = jax.jit(batched_ode_jac)

mm_jac_fn = jax.jacrev(mm_rnn_fn, 2)
batched_mm_jac = vmap(mm_jac_fn, in_axes=(None, 0, 0, 0))

batched_mm_jac = jax.jit(batched_mm_jac)


def get_mean_grad_norm(seq_len):

    ode_norms = []
    mm_norms = []
    for rng_key in range(3):  # aggregate over multiple initial seeds
        rng = jax.random.PRNGKey(rng_key)
        params = {}
        rng, params["ode_rnn"] = init_ode_rnn_params(
            rng, in_size, state_size, hidden_layers
        )
        rng, params["lstm"] = init_lstm_params(rng, in_size, state_size)
        x_train = jax.random.uniform(
            rng, shape=(batch_size, seq_len, in_size), minval=-1, maxval=1
        )
        init_state = jnp.zeros((batch_size, state_size))
        init_state_h = jnp.zeros((batch_size, state_size))

        ode_jac = batched_ode_jac(params, x_train, init_state)

        mm_jac = batched_mm_jac(params, x_train, init_state, init_state_h)

        # compute the norms of the gradient of each unit
        ode_norms.append(jnp.linalg.norm(ode_jac, axis=-1))
        mm_norms.append(jnp.linalg.norm(mm_jac, axis=-1))

    ode_norms = np.array(jnp.stack(ode_norms).flatten())
    mm_norms = np.array(jnp.stack(mm_norms).flatten())
    return ode_norms, mm_norms


seq_lens = [1, 5, 10, 25, 50]
(
    ode_norms,
    mm_norms,
) = zip(*[get_mean_grad_norm(s) for s in seq_lens])

for i in range(len(seq_lens)):
    print(f"### Seq len = {seq_lens[i]} ###")
    print(f"   odeRNN = {jnp.mean(ode_norms[i])} +- {jnp.std(ode_norms[i])}")
    print(f"   mmRNN  = {jnp.mean(mm_norms[i])} +- {jnp.std(mm_norms[i])}")
