# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Compute empirical NNGP and NTK; approximate functions via Taylor series.

All functions in this module are applicable to any JAX functions of proper
signatures (not only those from :obj:`~neural_tangents.stax`).

NNGP and NTK are computed using :obj:`~neural_tangents.empirical_nngp_fn`,
:obj:`~neural_tangents.empirical_ntk_fn`, or
:obj:`~neural_tangents.empirical_kernel_fn` (for both). The kernels have a very
specific output shape convention that may be unexpected. Further, NTK has
multiple implementations that may perform differently depending on the task.
Please read individual functions' docstrings.

For details, please see "`Fast Finite Width Neural Tangent Kernel
<https://arxiv.org/abs/2206.08720>`_".

Example:
  >>> from jax import random
  >>> import neural_tangents as nt
  >>> from neural_tangents import stax
  >>> #
  >>> key1, key2, key3 = random.split(random.PRNGKey(1), 3)
  >>> x_train = random.normal(key1, (20, 32, 32, 3))
  >>> y_train = random.uniform(key1, (20, 10))
  >>> x_test = random.normal(key2, (5, 32, 32, 3))
  >>> #
  >>> # A narrow CNN.
  >>> init_fn, f, _ = stax.serial(
  >>>     stax.Conv(32, (3, 3)),
  >>>     stax.Relu(),
  >>>     stax.Conv(32, (3, 3)),
  >>>     stax.Relu(),
  >>>     stax.Conv(32, (3, 3)),
  >>>     stax.Flatten(),
  >>>     stax.Dense(10)
  >>> )
  >>> #
  >>> _, params = init_fn(key3, x_train.shape)
  >>> #
  >>> # Default setting: reducing over logits; pass `vmap_axes=0` because the
  >>> # network is iid along the batch axis, no BatchNorm. Use default
  >>> # `implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION` (`1`).
  >>> kernel_fn = nt.empirical_kernel_fn(
  >>>     f, trace_axes=(-1,), vmap_axes=0,
  >>>     implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION)
  >>> #
  >>> # (5, 20) jnp.ndarray test-train NNGP/NTK
  >>> nngp_test_train = kernel_fn(x_test, x_train, 'nngp', params)
  >>> ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params)
  >>> #
  >>> # Full kernel: not reducing over logits. Use structured derivatives
  >>> # `implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES` (`3`) for
  >>> # typically faster computation and lower memory cost.
  >>> kernel_fn = nt.empirical_kernel_fn(
  >>>     f, trace_axes=(), vmap_axes=0,
  >>>     implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES)
  >>> #
  >>> # (5, 20, 10, 10) jnp.ndarray test-train NNGP/NTK namedtuple.
  >>> k_test_train = kernel_fn(x_test, x_train, None, params)
  >>> #
  >>> # A wide FCN with lots of parameters and many (`100`) outputs.
  >>> init_fn, f, _ = stax.serial(
  >>>     stax.Flatten(),
  >>>     stax.Dense(1024),
  >>>     stax.Relu(),
  >>>     stax.Dense(1024),
  >>>     stax.Relu(),
  >>>     stax.Dense(100)
  >>> )
  >>> #
  >>> _, params = init_fn(key3, x_train.shape)
  >>> #
  >>> # Use ntk-vector products
  >>> # (`implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS`) since the
  >>> # network has many parameters relative to the cost of forward pass,
  >>> # large outputs.
  >>> ntk_fn = nt.empirical_ntk_fn(
  >>>     f, vmap_axes=0,
  >>>     implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS)
  >>> #
  >>> # (5, 5) jnp.ndarray test-test NTK
  >>> ntk_test_test = ntk_fn(x_test, None, params)
  >>> #
  >>> # Compute only output variances:
  >>> nngp_fn = nt.empirical_nngp_fn(f, diagonal_axes=(0,))
  >>> #
  >>> # (20,) jnp.ndarray train-train diagonal NNGP
  >>> nngp_train_train_diag = nngp_fn(x_train, None, params)
"""

import enum
import functools
import operator
from typing import Callable, Iterable, KeysView, Optional, TypeVar, Union
import warnings

import jax
from jax import core
from jax import eval_shape
from jax import jacobian
from jax import jvp
from jax import lax
from jax import linear_transpose
from jax import vjp
from jax import vmap

from jax.core import Jaxpr
from jax.core import JaxprEqn
from jax.core import Literal
from jax.core import ShapedArray
from jax.core import Value
from jax.core import Var

from jax.extend import linear_util as lu

from jax.interpreters import ad
from jax.interpreters.ad import UndefinedPrimal
from jax.interpreters.ad import Zero

import jax.numpy as jnp

from jax.tree_util import tree_flatten
from jax.tree_util import tree_map
from jax.tree_util import tree_reduce
from jax.tree_util import tree_structure
from jax.tree_util import tree_transpose
from jax.tree_util import tree_unflatten
from jax.util import safe_map as map
from jax.util import safe_zip as zip

import numpy as np

from .utils import rules
from .utils import utils
from .utils.typing import ApplyFn
from .utils.typing import Axes
from .utils.typing import EmpiricalGetKernelFn
from .utils.typing import EmpiricalKernelFn
from .utils.typing import PyTree
from .utils.typing import VMapAxes
from .utils.typing import VMapAxisTriple


# LINEARIZATION AND TAYLOR EXPANSION


def linearize(f: ApplyFn, params: PyTree) -> ApplyFn:
  """Returns a function `f_lin`, the first order taylor approximation to `f`.

  Example:
    >>> # Compute the MSE of the first order Taylor series of a function.
    >>> f_lin = linearize(f, params)
    >>> mse = jnp.mean((f(new_params, x) - f_lin(new_params, x)) ** 2)

  Args:
    f:
      A function that we would like to linearize. It should have the signature
      `f(params, *args, **kwargs)` where `params` is a `PyTree` and `f` should
      return a `PyTree`.

    params:
      Initial parameters to the function that we would like to take the
      Taylor series about. This can be any structure that is compatible with the
      JAX tree operations.

  Returns:
    A function `f_lin(new_params, *args, **kwargs)` whose signature is the same
    as f. Here `f_lin` implements the first-order taylor series of `f` about
    `params`.
  """
  def f_lin(p, *args, **kwargs):
    dparams = _sub(p, params)
    f_params_x, proj = jvp(lambda param: f(param, *args, **kwargs),
                           (params,), (dparams,))
    return _add(f_params_x, proj)
  return f_lin


def taylor_expand(f: ApplyFn, params: PyTree, degree: int) -> ApplyFn:
  """Returns a function `f_tayl`, Taylor approximation to `f` of order `degree`.

  Example:
    >>> # Compute the MSE of the third order Taylor series of a function.
    >>> f_tayl = taylor_expand(f, params, 3)
    >>> mse = jnp.mean((f(new_params, x) - f_tayl(new_params, x)) ** 2)

  Args:
    f:
      A function that we would like to Taylor expand. It should have the
      signature `f(params, *args, **kwargs)` where `params` is a `PyTree`, and
      `f` returns a `PyTree`.

    params:
      Initial parameters to the function that we would like to take the Taylor
      series about. This can be any structure that is compatible with the JAX
      tree operations.

    degree:
      The degree of the Taylor expansion.

  Returns:
    A function `f_tayl(new_params, *args, **kwargs)` whose signature is the
    same as `f`. Here `f_tayl` implements the `degree`-order taylor series of
    `f` about `params`.
  """
  def taylorize_r(f, params, dparams, degree, current_degree):
    """Recursive function to accumulate contributions to the Taylor series."""
    if current_degree == degree:
      return f(params)

    def f_jvp(p):
      _, val_jvp = jvp(f, (p,), (dparams,))
      return val_jvp

    df = taylorize_r(f_jvp, params, dparams, degree, current_degree + 1)
    return _add(f(params), _div(df, (current_degree + 1)))

  def f_tayl(p, *args, **kwargs):
    dparams = _sub(p, params)
    return taylorize_r(lambda param: f(param, *args, **kwargs),
                       params, dparams, degree, 0)

  return f_tayl


# NNGP


def empirical_nngp_fn(
    f: ApplyFn,
    trace_axes: Axes = (-1,),
    diagonal_axes: Axes = ()
) -> EmpiricalKernelFn:
  """Returns a function to draw a single sample the NNGP of a given network `f`.

  The Neural Network Gaussian Process (NNGP) kernel is defined as
  :math:`f(X_1) f(X_2)^T`, i.e. the outer product of the function outputs.

  .. warning::
    Resulting kernel shape is *nearly* `zip(f(x1).shape, f(x2).shape)`
    subject to `trace_axes` and `diagonal_axes` parameters, which make certain
    assumptions about the outputs `f(x)` that may only be true in the infinite
    width / infinite number of samples limit, or may not apply to your
    architecture. For most precise results in the context of linearized training
    dynamics of a specific finite-width network, set both `trace_axes=()` and
    `diagonal_axes=()` to obtain the kernel exactly of shape
    `zip(f(x1).shape, f(x2).shape)`.

  For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal
  the empirical kernels will have terms measuring the covariance between the
  outputs. Here, we ignore these cross-terms and consider each output
  separately. Please raise an issue if this feature is important to you.

  Args:
    f:
      the function whose NNGP we are computing. It should have the signature
      `f(params, x, **kwargs)` where `params` is a `PyTree`, `x` is a `PyTree`,
      and `f` should also return a `PyTree`.

    trace_axes:
      output axes to trace the output kernel over, i.e. compute only the trace
      of the covariance along the respective pair of axes (one pair for each
      axis in `trace_axes`). This allows to save space and compute if you are
      only interested in the respective trace, but also improve approximation
      accuracy if you know that covariance along these pairs of axes converges
      to a `constant * identity matrix` in the limit of interest (e.g.
      infinite width or infinite `n_samples`). A common use case is the channel
      / feature / logit axis, since activation slices along such axis are i.i.d.
      and the respective covariance along the respective pair of axes indeed
      converges to a constant-diagonal matrix in the infinite width or infinite
      `n_samples` limit.
      Also related to "contracting dimensions" in XLA terms.
      (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

    diagonal_axes:
      output axes to diagonalize the output kernel over, i.e. compute only the
      diagonal of the covariance along the respective pair of axes (one pair for
      each axis in `diagonal_axes`). This allows to save space and compute, if
      off-diagonal values along these axes are not needed, but also improve
      approximation accuracy if their limiting value is known theoretically,
      e.g. if they vanish in the limit of interest (e.g. infinite
      width or infinite `n_samples`). If you further know that on-diagonal
      values converge to the same constant in your limit of interest, you should
      specify these axes in `trace_axes` instead, to save even more compute and
      gain even more accuracy. A common use case is computing the variance
      (instead of covariance) along certain axes.
      Also related to "batch dimensions" in XLA terms.
      (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  Returns:
     A function to draw a single sample the NNGP of a given network `f`.
  """
  def nngp_fn(
      x1: PyTree,
      x2: Optional[PyTree],
      params: PyTree,
      **apply_fn_kwargs
  ) -> PyTree:
    """Computes a single sample of the empirical NNGP.

    Args:
      x1:
        first batch of inputs.

      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.

      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.

      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or
        same (if `x1==x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical NNGP. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """

    def output(x, **kwargs):
      return f(params, x, **kwargs)

    kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2)

    out1 = output(x1, **kwargs1)
    out2 = output(x2, **kwargs2) if not utils.all_none(x2) else out1

    def contract(out1: jnp.ndarray, out2: jnp.ndarray) -> jnp.ndarray:
      dot = _dot_general(out1, out2, trace_axes, diagonal_axes)
      return dot / utils.size_at(out1, trace_axes)

    return tree_map(contract, out1, out2)

  return nngp_fn


# NTK


class NtkImplementation(enum.IntEnum):
  """Implementation method of the underlying finite width NTK computation.

  Below is a very brief summary of each method. For details, please see "`Fast
  Finite Width Neural Tangent Kernel <https://arxiv.org/abs/2206.08720>`_".

  Attributes:
    AUTO:
      (or `0`) evaluates FLOPs of all other methods at compilation time,
      and selects the fastest method. However, at the time it only works
      correctly on TPUs, and on CPU/GPU can return wrong results, which is why
      it is not the default. TODO(romann): revisit based on http://b/202218145.

    JACOBIAN_CONTRACTION:
      (or `1`) computes the NTK as the outer product of two Jacobians, each
      computed using reverse-mode Autodiff (vector-Jacobian products, VJPs).
      When JITted, the contraction is performed in a layerwise fashion, so that
      entire Jacobians aren't necessarily instantiated in memory at once, and
      the memory usage of the method can be lower than memory needed to
      instantiate the two Jacobians. This method is best suited for networks
      with small outputs (such as scalar outputs for binary classification or
      regression, as opposed to 1000 ImageNet classes), and an expensive
      forward pass relative to the number of parameters (such as CNNs, where
      forward pass reuses a small filter bank many times). It is also the the
      most reliable method, since its implementation is simplest, and
      reverse-mode Autodiff is most commonly used and well tested elsewhere.
      For this reason it is set as the default.

    NTK_VECTOR_PRODUCTS:
      (or `2`) computes the NTK as a sequence of NTK-vector products, similarly
      to how a Jacobian is computed as a sequence of Jacobian-vector products
      (JVPs) or vector-Jacobian products (VJPs). This amounts to using both
      forward (JVPs) and reverse (VJPs) mode Autodiff, and allows to eliminate
      the Jacobian contraction at the expense of additional forward passes.
      Therefore this method is recommended for networks with a cheap forward
      pass relative to the number of parameters (e.g. fully-connected networks,
      where each parameter matrix is used only once in the forward pass), and
      networks with large outputs (e.g. 1000 ImageNet classes). Memory
      requirements of this method are same as :attr:`JACOBIAN_CONTRACTION`
      (`1`). Due to reliance of forward-mode Autodiff, this method is slightly
      more prone to JAX and XLA bugs than :attr:`JACOBIAN_CONTRACTION` (`1`),
      but overall is quite simple and reliable.

    STRUCTURED_DERIVATIVES:
      (or `3`) uses a custom JAX interpreter to compute the NTK more
      efficiently than other methods. It traverses the computational graph of a
      function in the same order as during reverse-mode Autodiff, but instead
      of computing VJPs, it directly computes MJJMPs,
      "matrix-Jacobian-Jacobian-matrix" products, which arise in the
      computation of an NTK. Each MJJMP computation relies on the structure in
      the Jacobians, hence the name. This method can be dramatically faster
      (up to several orders of magnitude) then other methods on fully-connected
      networks, and is usually faster or equivalent on CNNs, Transformers, and
      other architectures, but exact speedup (e.g. from no speedup to 10X)
      depends on each specific setting. It can also use less memory than other
      methods. In our experience it consistently outperforms other methods in
      most settings. However, its implementation is significantly more complex
      (hence bug-prone), and it doesn't yet support functions using more exotic
      JAX primitives (e.g. :obj:`jax.checkpoint`, parallel collectives such as
      :obj:`jax.lax.psum`, compiled loops like :obj:`jax.lax.scan`, etc.), which
      is why it is highly-recommended to try, but not set as the default yet.
  """
  AUTO = 0
  JACOBIAN_CONTRACTION = 1
  NTK_VECTOR_PRODUCTS = 2
  STRUCTURED_DERIVATIVES = 3


DEFAULT_NTK_IMPLEMENTATION = NtkImplementation.JACOBIAN_CONTRACTION
"""Default user-facing empirical NTK implementation.

We default to `JACOBIAN_CONTRACTION` since it's the most straightforward and
reliable method, virtually guaranteed to compute the correct result.
"""


_DEFAULT_TESTING_NTK_IMPLEMENTATION = NtkImplementation.STRUCTURED_DERIVATIVES
"""Default empirical NTK implementation used in `tests`.

We default to `STRUCTURED_DERIVATIVES` since it is the fastest but also most
complex method, hence benefiting from additional testing against infinite-width
results.
"""


_DEFAULT_NTK_J_RULES: bool = True
"""Says whether to use custom Jacobian rules in `STRUCTURED_DERIVATIVES` (`3`).

Useful for debugging and testing. Theoretically should be set to `True`, but if
some Jacobian rule is implemented suboptimally, trying out `False` could improve
performance.
"""


_DEFAULT_NTK_S_RULES: bool = True
"""Says whether to use structure rules in `STRUCTURED_DERIVATIVES` (`3`).

Useful for debugging and testing. In practice should be set to `True`, and
setting it to `False` can lead to dramatic deterioration of performance.
"""


_DEFAULT_NTK_FWD: Optional[bool] = None
"""Says whether to use forward mode in `STRUCTURED_DERIVATIVES` (`3`) Jacobians.

Useful for debugging and testing, but for best performance should be set to
`None`, i.e. to selecting forward or reverse mode AD automatically based on
input/output sizes.
"""


def _empirical_auto_ntk_fn(**kwargs) -> EmpiricalGetKernelFn:
  """Compute NTK by automatically selecting the best implementation.

  Returns wrong FLOPS on CPU and GPU when JITting.

  TODO(romann): revisit based on http://b/202218145.
  """
  cache = {}

  def ntk_fn(
      x1: PyTree,
      x2: Optional[PyTree],
      params: PyTree,
      **apply_fn_kwargs
  ) -> jnp.ndarray:
    """Computes a single sample of the automatic empirical NTK.

    Args:
      x1:
        first batch of inputs.

      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.

      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.

      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or
        same (if `x1==x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
    shapes = tree_map(jnp.shape, (x1, x2, params, apply_fn_kwargs))
    shapes = _to_tuple_tree(shapes)

    if shapes not in cache:
      best_ntk_fn = None
      best_flops = np.inf
      for implementation in NtkImplementation:
        if implementation != NtkImplementation.AUTO:
          ntk_fn = empirical_ntk_fn(**kwargs, implementation=implementation)
          flops = _get_flops(ntk_fn, True, x1, x2, params, **apply_fn_kwargs)
          print(f'impl={implementation}, flops={flops}')
          if flops < best_flops:
            best_flops = flops
            best_ntk_fn = ntk_fn

      if best_ntk_fn is None:
        raise ValueError('This should not happen.')
      cache[shapes] = best_ntk_fn

    return cache[shapes](x1, x2, params, **apply_fn_kwargs)

  return ntk_fn


def _jacobian_contraction_ntk_fn(
    f: ApplyFn,
    trace_axes: Axes,
    diagonal_axes: Axes,
    vmap_axes: VMapAxes,
    **kwargs
) -> EmpiricalKernelFn:
  """Compute NTK by directly instantiating Jacobians and contracting."""

  def sum_and_contract(fx, j1, j2):
    ndim = fx.ndim
    size = utils.size_at(fx, trace_axes)

    _diagonal_axes = utils.canonicalize_axis(diagonal_axes, ndim)
    _trace_axes = utils.canonicalize_axis(trace_axes, ndim)

    def contract(x, y):
      param_axes = list(range(x.ndim))[ndim:]
      contract_axes = _trace_axes + param_axes
      return _dot_general(x, y, contract_axes, _diagonal_axes) / size

    return tree_reduce(operator.add, tree_map(contract, j1, j2))

  def ntk_fn(
      x1: PyTree,
      x2: Optional[PyTree],
      params: PyTree,
      **apply_fn_kwargs
  ) -> jnp.ndarray:
    """Computes a single sample of the empirical NTK (jacobian outer product).

    Args:
      x1:
        first batch of inputs.

      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.

      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.

      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or
        same (if `x1==x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
    args1, args2, fx1, fx2, fx_axis, keys, kw_axes, x_axis = _get_args(
        f, apply_fn_kwargs, params, vmap_axes, x1, x2)

    def j_fn(x, *args):
      _kwargs = {k: v for k, v in zip(keys, args)}
      fx = _get_f_params(f, x, x_axis, fx_axis, kw_axes, **_kwargs)
      jx = jacobian(fx)(params)
      return jx

    if not utils.all_none(x_axis) or not utils.all_none(kw_axes):
      in_axes = [x_axis] + [kw_axes[k] if k in kw_axes else None for k in keys]
      j_fn = vmap(j_fn, in_axes=in_axes, out_axes=fx_axis)

    j1 = j_fn(x1, *args1)
    j2 = j_fn(x2, *args2) if not utils.all_none(x2) else j1
    ntk = tree_map(sum_and_contract, fx1, j1, j2)
    return ntk

  return ntk_fn


def _ntk_vector_products_ntk_fn(
    f: ApplyFn,
    trace_axes: Axes,
    diagonal_axes: Axes,
    vmap_axes: VMapAxes,
    **kwargs
) -> EmpiricalKernelFn:
  """Compute NTK via NTK-vector products."""

  def ntk_fn(
      x1: PyTree,
      x2: Optional[PyTree],
      params: PyTree,
      **apply_fn_kwargs
  ) -> jnp.ndarray:
    """Computes a single sample of the empirical NTK with NTK-vector products.

    Args:
      x1:
        first batch of inputs.

      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.

      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.

      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1 != x2`)
        or same (if `x1 == x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
    args1, args2, fx1, fx2, fx_axis, keys, kw_axes, x_axis = _get_args(
        f, apply_fn_kwargs, params, vmap_axes, x1, x2)

    def get_ntk(x1, x2, *args):
      f1, f2 = _get_f1_f2(f, keys, x_axis, fx_axis, kw_axes, args, x1, x2)

      def delta_vjp_jvp(delta):
        def delta_vjp(delta):
          return vjp(f2, params)[1](delta)
        return jvp(f1, (params,), delta_vjp(delta))[1]

      fx1, fx2 = eval_shape(f1, params), eval_shape(f2, params)
      eye = _std_basis(fx1)
      ntk = vmap(linear_transpose(delta_vjp_jvp, fx2))(eye)
      ntk = tree_map(lambda fx12: _unravel_array_into_pytree(fx1, 0, fx12), ntk)
      ntk = _diagonal(ntk, fx1)
      return ntk

    if not utils.all_none(x_axis) or not utils.all_none(kw_axes):
      x2 = x1 if utils.all_none(x2) else x2

      kw_in_axes = [kw_axes[k] if k in kw_axes else None for k in keys]
      in_axes1 = [x_axis, None] + kw_in_axes + [None] * len(kw_in_axes)
      in_axes2 = [None, x_axis] + [None] * len(kw_in_axes) + kw_in_axes

      get_ntk = vmap(vmap(get_ntk,
                          in_axes1,
                          fx_axis),
                     in_axes2,
                     _add(fx_axis, _ndim(fx1)))

    ntk = get_ntk(x1, x2, *args1, *args2)
    ntk = tree_map(lambda x: _trace_and_diagonal(x, trace_axes, diagonal_axes),
                   ntk)
    return ntk

  return ntk_fn


def _structured_derivatives_ntk_fn(
    f: ApplyFn,
    trace_axes: Axes,
    diagonal_axes: Axes,
    vmap_axes: VMapAxes,
    _j_rules: bool,
    _s_rules: bool,
    _fwd: Optional[bool]
) -> EmpiricalKernelFn:
  """Compute NTK by using structured derivatives."""

  def sum_and_contract(
      fx1: jnp.ndarray,
      fx2: jnp.ndarray,
      fx_axis,
      df_dys_1: list[Union[jnp.ndarray, Zero]],
      df_dys_2: list[Union[jnp.ndarray, Zero]],
      dy_dws_1: list[tuple[jnp.ndarray, rules.Structure]],
      dy_dws_2: list[tuple[jnp.ndarray, rules.Structure]],
      dtype: jnp.dtype
  ):
    ndim = fx1.ndim
    size = utils.size_at(fx1, trace_axes)

    _diagonal_axes = utils.canonicalize_axis(diagonal_axes, ndim)
    _trace_axes = utils.canonicalize_axis(trace_axes, ndim)

    def contract(df_dys_1, df_dys_2, dy_dws_1, dy_dws_2):
      ntk = jnp.zeros((), dtype=dtype)

      for df_dy_1, dy_dw_1_ in zip(df_dys_1, dy_dws_1):
        for df_dy_2, dy_dw_2_ in zip(df_dys_2, dy_dws_2):

          dy_dw_1: jnp.ndarray
          s1: rules.Structure
          dy_dw_1, s1 = dy_dw_1_

          dy_dw_2: jnp.ndarray
          s2: rules.Structure
          dy_dw_2, s2 = dy_dw_2_

          if isinstance(dy_dw_1, Zero) or isinstance(dy_dw_2, Zero):
            continue

          df_dy_dims_1, df_dy_dims_2, out_dims = _get_dims(
              df_dy_1,
              df_dy_2,
              ndim,
              _trace_axes,
              _diagonal_axes
          )

          if len(s1.out_trace) != len(s2.out_trace):
            raise NotImplementedError('Different number of trace_axes 1/2.')

          for i, (id_1, id_2) in enumerate(zip(s1.out_trace, s2.out_trace)):
            axis_id = df_dy_1.ndim + df_dy_2.ndim + i
            y_axis_1 = id_1 % (df_dy_1.ndim - ndim)
            y_axis_2 = id_2 % (df_dy_2.ndim - ndim)
            df_dy_dims_1[ndim + y_axis_1] = axis_id
            df_dy_dims_2[ndim + y_axis_2] = axis_id

          dy_dw_dims_1 = list(range(-dy_dw_1.ndim, 0))
          dy_dw_dims_2 = list(range(-dy_dw_2.ndim, 0))

          if fx_axis is not None:
            df_dy_1 = jnp.moveaxis(df_dy_1, 0, fx_axis)
            df_dy_2 = jnp.moveaxis(df_dy_2, 0, fx_axis)

            dy_dw_dims_1[0] = df_dy_dims_1[fx_axis]
            dy_dw_dims_2[0] = df_dy_dims_2[fx_axis]
            ix_1, ix_2 = 1, 1

          else:
            ix_1, ix_2 = 0, 0

          if len(s1.out_diagonal) != len(s2.out_diagonal):
            raise NotImplementedError('Different number of diagonal_axes 1/2.')

          for i, (id_1, id_2) in enumerate(zip(s1.out_diagonal,
                                               s2.out_diagonal)):
            # TODO(romann): compute based on array dimensions.
            axis_shift = -100_000  # Huge axis shift to ensure unique axis ids.

            axis_id = (-axis_shift -df_dy_1.ndim - df_dy_2.ndim - dy_dw_1.ndim
                       - dy_dw_2.ndim - i)

            df_dy_dims_1[ndim + id_1] = axis_id
            dy_dw_dims_1[ix_1 + id_1] = axis_id

            df_dy_dims_2[ndim + id_2] = axis_id
            dy_dw_dims_2[ix_2 + id_2] = axis_id

          for i in range(ndim, df_dy_1.ndim):
            if i - ndim not in (s1.out_trace +
                                s1.out_diagonal +
                                s1.out_broadcast):
              dy_dw_dims_1[ix_1] = df_dy_dims_1[i]
            ix_1 += 1

          for i in range(ndim, df_dy_2.ndim):
            if i - ndim not in (s2.out_trace +
                                s2.out_diagonal +
                                s2.out_broadcast):
              dy_dw_dims_2[ix_2] = df_dy_dims_2[i]
            ix_2 += 1

          _check_einsum_no_broadcast(
              arrays=[df_dy_1, dy_dw_1, dy_dw_2, df_dy_2],
              dims=[df_dy_dims_1, dy_dw_dims_1, dy_dw_dims_2, df_dy_dims_2]
          )

          ntk_l = jnp.einsum(
              df_dy_1, df_dy_dims_1,
              dy_dw_1, dy_dw_dims_1,
              dy_dw_2, dy_dw_dims_2,
              df_dy_2, df_dy_dims_2,
              out_dims
          )
          ntk += ntk_l

      return ntk

    ntk = tree_reduce(
        operator.add,
        tree_map(
            contract,
            df_dys_1, df_dys_2, dy_dws_1, dy_dws_2,
            is_leaf=
            lambda x: (x == [] or (isinstance(x, list) and
                                   isinstance(x[0], jnp.ndarray)))),
        jnp.zeros((), dtype)
    )
    ntk /= size
    ntk_shape = _ntk_shape(fx1.shape, fx2.shape, trace_axes, diagonal_axes)
    ntk = jnp.broadcast_to(ntk, ntk_shape)  # if ntk is 0.
    return ntk

  def ntk_fn(
      x1: PyTree,
      x2: Optional[PyTree],
      params: PyTree,
      **apply_fn_kwargs
  ) -> jnp.ndarray:
    """Computes a single sample of the structured derivatives NTK.

    Args:
      x1:
        first batch of inputs.

      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.

      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.

      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or
        same (if `x1==x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
    args1, args2, fx1, fx2, fx_axis, keys, kw_axes, x_axis = _get_args(
        f, apply_fn_kwargs, params, vmap_axes, x1, x2)

    def j_fn(x, *args):
      _kwargs = {k: v for k, v in zip(keys, args)}
      fx = _get_f_params(f, x, x_axis, fx_axis, kw_axes, **_kwargs)
      df_dys, dy_dws = _get_df_dys_and_dy_dws(fn=fx, params=params,
                                              _j_rules=_j_rules,
                                              _s_rules=_s_rules, _fwd=_fwd)
      return df_dys, dy_dws

    if not utils.all_none(x_axis) or not utils.all_none(kw_axes):
      in_axes = [x_axis] + [kw_axes[k] if k in kw_axes else None for k in keys]
      j_fn = vmap(j_fn, in_axes=in_axes, out_axes=0)

    df_dys_1, dy_dws_1 = j_fn(x1, *args1)
    df_dys_2, dy_dws_2 = j_fn(x2, *args2) if not utils.all_none(x2) else (
        df_dys_1, dy_dws_1)

    fx_axis, dtype = _get_fx_axis_and_dtype(fx1, fx_axis, params)
    ntk = tree_map(
        functools.partial(
            sum_and_contract,
            dy_dws_1=dy_dws_1,
            dy_dws_2=dy_dws_2,
            dtype=dtype),
        fx1,
        fx2,
        fx_axis,
        df_dys_1,
        df_dys_2,
    )

    return ntk

  return ntk_fn


_implementation_to_ntk_fn = {
    NtkImplementation.AUTO: _empirical_auto_ntk_fn,
    NtkImplementation.JACOBIAN_CONTRACTION: _jacobian_contraction_ntk_fn,
    NtkImplementation.NTK_VECTOR_PRODUCTS: _ntk_vector_products_ntk_fn,
    NtkImplementation.STRUCTURED_DERIVATIVES: _structured_derivatives_ntk_fn,
}


def empirical_ntk_fn(
    f: ApplyFn,
    trace_axes: Axes = (-1,),
    diagonal_axes: Axes = (),
    vmap_axes: VMapAxes = None,
    implementation: Union[NtkImplementation, int] = DEFAULT_NTK_IMPLEMENTATION,
    _j_rules: bool = _DEFAULT_NTK_J_RULES,
    _s_rules: bool = _DEFAULT_NTK_S_RULES,
    _fwd: Optional[bool] = _DEFAULT_NTK_FWD,
) -> EmpiricalKernelFn:
  r"""Returns a function to draw a single sample the NTK of a given network `f`.

  The Neural Tangent Kernel is defined as :math:`J(X_1) J(X_2)^T` where
  :math:`J` is the Jacobian :math:`df/dparams` of shape
  `full_output_shape + params.shape`.

  For best performance:
  1) pass `x2=None` if `x1 == x2;
  2) prefer square batches (i.e `x1.shape == x2.shape`);
  3) make sure to set `vmap_axes` correctly.
  4) try different `implementation` values.

  .. warning::
    Resulting kernel shape is *nearly* `zip(f(x1).shape, f(x2).shape)`
    subject to `trace_axes` and `diagonal_axes` parameters, which make certain
    assumptions about the outputs `f(x)` that may only be true in the infinite
    width / infinite number of samples limit, or may not apply to your
    architecture. For most precise results in the context of linearized training
    dynamics of a specific finite-width network, set both `trace_axes=()` and
    `diagonal_axes=()` to obtain the kernel exactly of shape
    `zip(f(x1).shape, f(x2).shape)`.

  For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal
  the empirical kernels will have terms measuring the covariance between the
  outputs. Here, we ignore these cross-terms and consider each output
  separately. Please raise an issue if this feature is important to you.

  Args:
    f:
      the function whose NTK we are computing. It should have the signature
      `f(params, x, **kwargs)` where `params` is a `PyTree`, `x` is a `PyTree`,
      and `f` should also return a `PyTree`.

    trace_axes:
      output axes to trace the output kernel over, i.e. compute only the trace
      of the covariance along the respective pair of axes (one pair for each
      axis in `trace_axes`). This allows to save space and compute if you are
      only interested in the respective trace, but also improve approximation
      accuracy if you know that covariance along these pairs of axes converges
      to a `constant * identity matrix` in the limit of interest (e.g.
      infinite width or infinite `n_samples`). A common use case is the channel
      / feature / logit axis, since activation slices along such axis are i.i.d.
      and the respective covariance along the respective pair of axes indeed
      converges to a constant-diagonal matrix in the infinite width or infinite
      `n_samples` limit.
      Also related to "contracting dimensions" in XLA terms.
      (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

    diagonal_axes:
      output axes to diagonalize the output kernel over, i.e. compute only the
      diagonal of the covariance along the respective pair of axes (one pair for
      each axis in `diagonal_axes`). This allows to save space and compute, if
      off-diagonal values along these axes are not needed, but also improve
      approximation accuracy if their limiting value is known theoretically,
      e.g. if they vanish in the limit of interest (e.g. infinite
      width or infinite `n_samples`). If you further know that on-diagonal
      values converge to the same constant in your limit of interest, you should
      specify these axes in `trace_axes` instead, to save even more compute and
      gain even more accuracy. A common use case is computing the variance
      (instead of covariance) along certain axes.
      Also related to "batch dimensions" in XLA terms.
      (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

    vmap_axes:
      A triple of `(in_axes, out_axes, kwargs_axes)`
      passed to `vmap` to evaluate the empirical NTK in parallel ove these axes.
      Precisely, providing this argument implies that `f(params, x, **kwargs)`
      equals to a concatenation along `out_axes` of `f` applied to slices of
      `x` and `**kwargs` along `in_axes` and `kwargs_axes`. In other words, it
      certifies that `f` can be evaluated as a `vmap` with `out_axes=out_axes`
      over `x` (along `in_axes`) and those arguments in `**kwargs` that are
      present in `kwargs_axes.keys()` (along `kwargs_axes.values()`).

      For example if `_, f, _ = nt.stax.Aggregate()`, `f` is called via
      `f(params, x, pattern=pattern)`. By default, inputs `x`, patterns
      `pattern`, and outputs of `f` are all batched along the leading `0`
      dimension, and each output `f(params, x, pattern=pattern)[i]` only
      depends on the inputs `x[i]` and `pattern[i]`. In this case, we can
      pass `vmap_axes=(0, 0, dict(pattern=0)` to specify along which dimensions
      inputs, outputs, and keyword arguments are batched respectively.

      This allows us to evaluate Jacobians much more
      efficiently. If `vmap_axes` is not a triple, it is interpreted as
      `in_axes = out_axes = vmap_axes, kwargs_axes = {}`. For example a very
      common use case is `vmap_axes=0` for a neural network with leading (`0`)
      batch dimension, both for inputs and outputs, and no interactions between
      different elements of the batch (e.g. no BatchNorm, and, in the case of
      `nt.stax`, also no Dropout). However, if there is interaction between
      batch elements or no concept of a batch axis at all, `vmap_axes` must be
      set to `None`, to avoid wrong (and potentially silent) results.

    implementation:
      An :class:`NtkImplementation` value (or an :class:`int`  `0`, `1`, `2`,
      or `3`). See the :class:`NtkImplementation` docstring for details.

    _j_rules:
      Internal debugging parameter, applicable only when
      `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES`
      (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow
      custom Jacobian rules for intermediary primitive `dy/dw` computations for
      MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set to `False` to use
      JVPs or VJPs, via JAX's :obj:`jax.jacfwd` or :obj:`jax.jacrev`. Custom
      Jacobian rules (`True`) are expected to be not worse, and sometimes better
      than automated alternatives, but in case of a suboptimal implementation
      setting it to `False` could improve performance.

    _s_rules:
      Internal debugging parameter, applicable only when
      `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES`
      (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow
      efficient MJJMp rules for structured `dy/dw` primitive Jacobians. In
      practice should be set to `True`, and setting it to `False` can lead to
      dramatic deterioration of performance.

    _fwd:
      Internal debugging parameter, applicable only when
      `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES`
      (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow
      :obj:`jax.jvp` in intermediary primitive Jacobian `dy/dw` computations,
      `False` to always use :obj:`jax.vjp`. `None` to decide automatically
      based on input/output sizes. Applicable when `_j_rules=False`, or when a
      primitive does not have a Jacobian rule. Should be set to `None` for best
      performance.

  Returns:
    A function `ntk_fn` that computes the empirical ntk.
  """
  return _implementation_to_ntk_fn[implementation](
      f=f,
      trace_axes=trace_axes,
      diagonal_axes=diagonal_axes,
      vmap_axes=vmap_axes,
      _j_rules=_j_rules,
      _s_rules=_s_rules,
      _fwd=_fwd
  )


# JOINT NNGP/NTK KERNEL FUNCTION


def empirical_kernel_fn(
    f: ApplyFn,
    trace_axes: Axes = (-1,),
    diagonal_axes: Axes = (),
    vmap_axes: VMapAxes = None,
    implementation: Union[NtkImplementation, int] = DEFAULT_NTK_IMPLEMENTATION,
    _j_rules: bool = _DEFAULT_NTK_J_RULES,
    _s_rules: bool = _DEFAULT_NTK_S_RULES,
    _fwd: Optional[bool] = _DEFAULT_NTK_FWD,
) -> EmpiricalGetKernelFn:
  r"""Returns a function that computes single draws from NNGP and NT kernels.

  .. warning::
    Resulting kernel shape is *nearly* `zip(f(x1).shape, f(x2).shape)`
    subject to `trace_axes` and `diagonal_axes` parameters, which make certain
    assumptions about the outputs `f(x)` that may only be true in the infinite
    width / infinite number of samples limit, or may not apply to your
    architecture. For most precise results in the context of linearized training
    dynamics of a specific finite-width network, set both `trace_axes=()` and
    `diagonal_axes=()` to obtain the kernel exactly of shape
    `zip(f(x1).shape, f(x2).shape)`.

  For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal
  the empirical kernels will have terms measuring the covariance between the
  outputs. Here, we ignore these cross-terms and consider each output
  separately. Please raise an issue if this feature is important to you.

  Args:
    f:
      the function whose kernel(s) (NNGP and/or NTK) we are computing. It
      should have the signature `f(params, x, **kwargs)` where `params` is a
      `PyTree`, `x` is a  `PyTree`, and `f` should also return a `PyTree`.

    trace_axes:
      output axes to trace the output kernel over, i.e. compute only the trace
      of the covariance along the respective pair of axes (one pair for each
      axis in `trace_axes`). This allows to save space and compute if you are
      only interested in the respective trace, but also improve approximation
      accuracy if you know that covariance along these pairs of axes converges
      to a `constant * identity matrix` in the limit of interest (e.g.
      infinite width or infinite `n_samples`). A common use case is the channel
      / feature / logit axis, since activation slices along such axis are i.i.d.
      and the respective covariance along the respective pair of axes indeed
      converges to a constant-diagonal matrix in the infinite width or infinite
      `n_samples` limit.
      Also related to "contracting dimensions" in XLA terms.
      (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

    diagonal_axes:
      output axes to diagonalize the output kernel over, i.e. compute only the
      diagonal of the covariance along the respective pair of axes (one pair for
      each axis in `diagonal_axes`). This allows to save space and compute, if
      off-diagonal values along these axes are not needed, but also improve
      approximation accuracy if their limiting value is known theoretically,
      e.g. if they vanish in the limit of interest (e.g. infinite
      width or infinite `n_samples`). If you further know that on-diagonal
      values converge to the same constant in your limit of interest, you should
      specify these axes in `trace_axes` instead, to save even more compute and
      gain even more accuracy. A common use case is computing the variance
      (instead of covariance) along certain axes.
      Also related to "batch dimensions" in XLA terms.
      (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

    vmap_axes:
      applicable only to NTK.

      A triple of `(in_axes, out_axes, kwargs_axes)`
      passed to `vmap` to evaluate the empirical NTK in parallel ove these axes.
      Precisely, providing this argument implies that `f(params, x, **kwargs)`
      equals to a concatenation along `out_axes` of `f` applied to slices of
      `x` and `**kwargs` along `in_axes` and `kwargs_axes`. In other words, it
      certifies that `f` can be evaluated as a `vmap` with `out_axes=out_axes`
      over `x` (along `in_axes`) and those arguments in `**kwargs` that are
      present in `kwargs_axes.keys()` (along `kwargs_axes.values()`).

      For example if `_, f, _ = nt.stax.Aggregate()`, `f` is called via
      `f(params, x, pattern=pattern)`. By default, inputs `x`, patterns
      `pattern`, and outputs of `f` are all batched along the leading `0`
      dimension, and each output `f(params, x, pattern=pattern)[i]` only
      depends on the inputs `x[i]` and `pattern[i]`. In this case, we can
      pass `vmap_axes=(0, 0, dict(pattern=0)` to specify along which dimensions
      inputs, outputs, and keyword arguments are batched respectively.

      This allows us to evaluate Jacobians much more
      efficiently. If `vmap_axes` is not a triple, it is interpreted as
      `in_axes = out_axes = vmap_axes, kwargs_axes = {}`. For example a very
      common use case is `vmap_axes=0` for a neural network with leading (`0`)
      batch dimension, both for inputs and outputs, and no interactions between
      different elements of the batch (e.g. no BatchNorm, and, in the case of
      `nt.stax`, also no Dropout). However, if there is interaction between
      batch elements or no concept of a batch axis at all, `vmap_axes` must be
      set to `None`, to avoid wrong (and potentially silent) results.

    implementation:
      Applicable only to NTK, an :class:`NtkImplementation` value (or an
      :class:`int`  `0`, `1`, `2`, or `3`). See the :class:`NtkImplementation`
      docstring for details.

    _j_rules:
      Internal debugging parameter, applicable only to NTK when
      `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES`
      (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow
      custom Jacobian rules for intermediary primitive `dy/dw` computations for
      MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set to `False` to use
      JVPs or VJPs, via JAX's :obj:`jax.jacfwd` or :obj:`jax.jacrev`. Custom
      Jacobian rules (`True`) are expected to be not worse, and sometimes better
      than automated alternatives, but in case of a suboptimal implementation
      setting it to `False` could improve performance.

    _s_rules:
      Internal debugging parameter, applicable only to NTK when
      `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES`
      (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow
      efficient MJJMp rules for structured `dy/dw` primitive Jacobians. In
      practice should be set to `True`, and setting it to `False` can lead to
      dramatic deterioration of performance.

    _fwd:
      Internal debugging parameter, applicable only to NTK when
      `implementation` is :attr:`~NtkImplementation.STRUCTURED_DERIVATIVES`
      (`3`) or :attr:`~NtkImplementation.AUTO` (`0`). Set to `True` to allow
      :obj:`jax.jvp` in intermediary primitive Jacobian `dy/dw` computations,
      `False` to always use :obj:`jax.vjp`. `None` to decide automatically
      based on input/output sizes. Applicable when `_j_rules=False`, or when a
      primitive does not have a Jacobian rule. Should be set to `None` for best
      performance.

  Returns:
    A function to draw a single sample the NNGP and NTK empirical kernels of a
    given network `f`.
  """
  kwargs = dict(
      f=f,
      trace_axes=trace_axes,
      diagonal_axes=diagonal_axes
  )

  ntk_kwargs = dict(
      vmap_axes=vmap_axes,
      implementation=implementation,
      _j_rules=_j_rules,
      _s_rules=_s_rules,
      _fwd=_fwd,
  )

  kernel_fns = {
      'nngp': empirical_nngp_fn(**kwargs),
      'ntk': empirical_ntk_fn(**kwargs, **ntk_kwargs)
  }

  @utils.get_namedtuple('EmpiricalKernel')
  def kernel_fn(
      x1: PyTree,
      x2: Optional[PyTree],
      get: Union[None, str, tuple[str, ...]],
      params: PyTree,
      **apply_fn_kwargs
  ) -> PyTree:
    """Computes a single sample of the empirical kernel of type `get`.

    Args:
      x1:
        first batch of inputs.

      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.

      get:
        type of the empirical kernel. `get=None` means `get=("nngp", "ntk")`.
        Can be a string (`"nngp"`) or a tuple of strings (`("ntk", "nngp")`).

      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.

      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or
        same (if `x1==x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical kernel. The shape is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.

      If `get` is a string, returns the requested `jnp.ndarray`. If `get` is a
      tuple, returns an `EmpiricalKernel` namedtuple containing the
      requested information.
    """
    if get is None:
      get = ('nngp', 'ntk')

    out_dict = {g: kernel_fns[g](x1, x2, params, **apply_fn_kwargs)
                for g in get}
    out_dict = _dict_of_tree_to_tree_of_dict(out_dict, get)

    return out_dict

  return kernel_fn


# NTK-VECTOR PRODUCT FUNCTION


def empirical_ntk_vp_fn(
    f: ApplyFn,
    x1: PyTree,
    x2: Optional[PyTree],
    params: PyTree,
    **apply_fn_kwargs
) -> Callable[[PyTree], PyTree]:
  """Returns an NTK-vector product function.

  The function computes NTK-vector product without instantiating the NTK, and
  has the runtime equivalent to `(N1 + N2)` forward passes through `f`, and
  memory equivalent to evaluating a vector-Jacobian product of `f`.

  For details, please see section L of "`Fast Finite Width Neural Tangent Kernel
  <https://arxiv.org/abs/2206.08720>`_".

  Example:
    >>> from jax import random
    >>> import neural_tangents as nt
    >>> from neural_tangents import stax
    >>> #
    >>> k1, k2, k3, k4 = random.split(random.PRNGKey(1), 4)
    >>> x1 = random.normal(k1, (20, 32, 32, 3))
    >>> x2 = random.normal(k2, (10, 32, 32, 3))
    >>> #
    >>> # Define a forward-pass function `f`.
    >>> init_fn, f, _ = stax.serial(
    >>>     stax.Conv(32, (3, 3)),
    >>>     stax.Relu(),
    >>>     stax.Conv(32, (3, 3)),
    >>>     stax.Relu(),
    >>>     stax.Conv(32, (3, 3)),
    >>>     stax.Flatten(),
    >>>     stax.Dense(10)
    >>> )
    >>> #
    >>> # Initialize parameters.
    >>> _, params = init_fn(k3, x1.shape)
    >>> #
    >>> # NTK-vp function. Can/should be JITted.
    >>> ntk_vp_fn = empirical_ntk_vp_fn(f, x1, x2, params)
    >>> #
    >>> # Cotangent vector
    >>> cotangents = random.normal(k4, f(params, x2).shape)
    >>> #
    >>> # NTK-vp output
    >>> ntk_vp = ntk_vp_fn(cotangents)
    >>> #
    >>> # Output has same shape as `f(params, x1)`.
    >>> assert ntk_vp.shape == f(params, x1).shape

  Args:
    f:
      forward-pass function of signature `f(params, x)`.

    x1:
      first batch of inputs.

    x2:
      second batch of inputs. `x2=None` means `x2=x1`.

    params:
      A `PyTree` of parameters about which we would like to compute the neural
      tangent kernel.

    **apply_fn_kwargs:
      keyword arguments passed to `f`. `apply_fn_kwargs` will be split into
      `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs` function
      which will be passed to `f`. In particular, the rng key in
      `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or same
      (if `x1==x2`) rng keys. See the `_read_key` function for more details.

  Returns:
    An NTK-vector product function accepting a `PyTree` of cotangents of shape
    and structure of `f(params, x2)`, and returning the NTK-vector product of
    shape and structure of `f(params, x1)`.
  """
  args1, args2, fx1, fx2, fx_axis, keys, kw_axes, x_axis = _get_args(
      f, apply_fn_kwargs, params, None, x1, x2)

  f1, f2 = _get_f1_f2(f, keys, x_axis, fx_axis, kw_axes, args1 + args2, x1, x2)

  def ntk_vp_fn(cotangents: PyTree) -> PyTree:
    """Computes a single empirical NTK-vector product.

    Args:
      cotangents:
        a `PyTree` of cotangents. Must have the same shape and tree structure
        as `f(params, x2)`.

    Returns:
      A single NTK-vector product of shape and tree structure of
      `f(params, x1)`.
    """
    vjp_out = vjp(f2, params)[1](cotangents)
    jvp_out = jvp(f1, (params,), vjp_out)[1]
    return jvp_out

  return ntk_vp_fn


# INTERNAL UTILITIES


def _trace_and_diagonal(
    ntk: jnp.ndarray,
    trace_axes: Axes,
    diagonal_axes: Axes
) -> jnp.ndarray:
  """Extract traces and diagonals along respective pairs of axes from the `ntk`.

  Args:
    ntk:
      input empirical NTK of shape `(N1, X, Y, Z, ..., N2, X, Y, Z, ...)`.

    trace_axes:
      axes (among `X, Y, Z, ...`) to trace over, i.e. compute the trace along
      and remove the  respective pairs of axes from the `ntk`.

    diagonal_axes:
      axes (among `X, Y, Z, ...`) to take the diagonal along, i.e. extract the
      diagonal along the respective pairs of axes from the `ntk` (and hence
      reduce the resulting `ntk` axes count by 2).

  Returns:
    An array of shape, for example, `(N1, N2, Y, Z, Z, ...)` if
    `trace_axes=(1,)` (`X` axes removed), and `diagonal_axes=(2,)` (`Y` axes
    replaced with a single `Y` axis).
  """
  if ntk.ndim % 2 == 1:
    raise ValueError('Expected an even-dimensional kernel.')

  output_ndim = ntk.ndim // 2

  trace_axes = utils.canonicalize_axis(trace_axes, output_ndim)
  diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim)

  n_diag, n_trace = len(diagonal_axes), len(trace_axes)
  contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes)

  for i, c in enumerate(reversed(trace_axes)):
    ntk = jnp.trace(ntk, axis1=c, axis2=output_ndim + c - i)

  for i, d in enumerate(diagonal_axes):
    axis1 = d - i
    axis2 = output_ndim + d - 2 * i - n_trace
    for c in trace_axes:
      if c < d:
        axis1 -= 1
        axis2 -= 1
    ntk = jnp.diagonal(ntk, axis1=axis1, axis2=axis2)

  ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_diag)
  res_diagonal_axes = _get_res_batch_dims(trace_axes, diagonal_axes)
  ntk = jnp.moveaxis(ntk, range(-n_diag, 0), res_diagonal_axes)
  return ntk / contract_size


def _dict_of_tree_to_tree_of_dict(
    out_dict: dict[str, PyTree],
    get: tuple[str, ...]
) -> PyTree:
  # If the elements of an output dict are tuples then change the representation
  # to be a tuple of dicts instead. This occurs when the output of a network is
  # a parallel layer.
  return tree_map(lambda *x: dict((g, v) for g, v in zip(get, x)),
                  *[out_dict[g] for g in get])


def _get_f_params(
    f: Callable,
    x: PyTree,
    x_axis: PyTree,
    fx_axis: PyTree,
    kw_axes: dict[str, PyTree],
    **apply_fn_kwargs
) -> Callable[[PyTree], PyTree]:
  x = _expand_dims(x, x_axis)

  apply_fn_kwargs = {
      k: _expand_dims(v, kw_axes[k]) if k in kw_axes else v
      for k, v in apply_fn_kwargs.items()
  }

  def _f(p: PyTree) -> PyTree:
    fx = f(p, x, **apply_fn_kwargs)
    return _squeeze(fx, fx_axis)

  return _f


def _get_args(
    f: Callable,
    apply_fn_kwargs: dict[str, PyTree],
    params: PyTree,
    vmap_axes: VMapAxes,
    x1: PyTree,
    x2: PyTree
):
  kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2)

  fx1 = eval_shape(f, params, x1, **kwargs1)
  fx2 = fx1 if utils.all_none(x2) else eval_shape(f, params, x2, **kwargs2)

  x_axis, fx_axis, kw_axes = _canonicalize_axes(vmap_axes, x1, fx1, **kwargs1)

  keys = apply_fn_kwargs.keys()
  args1 = tuple(kwargs1[k] for k in keys)
  args2 = tuple(kwargs2[k] for k in keys)
  return args1, args2, fx1, fx2, fx_axis, keys, kw_axes, x_axis


def _get_f1_f2(
    f: Callable,
    keys: KeysView[str],
    x_axis: PyTree,
    fx_axis: PyTree,
    kw_axes: dict[str, PyTree],
    args: tuple,
    x1: PyTree,
    x2: Optional[PyTree]
) -> tuple[Callable[[PyTree], PyTree], Callable[[PyTree], PyTree]]:
  args1, args2 = args[:len(args) // 2], args[len(args) // 2:]
  _kwargs1 = {k: v for k, v in zip(keys, args1)}
  _kwargs2 = {k: v for k, v in zip(keys, args2)}
  f1 = _get_f_params(f, x1, x_axis, fx_axis, kw_axes, **_kwargs1)
  f2 = f1 if utils.all_none(x2) else _get_f_params(
      f, x2, x_axis, fx_axis, kw_axes, **_kwargs2)
  return f1, f2


_ArrayOrShape = TypeVar('_ArrayOrShape', jnp.ndarray, ShapedArray)


def _check_einsum_no_broadcast(
    arrays: list[jnp.ndarray],
    dims: list[list[int]]
):
  """Check that all matching einsum contracting axis sizes are equal.

  Einsum allows silent broadcasting, and this function helps ensure it doesn't
  happen.
  """
  for idx_1, (a1, dims_1) in enumerate(zip(arrays, dims)):
    if len(set(dims_1)) != len(dims_1):
      raise ValueError(f'Dimensions {idx_1} contain duplicate axes: '
                       f'{dims_1}.')

    for ax_1, dim_1 in enumerate(dims_1):
      sz_idx_1 = a1.shape[ax_1]
      for idx_2, (a2, dims_2) in enumerate(zip(arrays, dims)):
        if dim_1 in dims_2:
          ax_2 = dims_2.index(dim_1)
          sz_idx_2 = a2.shape[ax_2]
          if sz_idx_2 != sz_idx_1:
            raise ValueError(f'Arrays {idx_1} and {idx_2} mismatch '
                             f'sizes at {ax_1} and {ax_2}: '
                             f'{sz_idx_1} != {sz_idx_2}')


def _expand_dims_array(x: _ArrayOrShape, axis: int) -> _ArrayOrShape:
  def expand(x: jnp.ndarray) -> jnp.ndarray:
    return jnp.expand_dims(x, axis)

  if isinstance(x, ShapedArray):
    return eval_shape(expand, x)

  if isinstance(x, jnp.ndarray):
    return expand(x)

  raise TypeError(type(x), x)


def _expand_dims(
    x: Union[None, PyTree, UndefinedPrimal],
    axis: Optional[PyTree]
) -> Optional[PyTree]:
  if axis is None or x is None or isinstance(x, UndefinedPrimal):
    return x
  return tree_map(_expand_dims_array, x, axis)


def _add(x: Optional[PyTree], y: Optional[PyTree]) -> Optional[PyTree]:
  if x is None or y is None:
    return None
  return tree_map(operator.add, x, y)


def _sub(x: PyTree, y: PyTree) -> PyTree:
  return tree_map(operator.sub, x, y)


def _div(x: PyTree, y: int) -> PyTree:
  return tree_map(lambda x: x / y, x)


def _squeeze(x: PyTree, axis: Optional[PyTree]) -> PyTree:
  if axis is None:
    return x

  def squeeze(
      x: jnp.ndarray,
      axis: Union[None, int, tuple[int, ...]]
  ) -> jnp.ndarray:
    """`np.squeeze` analog working with 0-sized axes."""
    if isinstance(axis, int):
      axis = (axis,)

    non_zero_axes = tuple()
    shift = 0

    for a in sorted(axis):
      if x.shape[a - shift] == 0:
        new_shape = x.shape[:a] + x.shape[a + 1:]
        if utils.size_at(new_shape) == 0:
          x = x.reshape(new_shape)
        else:
          x = jnp.zeros(new_shape, x.dtype)

        shift += 1
      else:
        non_zero_axes += (a - shift,)

    return jnp.squeeze(x, non_zero_axes)

  return tree_map(squeeze, x, axis)


def _ndim(x: PyTree) -> PyTree:
  return tree_map(lambda x: x.ndim, x)


def _mod(
    x: Optional[PyTree],
    y: PyTree
) -> PyTree:
  if x is None:
    return None
  return tree_map(operator.mod, x, y)


def _diagonal(ntk: PyTree, fx: PyTree) -> PyTree:
  ntk_flat, _ = tree_flatten(ntk)
  fx_flat, fx_tree = tree_flatten(fx)
  n = len(fx_flat)
  diag = [ntk_flat[i * (n + 1)] for i in range(n)]
  return tree_unflatten(fx_tree, diag)


def _canonicalize_axes(
    vmap_axes: Optional[VMapAxes],
    x: PyTree,
    fx: PyTree,
    **kwargs
) -> VMapAxisTriple:
  if isinstance(vmap_axes, tuple) and len(vmap_axes) == 3:
    x_axis, fx_axis, kw_axes = vmap_axes
  else:
    x_axis, fx_axis, kw_axes = vmap_axes, vmap_axes, {}

  if isinstance(x_axis, int):
    x_axis = tree_map(lambda _: x_axis, x)

  if isinstance(fx_axis, int):
    fx_axis = tree_map(lambda _: fx_axis, fx)

  if isinstance(kw_axes, int):
    kw_axes = tree_map(lambda _: kw_axes, kwargs)

  x_axis = _mod(x_axis, _ndim(x))
  fx_axis = _mod(fx_axis, _ndim(fx))
  kw_axes = _mod(kw_axes, {k: _ndim(kwargs[k]) for k in kw_axes})
  return x_axis, fx_axis, kw_axes


def _to_tuple_tree(x: PyTree) -> tuple:
  """Replace all lists and dictionaries with tuples in a PyTree for hashing."""
  if isinstance(x, (tuple, list)):
    return tuple(_to_tuple_tree(x_i) for x_i in x)

  if isinstance(x, dict):
    return tuple((k, _to_tuple_tree(v)) for k, v in sorted(x.items()))

  return x


def _ntk_shape(fx1_shape, fx2_shape, trace_axes: Axes, diagonal_axes: Axes):
  ntk_shape = ()

  trace_axes = utils.canonicalize_axis(trace_axes, fx1_shape)
  diagonal_axes = utils.canonicalize_axis(diagonal_axes, fx1_shape)

  for i, (a1, a2) in enumerate(zip(fx1_shape, fx2_shape)):
    if i not in trace_axes:
      if i in diagonal_axes:
        assert a1 == a2
        ntk_shape += (a1,)
      else:
        ntk_shape += (a1, a2)
    else:
      assert a1 == a2
  return ntk_shape


def _get_dims(
    df_dy_1: jnp.ndarray,
    df_dy_2: jnp.ndarray,
    ndim: int,
    trace_axes: Axes,
    diagonal_axes: Axes
) -> tuple[list[int], list[int], list[int]]:
  df_dy_dims_1 = list(range(df_dy_1.ndim))
  df_dy_dims_2 = list(range(df_dy_1.ndim, df_dy_1.ndim + df_dy_2.ndim))

  out_dims = []

  for i in range(ndim):
    if i in trace_axes:
      assert df_dy_1.shape[i] == df_dy_2.shape[i]
      df_dy_dims_2[i] = df_dy_dims_1[i]

    elif i in diagonal_axes:
      assert df_dy_1.shape[i] == df_dy_2.shape[i]
      df_dy_dims_2[i] = df_dy_dims_1[i]
      out_dims += [df_dy_dims_1[i]]

    else:
      out_dims += [df_dy_dims_1[i], df_dy_dims_2[i]]

  return df_dy_dims_1, df_dy_dims_2, out_dims


def _is_abstract_array(x) -> bool:
  return isinstance(x, jnp.ndarray) or isinstance(
      getattr(x, 'aval', None), core.ShapedArray)


def _vmap(f: Callable, in_axes, out_axes, squeeze_out: bool = True) -> Callable:
  """An expand-then-squeeze `vmap` for `f` expecting/returning batch dims."""
  in_axes_plus_1 = tree_map(lambda x: x if x in (None, -1) else x + 1, in_axes)

  @utils.wraps(f)
  def f_vmapped(*args):
    args = tree_map(
        _expand_dims, args, in_axes_plus_1, is_leaf=_is_abstract_array)
    out = vmap(f, in_axes, out_axes)(*args)
    if squeeze_out:
      out_axes_plus_1 = tree_map(
          lambda x: x if x in (None, -1) else x + 1, out_axes)
      out = _squeeze(out, out_axes_plus_1)
    return out

  return f_vmapped


def _get_fx_axis_and_dtype(fx, fx_axis, params: PyTree):
  if fx_axis is None:
    fx_axis = tree_map(lambda x: None, fx)
  # Set the default type to be the least common type ancestor.
  dtypes, _ = tree_flatten(tree_map(jnp.dtype, params))
  if not dtypes:
    dtype = None
  else:
    dtype = functools.reduce(jnp.promote_types, dtypes)
  return fx_axis, dtype


def _unravel_dfs(dfs: PyTree, params: PyTree, y: PyTree) -> PyTree:
  dfs = tree_map(functools.partial(_unravel_array_into_pytree, y, 0), dfs)

  if tree_structure(dfs).num_leaves > 0:
    dfs = tree_transpose(tree_structure(tree_map(lambda x, y: [x] * len(y),
                                                 params,
                                                 dfs)),
                         tree_structure(y), dfs)

  if tree_structure(dfs).num_leaves == 0:
    dfs = tree_map(lambda x: dfs, y)
  return dfs


class _MODE(enum.Enum):
  """`F` - final output; `Y` - intermediary pre-activations; `W` - weights."""
  DF_DY = 'DF_DY'
  DY_DW = 'DY_DW'


def _get_df_dys_and_dy_dws(
    fn: Callable[[PyTree], PyTree],
    params: PyTree,
    _j_rules: bool,
    _s_rules: bool,
    _fwd: Optional[bool]
) -> tuple[PyTree, PyTree]:
  """Computes primitive output cotangents (`df/dy`) and Jacobians (`dy/dw`)."""
  def primals_out_and_pullback(mode: _MODE) -> PyTree:
    return _get_primals_out_and_pullback(fn, mode, _j_rules, _s_rules, _fwd,
                                         params)

  primals_out, pullback_df_dy = primals_out_and_pullback(_MODE.DF_DY)
  df_dys = vmap(pullback_df_dy)(_std_basis(primals_out))
  df_dys = _unravel_dfs(df_dys[0], params, primals_out)

  _, pullback_dy_dw = primals_out_and_pullback(_MODE.DY_DW)
  dy_dws = pullback_dy_dw(primals_out)  # values of `primals_out` don't matter.
  dy_dws = dy_dws[0]

  return df_dys, dy_dws


def _get_primals_out_and_pullback(
    fn: Callable[[PyTree], PyTree],
    mode: _MODE,
    _j_rules: bool,
    _s_rules: bool,
    _fwd: Optional[bool],
    *primals_in: PyTree
) -> tuple[PyTree, Callable]:
  """Adapted from `jax.interpreters.ad`.

  Return outputs of `fn` and the "pullback" function, which is similar to the
  regular pullback function (computing cotangents to `primals_in` given output
  cotangents), but collects and returns other quantities.
  """
  primals_in_flat, in_tree = tree_flatten(primals_in)
  fn_flat, out_tree = jax.api_util.flatten_fun_nokwargs(
      lu.wrap_init(fn), in_tree)

  # TODO(romann): handle call primitives more gracefully.
  with jax.disable_jit():
    outs = ad.linearize(fn_flat, *primals_in_flat, has_aux=False)

  primals_out, pvals, jaxpr, consts = outs
  primals_out = tree_unflatten(out_tree(), primals_out)

  def pullback_fn(*cts_in: PyTree):
    cts_in, _ = tree_flatten(cts_in)
    cts_in = tuple(ct for ct, pval in zip(cts_in, pvals) if not pval.is_known())
    dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars]
    cts_out = _backward_pass(jaxpr, mode=mode, consts=consts,
                             primals_in=dummy_args, cotangents_in=cts_in,
                             _j_rules=_j_rules, _s_rules=_s_rules, _fwd=_fwd)
    return tree_unflatten(in_tree, cts_out)

  return primals_out, pullback_fn


def _backward_pass(
    jaxpr: Jaxpr,
    mode: _MODE,
    consts: list[Value],
    primals_in: list[UndefinedPrimal],
    cotangents_in: tuple[jnp.ndarray, ...],
    _j_rules: bool,
    _s_rules: bool,
    _fwd: Optional[bool]
) -> Union[list[list[Union[jnp.ndarray, Zero]]],
           list[list[tuple[jnp.ndarray, rules.Structure]]]]:
  """Similar to and adapted from `jax.interpreters.ad.backward_pass`.

  Traverses the computational graph in the same order as the above, but collects
  and returns _not_ the cotangents wrt `jaxpr.invars`, but rather primitive
  output cotangents (`df/dy`) and Jacobians (`dy/dw`). Precisely:

    `mode=_MODE.DF_DY`: cotangents wrt outputs of equations where `jaxpr.invars`
    are inputs.

    `mode=_MODE.DY_DF`: Jacobians (of outputs wrt inputs that are within
    `jaxpr.invars`) of equations to which `jaxpr.invars` are inputs. Jacobians
    are accompanied by their `rules.Structure` metadata.

  The above are then efficiently contracted with each other elsewhere to compute
  the NTK.
  """

  def read_cotangent(v: Var) -> Union[jnp.ndarray, Zero]:
    return ct_env.pop(v, Zero(v.aval))

  primal_env: dict[Var, jnp.ndarray] = {}
  map(functools.partial(_write_primal, primal_env), jaxpr.constvars, consts)
  map(functools.partial(_write_primal, primal_env), jaxpr.invars, primals_in)

  ct_env: dict[Var, jnp.ndarray] = {}
  ctx = ad.source_info_util.transform_name_stack('transpose')
  with ctx:
    map(functools.partial(_write_cotangent, 'outvars', ct_env),
        jaxpr.outvars, cotangents_in)

    # List of `df_dy`s or `dy_dw`s for each variable in `jaxpr.invars`.
    outs = [[] for _ in jaxpr.invars]

    if mode == _MODE.DY_DW:
      invar_to_structure = rules.get_structure_cache(jaxpr, _s_rules=_s_rules)
      vars_needing_cts_in = set()
    elif mode == _MODE.DF_DY:
      vars_needing_cts_in = _get_vars_needing_cts_in(jaxpr)
    else:
      raise ValueError(f'Unrecognized mode {mode}.')

    for eqn in jaxpr.eqns[::-1]:
      # Do regular backprop.
      cts_in, invals = _backprop_step(
          eqn=eqn,
          primal_env=primal_env,
          ct_env=ct_env,
          read_cotangent=read_cotangent,
          do_write_cotangents=any(
              not isinstance(i, Literal) and i in vars_needing_cts_in
              for i in eqn.invars
          )
      )

      # Compute `df_dy`s or `dy_dw`s.
      for i_eqn, eq_invar in enumerate(eqn.invars):
        if eq_invar in jaxpr.invars:
          i_jaxpr = jaxpr.invars.index(eq_invar)
          inval = invals[i_eqn].aval

          if mode == _MODE.DF_DY:
            if not isinstance(cts_in, Zero):
              if eqn.primitive == lax.reshape_p:
                cts_in = cts_in.reshape(inval.shape)
              cts_in = cts_in.astype(inval.dtype)
            outs[i_jaxpr] += [cts_in]

          elif mode == _MODE.DY_DW:
            structure = rules.get_structure(
                eqn=eqn,
                invals=[v.aval for v in eqn.invars],
                idx=i_eqn,
                _s_rules=_s_rules
            )
            structure &= invar_to_structure[eq_invar]

            if eqn.primitive == lax.reshape_p:
              cts_in = ShapedArray(inval.shape, inval.dtype)
            elif hasattr(cts_in, 'aval'):
              cts_in = cts_in.aval

            trimmed_invals = _trim_invals(invals, structure)
            if not isinstance(cts_in, ShapedArray):
              raise TypeError(cts_in)
            trimmed_cts_in = _trim_cotangents(cts_in, structure)

            if _s_rules:
              eqn = _trim_eqn(eqn, i_eqn, trimmed_invals, trimmed_cts_in)

            def j_fn(invals):
              return _get_jacobian(eqn=eqn,
                                   cts_in=trimmed_cts_in,
                                   invals=invals,
                                   idx=i_eqn,
                                   _fwd=_fwd,
                                   _j_rules=_j_rules)

            for in_d, out_d in zip(structure.in_diagonal,
                                   structure.out_diagonal):
              in_axes = [
                  None
                  if isinstance(invals[ix], UndefinedPrimal)
                  else i
                  for ix, i in enumerate(in_d)]
              j_fn = _vmap(j_fn, in_axes=(in_axes,), out_axes=out_d)

            dy_dw = j_fn(trimmed_invals)
            outs[i_jaxpr] += [(dy_dw, structure)]

          else:
            raise ValueError(f'Unrecognized mode {mode}.')

  # If output contains any of `primals_in`, this "identity" primitive is not
  # present in `jaxpr.eqns`. Below we treat this case by passing `cotangents_in`
  # as `df_dy`, and an identity matrix as `dy_dw`.
  for i_in, v_out in enumerate(jaxpr.outvars):
    for i_eqn, v in enumerate(jaxpr.invars):
      if v == v_out:
        if mode == _MODE.DF_DY:
          if v in ct_env:
            df_dy = cotangents_in[i_in]
          else:
            df_dy = v.aval

          outs[i_eqn] += [df_dy]
          break

        elif mode == _MODE.DY_DW:
          # Identity function
          structure = rules.get_id_structure(v.aval, _s_rules)
          structure &= invar_to_structure[v]

          # Identity Jacobian
          trimmed_invals = _trim_invals([UndefinedPrimal(v.aval)], structure)
          if not isinstance(v.aval, ShapedArray):
            raise TypeError(v.aval)
          trimmed_cts_in = _trim_cotangents(v.aval, structure)
          dy_dw = _get_jacobian(
              eqn=None,
              cts_in=trimmed_cts_in,
              invals=trimmed_invals,
              idx=0,
              _j_rules=_j_rules,
              _fwd=_fwd,
          )
          outs[i_eqn] += [(dy_dw, structure)]

        else:
          raise ValueError(f'Unrecognized mode {mode}.')

  return outs


def _get_vars_needing_cts_in(jaxpr: Jaxpr) -> set[Var]:
  """Get a set of variables that need cotangents for structured derivatives.

  Specifically, returns variables which are outputs of equations to which
  `jaxpr.invars` are inputs. Cotangents `df/dy` to these variables are needed
  elsewhere to compute the NTK.
  """
  need_cts: set[Var] = set()

  def visit(vs: set[Var]):
    if len(vs) == 0:
      return

    next_visit = set()

    for e in jaxpr.eqns:
      if any(v in e.invars for v in vs):
        for o in e.outvars:
          if o not in need_cts:
            need_cts.add(o)
            next_visit.add(o)

    visit(next_visit)

  visit(set(jaxpr.invars))

  # `invars` don't need cotangents in `STRUCTURED_DERIVATIVES` mode.
  assert all(i not in need_cts for i in jaxpr.invars)
  return need_cts


def _backprop_step(
    eqn: JaxprEqn,
    primal_env: dict[Var, jnp.ndarray],
    ct_env: dict[Var, jnp.ndarray],
    read_cotangent: Callable[[Var], Union[jnp.ndarray, Zero]],
    do_write_cotangents: bool = True
) -> tuple[Union[jnp.ndarray, Zero], list[Union[jnp.ndarray, UndefinedPrimal]]]:
  """Adapted from `jax.interpreters.ad`."""
  invals = map(functools.partial(_read_primal, primal_env), eqn.invars)
  cts_in = map(read_cotangent, eqn.outvars)
  if len(cts_in) == 1:
    cts_in = cts_in[0]
  else:
    raise NotImplementedError(
        f'Primitives with multiple outputs are not supported. '
        f'Please file a bug at '
        f'https://github.com/google/neural-tangents/issues. '
        f'Got {len(eqn.outvars)} outputs for {eqn}, with input '
        f'cotangents {cts_in}.')

  if do_write_cotangents:
    cts_out = _eqn_vjp_fn(eqn, cts_in, *invals)
    cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
    map(functools.partial(_write_cotangent, eqn.primitive, ct_env),
        eqn.invars, cts_out)
  return cts_in, invals


def _trim_cotangents(
    cts_in: ShapedArray,
    structure: rules.Structure
) -> ShapedArray:
  cts_in = _trim_axis(
      cts_in,
      structure.out_trace + structure.out_broadcast + structure.out_diagonal)
  cts_in: ShapedArray
  return cts_in


def _trim_invals(
    invals: list[Union[jnp.ndarray, UndefinedPrimal]],
    structure: rules.Structure,
) -> list[Union[jnp.ndarray, UndefinedPrimal]]:
  trimmed_invals = list(invals)

  for i in structure.in_trace_idxs:
    trimmed_invals[i] = _trim_axis(trimmed_invals[i], structure.in_trace)

  for ax in structure.in_broadcast:
    trimmed_invals[structure.in_broadcast_idx] = _trim_axis(
        trimmed_invals[structure.in_broadcast_idx], ax)

  for ax in structure.out_broadcast:
    for i in structure.out_broadcast_idxs:
      trimmed_invals[i] = _trim_axis(trimmed_invals[i], ax)

  for i in range(len(trimmed_invals)):
    for in_d in sorted([axis[i] for axis in structure.in_diagonal
                        if axis[i] is not None],
                       reverse=True):
      if isinstance(trimmed_invals[i], UndefinedPrimal):
        trimmed_invals[i] = _trim_axis(trimmed_invals[i], in_d)

  return trimmed_invals  # pytype: disable=bad-return-type  # jax-ndarray


def _trim_eqn(
    eqn: JaxprEqn,
    idx: int,
    trimmed_invals: list[Union[jnp.ndarray, UndefinedPrimal]],
    trimmed_cts_in: ShapedArray
) -> JaxprEqn:
  if eqn.primitive in rules.EQN_PARAMS_RULES:
    # Copy the equation parameters to modify.
    trimmed_invals_e = [i.aval if isinstance(i, UndefinedPrimal) else i for i in
                        trimmed_invals]
    params = rules.EQN_PARAMS_RULES[eqn.primitive](
        params=dict(eqn.params),
        idx=idx,
        trimmed_invals=trimmed_invals_e,
        trimmed_cts_in=trimmed_cts_in
    )
    eqn = eqn.replace(params=params)

  return eqn


def _trim_axis(
    x: Union[UndefinedPrimal, ShapedArray, jnp.ndarray],
    axis: Union[int, tuple[int, ...]],
) -> Union[UndefinedPrimal, ShapedArray]:
  """Trim `axis` of `x` to be of length `1`. `x` is only used for shape."""
  if isinstance(axis, int):
    axis = (axis,)

  if isinstance(x, UndefinedPrimal):
    return UndefinedPrimal(_trim_axis(x.aval, axis))

  if isinstance(x, (ShapedArray, jnp.ndarray)):
    return ShapedArray([1 if i in axis else x.shape[i]
                        for i in range(x.ndim)], dtype=x.dtype)

  raise TypeError(type(x), x)


def _eqn_jvp_fn(
    eqn: Optional[JaxprEqn],
    idx: int,
    tangents: jnp.ndarray,
    *invals
) -> jnp.ndarray:
  """Perform a JVP for `eqn`."""
  if eqn is None:
    # Identity function
    return tangents

  new_tangents = []
  new_invals = []

  for i_dx, i in enumerate(invals):
    if i_dx == idx:
      inval = jnp.zeros(i.aval.shape, i.aval.dtype)
      tangent = tangents
    else:
      inval = i
      aval = i.aval if hasattr(i, 'aval') else ShapedArray(i.shape, i.dtype)
      tangent = Zero(aval)
      if isinstance(inval, (UndefinedPrimal, ShapedArray)):
        inval = jnp.zeros(aval.shape, aval.dtype)

    new_invals.append(inval)
    new_tangents.append(tangent)

  jvp_fn = ad.primitive_jvps[eqn.primitive]
  out = jvp_fn(new_invals, new_tangents, **eqn.params)[1]
  if isinstance(out, list) and len(out) == 1:
    return out[0]
  elif isinstance(out, jax.Array):
    return out
  raise TypeError(out, type(out))


def _eqn_vjp_fn(
    eqn: Optional[JaxprEqn],
    cts_in: jnp.ndarray,
    *invals
) -> tuple[jnp.ndarray, ...]:
  """Perform a VJP for `eqn`. Adapted from `jax.interpreters.ad`."""
  if eqn is None:
    # Identity function
    return cts_in,

  name_stack = (ad.source_info_util.current_name_stack() +
                eqn.source_info.name_stack)
  with ad.source_info_util.user_context(eqn.source_info.traceback,
                                        name_stack=name_stack):
    if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
      cts_in_avals = [v.aval for v in eqn.outvars]
      params = dict(eqn.params)
      call_jaxpr = params.pop('call_jaxpr')
      cts_out = ad.get_primitive_transpose(eqn.primitive)(
          params, call_jaxpr, invals, cts_in, cts_in_avals, ())
    elif eqn.primitive in ad.reducing_transposes:
      cts_out = ad.reducing_transposes[eqn.primitive](
          (), (cts_in,), *invals, **eqn.params)
    else:
      cts_out = ad.get_primitive_transpose(eqn.primitive)(cts_in, *invals,
                                                          **eqn.params)
  return cts_out


def _get_jacobian(
    eqn: Optional[JaxprEqn],
    cts_in: ShapedArray,
    invals: list[Union[jnp.ndarray, UndefinedPrimal]],
    idx: int,
    _j_rules: bool,
    _fwd: Optional[bool],
) -> Union[jnp.ndarray, Zero]:
  """Get the (structured) `eqn` output Jacobian wrt `eqn.invars[idx]`."""
  if eqn is None:
    primitive = None
  else:
    primitive = eqn.primitive

  inval_shape = invals[idx].aval.shape
  cts_in_shape = cts_in.shape
  dy_dw_shape = cts_in_shape + inval_shape

  if primitive not in rules.JACOBIAN_RULES:
    warnings.warn(f'No Jacobian rule found for {primitive}.')

  if primitive in rules.JACOBIAN_RULES and _j_rules:
    # Custom Jacobian rule.
    invals_j = [i.aval if isinstance(i, UndefinedPrimal) else i for i in invals]
    dy_dw = rules.JACOBIAN_RULES[primitive](eqn, idx, invals_j, cts_in)

  else:
    # Vanilla Jacobian evaluation.
    if _get_fwd(_fwd, cts_in_shape, inval_shape):  # pytype: disable=wrong-arg-types  # always-use-return-annotations
      # Forward mode.
      out_axes = -1
      inputs = invals[idx].aval
      def jac_fn(tangents):
        return _eqn_jvp_fn(eqn, idx, tangents, *invals)

    else:
      # Reverse mode.
      out_axes = 0
      inputs = cts_in
      def jac_fn(cotangents):
        return _eqn_vjp_fn(eqn, cotangents, *invals)[idx]

    eye = _std_basis(inputs)
    dy_dw = vmap(jac_fn, out_axes=out_axes)(eye)

    if isinstance(dy_dw, Zero):
      dy_dw = Zero(ShapedArray(dy_dw_shape, cts_in.dtype))
    else:
      dy_dw = dy_dw.reshape(dy_dw_shape)

  dy_dw_shape_ = dy_dw.aval.shape if isinstance(dy_dw, Zero) else dy_dw.shape  # pytype:disable=attribute-error
  assert dy_dw_shape_ == dy_dw_shape, (dy_dw_shape_, dy_dw_shape)
  return dy_dw


def _write_cotangent(
    prim: core.Primitive,
    ct_env: dict[Var, jnp.ndarray],
    v: Var,
    ct: Union[jnp.ndarray, Zero]
):
  """Adapted from `jax.interpreters.ad`."""
  assert ct is not Zero, (prim, v.aval)
  if ct is None or type(v) is Literal:
    return

  if type(ct) is Zero:
    return

  ct_env[v] = ad.add_tangents(ct_env[v], ct) if v in ct_env else ct
  if ad.config.jax_enable_checks:
    ct_aval = core.get_aval(ct_env[v])
    joined_aval = core.lattice_join(
        v.aval, ct_aval).strip_weak_type().strip_named_shape()
    assert v.aval.strip_weak_type().strip_named_shape() == joined_aval, (
        prim, v.aval, ct_aval)


def _read_primal(
    env: dict[Var, jnp.ndarray],
    v: Union[Var, Literal],
) -> Union[jnp.ndarray, UndefinedPrimal]:
  if type(v) is Literal:
    return v.val

  a = v.aval
  if type(a) is core.DShapedArray:
    shape = [env[d] if type(d) is core.Var else d for d in a.shape]
    a = a.update(shape=tuple(shape))
  return env.get(v, UndefinedPrimal(a))


def _write_primal(
    env: dict[Var, jnp.ndarray],
    v: Var,
    val: Union[jnp.ndarray, UndefinedPrimal]
):
  if not ad.is_undefined_primal(val):
    env[v] = val  # pytype: disable=container-type-mismatch  # jax-ndarray


def _get_fwd(
    _fwd: Optional[bool],
    cts_in_shape: tuple[int, ...],
    inval_shape: tuple[int, ...]
) -> bool:
  if _fwd is None:
    out_size = np.prod(cts_in_shape)
    in_size = np.prod(inval_shape)
    _fwd = out_size > in_size
  return _fwd


def _get_flops(f: Callable, optimize: bool, *a, **kw) -> float:
  e = jax.jit(f).lower(*a, **kw)
  if optimize:
    analysis = e.compile().cost_analysis()[0]
  else:
    analysis = e.cost_analysis()
  return analysis['flops']


def _std_basis(pytree: PyTree) -> PyTree:
  """Similar to `jax.api._std_basis` without host-side ops."""
  leaves, _ = tree_flatten(pytree)
  ndim = sum(map(jnp.size, leaves))
  dtype = jax.dtypes.result_type(*leaves)
  flat_basis = jnp.eye(ndim, dtype=dtype)
  return _unravel_array_into_pytree(pytree, 1, flat_basis)


def _unravel_array_into_pytree(
    pytree: PyTree,
    axis: int,
    arr: jnp.ndarray
) -> PyTree:
  """Similar to `jax.api._unravel_array_into_pytree` without host-side ops."""
  leaves, treedef = tree_flatten(pytree)
  if arr.ndim > 0:
    axis %= arr.ndim
  shapes = [arr.shape[:axis] + jnp.shape(l) + arr.shape[axis + 1:]
            for l in leaves]
  parts = jnp.split(arr, np.cumsum([jnp.size(l) for l in leaves[:-1]]), axis)
  reshaped_parts = [jnp.reshape(x, shape) for x, shape in zip(parts, shapes)]
  return tree_unflatten(treedef, reshaped_parts)


def _get_res_batch_dims(
    contracting_dims: Iterable[int],
    batch_dims: Iterable[int]
) -> list[int]:
  res_batch_dims = [2 * b - i for i, b in enumerate(batch_dims)]
  for i, b in enumerate(batch_dims):
    for c in contracting_dims:
      if b > c:
        res_batch_dims[i] -= 2
  return res_batch_dims


def _dot_general(
    lhs: jnp.ndarray,
    rhs: jnp.ndarray,
    contracting_dims: Axes,
    batch_dims: Axes,
    precision=None
) -> jnp.ndarray:
  """`jax.lax.dot_general` with preserved dims order and shared lhs / rhs dims.

  Precisely, returns `jax.lax.dot_general(lhs, rhs, dimension_numbers)` where
  `dimension_numbers == ((contracting_dims, contracting_dims),
                         (batch_dims, batch_dims))`,
  but preserves the dimension order in the output. See XLA's
   `DotGeneral<https://www.tensorflow.org/xla/operation_semantics#dotgeneral>`.

  Args:
    lhs: array.
    rhs: array, must have the same dimensionality as `lhs`.
    contracting_dims: contracting dimensions.
    batch_dims: batch dimensions.
    precision: Optional. Either `None`, which means the default precision for
      the backend, or a `Precision` enum value.

  Returns:
    Dot product result with preserved dimension order.
  """
  if lhs.ndim != rhs.ndim:
    raise ValueError(f'`lhs` and `rhs` must have the same dimensionality, got'
                     f'`lhs.ndim == {lhs.ndim}` and `rhs.ndim == {rhs.ndim}`.')

  contracting_dims = utils.canonicalize_axis(contracting_dims, lhs)
  batch_dims = utils.canonicalize_axis(batch_dims, lhs)

  n_batch_dims = len(batch_dims)
  leading_batch_dims = range(n_batch_dims)

  dimension_numbers = ((contracting_dims, contracting_dims),
                       (batch_dims, batch_dims))

  prod = lax.dot_general(lhs, rhs, dimension_numbers, precision)
  prod = utils.zip_axes(prod, n_batch_dims)

  res_batch_dims = _get_res_batch_dims(contracting_dims, batch_dims)
  prod = jnp.moveaxis(prod, leading_batch_dims, res_batch_dims)
  return prod
