'''
Edited from the VMoE code.

Each experiment requires a different hook, hence different files.
However, since Flax modules cannot be extended easily (or maybe I'm too dumb to figure it out),
you need to copy the entire module and make the necessary changes.

There are three required components:
- The `prior_mixer` function, which mixes the prior attention with the gates outputs
- The custom `attn_hook`, which decides what in the attention process to be returned
- The custom `CustomRouter`, which decides where the prior incorporating function is called
    CustomRouter should call prior_mixer by default.
'''

import functools
import warnings
from typing import Optional, Tuple, Callable, Union
import jax
from jax import lax
import jax.numpy as jnp
import vmoe.moe
import vmoe.utils
from vmoe.nn import vit_moe, routing
import flax.linen as nn
from flax.linen.linear import DenseGeneral
from flax.linen.dtypes import promote_dtype
from flax.linen.normalization import LayerNorm
from flax.linen.attention import combine_masks
from flax.linen.module import Module, compact, merge_param
from flax.typing import Array, PRNGKey, Dtype, PrecisionLike


####################
# CHANGE THIS PART #
####################


def prior_mixer(
    gates_logits: jnp.ndarray,
    prior_attn: jnp.ndarray,
    softmax_temperature: float = 1,
    detach: bool = False,
):
    """
    Both `gates_logits` and `prior_attn` are pre-activation.
    """
    # prior_attn.shape == (batch_size, num_seqs, 50, 50)
    # gates_logits.shape == (batch_size x 50 / group_size, group_size, num_experts)
    # where 50 = 1 (cls token) + 7 x 7 patches is seqs_length
    # import pdb;pdb.set_trace()
    bsz, seqs_length, _ = prior_attn.shape
    _, group_size, num_experts = gates_logits.shape

    # test stop_gradient!
    if detach:
        prior_attn = jax.lax.stop_gradient(prior_attn)




    # batch_size, num_seqs, seqs_length
    # prior_attn_entropies = jax.scipy.special.entr(prior_attn).sum(axis=-1)

    # # weight to prioritize low entropy
    # # NOTE: might need to regulate entropy somehow to avoid collapse!
    # prior_attn_weights = jax.nn.softmax(
    #     -prior_attn_entropies * softmax_temperature, axis=1
    # )

    # # (bsz, num_seqs, seqs_length) x (bsz, num_seqs, seqs_length, seqs_length)
    # # -> (bsz, seqs_length, seqs_length)
    # prior_attn = (prior_attn_weights[..., None] * prior_attn).sum(axis=1)

    # # combine!
    gates_logits = gates_logits.reshape(bsz, seqs_length, num_experts)
    # (bsz, seqs_length, seqs_length) x (bsz, seqs_length, num_experts)
    # -> (bsz, seqs_length, num_experts)


    # gates_logits = prior_attn @ gates_logits
    diag = jax.vmap(jnp.diagflat)(jnp.diagonal(prior_attn, axis1 = 1, axis2= 2))

    prior_attn = 0.5*diag + 0.5*(prior_attn - diag)
    prior_attn = prior_attn/(jnp.sum(prior_attn, axis = -1, keepdims = True) + 1e-6)



    # import pdb;pdb.set_trace()

    gates_logits = gates_logits.reshape(
        bsz * seqs_length // group_size, group_size, num_experts
    )
    # import pdb;pdb.set_trace()
    return gates_logits


# this is almost identical to the source code from flax.attention, the only difference is the
# additional return of the post-softmax attention matrix.
def attn_hook(
    query: Array,
    key: Array,
    bias: Optional[Array] = None,
    mask: Optional[Array] = None,
    broadcast_dropout: bool = True,
    dropout_rng: Optional[PRNGKey] = None,
    dropout_rate: float = 0.0,
    deterministic: bool = False,
    dtype: Optional[Dtype] = None,
    precision: PrecisionLike = None,
    module: Optional[Module] = None,
    force_fp32_for_softmax: bool = False,
    einsum_dot_general: Callable[..., Array] = jax.lax.dot_general,
):
    """Computes dot-product attention weights given query and key.

    Used by :func:`dot_product_attention`, which is what you'll most likely use.
    But if you want access to the attention weights for introspection, then
    you can directly call this function and call einsum yourself.

    Args:
      query: queries for calculating attention with shape of ``[batch...,
        q_length, num_heads, qk_depth_per_head]``.
      key: keys for calculating attention with shape of ``[batch..., kv_length,
        num_heads, qk_depth_per_head]``.
      bias: bias for the attention weights. This should be broadcastable to the
        shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for
        incorporating causal masks, padding masks, proximity bias, etc.
      mask: mask for the attention weights. This should be broadcastable to the
        shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for
        incorporating causal masks. Attention weights are masked out if their
        corresponding mask value is ``False``.
      broadcast_dropout: bool: use a broadcasted dropout along batch dims.
      dropout_rng: JAX PRNGKey: to be used for dropout
      dropout_rate: dropout rate
      deterministic: bool, deterministic or not (to apply dropout)
      dtype: the dtype of the computation (default: infer from inputs and params)
      precision: numerical precision of the computation see ``jax.lax.Precision``
        for details.
      module: the Module that will sow the attention weights into the
        'intermediates' collection. Remember to mark 'intermediates' as mutable
        via ``mutable=['intermediates']`` in order to have that collection
        returned. If ``module`` is None, the attention weights will not be sowed.
      force_fp32_for_softmax: bool, whether to force the softmax to be computed in
        fp32. This is useful for mixed-precision training where higher precision
        is desired for numerical stability.
      einsum_dot_general: the dot_general to use in einsum.

    Returns:
      Output of shape ``[batch..., num_heads, q_length, kv_length]``.
    """
    query, key = promote_dtype(query, key, dtype=dtype)
    dtype = query.dtype

    assert query.ndim == key.ndim, "q, k must have same rank."
    assert query.shape[:-3] == key.shape[:-3], "q, k batch dims must match."
    assert query.shape[-2] == key.shape[-2], "q, k num_heads must match."
    assert query.shape[-1] == key.shape[-1], "q, k depths must match."

    # calculate attention matrix
    depth = query.shape[-1]
    query = query / jnp.sqrt(depth).astype(dtype)
    # attn weight shape is (batch..., num_heads, q_length, kv_length)
    attn_weights = jnp.einsum(
        "...qhd,...khd->...hqk",
        query,
        key,
        precision=precision,
        _dot_general=einsum_dot_general,
    )

    # apply attention bias: masking, dropout, proximity bias, etc.
    if bias is not None:
        attn_weights = attn_weights + bias
    # apply attention mask
    if mask is not None:
        big_neg = jnp.finfo(dtype).min
        attn_weights = jnp.where(mask, attn_weights, big_neg)

    # normalize the attention weights
    if force_fp32_for_softmax and dtype != jnp.float32:
        attn_weights = jax.nn.softmax(attn_weights.astype(jnp.float32))
    else:
        attn_weights = jax.nn.softmax(attn_weights).astype(dtype)

    # thank god JAX has no inplace operations... right?
    pre_dropout_attn_weights = attn_weights

    if module:
        module.sow("intermediates", "attention_weights", attn_weights)

    # apply attention dropout
    if not deterministic and dropout_rate > 0.0:
        keep_prob = 1.0 - dropout_rate
        if broadcast_dropout:
            # dropout is broadcast across the batch + head dimensions
            dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
            keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape)  # type: ignore
        else:
            keep = jax.random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)  # type: ignore
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        attn_weights = attn_weights * multiplier

    return attn_weights, pre_dropout_attn_weights


# add a prior incorporator in the logit calculations
# given it overrides function signature, mypy type ignore is needed
class CustomRouter(routing.NoisyTopExpertsPerItemRouter):
    attn_mixer: Callable[[Array, Array], Array] = prior_mixer

    @nn.compact
    def __call__(  # type: ignore
        self, inputs: Array, prior_attn: Array
    ) -> Tuple[routing.BaseDispatcher, routing.Metrics]:
        gates_softmax, metrics = self._compute_gates_softmax_and_metrics(
            inputs, prior_attn, self.num_experts
        )
        dispatcher = self._create_dispatcher(gates_softmax)
        return dispatcher, metrics

    @nn.nowrap
    def _compute_gates_softmax_and_metrics(  # type: ignore
        self, inputs: Array, prior_attn: Array, num_experts: int
    ) -> Tuple[Array, routing.Metrics]:
        if inputs.ndim != 3:
            raise ValueError(f"inputs.ndim must be 3, but it is {inputs.ndim}")
        if not num_experts >= self.num_selected_experts >= 1:
            raise ValueError(
                f"num_experts >= num_selected_experts >= 1, but got "
                f"num_experts = {num_experts} and "
                f"num_selected_experts = {self.num_selected_experts}."
            )
        dtype = self.dtype or inputs.dtype
        # Compute the gating logits for each pair of (item, expert).
        gates_logits = nn.Dense(
            features=num_experts, use_bias=False, dtype=dtype, name="dense"
        )(inputs)

        # the main difference is here
        gates_logits = self.attn_mixer(gates_logits, prior_attn)
        # import pdb;pdb.set_trace()

        gates_softmax = jax.nn.softmax(gates_logits)
        importance_loss = jax.vmap(self._importance_auxiliary_loss)(gates_softmax)
        if self.deterministic or self.noise_std == 0.0:
            gshard_loss = jax.vmap(self._gshard_auxiliary_loss)(gates_softmax)
            metrics = {
                "auxiliary_loss": routing._weighted_sum(
                    (self.gshard_loss_weight, gshard_loss),
                    (self.importance_loss_weight, importance_loss),
                ),
                "gshard_loss": gshard_loss,
                "importance_loss": importance_loss,
                "features": inputs,
                "clusters": jnp.argmax(gates_logits, axis=-1),
            }
            return gates_softmax, metrics
        else:
            noise_std = (1.0 / num_experts) * self.noise_std
            logits_noise = noise_std * jax.random.normal(
                key=self.make_rng("gating"), shape=gates_logits.shape
            )
            gates_logits_noisy = gates_logits + logits_noise
            gates_softmax_noisy = jax.nn.softmax(gates_logits_noisy)
            load_loss = jax.vmap(
                functools.partial(
                    self._load_auxiliary_loss,
                    num_selected_experts=self.num_selected_experts,
                    noise_std=noise_std,
                )
            )(gates_logits, gates_logits_noisy)
            gshard_loss = jax.vmap(self._gshard_auxiliary_loss)(gates_softmax_noisy)

            metrics = {
                "auxiliary_loss": routing._weighted_sum(
                    (self.gshard_loss_weight, gshard_loss),
                    (self.importance_loss_weight, importance_loss),
                    (self.load_loss_weight, load_loss),
                ),
                "gshard_loss": gshard_loss,
                "importance_loss": importance_loss,
                "load_loss": load_loss,
            }
            return gates_softmax_noisy, metrics


####################
# BOILERPLATE CODE #
####################

# the differences are the custom expert router and anywhere that uses the
# prior attention matrix.
class PriorMlpMoeBlock(vit_moe.MlpMoeBlock):
    @nn.nowrap
    def create_router(self) -> nn.Module:
        router_kwargs = dict(num_experts=self.num_experts, **(self.router or {}))
        # By default, the router will be deterministic during inference. But we
        # allow to override it.
        router_kwargs["deterministic"] = router_kwargs.get(
            "deterministic", self.deterministic
        )

        # Create instance of the router class.
        return CustomRouter(
            dtype=self.dtype,
            name="Router",
            **router_kwargs,
        )

    @nn.compact
    def __call__(self, inputs, prior_attn = None):
        # sim = jax.vmap ############################TAM IS HERE
        # inputs_ = nn.Dense(features=inputs.shape[-1], name='prior_dense_u')(inputs)
        prior_attn = jax.nn.softmax(inputs @ (jnp.transpose(inputs, (0,2,1))), axis=-1)
        # import pdb;pdb.set_trace()
        assert inputs.ndim == 3, f"Expected ndim = 3, but got shape {inputs.shape}"
        # Reshape inputs from (num_seqs, seq_length, hidden_size) to
        # (num_groups, groups_size, hidden_size).
        inputs_shape = inputs.shape
        inputs = inputs.reshape(-1, self.group_size, inputs.shape[-1])
        dispatcher, metrics = self.create_router()(inputs, prior_attn)
        # Use the dispatcher to apply a MoE of MlpBlocks.
        mlp_moe_layer = vmoe.moe.sparse_moe_spmd(
            vit_moe.MlpBlock,
            has_aux=False,
            variable_axes={"params": 0, "intermediates": 0},
            split_rngs=self.create_split_rngs(),
        )(
            mlp_dim=self.mlp_dim,
            dropout_rate=self.dropout_rate,
            dtype=self.dtype,
            deterministic=self.deterministic,
            name="Mlp",
        )
        outputs = mlp_moe_layer(dispatcher, inputs)
        # Reshape outputs from (num_groups, group_size, output_dim) to
        # (num_seqs, seqs_length, output_dim).
        outputs = outputs.reshape(*inputs_shape[:-1], outputs.shape[-1])
        return outputs, metrics


# the only difference is that it returns additional attn_weights
class PriorMultiHeadDotProductAttention(nn.MultiHeadDotProductAttention):
    @compact
    def __call__(
        self,
        inputs_q: Array,
        inputs_k: Optional[Array] = None,
        inputs_v: Optional[Array] = None,
        *,
        inputs_kv: Optional[Array] = None,
        mask: Optional[Array] = None,
        deterministic: Optional[bool] = None,
        dropout_rng: Optional[PRNGKey] = None,
        sow_weights: bool = False,
    ):
        if inputs_kv is not None:
            if inputs_k is not None or inputs_v is not None:
                raise ValueError(
                    "If either `inputs_k` or `inputs_v` is not None, "
                    "`inputs_kv` must be None. If `inputs_kv` is not None, both `inputs_k` "
                    "and `inputs_v` must be None. We recommend using `inputs_k` and "
                    "`inputs_v` args, since `inputs_kv` will be deprecated soon. See "
                    "https://github.com/google/flax/discussions/3389 for more "
                    "information."
                )
            inputs_k = inputs_v = inputs_kv
            warnings.warn(
                "The inputs_kv arg will be deprecated soon. "
                "Use inputs_k and inputs_v instead. See "
                "https://github.com/google/flax/discussions/3389 "
                "for more information.",
                DeprecationWarning,
            )
        else:
            if inputs_k is None:
                if inputs_v is not None:
                    raise ValueError(
                        "`inputs_k` cannot be None if `inputs_v` is not None. "
                        "To have both `inputs_k` and `inputs_v` be the same value, pass in the "
                        "value to `inputs_k` and leave `inputs_v` as None."
                    )
                inputs_k = inputs_q
            if inputs_v is None:
                inputs_v = inputs_k
            elif inputs_v.shape[-1] == inputs_v.shape[-2]:
                warnings.warn(
                    f"You are passing an array of shape {inputs_v.shape} "
                    "to the `inputs_v` arg, when you may have intended "
                    "to pass it to the `mask` arg. As of Flax version "
                    "0.7.4, the function signature of "
                    "MultiHeadDotProductAttention's `__call__` method "
                    "has changed to `__call__(inputs_q, inputs_k=None, "
                    "inputs_v=None, *, inputs_kv=None, mask=None, "
                    "deterministic=None)`. Use the kwarg `mask` instead. "
                    "See https://github.com/google/flax/discussions/3389 "
                    "and read the docstring for more information.",
                    DeprecationWarning,
                )

        features = self.out_features or inputs_q.shape[-1]
        qkv_features = self.qkv_features or inputs_q.shape[-1]
        assert qkv_features % self.num_heads == 0, (
            f"Memory dimension ({qkv_features}) must be divisible by number of"
            f" heads ({self.num_heads})."
        )
        head_dim = qkv_features // self.num_heads

        dense = functools.partial(
            DenseGeneral,
            axis=-1,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            features=(self.num_heads, head_dim),
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
            use_bias=self.use_bias,
            precision=self.precision,
            dot_general=self.qkv_dot_general,
            dot_general_cls=self.qkv_dot_general_cls,
        )
        # project inputs_q to multi-headed q/k/v
        # dimensions are then [batch..., length, n_heads, n_features_per_head]
        query, key, value = (
            dense(name="query")(inputs_q),
            dense(name="key")(inputs_k),
            dense(name="value")(inputs_v),
        )

        if self.normalize_qk:
            # Normalizing query and key projections stabilizes training with higher
            # LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis.
            query = LayerNorm(
                name="query_ln",
                use_bias=False,
                dtype=self.dtype,
                param_dtype=self.param_dtype,
            )(query)  # type: ignore[call-arg]
            key = LayerNorm(
                name="key_ln",
                use_bias=False,
                dtype=self.dtype,
                param_dtype=self.param_dtype,
            )(key)  # type: ignore[call-arg]

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.decode:
            # detect if we're initializing by absence of existing cache data.
            is_initialized = self.has_variable("cache", "cached_key")
            cached_key = self.variable(
                "cache", "cached_key", jnp.zeros, key.shape, key.dtype
            )
            cached_value = self.variable(
                "cache", "cached_value", jnp.zeros, value.shape, value.dtype
            )
            cache_index = self.variable(
                "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
            )
            if is_initialized:
                (
                    *batch_dims,
                    max_length,
                    num_heads,
                    depth_per_head,
                ) = cached_key.value.shape
                # shape check of cached keys against query input
                expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head)
                if expected_shape != query.shape:
                    raise ValueError(
                        "Autoregressive cache shape error, "
                        "expected query shape %s instead got %s."
                        % (expected_shape, query.shape)
                    )
                # update key, value caches with our new 1d spatial slices
                cur_index = cache_index.value
                zero = jnp.array(0, dtype=lax.dtype(cur_index.dtype))
                indices: tuple[Union[int, jax.Array], ...] = (zero,) * len(
                    batch_dims
                ) + (
                    cur_index,
                    zero,
                    zero,
                )
                key = lax.dynamic_update_slice(cached_key.value, key, indices)
                value = lax.dynamic_update_slice(cached_value.value, value, indices)
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1
                # causal mask for cached decoder self-attention:
                # our single query position should only attend to those key
                # positions that have already been generated and cached,
                # not the remaining zero elements.
                mask = combine_masks(
                    mask,
                    jnp.broadcast_to(
                        jnp.arange(max_length) <= cur_index,
                        tuple(batch_dims) + (1, 1, max_length),
                    ),
                )

        if self.dropout_rate > 0.0:  # Require `deterministic` only if using dropout.
            m_deterministic = merge_param(
                "deterministic", self.deterministic, deterministic
            )
            if not m_deterministic and dropout_rng is None:
                dropout_rng = self.make_rng("dropout")
        else:
            m_deterministic = True

        # apply attention
        if sow_weights:
            x = self.attention_fn(
                query,
                key,
                value,
                mask=mask,
                dropout_rng=dropout_rng,
                dropout_rate=self.dropout_rate,
                broadcast_dropout=self.broadcast_dropout,
                deterministic=m_deterministic,
                dtype=self.dtype,
                precision=self.precision,
                module=self,
            )  # pytype: disable=wrong-keyword-args
        else:
            x = self.attention_fn(
                query,
                key,
                value,
                mask=mask,
                dropout_rng=dropout_rng,
                dropout_rate=self.dropout_rate,
                broadcast_dropout=self.broadcast_dropout,
                deterministic=m_deterministic,
                dtype=self.dtype,
                precision=self.precision,
            )

        if not isinstance(x, jnp.ndarray):
            x, attn_weights = x
        else:
            # should never happen!
            raise

        # back to the original inputs dimensions
        out = DenseGeneral(
            features=features,
            axis=(-2, -1),
            kernel_init=self.out_kernel_init or self.kernel_init,
            bias_init=self.out_bias_init or self.bias_init,
            use_bias=self.use_bias,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision,
            dot_general=self.out_dot_general,
            dot_general_cls=self.out_dot_general_cls,
            name="out",  # type: ignore[call-arg]
        )(x)

        #### check if the densegeneral here can be applied to value matrix?
        return out, attn_weights


# the only difference is in the MultiHeadDotProductAttention creation and the self.mlp_block call
class PriorEncoderBlock(vit_moe.EncoderBlock):

    @nn.compact
    def __call__(self, inputs):
        # Attention Block.
        x = nn.LayerNorm(dtype=self.dtype)(inputs)
        x, pre_dropout_attn_weights = PriorMultiHeadDotProductAttention(
            dtype=self.dtype,
            kernel_init=nn.initializers.xavier_uniform(),
            broadcast_dropout=False,
            deterministic=self.deterministic,
            dropout_rate=self.attention_dropout_rate,
            normalize_qk=self.attention_qk_norm,
            num_heads=self.num_heads,
            attention_fn=prior_dot_product_attention,
            name="SelfAttention",
        )(inputs_q=x, inputs_kv=x)
        x = nn.Dropout(rate=self.dropout_rate, deterministic=self.deterministic)(x)
        x = x + inputs
        # MLP-MoE block.
        y = nn.LayerNorm(dtype=self.dtype)(x)
        y = self.mlp_block(dtype=self.dtype, deterministic=self.deterministic)(
            y, pre_dropout_attn_weights
        )
        if isinstance(y, jnp.ndarray):
            return x + y
        else:
            y, metrics = y
            return x + y, metrics


# the only difference is that the __call__() function takes in an additional unused parameter to
# match the API of PriorMlpMoeBlock.
class PriorMlpBlock(vit_moe.MlpBlock):
    @nn.compact
    def __call__(self, inputs, _):
        return super().__call__(inputs)


# the only difference is in the references of the block classes
class PriorEncoderMoe(vit_moe.EncoderMoe):

    @nn.compact
    def __call__(self, inputs):
        assert inputs.ndim == 3, f"Expected ndim = 3, but got shape {inputs.shape}"
        x = self.add_position_emb(inputs)
        x = nn.Dropout(rate=self.dropout_rate, deterministic=self.deterministic)(x)

        dense_mlp_params = dict(mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate)
        moe_mlp_params = {**dense_mlp_params, **(self.moe or {})}
        moe_mlp_layers = moe_mlp_params.pop("layers", ())

        dense_mlp_cls = vmoe.utils.partialclass(
            PriorMlpBlock, **dense_mlp_params, name="Mlp"
        )
        moe_mlp_cls = vmoe.utils.partialclass(
            PriorMlpMoeBlock, **moe_mlp_params, name="Moe"
        )
        encoder_block_cls = vmoe.utils.partialclass(
            PriorEncoderBlock,
            num_heads=self.num_heads,
            dropout_rate=self.dropout_rate,
            attention_dropout_rate=self.attention_dropout_rate,
            attention_qk_norm=self.attention_qk_norm,
            deterministic=self.deterministic,
            dtype=self.dtype,
        )

        metrics = {}
        for block in range(self.num_layers):
            if block in moe_mlp_layers:
                x, metrics[f"encoderblock_{block}"] = encoder_block_cls(
                    name=f"encoderblock_{block}", mlp_block=moe_mlp_cls
                )(x)
            else:
                x = encoder_block_cls(
                    name=f"encoderblock_{block}", mlp_block=dense_mlp_cls
                )(x)
        encoded = nn.LayerNorm(name="encoder_norm")(x)
        # Sum auxiliary losses from all blocks.
        metrics["auxiliary_loss"] = sum(m["auxiliary_loss"] for m in metrics.values())
        return encoded, metrics


# this is almost identical to the source code from flax_attention, the only difference is a
# customizable attention function that returns an additional attention matrix.
def prior_dot_product_attention(
    query: Array,
    key: Array,
    value: Array,
    bias: Optional[Array] = None,
    mask: Optional[Array] = None,
    broadcast_dropout: bool = True,
    dropout_rng: Optional[PRNGKey] = None,
    dropout_rate: float = 0.0,
    deterministic: bool = False,
    dtype: Optional[Dtype] = None,
    precision: PrecisionLike = None,
    module: Optional[Module] = None,
    force_fp32_for_softmax: bool = False,
    einsum_dot_general: Callable[..., Array] = jax.lax.dot_general,
):
    """Computes dot-product attention given query, key, and value.

    This is the core function for applying attention based on
    https://arxiv.org/abs/1706.03762. It calculates the attention weights given
    query and key and combines the values using the attention weights.

    .. note::
      ``query``, ``key``, ``value`` needn't have any batch dimensions.

    Args:
      query: queries for calculating attention with shape of ``[batch...,
        q_length, num_heads, qk_depth_per_head]``.
      key: keys for calculating attention with shape of ``[batch..., kv_length,
        num_heads, qk_depth_per_head]``.
      value: values to be used in attention with shape of ``[batch..., kv_length,
        num_heads, v_depth_per_head]``.
      bias: bias for the attention weights. This should be broadcastable to the
        shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for
        incorporating causal masks, padding masks, proximity bias, etc.
      mask: mask for the attention weights. This should be broadcastable to the
        shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for
        incorporating causal masks. Attention weights are masked out if their
        corresponding mask value is ``False``.
      broadcast_dropout: bool: use a broadcasted dropout along batch dims.
      dropout_rng: JAX PRNGKey: to be used for dropout
      dropout_rate: dropout rate
      deterministic: bool, deterministic or not (to apply dropout)
      dtype: the dtype of the computation (default: infer from inputs)
      precision: numerical precision of the computation see ``jax.lax.Precision`
        for details.
      module: the Module that will sow the attention weights into the
        'intermediates' collection. Remember to mark 'intermediates' as mutable
        via ``mutable=['intermediates']`` in order to have that collection
        returned. If ``module`` is None, the attention weights will not be sowed.
      force_fp32_for_softmax: bool, whether to force the softmax to be computed in
        fp32. This is useful for mixed-precision training where higher precision
        is desired for numerical stability.
      einsum_dot_general: the dot_general to use in einsum.

    Returns:
      Output of shape ``[batch..., q_length, num_heads, v_depth_per_head]``.
    """
    query, key, value = promote_dtype(query, key, value, dtype=dtype)
    dtype = query.dtype
    assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
    assert (
        query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
    ), "q, k, v batch dims must match."
    assert (
        query.shape[-2] == key.shape[-2] == value.shape[-2]
    ), "q, k, v num_heads must match."
    assert key.shape[-3] == value.shape[-3], "k, v lengths must match."

    # compute attention weights
    attn_weights, pre_dropout_attn_weights = attn_hook(
        query,
        key,
        bias,
        mask,
        broadcast_dropout,
        dropout_rng,
        dropout_rate,
        deterministic,
        dtype,
        precision,
        module,
        force_fp32_for_softmax,
        einsum_dot_general=einsum_dot_general,
    )

    # return weighted sum over values for each query position
    result = jnp.einsum(
        "...hqk,...khd->...qhd",
        attn_weights,
        value,
        precision=precision,
        _dot_general=einsum_dot_general,
    ) 

    return result, pre_dropout_attn_weights
