import jax
import jax.numpy as np
import jax.random as random
from jax.scipy.linalg import expm

from nn.ntwide import SetupType
from .approxwide import get_nn_ntk

jax.config.update("jax_enable_x64", True)

qnetwork_predict = None
qnetwork_grad = None

def wide_init(qparams_key, network_size, sigmaweights, sigmabiases, size_state):
  global qnetwork_predict, qnetwork_grad

  qnetwork_init, qnetwork_predict = get_nn_ntk(network_size, sigmaweights, sigmabiases)
  _, qparams = qnetwork_init(random.PRNGKey(qparams_key), (size_state,))

  qnetwork_grad = jax.vmap(jax.grad(lambda qparams, s, a: qnetwork_predict(qparams, s)[a][0], 0), (None, 0, 0))

  return qparams

def wide_mul_T(A, B):
  mul_independent = jax.tree_map(lambda u, v: u.reshape(u.shape[0], -1) @ v.reshape(v.shape[0], -1).T, A, B)
  mul_sum = np.sum(np.array(jax.tree_util.tree_flatten(mul_independent)[0]), axis=0)
  return mul_sum

def wide_mul_omega(delta, omega):
  result = jax.tree_map(lambda u: u.reshape(u.shape[0], -1).T @ omega, delta)
  return result

def wide_compute_base(qparams, s, a):
  qz = qnetwork_predict(qparams, s)
  qz = np.take_along_axis(qz, a, axis=1)

  gradqz = qnetwork_grad(qparams, s, a)

  return qz, gradqz

def wide_apply(qparams, s, a, r, snew, done, anew, sstar, astar, gamma, t, lr, method=SetupType.SINGLE_Q, tau=0, return_omega=False, verbose=0):
  qz, gradqz = wide_compute_base(qparams, s, a)
  qznew, gradqznew = wide_compute_base(qparams, snew, anew)
  qstar, gradqstar = wide_compute_base(qparams, sstar, astar)

  if t == 0:
    return qstar if not return_omega else (qstar, jax.tree_map(lambda a: np.zeros_like(a), qparams))

  qznew = np.where(done == 1., np.zeros_like(qznew), qznew)
  gradqznew = jax.tree_map(lambda x: np.where((done if len(x.shape) == 2 else np.expand_dims(done, axis=2)) == 1., np.zeros(x.shape), x), gradqznew)

  error = qz - r - gamma * qznew
  delta = jax.tree_map(lambda gradqz_i, gradqznew_i: gradqz_i - gamma * gradqznew_i, gradqz, gradqznew)

  if method == SetupType.SINGLE_Q:
    Theta = wide_mul_T(delta, delta)

    inv_term = np.linalg.inv(Theta)
    exp_term = expm(-lr * t * Theta, max_squarings=32)
    if t == np.inf:
      exp_term = np.zeros_like(exp_term)
    if verbose >= 1:
      print('exp_term', exp_term)
      print('inv_term', inv_term)
    omega_nodelta = - inv_term @ (np.identity(Theta.shape[0]) - exp_term) @ error

    qlin = qstar + wide_mul_T(gradqstar, delta) @ omega_nodelta
    if return_omega:
      return qlin, wide_mul_omega(delta, omega_nodelta)
    return qlin
  elif method == SetupType.WITH_TARGET_Q:
    U = gradqz
    V = gradqznew
    U_UT = wide_mul_T(U, U)
    V_UT = wide_mul_T(V, U)

    sizeD = qz.shape[0]

    inv_inner = np.linalg.pinv(U_UT - gamma * V_UT)
    E = expm(np.block([[-lr * U_UT, lr * gamma * V_UT],
                       [tau * np.identity(sizeD), -tau * np.identity(sizeD)]]) * t, max_squarings=32)
    if t == np.inf:
      E = np.zeros_like(E)

    E_11 = E[:sizeD, :sizeD]
    E_21 = E[sizeD:, :sizeD]

    omega_noU = inv_inner @ (E_11 - np.identity(sizeD) + gamma * lr / tau * V_UT @ E_21) @ error
    qlin = qstar + wide_mul_T(gradqstar, U) @ omega_noU

    if return_omega:
      return qlin, wide_mul_omega(U, omega_noU)

    return qlin

wide_apply_jit = jax.jit(wide_apply, static_argnums=(9, 10, 11, 12, 13, 14, 15))
