'''
This script is a slight modification of GP inference of "neural tangents" library.
'''
from neural_tangents.predict import *
from neural_tangents.predict import _get_cho_solve, _get_axes, _get_first, _get_attr

class Gaussian(NamedTuple):
    """A `(mean, covariance)` convenience namedtuple."""
    mean: np.ndarray
    covariance: np.ndarray
    k_inv_y: np.ndarray

def gp_inference(
    k_train_train,
    y_train: np.ndarray,
    diag_reg: float = 0.,
    diag_reg_absolute_scale: bool = False,
    trace_axes: Axes = (-1,)):
  r"""Compute the mean and variance of the 'posterior' of NNGP/NTK/NTKGP.
  NNGP - the exact posterior of an infinitely wide Bayesian NN. NTK - exact
  distribution of an infinite ensemble of infinitely wide NNs trained with
  gradient flow for infinite time. NTKGP - posterior of a GP (Gaussian process)
  with the NTK covariance (see https://arxiv.org/abs/2007.05864 for how this
  can correspond to infinite ensembles of infinitely wide NNs as well).
  Note that first invocation of the returned `predict_fn` will be slow and
  allocate a lot of memory for its whole lifetime, as a Cholesky factorization
  of `k_train_train.nngp` or `k_train_train.ntk` (or both) is performed and
  cached for future invocations.
  Args:
    k_train_train:
      train-train kernel. Can be (a) `np.ndarray`, (b) `Kernel` namedtuple, (c)
      `Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels
      for arguments provided to the returned `predict_fn` function. For
      example, if you request to compute posterior test [only] NTK covariance in
      future `predict_fn` invocations, `k_train_train` must contain both `ntk`
      and `nngp` kernels.
    y_train:
      train targets.
    diag_reg:
      a scalar representing the strength of the diagonal regularization for
      `k_train_train`, i.e. computing `k_train_train + diag_reg * I` during
      Cholesky factorization.
    diag_reg_absolute_scale:
      `True` for `diag_reg` to represent regularization in absolute units,
      `False` to be `diag_reg * np.mean(np.trace(k_train_train))`.
    trace_axes:
      `f(x_train)` axes such that `k_train_train`,
      `k_test_train`[, and `k_test_test`] lack these pairs of dimensions and
      are to be interpreted as :math:`\Theta \otimes I`, i.e. block-diagonal
      along `trace_axes`. These can can be specified either to save space and
      compute, or to even improve approximation accuracy of the infinite-width
      or infinite-samples limit, since in in these limits the covariance along
      channel / feature / logit axes indeed converges to a  constant-diagonal
      matrix. However, if you target linearized dynamics of a specific
      finite-width network, `trace_axes=()` will yield most accurate result.
  Returns:
    A function of signature `predict_fn(get, k_test_train, k_test_test)`
    computing 'posterior' Gaussian distribution (mean or mean and covariance)
    on a given test set.
  """
  even, odd, first, last = _get_axes(_get_first(k_train_train))
  trace_axes = utils.canonicalize_axis(trace_axes, y_train)

  @lru_cache(2)
  def solve(g: str):
    k_dd = _get_attr(k_train_train, g)
    return _get_cho_solve(k_dd, diag_reg, diag_reg_absolute_scale)

  @lru_cache(2)
  def k_inv_y(g: str):
    return solve(g)(y_train, trace_axes)

  @utils.get_namedtuple('Gaussians')
  def predict_fn(get: Optional[Get] = None,
                 k_test_train=None,
                 k_test_test=None
                 ) -> Dict[str, Union[np.ndarray, Gaussian]]:
    """`test`-set posterior given respective covariance matrices.
    Args:
      get:
        string, the mode of the Gaussian process, either "nngp", "ntk", "ntkgp",
        (see https://arxiv.org/abs/2007.05864) or a tuple, or `None`. If `None`
        then both `nngp` and `ntk` predictions are returned.
      k_test_train:
        test-train kernel. Can be (a) `np.ndarray`, (b) `Kernel` namedtuple, (c)
        `Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels
        for arguments provided to the returned `predict_fn` function. For
        example, if you request to compute posterior test [only] NTK covariance,
        `k_test_train` must contain both `ntk` and `nngp` kernels. If `None`,
        returns predictions on the training set. Note that train-set outputs are
        always `N(y_train, 0)` and mostly returned for API consistency.
      k_test_test:
        test-test kernel. Can be (a) `np.ndarray`, (b) `Kernel` namedtuple, (c)
        `Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels
        for arguments provided to the returned `predict_fn` function. Provide
        if you want to compute test-test posterior covariance.
        `k_test_test=None` means to not compute it. If `k_test_train is None`,
        pass any non-`None` value (e.g. `True`) if you want to get
        non-regularized (`diag_reg=0`) train-train posterior covariance. Note
        that non-regularized train-set outputs will always be the zero-variance
        Gaussian `N(y_train, 0)` and mostly returned for API consistency. For
        regularized train-set posterior outputs according to a positive
        `diag_reg`, pass `k_test_train=k_train_train`, and, optionally,
        `k_test_test=nngp_train_train`.
    Returns:
      Either a `Gaussian('mean', 'variance')` namedtuple or `mean` of the GP
      posterior on the  `test` set.
    """
    if get is None:
      get = ('nngp', 'ntk')

    out = {}

    for g in get:
      k = g if g != 'ntkgp' else 'ntk'
      k_dd = _get_attr(k_train_train, k)
      k_td = None if k_test_train is None else _get_attr(k_test_train, k)

      if k_td is None:
        # Train set predictions.
        y = y_train.astype(k_dd.dtype)
      else:
        _k_inv_y = k_inv_y(k)
        # Test set predictions.
        y = np.tensordot(k_td, _k_inv_y, (odd, first))
        y = np.moveaxis(y, range(-len(trace_axes), 0), trace_axes)

      if k_test_test is not None:
        if k_td is None:
          out[g] = Gaussian(y, np.zeros_like(k_dd, k_dd.dtype))
        else:
          if (g == 'ntk' and
              (not hasattr(k_train_train, 'nngp') or
               not hasattr(k_test_train, 'nngp'))):
            raise ValueError(
                'If `"ntk" in get`, and `k_test_test is not None`, '
                'and `k_test_train is not None`, i.e. you request the '
                'NTK posterior covariance on the test set, you need '
                'both NTK and NNGP train-train and test-train matrices '
                'contained in `k_test_train` and `k_train_train`. '
                'Hence they must be `namedtuple`s with `nngp` and '
                '`ntk` attributes.')

          #  kernel of wide NN at initialization
          g_init = 'nngp' if g != 'ntkgp' else 'ntk'

          k_td_g_inv_y = solve(k)(_get_attr(k_test_train, g_init), even)
          k_tt = _get_attr(k_test_test, g_init)

          if g == 'nngp' or g == 'ntkgp':
            cov = np.tensordot(k_td, k_td_g_inv_y, (odd, first))
            cov = k_tt - utils.zip_axes(cov)
            out[g] = Gaussian(y, cov, _k_inv_y)

          elif g == 'ntk':
            term_1 = solve(g)(k_td, even)
            cov = np.tensordot(_get_attr(k_train_train, 'nngp'), term_1,
                               (odd, first))
            cov = np.tensordot(term_1, cov, (first, first))

            term_2 = np.tensordot(k_td, k_td_g_inv_y, (odd, first))
            term_2 += np.moveaxis(term_2, first, last)
            cov = utils.zip_axes(cov - term_2) + k_tt
            out[g] = Gaussian(y, cov, _k_inv_y)

          else:
            raise ValueError(g)

      else:
        out[g] = y

    return out

  return predict_fn
