"""
Realtive positional embeddings added for each layer of latte
Notation:
    B = batch size, T = sequene length, D = embed/hidden dimension, H = number heads
"""

from typing import Any, Dict
from typing import Tuple
import jax
from flax import linen as nn
from jax import numpy as jnp
from flax.linen.dtypes import promote_dtype
from .layers import Conv1D, DepthConv1D, S5Layer, RopeEmbeds, RMSNorm
from recurrentgemma.jax.layers import RGLRU
from latte_trans.config import Config
from .xPos import XPos
import math

parallel_scan = jax.lax.associative_scan
PRECISION = jax.lax.Precision.DEFAULT  # "HIGEST" #


def apply_rope(rel_pos, mat):
    """
    Implement rotation where rel_pos is already A^t.
    Uses fast implementation of the parse matrix
    Args:
        rel_pos: jnp.array(B, H, T, D] -> TBHD
            Half sin & second half cos
        mat: jnp.array(B,H,T,D) -> TD
            input matrix
        neg: bool
            Denotes weather we need to calculate R^{-s}
    """
    sin, cos = jnp.split(rel_pos, indices_or_sections=2, axis=-1)
    sin = jnp.tile(sin, reps=(1, 2))
    cos = jnp.tile(cos, reps=(1, 2))
    # jax.debug.print("sin shape: {x}", x=sin.shape)
    rotate_half_mat = jnp.concatenate([-mat[..., 1::2], mat[..., 0::2]], axis=-1)
    # print(rotate_half_mat.shape, mat.shape, sin.shape, cos.shape)
    # rotated = cos * mat + sin * rotate_half_mat
    rotated = jnp.einsum("BHTD,TD->BHTD", mat, cos) + jnp.einsum(
        "BHTD,TD->BHTD", rotate_half_mat, sin
    )

    return rotated


def apply_vapor(rel_pos, mat, neg=False):
    """
    Implement rotation where rel_pos is already A^t.
    Uses fast implementation of the parse matrix
    Args:
        rel_pos: jnp.array(B, H, T, D] -> TBHD
            Half sin & second half cos
        mat: jnp.array(B,H,T,D) -> TD
            input matrix
        neg: bool
            Denotes weather we need to calculate R^{-s}
    """
    sin, cos = jnp.split(rel_pos, indices_or_sections=2, axis=-1)
    sin = jnp.tile(sin, reps=(1, 2))
    cos = jnp.tile(cos, reps=(1, 2))
    # jax.debug.print("sin shape: {x}", x=sin.shape)
    if neg:
        rotate_half_mat = jnp.concatenate([mat[..., 1::2], -mat[..., 0::2]], axis=-1)
    else:
        rotate_half_mat = jnp.concatenate([-mat[..., 1::2], mat[..., 0::2]], axis=-1)
    # print(rotate_half_mat.shape, mat.shape, sin.shape, cos.shape)
    # rotated = cos * mat + sin * rotate_half_mat
    rotated = jnp.einsum("TBHD,TD->TBHD", mat, cos) + jnp.einsum(
        "TBHD,TD->TBHD", rotate_half_mat, sin
    )

    return rotated


class RotCausalScanLatte(nn.Module):
    """
    Numerically stable causal latent attention.
    """

    config: Config
    unroll: int = 100
    dtype: jnp.dtype = jnp.float32

    @staticmethod
    def accumulate(carry, args):
        nu, alpha, prev_max = carry
        Qs_t, curr_alph, V_t, c_mx = args
        revert_maxi = jnp.exp(-c_mx + prev_max)
        add_maxi = jnp.exp(curr_alph - c_mx)

        alpha = jnp.einsum("BHL,BHL->BHL", alpha, revert_maxi)
        alpha += add_maxi
        nu = jnp.einsum("BHLD,BHL->BHLD", nu, revert_maxi)
        nu += jnp.einsum("BHL,BHD->BHLD", add_maxi, V_t)
        y = jnp.einsum("BHL,BHLD->BHD", Qs_t / alpha, nu)
        return ((nu, alpha, c_mx), y)

    @staticmethod
    def accumulate3(carry, args):
        """Optimized version of accumulate2 where we compute as much as possible outside the sequentioal operation"""
        nu = carry
        Qs_t, V_t, revert_maxi, add_maxi = args
        nu = jnp.einsum("BHLD,BHL->BHLD", nu, revert_maxi)
        nu += jnp.einsum("BHL,BHD->BHLD", add_maxi, V_t)
        y = jnp.einsum("BHL,BHLD->BHD", Qs_t, nu)
        return nu, y

    def mix_sequence(self, Q, K, V, Q_drop, rot_embeds):
        T, B, H, C = V.shape
        L = Q.shape[-1]

        maxi = jax.lax.cummax(K, axis=0)
        # maxi for stability should be trated as a constant - no grad is faster
        maxi = jax.lax.stop_gradient(maxi)

        init_alpha = jnp.zeros(shape=(B, H, L), dtype=self.dtype)
        init_nu = jnp.zeros((B, H, L, C), dtype=self.dtype)
        Qs = jax.nn.softmax(Q, axis=-1)
        Qs = Q_drop(Qs)
        # calc R^{-s}x_s
        # V = apply_rotation(sinusoidal_enc=sin_pos, mat=V, neg=True)
        # if isinstance(rot_embeds, XPos):
        #     # T, B, self.config.nheads, -1 -> BHTD
        #     V = rot_embeds(V.transpose(1, 2, 0, 3), offset=0, downscale=True).transpose(
        #         2, 0, 1, 3
        #     )
        # else:
        #     V = rot_embeds.apply_vapor(mat=V, neg=True)
        # V = V_drop(V)

        _, y = jax.lax.scan(
            self.accumulate,
            unroll=self.unroll,
            init=(
                init_nu,
                init_alpha,
                K[0],
            ),
            xs=[Qs, K, V, maxi],
        )
        # calc R^t \sum_l ...
        # y = apply_rotation(sinusoidal_enc=sin_pos, mat=y, neg=False)
        # if isinstance(self.rot_embeds, XPos):
        #     # TBHD -> BHTD
        #     y = y.transpose(1, 2, 0, 3)
        #     y = rot_embeds(y, offset=0, downscale=False)
        #     # BHTD -> BTHD
        #     y = y.transpose(0, 2, 1, 3).reshape(B, T, -1)
        #     return y, Qs
        # else:
        #     # TBHD -> BHTD
        #     # y = y.transpose(1, 2, 0, 3)
        #     y = rot_embeds.apply_vapor(mat=y, neg=False)
        # TBHD -> BTHD
        y = y.transpose(1, 0, 2, 3)
        y = y.reshape(B, T, -1)
        return y, Qs

    def mix_sequence4(self, Q, K, V, Q_drop, rot_embeds):
        """Fastest O(TLD)
        Args:
            Q: jax.Array(T,B,H,L)
            K: jax.Array(T,B,H,L)
            V: jax.Array(T,B,H,D)
        """
        T, B, H, C = V.shape
        L = Q.shape[-1]
        # calc R^{-s}x_s
        # V = apply_rotation(sinusoidal_enc=sin_pos, mat=V, neg=True)
        # if isinstance(rot_embeds, XPos):
        #     # T, B, self.config.nheads, -1 -> BHTD
        #     V = rot_embeds(V.transpose(1, 2, 0, 3), offset=0, downscale=True).transpose(
        #         2, 0, 1, 3
        #     )
        # else:
        #     V = rot_embeds.apply_vapor(mat=V, neg=True)
        # V = V_drop(V)
        Qs = jax.nn.softmax(Q, axis=-1)
        Qs = Q_drop(Qs)

        maxi = jax.lax.cummax(K, axis=0)
        # maxi for stability should be trated as a constant - no grad is faster
        maxi = jax.lax.stop_gradient(maxi)
        # revert maxi
        revert_maxi = jnp.zeros_like(maxi)
        revert_maxi = revert_maxi.at[1:].set(-maxi[1:] + maxi[:-1])
        revert_maxi = jnp.exp(revert_maxi)  # TBHL
        add_maxi = jnp.exp(K - maxi)
        nu = jnp.einsum("TBHL,TBHD->TBHLD", add_maxi, V)

        def bin_V(A, B):
            rmA, amA, nuA = A
            rmB, amB, nuB = B
            nu = nuA * rmB[..., None] + nuB
            alpha = amA * rmB + amB
            return (rmA * rmB, alpha, nu)

        _, alpha, y = parallel_scan(bin_V, (revert_maxi, add_maxi, nu))
        y = jnp.einsum("TBHL,TBHLD->TBHD", Qs / alpha, y)

        # calc R^t \sum_l ...
        # y = apply_rotation(sinusoidal_enc=sin_pos, mat=y, neg=False)
        # if isinstance(rot_embeds, XPos):
        #     # TBHD -> BHTD
        #     y = y.transpose(1, 2, 0, 3)
        #     y = rot_embeds(y, offset=0, downscale=False)
        #     # BHTD -> BTHD
        #     y = y.transpose(0, 2, 1, 3).reshape(B, T, -1)
        #     return y, Qs
        # else:
        #     # TBHD -> BHTD
        #     # y = y.transpose(1, 2, 0, 3)
        #     y = rot_embeds.apply_vapor(mat=y, neg=False)
        # TBHD -> BTHD
        y = y.transpose(1, 0, 2, 3)
        y = y.reshape(B, T, -1)
        return y, Qs

    @nn.compact
    def __call__(
        self,
        X: jnp.array,
        train: bool = False,
        cache: Dict[str, Any] = None,
        do_inference: bool = False,
        **kwargs,
    ) -> jnp.array:
        """
        B: batch size H: nr heads, T: seq_len, D: hidden_dim. L: latent dimension
        Args:
            X: jnp.array(BTD)
            train: bool - Constant used for dropout
        Returns:
            y: jnp.array(BTD) - transformed output sequence
        """
        if self.config.embed_type == "rope":
            rot_embeds = RopeEmbeds(
                n_pos=self.config.pos_embed_max_len,
                d_model=self.config.hidden_dim // self.config.nheads,
                dtype=self.dtype,
            )
        elif self.config.embed_type == "xpos":
            rot_embeds = XPos(
                head_dim=self.config.hidden_dim // self.config.nheads,
                scale_base=self.config.max_seq_len,
                dtype=self.dtype,
            )
        Wk = self.param(
            "Wk",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        Wq = self.param(
            "Wq",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        Wv = self.param(
            "Wv",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.hidden_dim),
        )
        o_proj = self.param(
            "o_proj",
            jax.nn.initializers.normal(
                stddev=self.config.initializer_range
                / math.sqrt(2 * self.config.nlayers)
            ),
            (self.config.hidden_dim, self.config.hidden_dim),
        )
        Wk, Wq, Wv, o_proj = promote_dtype(Wk, Wq, Wv, o_proj, dtype=self.dtype)
        Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)
        V_drop = None  # nn.Dropout(self.config.dropout_att, deterministic=not train)
        resid_drop = nn.Dropout(self.config.dropout, deterministic=not train)
        if do_inference:
            return self.inference(
                X,
                cache=cache,
                Wq=Wq,
                Wk=Wk,
                Wv=Wv,
                o_proj=o_proj,
                Q_drop=Q_drop,
                V_drop=V_drop,
                resid_drop=resid_drop,
            )

        B, T, _ = X.shape
        H, L = self.config.nheads, self.config.L // self.config.nheads
        # multi head implementation
        V = jnp.einsum("DM,BTD->TBM", Wv, X).reshape(T, B, H, -1)
        Q = jnp.einsum("DL,BTD->TBL", Wq, X).reshape(T, B, H, -1)
        K = jnp.einsum("DL,BTD->TBL", Wk, X).reshape(T, B, H, -1)

        y, _ = self.mix_sequence4(Q=Q, K=K, V=V, Q_drop=Q_drop, rot_embeds=rot_embeds)
        # self.sow("intermediates", "Qs", Qs)
        # self.sow("intermediates", "K", K)
        return y @ o_proj


class RotConvAllCausalScanLatte(RotCausalScanLatte):
    """
    Convolve entire input before we apply latte with projections
    """

    config: Config
    unroll: int = 100
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(
        self,
        X: jnp.array,
        train: bool = False,
        cache: Dict[str, Any] = None,
        do_inference: bool = False,
        **kwargs,
    ) -> jnp.array:
        """
        B: batch size H: nr heads, T: seq_len, D: hidden_dim. L: latent dimension
        Args:
            X: jnp.array(BTD)
            train: bool - Constant used for dropout
        Returns:
            y: jnp.array(BTD) - transformed output sequence
        """
        if self.config.embed_type == "rope":
            rot_embeds = RopeEmbeds(
                n_pos=self.config.pos_embed_max_len,
                d_model=self.config.hidden_dim // self.config.nheads,
            )
        elif self.config.embed_type == "xpos":
            rot_embeds = XPos(
                head_dim=self.config.hidden_dim // self.config.nheads,
                scale_base=self.config.max_seq_len,
            )
        Wk = self.param(
            "Wk",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        Wq = self.param(
            "Wq",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        Wv = self.param(
            "Wv",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.hidden_dim),
        )
        o_proj = self.param(
            "o_proj",
            jax.nn.initializers.normal(
                stddev=self.config.initializer_range
                / math.sqrt(2 * self.config.nlayers)
            ),
            (self.config.hidden_dim, self.config.hidden_dim),
        )
        conv = Conv1D(
            nchannels=self.config.hidden_dim,
            out_channels=self.config.hidden_dim,
            kernel_size=3,
            dtype=self.dtype,
        )
        Wk, Wq, Wv, o_proj = promote_dtype(Wk, Wq, Wv, o_proj, dtype=self.dtype)
        # Wv, Wk, o_proj = promote_dtype(Wv, Wk, o_proj, dtype=self.dtype)
        Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)
        V_drop = None  # nn.Dropout(self.config.dropout_att, deterministic=not train)
        resid_drop = nn.Dropout(self.config.dropout, deterministic=not train)
        if do_inference:
            return self.inference(
                X,
                cache=cache,
                Wq=Wq,
                Wk=Wk,
                Wv=Wv,
                o_proj=o_proj,
                Q_drop=Q_drop,
                V_drop=V_drop,
                resid_drop=resid_drop,
            )

        B, T, _ = X.shape
        H, L = self.config.nheads, self.config.L // self.config.nheads
        # multi head implementation
        Y = conv(X)
        V = jnp.einsum("DM,BTD->TBM", Wv, Y).reshape(T, B, H, -1)
        Q = jnp.einsum("DL,BTD->TBL", Wq, Y).reshape(T, B, H, -1)
        K = jnp.einsum("DL,BTD->TBL", Wk, Y).reshape(T, B, H, -1)

        y, Qs = self.mix_sequence4(Q=Q, K=K, V=V, Q_drop=Q_drop, rot_embeds=rot_embeds)
        self.sow("intermediates", "Qs", Qs)
        self.sow("intermediates", "K", K)
        return y @ o_proj


class MyDropout(nn.Module):
    """Create a dropout layer.

    .. note::
      When using :meth:`Module.apply() <flax.linen.Module.apply>`, make sure
      to include an RNG seed named ``'dropout'``. Dropout isn't necessary for
      variable initialization.

    Example usage::

      >>> import flax.linen as nn
      >>> import jax, jax.numpy as jnp

      >>> class MLP(nn.Module):
      ...   @nn.compact
      ...   def __call__(self, x, train):
      ...     x = nn.Dense(4)(x)
      ...     x = nn.Dropout(0.5, deterministic=not train)(x)
      ...     return x

      >>> model = MLP()
      >>> x = jnp.ones((1, 3))
      >>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout
      >>> model.apply(variables, x, train=False) # don't use dropout
      Array([[-0.88686204, -0.5928178 , -0.5184689 , -0.4345976 ]], dtype=float32)
      >>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout
      Array([[ 0.       , -1.1856356, -1.0369378,  0.       ]], dtype=float32)

    Attributes:
      rate: the dropout probability.  (_not_ the keep rate!)
      broadcast_dims: dimensions that will share the same dropout mask
      deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and
        masked, whereas if true, no mask is applied and the inputs are returned as
        is.
      rng_collection: the rng collection name to use when requesting an rng key.
    """

    rate: float
    deterministic: bool = None
    rng_collection: str = "dropout"

    @nn.compact
    def __call__(
        self,
        inputs,
        deterministic: bool = None,
        rng: jax.random.PRNGKey = None,
    ):
        """Applies a random dropout mask to the input.

        Args:
          inputs: the inputs that should be randomly masked.
          deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and
            masked, whereas if true, no mask is applied and the inputs are returned
            as is.
          rng: an optional PRNGKey used as the random key, if not specified, one
            will be generated using ``make_rng`` with the ``rng_collection`` name.

        Returns:
          The masked inputs reweighted to preserve mean.
        """

        if (self.rate == 0.0) or deterministic:
            return inputs

        # Prevent gradient NaNs in 1.0 edge-case.
        if self.rate == 1.0:
            return jnp.zeros_like(inputs)

        keep_prob = 1.0 - self.rate
        if rng is None:
            rng = self.make_rng(self.rng_collection)
        broadcast_shape = list(inputs.shape)
        mask = jax.random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
        # avoid nans
        masked = jax.lax.select(mask, inputs, jnp.zeros_like(inputs)) + 1e-8
        masked /= masked.sum(axis=-1, keepdims=True)
        return masked


class RotConvQKCausalScanLatte(nn.Module):
    """
    Convolve only Q and K before we apply latte with projections
    """

    config: Config
    unroll: int = 100
    rot_embeds: jnp.array = None
    dtype: jnp.dtype = jnp.float32

    def mix_sequence4(self, Qs, K, V, rot_embeds):
        """Fastest O(TLD)
        Args:
            Q: jax.Array(T,B,H,L)
            K: jax.Array(T,B,H,L)
            V: jax.Array(T,B,H,D)
        """
        T, B, H, C = V.shape
        L = Qs.shape[-1]

        maxi = jax.lax.cummax(K, axis=0)
        # maxi for stability should be trated as a constant - no grad is faster
        maxi = jax.lax.stop_gradient(maxi)
        # revert maxi
        revert_maxi = jnp.zeros_like(maxi)
        revert_maxi = revert_maxi.at[1:].set(-maxi[1:] + maxi[:-1])
        revert_maxi = jnp.exp(revert_maxi)  # TBHL
        add_maxi = jnp.exp(K - maxi)
        nu = jnp.einsum("TBHL,TBHD->TBHLD", add_maxi, V)

        def bin_V(A, B):
            rmA, amA, nuA = A
            rmB, amB, nuB = B
            nu = nuA * rmB[..., None] + nuB
            alpha = amA * rmB + amB
            return (rmA * rmB, alpha, nu)

        _, alpha, y = parallel_scan(bin_V, (revert_maxi, add_maxi, nu))
        y = jnp.einsum("TBHL,TBHLD->TBHD", Qs / alpha, y)
        # TBHD -> BTHD
        y = y.transpose(1, 0, 2, 3)
        y = y.reshape(B, T, -1)
        return y, Qs

    @nn.compact
    def __call__(
        self,
        X: jnp.array,
        train: bool = False,
        cache: Dict[str, Any] = None,
        do_inference: bool = False,
        **kwargs,
    ) -> jnp.array:
        """
        B: batch size H: nr heads, T: seq_len, D: hidden_dim. L: latent dimension
        Args:
            X: jnp.array(BTD)
            train: bool - Constant used for dropout
        Returns:
            y: jnp.array(BTD) - transformed output sequence
        """
        if self.config.embed_type == "rope":
            rot_embeds = RopeEmbeds(
                n_pos=self.config.pos_embed_max_len,
                d_model=self.config.hidden_dim // self.config.nheads,
                dtype=self.dtype,
            )
        elif self.config.embed_type == "xpos":
            rot_embeds = XPos(
                head_dim=self.config.hidden_dim // self.config.nheads,
                scale_base=self.config.max_seq_len,
                dtype=self.dtype,
            )
        Wk = self.param(
            "Wk",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.L, self.config.L),
            # (self.config.hidden_dim, self.config.L),
        )
        Wq = self.param(
            "Wq",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.L, self.config.L),
            # (self.config.hidden_dim, self.config.L),
        )
        Wv = self.param(
            "Wv",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.hidden_dim),
        )
        o_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            # kernel_init=jax.nn.initializers.normal(
            #     stddev=self.config.initializer_range
            #     / math.sqrt(2 * self.config.nlayers)
            # ),
        )
        # conv = DepthConv1D(
        #     nchannels=self.config.hidden_dim,
        #     out_channels=self.config.hidden_dim,
        #     kernel_size=3,
        #     dtype=self.dtype,
        # )

        # conv = S5Layer(
        #     ssm_size=128,
        #     hidden_dim=self.config.hidden_dim,
        #     blocks=8,
        #     dtype=self.dtype,
        #     name="S5",
        # )

        lru_in = self.param(
            "lru_in",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        conv = RGLRU(
            width=self.config.L,
            num_heads=self.config.nheads,
            dtype=self.dtype,
        )

        Wk, Wq, Wv = promote_dtype(Wk, Wq, Wv, dtype=self.dtype)
        # Wv, Wk, o_proj = promote_dtype(Wv, Wk, o_proj, dtype=self.dtype)
        # self.config.dropout_att
        Q_drop = nn.Dropout(self.config.dropout_att)  # MyDropout(rate=0.1)

        B, T, _ = X.shape
        H, L = self.config.nheads, self.config.L // self.config.nheads
        # multi head implementation
        V = jnp.einsum("DM,BTD->TBM", Wv, X).reshape(T, B, H, -1)
        # Y = conv(X)
        pos_ids = jnp.repeat(jnp.arange(T)[None], B, axis=0)
        Y, _ = conv(jnp.einsum("DL,BTD->BTL", lru_in, X), pos_ids, return_cache=False)
        Y = RMSNorm(self.config.L, dtype=self.dtype)(Y)

        Q = jnp.einsum("DL,BTD->TBL", Wq, Y).reshape(T, B, H, -1)
        K = jnp.einsum("DL,BTD->TBL", Wk, Y).reshape(T, B, H, -1)

        Qs = jax.nn.softmax(Q, axis=-1)
        # Qs = Q_drop(Qs, deterministic=not train)
        y, _ = self.mix_sequence4(Qs=Qs, K=K, V=V, rot_embeds=rot_embeds)
        return o_proj(y)


class RotBidLatte(nn.Module):
    """
    Bidirectional version in which we sum to "T" instead of "t".
    No sequential implementation required.
    """

    config: Config
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(
        self,
        X: jnp.array,
        attention_mask: jnp.array = None,
        train: bool = False,
        **kwargs,
    ) -> jnp.array:
        """
        B: batch size H: nr heads, T: seq_len, D: hidden_dim. L: latent dimension
        Args:
            X: jnp.array(BTD)
            attention_mask: jnp.array(BTD) - attnention used to ignore pads
                Only used in bidirectional since we sum up to T, and pad needed for batching
            train: bool - Just to respect the interface of trainer.
        Returns:
            y: jnp.array(BTD) - transformed output sequence
        """
        if self.config.embed_type == "rope":
            rot_embeds = RopeEmbeds(
                n_pos=self.config.pos_embed_max_len,
                d_model=self.config.hidden_dim // self.config.nheads,
            )
        elif self.config.embed_type == "xpos":
            rot_embeds = XPos(
                head_dim=self.config.hidden_dim // self.config.nheads,
                scale_base=self.config.max_seq_len,
            )
        B, T, D = X.shape
        L, H = self.config.L, self.config.nheads
        # sin_pos = self.rot_embeds[:T, :]  # T D

        Wk = self.param(
            "Wk",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        Wq = self.param(
            "Wq",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        Wv = self.param(
            "Wv",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.hidden_dim),
        )
        Wk, Wq, Wv, o_proj = promote_dtype(Wk, Wq, Wv, o_proj, dtype=self.dtype)

        o_proj = self.param(
            "o_proj",
            jax.nn.initializers.normal(
                stddev=self.config.initializer_range
                / math.sqrt(2 * self.config.nlayers)
            )(self.config.hidden_dim, self.config.hidden_dim),
        )
        Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)

        # multi head implementation
        V = jnp.einsum("DM,BTD->TBM", Wv, X, precision=PRECISION).reshape(T, B, H, -1)
        Q = jnp.einsum("DL,BTD->TBL", Wq, X, precision=PRECISION).reshape(T, B, H, -1)
        K = jnp.einsum("DL,BTD->LBT", Wk, X, precision=PRECISION)

        V = rot_embeds.apply_vapor(mat=V, neg=True)

        K = jnp.where(attention_mask, K, -9e15).transpose(2, 1, 0).reshape(T, B, H, -1)
        Qs = jax.nn.softmax(Q, axis=-1)  # T B H L
        Qs = Q_drop(Qs)
        maxi = jnp.max(K, axis=0, keepdims=True)
        K = jnp.exp(K - maxi)

        Kv = jnp.einsum("TBHL,TBHD->BHLD", K, V, precision=PRECISION)
        # normalize
        K = K.sum(axis=0)  # BLH
        Kv = jnp.einsum("BHL,BHLD->BHLD", 1 / K, Kv, precision=PRECISION)
        y = jnp.einsum("TBHL,BHLD->TBHD", Qs, Kv, precision=PRECISION)

        # calc R^t \sum_l ...
        y = rot_embeds.apply_vapor(mat=y, neg=False)
        # TBHD -> BTHD
        y = y.transpose(1, 0, 2, 3)
        y = y.reshape(B, T, -1)
        return y @ o_proj
