import jax
import jax.numpy as np
from jax.scipy.linalg import expm
from neural_tangents import stax
from enum import Enum

jax.config.update("jax_enable_x64", True)

class SetupType(Enum):
  SINGLE_Q = 1
  WITH_TARGET_Q = 2

def nt_wide_get_kernels_fn(network_size, sigmaweights, sigmabiases):
  layers = []
  for size in network_size[1:-1]:
    layers += [
      stax.Dense(size, W_std=sigmaweights, b_std=sigmabiases),
      stax.Relu()
    ]
  layers += [stax.Dense(network_size[-1])]

  _, _, kernel_fn = stax.serial(*layers)

  def get_kernels(s1, a1, s2, a2): # TODO: check what to do with actions
    K, Th, *_ = kernel_fn(s1, s2).astuple()
    return K, Th

  return get_kernels


def nt_wide_predict_mean_cov(network_size, sigmaweights, sigmabiases, s, a, r, snew, done, anew, sstar, astar, gamma, t, lr, method=SetupType.SINGLE_Q, tau=None):
  get_kernels = nt_wide_get_kernels_fn(network_size, sigmaweights, sigmabiases)

  gamma_done = gamma*(1-done.T)

  K_XX, Th_XX = get_kernels(s, a, s, a)
  K_XXnew, Th_XXnew = get_kernels(s, a, snew, anew)
  K_XXstar, Th_XXstar = get_kernels(s, a, sstar, astar)
  K_XnewXnew, Th_XnewXnew = get_kernels(snew, anew, snew, anew)
  K_XnewXstar, Th_XnewXstar = get_kernels(snew, anew, sstar, astar)
  K_XstarXstar, _ = get_kernels(sstar, astar, sstar, astar)

  if method == SetupType.SINGLE_Q:
    Th = Th_XX - gamma_done * Th_XXnew - gamma_done.T * Th_XXnew.T + (gamma_done*gamma_done.T) * Th_XnewXnew

    Th_inv = np.linalg.pinv(Th)
    exponent = expm(-lr * t * Th, max_squarings=32)
    if t == np.inf:
      exponent = np.zeros_like(exponent)

    Z = (Th_XXstar.T - gamma_done * Th_XnewXstar.T) @ Th_inv @ (np.identity(Th.shape[0]) - exponent)
  elif method == SetupType.WITH_TARGET_Q:
    sizeD = Th_XX.shape[0]

    E = expm(np.block([[-lr * Th_XX, lr * gamma * Th_XXnew.T],
                       [tau * np.identity(sizeD), -tau * np.identity(sizeD)]]) * t, max_squarings=32)
    if t == np.inf:
      E = np.zeros_like(E)

    E_11 = E[:sizeD, :sizeD]
    E_21 = E[sizeD:, :sizeD]

    Z = Th_XXstar.T @ np.linalg.pinv(Th_XX - gamma_done * Th_XXnew.T) @ (np.identity(Th_XX.shape[0]) - E_11 - gamma_done * lr / tau * Th_XXnew.T @ E_21)

  ZZ = Z @ (K_XXstar - gamma_done.T * K_XnewXstar)

  mean = Z @ r
  cov = K_XstarXstar - ZZ - ZZ.T + Z @ (K_XX - gamma_done * K_XXnew - gamma_done.T * K_XXnew.T + gamma_done*gamma_done.T * K_XnewXnew) @ Z.T

  return mean, cov

def nt_wide_predict_lr(network_size, sigmaweights, sigmabiases, s, a, _, snew, done, anew, gamma, method=SetupType.SINGLE_Q):
  get_kernels = nt_wide_get_kernels_fn(network_size, sigmaweights, sigmabiases)

  a = np.zeros_like(s, dtype=np.int64) # TODO: fix this

  gamma_done = gamma*(1-done.T)

  _, Th_XX = get_kernels(s, a, s, a)
  _, Th_XXnew = get_kernels(s, a, snew, anew)

  if method == SetupType.SINGLE_Q:
    _, Th_XnewXnew = get_kernels(snew, anew, snew, anew)
    Th = Th_XX - gamma_done * Th_XXnew - gamma_done.T * Th_XXnew.T + (gamma_done*gamma_done.T) * Th_XnewXnew
  elif method == SetupType.WITH_TARGET_Q:
    Th = Th_XX - gamma_done * Th_XXnew.T

  eigvals, _ = np.linalg.eigh(Th)

  eig_min, eig_max = (eigvals.min(), eigvals.max())

  lr_critical = 2 / (eig_min + eig_max)

  return lr_critical.item()
