# Copyright 2021 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Core Fast Attention Module for Flax.

This is a fork of:
https://github.com/google-research/google-research/blob/master/performer/fast_attention/jax/fast_attention.py
"""

# pylint: disable=invalid-name, missing-function-docstring, line-too-long

import abc
from collections.abc import Iterable  # pylint: disable=g-importing-member
import functools
from absl import logging
import gin
import jax
from jax import lax
from jax import random
import jax.numpy as jnp

import numpy as onp

# Nonlinear mappings encoding different attention kernels.
gin.external_configurable(jnp.cos, 'jcos')
gin.external_configurable(jnp.sin, 'jsin')
gin.external_configurable(jnp.tanh, 'jtanh')
gin.external_configurable(jax.nn.sigmoid, 'jsigmoid')
gin.external_configurable(
    lambda x: jax.nn.gelu(x, approximate=False), 'jgelu'
)  # Needs to be exact, although might be slower. See https://github.com/google/jax/issues/4428.
gin.external_configurable(lambda x: x * x * (x > 0.0), 'jrequ')
gin.external_configurable(jnp.exp, 'jexp')
gin.external_configurable(lambda x: x, 'jidentity')
gin.external_configurable(
    lambda x: (jnp.exp(x)) * (x <= 0.0) + (x + 1.0) * (x > 0.0), 'jshiftedelu'
)  # Nonlinearity used in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" (https://arxiv.org/abs/2006.16236).


def nonnegative_softmax_kernel_feature_creator(data,
                                               projection_matrix,
                                               attention_dims_t,
                                               batch_dims_t,
                                               precision,
                                               is_query,
                                               normalize_data=True,
                                               eps=0.0001):
  """Constructs nonnegative kernel features for fast softmax attention.

  Args:
    data: input for which features are computes
    projection_matrix: random matrix used to compute features
    attention_dims_t: tuple of attention dimensions
    batch_dims_t: tuple of batch dimensions
    precision: precision parameter
    is_query: predicate indicating whether input data corresponds to queries or
      keys
    normalize_data: predicate indicating whether data should be normalized,
    eps: numerical stabilizer.

  Returns:
    Random features for fast softmax attention.
  """

  if normalize_data:
    # We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where
    # w_norm = w * data_normalizer for w in {q,k}.
    data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
  else:
    data_normalizer = 1.0
  ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0])
  data_mod_shape = data.shape[0:len(batch_dims_t)] + projection_matrix.shape
  data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix

  data_dash = lax.dot_general(
      data_normalizer * data,
      data_thick_random_matrix,
      (((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)),
       (batch_dims_t, batch_dims_t)),
      precision=precision)

  diag_data = jnp.square(data)
  diag_data = jnp.sum(diag_data, axis=data.ndim - 1)
  diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
  diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1)

  last_dims_t = (len(data_dash.shape) - 1,)
  if is_query:
    data_dash = ratio * (
        jnp.exp(data_dash - diag_data -
                jnp.max(data_dash, axis=last_dims_t, keepdims=True)) + eps)
  else:
    data_dash = ratio * (
        jnp.exp(data_dash - diag_data - jnp.max(
            data_dash, axis=last_dims_t + attention_dims_t, keepdims=True)) +
        eps)

  return data_dash


def sincos_softmax_kernel_feature_creator(data,
                                          projection_matrix,
                                          attention_dims_t,
                                          batch_dims_t,
                                          precision,
                                          normalize_data=True):
  """Constructs kernel sin-cos features for fast softmax attention.

  Args:
    data: input for which features are computes
    projection_matrix: random matrix used to compute features
    attention_dims_t: tuple of attention dimensions
    batch_dims_t: tuple of batch dimensions
    precision: precision parameter
    normalize_data: predicate indicating whether data should be normalized.

  Returns:
    Random features for fast softmax attention.
  """
  if normalize_data:
    # We have: exp(qk^T/sqrt{d}) = exp(|q|^2/2sqrt{d}) * exp(|k|^2/2sqrt{d}) *
    # exp(-(|q*c-k*c|^2)/2), where c = 1.0 / sqrt{sqrt{d}}.
    data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
  else:
    data_normalizer = 1.0
  ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0])
  data_mod_shape = data.shape[0:len(batch_dims_t)] + projection_matrix.shape
  data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix

  data_dash = lax.dot_general(
      data_normalizer * data,
      data_thick_random_matrix,
      (((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)),
       (batch_dims_t, batch_dims_t)),
      precision=precision)
  data_dash_cos = ratio * jnp.cos(data_dash)
  data_dash_sin = ratio * jnp.sin(data_dash)
  data_dash = jnp.concatenate((data_dash_cos, data_dash_sin), axis=-1)

  # Constructing D_data and data^{'}
  diag_data = jnp.square(data)
  diag_data = jnp.sum(diag_data, axis=data.ndim - 1)
  diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
  diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1)
  # Additional renormalization for numerical stability
  data_renormalizer = jnp.max(diag_data, attention_dims_t, keepdims=True)
  diag_data -= data_renormalizer
  diag_data = jnp.exp(diag_data)
  data_prime = data_dash * diag_data
  return data_prime


def generalized_kernel_feature_creator(data, projection_matrix, batch_dims_t,
                                       precision, kernel_fn, kernel_epsilon,
                                       normalize_data):
  """Constructs kernel features for fast generalized attention.

  Args:
    data: input for which features are computes
    projection_matrix: matrix used to compute features
    batch_dims_t: tuple of batch dimensions
    precision: precision parameter
    kernel_fn: kernel function used
    kernel_epsilon: additive positive term added to every feature for numerical
      stability
    normalize_data: predicate indicating whether data should be normalized.

  Returns:
    Random features for fast generalized attention.
  """
  if normalize_data:
    data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
  else:
    data_normalizer = 1.0
  if projection_matrix is None:
    return kernel_fn(data_normalizer * data) + kernel_epsilon
  else:
    data_mod_shape = data.shape[0:len(batch_dims_t)] + projection_matrix.shape
    data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix
    data_dash = lax.dot_general(
        data_normalizer * data,
        data_thick_random_matrix,
        (((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)),
         (batch_dims_t, batch_dims_t)),
        precision=precision)
  data_prime = kernel_fn(data_dash) + kernel_epsilon
  return data_prime


@gin.configurable
def make_fast_softmax_attention(qkv_dim,
                                renormalize_attention=True,
                                numerical_stabilizer=0.000001,
                                nb_features=256,
                                ortho_features=True,
                                ortho_scaling=0.0,
                                redraw_features=True,
                                unidirectional=False,
                                nonnegative_features=True,
                                lax_scan_unroll=1):
  """Construct a fast softmax attention method."""
  logging.info(
      'Fast softmax attention: %s features and orthogonal=%s, renormalize=%s',
      nb_features, ortho_features, renormalize_attention)
  if ortho_features:
    matrix_creator = functools.partial(
        GaussianOrthogonalRandomMatrix,
        nb_features,
        qkv_dim,
        scaling=ortho_scaling)
  else:
    matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix,
                                       nb_features, qkv_dim)
  if nonnegative_features:

    def kernel_feature_creator(data,
                               projection_matrix,
                               attention_dims_t,
                               batch_dims_t,
                               precision,
                               is_query,
                               normalize_data=True):
      return nonnegative_softmax_kernel_feature_creator(
          data, projection_matrix, attention_dims_t, batch_dims_t, precision,
          is_query, normalize_data, numerical_stabilizer)
  else:

    def kernel_feature_creator(data,
                               projection_matrix,
                               attention_dims_t,
                               batch_dims_t,
                               precision,
                               is_query,
                               normalize_data=True):
      del is_query
      return sincos_softmax_kernel_feature_creator(data, projection_matrix,
                                                   attention_dims_t,
                                                   batch_dims_t, precision,
                                                   normalize_data)

  attention_fn = FastAttentionviaLowRankDecomposition(
      matrix_creator,
      kernel_feature_creator,
      renormalize_attention=renormalize_attention,
      numerical_stabilizer=numerical_stabilizer,
      redraw_features=redraw_features,
      unidirectional=unidirectional,
      lax_scan_unroll=lax_scan_unroll).dot_product_attention
  return attention_fn


@gin.configurable
def make_fast_generalized_attention(qkv_dim,
                                    renormalize_attention=True,
                                    numerical_stabilizer=0.0,
                                    nb_features=256,
                                    features_type='deterministic',
                                    kernel_fn=jax.nn.relu,
                                    kernel_epsilon=0.001,
                                    redraw_features=False,
                                    unidirectional=False,
                                    lax_scan_unroll=1):
  """Construct a fast generalized attention menthod."""
  logging.info('Fast generalized attention.: %s features and renormalize=%s',
               nb_features, renormalize_attention)
  if features_type == 'ortho':
    matrix_creator = functools.partial(
        GaussianOrthogonalRandomMatrix, nb_features, qkv_dim, scaling=False)
  elif features_type == 'iid':
    matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix,
                                       nb_features, qkv_dim)
  elif features_type == 'deterministic':
    matrix_creator = None
  else:
    raise ValueError('Unknown feature value type')

  def kernel_feature_creator(data,
                             projection_matrix,
                             attention_dims_t,
                             batch_dims_t,
                             precision,
                             is_query,
                             normalize_data=False):
    del attention_dims_t
    del is_query
    return generalized_kernel_feature_creator(data, projection_matrix,
                                              batch_dims_t, precision,
                                              kernel_fn, kernel_epsilon,
                                              normalize_data)

  attention_fn = FastAttentionviaLowRankDecomposition(
      matrix_creator,
      kernel_feature_creator,
      renormalize_attention=renormalize_attention,
      numerical_stabilizer=numerical_stabilizer,
      redraw_features=redraw_features,
      unidirectional=unidirectional,
      lax_scan_unroll=lax_scan_unroll).dot_product_attention
  return attention_fn


class RandomMatrix(object):
  r"""Abstract class providing a method for constructing 2D random arrays.

  Class is responsible for constructing 2D random arrays.
  """

  __metaclass__ = abc.ABCMeta

  @abc.abstractmethod
  def get_2d_array(self):
    raise NotImplementedError('Abstract method')


class GaussianUnstructuredRandomMatrix(RandomMatrix):

  def __init__(self, nb_rows, nb_columns, key):
    self.nb_rows = nb_rows
    self.nb_columns = nb_columns
    self.key = key

  def get_2d_array(self):
    return random.normal(self.key, (self.nb_rows, self.nb_columns))


class GaussianOrthogonalRandomMatrix(RandomMatrix):
  r"""Class providing a method to create Gaussian orthogonal matrix.

  Class is responsible for constructing 2D Gaussian orthogonal arrays.
  """

  def __init__(self, nb_rows, nb_columns, key, scaling=0):
    self.nb_rows = nb_rows
    self.nb_columns = nb_columns
    self.key = key
    self.scaling = scaling

  def get_2d_array(self):
    nb_full_blocks = int(self.nb_rows / self.nb_columns)
    block_list = []
    rng = self.key
    for _ in range(nb_full_blocks):
      rng, rng_input = jax.random.split(rng)
      unstructured_block = random.normal(rng_input,
                                         (self.nb_columns, self.nb_columns))
      q, _ = jnp.linalg.qr(unstructured_block)
      q = jnp.transpose(q)
      block_list.append(q)
    remaining_rows = self.nb_rows - nb_full_blocks * self.nb_columns
    if remaining_rows > 0:
      rng, rng_input = jax.random.split(rng)
      unstructured_block = random.normal(rng_input,
                                         (self.nb_columns, self.nb_columns))
      q, _ = jnp.linalg.qr(unstructured_block)
      q = jnp.transpose(q)
      block_list.append(q[0:remaining_rows])
    final_matrix = jnp.vstack(block_list)

    if self.scaling == 0:
      multiplier = jnp.linalg.norm(
          random.normal(self.key, (self.nb_rows, self.nb_columns)), axis=1)
    elif self.scaling == 1:
      multiplier = jnp.sqrt(float(self.nb_columns)) * jnp.ones((self.nb_rows))
    else:
      raise ValueError('Scaling must be one of {0, 1}. Was %s' % self._scaling)

    return jnp.matmul(jnp.diag(multiplier), final_matrix)


class FastAttention(object):
  r"""Abstract class providing a method for fast attention.

  Class is responsible for providing a method <dot_product_attention> for fast
  approximate attention.
  """

  __metaclass__ = abc.ABCMeta

  @abc.abstractmethod
  def dot_product_attention(self,
                            query,
                            key,
                            value,
                            dtype=jnp.float32,
                            bias=None,
                            axis=None,
                            broadcast_dropout=True,
                            dropout_rng=None,
                            dropout_rate=0.,
                            deterministic=False,
                            precision=None):
    """Computes dot-product attention given query, key, and value.

    This is the core function for applying fast approximate dot-product
    attention. It calculates the attention weights given query and key and
    combines the values using the attention weights. This function supports
    multi-dimensional inputs.
    Args:
      query: queries for calculating attention with shape of [batch_size, dim1,
        dim2, ..., dimN, num_heads, mem_channels].
      key: keys for calculating attention with shape of [batch_size, dim1, dim2,
        ..., dimN, num_heads, mem_channels].
      value: values to be used in attention with shape of [batch_size, dim1,
        dim2,..., dimN, num_heads, value_channels].
      dtype: the dtype of the computation (default: float32)
      bias: bias for the attention weights. This can be used for incorporating
        autoregressive mask, padding mask, proximity bias.
      axis: axises over which the attention is applied.
      broadcast_dropout: bool: use a broadcasted dropout along batch dims.
      dropout_rng: JAX PRNGKey: to be used for dropout.
      dropout_rate: dropout rate.
      deterministic: bool, deterministic or not (to apply dropout).
      precision: numerical precision of the computation see `jax.lax.Precision`
        for details.

    Returns:
      Output of shape [bs, dim1, dim2, ..., dimN,, num_heads, value_channels].
    """
    raise NotImplementedError('Abstract method')


def _numerator(z_slice_shape, precision, unroll=1):

  def fwd(qs, ks, vs):

    def body(p, qkv):
      (q, k, v) = qkv
      p += jnp.einsum('...m,...d->...md', k, v, precision=precision)
      X_slice = jnp.einsum('...m,...md->...d', q, p, precision=precision)
      return p, X_slice

    init_value = jnp.zeros(z_slice_shape)
    p, W = lax.scan(body, init_value, (qs, ks, vs), unroll=unroll)
    return W, (p, qs, ks, vs)

  def bwd(pqkv, W_ct):

    def body(carry, qkv_xct):
      p, p_ct = carry
      q, k, v, x_ct = qkv_xct
      q_ct = jnp.einsum('...d,...md->...m', x_ct, p, precision=precision)
      p_ct += jnp.einsum('...d,...m->...md', x_ct, q, precision=precision)
      k_ct = jnp.einsum('...md,...d->...m', p_ct, v, precision=precision)
      v_ct = jnp.einsum('...md,...m->...d', p_ct, k, precision=precision)
      p -= jnp.einsum('...m,...d->...md', k, v, precision=precision)
      return (p, p_ct), (q_ct, k_ct, v_ct)

    p, qs, ks, vs = pqkv
    _, (qs_ct, ks_ct, vs_ct) = lax.scan(
        body, (p, jnp.zeros_like(p)), (qs, ks, vs, W_ct),
        reverse=True,
        unroll=unroll)
    return qs_ct, ks_ct, vs_ct

  @jax.custom_vjp
  def _numerator_impl(qs, ks, vs):
    W, _ = fwd(qs, ks, vs)
    return W

  _numerator_impl.defvjp(fwd, bwd)

  return _numerator_impl


def _denominator(t_slice_shape, precision, unroll=1):

  def fwd(qs, ks):

    def body(p, qk):
      q, k = qk
      p += k
      x = jnp.einsum('...m,...m->...', q, p, precision=precision)
      return p, x

    p = jnp.zeros(t_slice_shape)
    p, R = lax.scan(body, p, (qs, ks), unroll=unroll)
    return R, (qs, ks, p)

  def bwd(qkp, R_ct):

    def body(carry, qkx):
      p, p_ct = carry
      q, k, x_ct = qkx
      q_ct = jnp.einsum('...,...m->...m', x_ct, p, precision=precision)
      p_ct += jnp.einsum('...,...m->...m', x_ct, q, precision=precision)
      k_ct = p_ct
      p -= k
      return (p, p_ct), (q_ct, k_ct)

    qs, ks, p = qkp
    _, (qs_ct, ks_ct) = lax.scan(
        body, (p, jnp.zeros_like(p)), (qs, ks, R_ct),
        reverse=True,
        unroll=unroll)
    return (qs_ct, ks_ct)

  @jax.custom_vjp
  def _denominator_impl(qs, ks):
    R, _ = fwd(qs, ks)
    return R

  _denominator_impl.defvjp(fwd, bwd)

  return _denominator_impl


class FastAttentionviaLowRankDecomposition(FastAttention):
  r"""Class providing a method for fast attention via low rank decomposition.

  Class is responsible for providing a method <dot_product_attention> for fast
  dot-product attention with the use of low rank decomposition (e.g. with
  random feature maps).
  """

  def __init__(self,
               matrix_creator,
               kernel_feature_creator,
               renormalize_attention,
               numerical_stabilizer,
               redraw_features,
               unidirectional,
               lax_scan_unroll=1):  # For optimal GPU performance, set to 16.
    rng = random.PRNGKey(0)
    self.matrix_creator = matrix_creator
    self.projection_matrix = self.draw_weights(rng)
    self.kernel_feature_creator = kernel_feature_creator
    self.renormalize_attention = renormalize_attention
    self.numerical_stabilizer = numerical_stabilizer
    self.redraw_features = redraw_features
    self.unidirectional = unidirectional
    self.lax_scan_unroll = lax_scan_unroll

  def draw_weights(self, key):
    if self.matrix_creator is None:
      return None
    matrixrng, _ = random.split(key)
    projection_matrix = self.matrix_creator(key=matrixrng).get_2d_array()
    return projection_matrix

  def dot_product_attention(self,
                            query,
                            key,
                            value,
                            dtype=jnp.float32,
                            bias=None,
                            axis=None,
                            broadcast_dropout=True,
                            dropout_rng=None,
                            dropout_rate=0.,
                            deterministic=False,
                            precision=None):

    assert key.shape[:-1] == value.shape[:-1]
    assert (query.shape[0:1] == key.shape[0:1] and
            query.shape[-1] == key.shape[-1])
    if axis is None:
      axis = tuple(range(1, key.ndim - 2))
    if not isinstance(axis, Iterable):
      axis = (axis,)
    assert key.ndim == query.ndim
    assert key.ndim == value.ndim
    for ax in axis:
      if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
        raise ValueError('Attention axis must be between the batch '
                         'axis and the last-two axes.')
    n = key.ndim

    # Constructing projection tensor.
    if self.redraw_features:
      # TODO(kchoro): Get rid of the constant below.
      query_seed = lax.convert_element_type(
          jnp.ceil(jnp.sum(query) * 10000000.0), jnp.int32)
      rng = random.PRNGKey(query_seed)
      self.projection_matrix = self.draw_weights(rng)

    # batch_dims is  <bs, <non-attention dims>, num_heads>
    batch_dims = tuple(onp.delete(range(n), axis + (n - 1,)))
    # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
    qk_perm = batch_dims + axis + (n - 1,)
    k_extra_perm = axis + batch_dims + (n - 1,)
    key_extra = key.transpose(k_extra_perm)
    key = key.transpose(qk_perm)
    query = query.transpose(qk_perm)
    # v -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
    v_perm = batch_dims + axis + (n - 1,)
    value = value.transpose(v_perm)
    batch_dims_t = tuple(range(len(batch_dims)))
    attention_dims_t = tuple(
        range(len(batch_dims),
              len(batch_dims) + len(axis)))

    # Constructing tensors Q^{'} and K^{'}.
    query_prime = self.kernel_feature_creator(query, self.projection_matrix,
                                              attention_dims_t, batch_dims_t,
                                              precision, True)
    key_prime = self.kernel_feature_creator(key, self.projection_matrix,
                                            attention_dims_t, batch_dims_t,
                                            precision, False)

    if self.unidirectional:
      index = attention_dims_t[0]
      z_slice_shape = key_prime.shape[0:len(batch_dims_t)] + (
          key_prime.shape[-1],) + (value.shape[-1],)

      numerator_fn = _numerator(z_slice_shape, precision, self.lax_scan_unroll)
      W = numerator_fn(
          jnp.moveaxis(query_prime, index, 0),
          jnp.moveaxis(key_prime, index, 0), jnp.moveaxis(value, index, 0))

      # Constructing W = (Q^{'}(K^{'})^{T})_{masked}V
      W = jnp.moveaxis(W, 0, index)

      if not self.renormalize_attention:
        # Unidirectional, not-normalized attention.
        perm_inv = _invert_perm(qk_perm)
        result = W.transpose(perm_inv)
        return result
      else:
        # Unidirectional, normalized attention.
        thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(
            key_extra.shape[0:len(axis)])

        index = attention_dims_t[0]
        t_slice_shape = key_prime.shape[0:len(batch_dims_t)] + (
            key_prime.shape[-1],)
        denominator_fn = _denominator(t_slice_shape, precision,
                                      self.lax_scan_unroll)
        R = denominator_fn(
            jnp.moveaxis(query_prime, index, 0),
            jnp.moveaxis(key_prime, index, 0))

        R = jnp.moveaxis(R, 0, index)
    else:
      contract_query = tuple(
          range(len(batch_dims) + len(axis),
                len(batch_dims) + len(axis) + 1))
      contract_z = tuple(range(len(batch_dims), len(batch_dims) + 1))
      # Constructing Z = (K^{'})^{T}V
      # Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
      Z = lax.dot_general(
          key_prime,
          value,
          ((attention_dims_t, attention_dims_t), (batch_dims_t, batch_dims_t)),
          precision=precision)
      # Constructing W = Q^{'}Z = Q^{'}(K^{'})^{T}V
      # q (bs, <non-attention dims>, num_heads, <attention dims>, channels_m)
      # Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
      # W (bs,  <non-attention dims>, num_heads, <attention dims>, channels_v)
      W = lax.dot_general(
          query_prime,
          Z, ((contract_query, contract_z), (batch_dims_t, batch_dims_t)),
          precision=precision)
      if not self.renormalize_attention:
        # Bidirectional, not-normalized attention.
        perm_inv = _invert_perm(qk_perm)
        result = W.transpose(perm_inv)
        return result
      else:
        # Bidirectional, normalized attention.
        thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(
            key_extra.shape[0:len(axis)])
        contract_key = tuple(
            range(len(batch_dims),
                  len(batch_dims) + len(axis)))
        contract_thick_all_ones = tuple(
            range(thick_all_ones.ndim - len(axis), thick_all_ones.ndim))
        # Construct T = (K^{'})^{T} 1_L
        # k (bs, <non-attention dims>, num_heads, <attention dims>, channels)
        T = lax.dot_general(
            key_prime,
            thick_all_ones, ((contract_key, contract_thick_all_ones),
                             (batch_dims_t, batch_dims_t)),
            precision=precision)

        # Construct partition function: R = Q^{'} T = Q^{'}(K^{'})^{T} 1_L
        # q_p (bs, <non-attention dims>, num_heads, <attention dims>, channs_m)
        # T   (bs, <non-attention dims>, num_heads, channels_m)
        R = lax.dot_general(
            query_prime,
            T, (((query_prime.ndim - 1,), (T.ndim - 1,)),
                (batch_dims_t, range(0,
                                     len(T.shape) - 1))),
            precision=precision)

    R = R + 2 * self.numerical_stabilizer * (
        jnp.abs(R) <= self.numerical_stabilizer)
    R = jnp.reciprocal(R)
    R = jnp.expand_dims(R, len(R.shape))
    # W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v)
    # R (bs, <non-attention dims>, num_heads, <attention dims>, extra_channel)
    result = W * R
    # back to (bs, dim1, dim2, ..., dimN, num_heads, channels)
    perm_inv = _invert_perm(qk_perm)
    result = result.transpose(perm_inv)
    return result


def _invert_perm(perm):
  perm_inv = [0] * len(perm)
  for i, j in enumerate(perm):
    perm_inv[j] = i
  return tuple(perm_inv)
