from typing import Type
from functools import partial
import math
import jax
from flax import linen as nn
from jax import numpy as jnp, lax
from .xPos import XPos
from .layers import RopeEmbeds
from latte_trans.config import Config


def apply_rotation(rel_pos, k, q):
    """
    Implement rotation where rel_pos is already A^t.
    Uses fast implementation of the parse matrix
    Args:
        rel_pos: Union[jnp.array(B, H, T, D], Tuple] -> TBHD
            Half sin & second half cos
        k,q: jnp.array(B,H,T,D) # B, nh, T, hs
            input matrix
        neg: bool
            Denotes weather we need to calculate R^{-s}
    """
    if isinstance(rel_pos, XPos):
        k = rel_pos(k, offset=0, downscale=True)
        q = rel_pos(q, offset=0, downscale=False)
    else:
        q = rel_pos(q)
        k = rel_pos(k)
    return k, q


def attention(attn_dropout, q, k, v, mask):
    # manual implementation of attention
    att = (q @ jnp.swapaxes(k, -2, -1)) * (1.0 / math.sqrt(k.shape[-1]))

    # att = att.masked_fill(bias[:,:,:T,:T] == 0, float('-inf'))
    att = jnp.where(mask == 0, -9e15, att)
    att = jax.nn.softmax(att, axis=-1)

    att = attn_dropout(att)
    y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
    return y


class CausalSelfAttention(nn.Module):
    config: Config
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, X: jnp.array, train: bool, **kwargs) -> jnp.array:
        # key, query, value projections for all heads, but in a batch
        c_attn = nn.Dense(
            3 * self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )
        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
                / math.sqrt(2 * self.config.nlayers)
            ),
        )

        # regularization
        attn_dropout = nn.Dropout(rate=self.config.dropout_att, deterministic=not train)
        resid_dropout = nn.Dropout(rate=self.config.dropout, deterministic=not train)

        B, T, C = (
            X.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)

        # causal mask to ensure that attention is only applied to the left in the input sequence
        bias = jnp.tril(jnp.ones(shape=(T, T))).reshape(1, 1, T, T).astype(jnp.bool_)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = jnp.split(c_attn(X), 3, axis=2)
        k = k.reshape(B, T, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        q = q.reshape(B, T, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        v = v.reshape(B, T, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)

        y = attention(attn_dropout, q, k, v, mask=bias)
        y = y.transpose(0, 2, 1, 3).reshape(
            B, T, C
        )  # re-assemble all head outputs side by side

        # output projection
        y = c_proj(y)  # resid_dropout(c_proj(y))
        return y


################# Attention  #################
class CrossAttention(nn.Module):
    # nr_heads: int = 4
    # hidden_dim: int = 128
    # dropout: float = 0.0
    # att_dropout: float = 0.0
    config: Config
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(
        self,
        src: jnp.array,
        src_att_mask: jnp.array,
        tgt: jnp.array,
        train: bool,
        **kwargs,
    ) -> jnp.array:
        # key, query, value projections for all heads, but in a batch
        Wq = nn.Dense(self.config.hidden_dim, use_bias=False, dtype=self.dtype)
        Wk = nn.Dense(self.config.hidden_dim, use_bias=False, dtype=self.dtype)
        Wv = nn.Dense(self.config.hidden_dim, use_bias=False, dtype=self.dtype)

        # output projection
        c_proj = nn.Dense(self.config.hidden_dim, use_bias=False, dtype=self.dtype)

        # regularization
        attn_dropout = nn.Dropout(rate=self.config.dropout_att, deterministic=not train)
        resid_dropout = nn.Dropout(rate=self.config.dropout, deterministic=not train)

        B, S, C = (
            src.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)
        T = tgt.shape[1]

        # causal mask to ensure that attention is only applied to the left in the input sequence
        bias = jnp.tril(jnp.ones(shape=(T, S))).reshape(1, 1, T, S).astype(jnp.bool_)
        # expand source masK BS -> BHTS
        src_att_mask = src_att_mask[:, None, None, :]
        src_att_mask = jnp.broadcast_to(src_att_mask, shape=(B, 1, T, S))

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = Wq(tgt), Wk(src), Wv(src)

        q = q.reshape(B, T, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        k = k.reshape(B, S, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        v = v.reshape(B, S, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)

        mask = src_att_mask  # & bias
        y = attention(attn_dropout, q, k, v, mask=mask)
        y = y.transpose(0, 2, 1, 3).reshape(
            B, T, C
        )  # re-assemble all head outputs side by side

        # output projection
        y = resid_dropout(c_proj(y))
        return y


class BidirectionalAttention(nn.Module):
    # nr_heads: int = 4
    # hidden_dim: int = 128
    # dropout: float = 0.0
    # att_dropout: float = 0.0
    config: Config
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(
        self, src: jnp.array, train: bool, attention_mask: jnp.array, **kwargs
    ) -> jnp.array:
        # key, query, value projections for all heads, but in a batch
        Wq = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )
        Wk = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )
        Wv = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )

        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
                / math.sqrt(2 * self.config.nlayers)
            ),
        )

        # regularization
        attn_dropout = nn.Dropout(rate=self.config.dropout_att, deterministic=not train)
        resid_dropout = nn.Dropout(rate=self.config.dropout, deterministic=not train)

        B, S, C = (
            src.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)

        # expand source masK BS -> BHTS
        attention_mask = attention_mask[:, None, None, :]
        attention_mask = jnp.broadcast_to(attention_mask, shape=(B, 1, S, S))

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = Wq(src), Wk(src), Wv(src)

        q = q.reshape(B, S, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        k = k.reshape(B, S, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        v = v.reshape(B, S, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)

        y = attention(attn_dropout, q, k, v, mask=attention_mask)
        y = y.transpose(0, 2, 1, 3).reshape(
            B, S, C
        )  # re-assemble all head outputs side by side

        # output projection
        y = resid_dropout(c_proj(y))
        return y


class RopeCrossAtt(nn.Module):
    # nr_heads: int = 4
    # hidden_dim: int = 128
    # dropout: float = 0.0
    # att_dropout: float = 0.0
    config: Config
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(
        self,
        src: jnp.array,
        src_att_mask: jnp.array,
        tgt: jnp.array,
        train: bool,
        **kwargs,
    ) -> jnp.array:
        if self.config.embed_type == "rope":
            rot_embeds = RopeEmbeds(
                n_pos=self.config.pos_embed_max_len,
                d_model=self.config.hidden_dim // self.config.nheads,
            )
        elif self.config.embed_type == "xpos":
            rot_embeds = XPos(
                head_dim=self.config.hidden_dim // self.config.nheads,
                scale_base=self.config.max_seq_len,
            )
        # key, query, value projections for all heads, but in a batch
        Wq = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )
        Wk = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )
        Wv = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )

        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
                / math.sqrt(2 * self.config.nlayers)
            ),
        )

        # regularization
        attn_dropout = nn.Dropout(rate=self.config.dropout_att, deterministic=not train)
        resid_dropout = nn.Dropout(rate=self.config.dropout, deterministic=not train)

        B, S, C = (
            src.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)
        T = tgt.shape[1]

        # causal mask to ensure that attention is only applied to the left in the input sequence
        bias = jnp.tril(jnp.ones(shape=(T, S))).reshape(1, 1, T, S).astype(jnp.bool_)
        # expand source masK BS -> BHTS
        src_att_mask = src_att_mask[:, None, None, :]
        src_att_mask = jnp.broadcast_to(src_att_mask, shape=(B, 1, T, S))

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = Wq(tgt), Wk(src), Wv(src)

        q = q.reshape(B, T, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        k = k.reshape(B, S, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        v = v.reshape(B, S, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)

        k, q = apply_rotation(rot_embeds, k, q)

        mask = src_att_mask  # & bias
        y = attention(attn_dropout, q, k, v, mask=mask)
        y = y.transpose(0, 2, 1, 3).reshape(
            B, T, C
        )  # re-assemble all head outputs side by side

        # output projection
        y = resid_dropout(c_proj(y))
        return y


class RopeBidAtt(nn.Module):
    # nr_heads: int = 4
    # hidden_dim: int = 128
    # dropout: float = 0.0
    # att_dropout: float = 0.0
    config: Config
    rot_embeds: jnp.array = None
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(
        self, src: jnp.array, train: bool, attention_mask: jnp.array, **kwargs
    ) -> jnp.array:
        if self.config.embed_type == "rope":
            rot_embeds = RopeEmbeds(
                n_pos=self.config.pos_embed_max_len,
                d_model=self.config.hidden_dim // self.config.nheads,
            )
        elif self.config.embed_type == "xpos":
            rot_embeds = XPos(
                head_dim=self.config.hidden_dim // self.config.nheads,
                scale_base=self.config.max_seq_len,
            )
        # key, query, value projections for all heads, but in a batch
        Wq = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )
        Wk = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )
        Wv = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )

        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
                / math.sqrt(2 * self.config.nlayers)
            ),
        )

        # regularization
        attn_dropout = nn.Dropout(rate=self.config.dropout_att, deterministic=not train)
        resid_dropout = nn.Dropout(rate=self.config.dropout, deterministic=not train)

        B, S, C = (
            src.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)

        # expand source masK BS -> BHTS
        attention_mask = attention_mask[:, None, None, :]
        attention_mask = jnp.broadcast_to(attention_mask, shape=(B, 1, S, S))

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = Wq(src), Wk(src), Wv(src)

        q = q.reshape(B, S, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        k = k.reshape(B, S, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        v = v.reshape(B, S, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)

        k, q = apply_rotation(rot_embeds, k, q)  # BHTD

        y = attention(attn_dropout, q, k, v, mask=attention_mask)
        y = y.transpose(0, 2, 1, 3).reshape(
            B, S, C
        )  # re-assemble all head outputs side by side

        # output projection
        y = resid_dropout(c_proj(y))
        return y


class CausalRope(nn.Module):
    # nr_heads: int = 4
    # hidden_dim: int = 128
    # dropout: float = 0.0
    # att_dropout: float = 0.0
    config: Config
    rot_embeds: jnp.array = None
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, X: jnp.array, train: bool, **kwargs) -> jnp.array:
        if self.config.embed_type == "rope":
            rot_embeds = RopeEmbeds(
                n_pos=self.config.pos_embed_max_len,
                d_model=self.config.hidden_dim // self.config.nheads,
            )
        elif self.config.embed_type == "xpos":
            rot_embeds = XPos(
                head_dim=self.config.hidden_dim // self.config.nheads,
                scale_base=self.config.max_seq_len,
            )
        # key, query, value projections for all heads, but in a batch
        c_attn = nn.Dense(
            3 * self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )
        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            # kernel_init=jax.nn.initializers.normal(
            #     stddev=self.config.initializer_range
            #     / math.sqrt(2 * self.config.nlayers)
            # ),
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )

        # regularization
        attn_dropout = nn.Dropout(rate=self.config.dropout_att, deterministic=not train)
        resid_dropout = nn.Dropout(rate=self.config.dropout, deterministic=not train)
        B, T, C = (
            X.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)

        # causal mask to ensure that attention is only applied to the left in the input sequence
        bias = jnp.tril(jnp.ones(shape=(T, T))).reshape(1, 1, T, T)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = jnp.split(c_attn(X), 3, axis=2)
        k = k.reshape(B, T, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        q = q.reshape(B, T, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        v = v.reshape(B, T, self.config.nheads, C // self.config.nheads).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        k, q = apply_rotation(rot_embeds, k, q)  # BHTD

        att = (q @ jnp.swapaxes(k, -2, -1)) * (1.0 / math.sqrt(k.shape[-1]))

        # att = att.masked_fill(bias[:,:,:T,:T] == 0, float('-inf'))
        att = jnp.where(bias[:, :, :T, :T] == 0, float("-inf"), att)
        att = jax.nn.softmax(att, axis=-1)

        att = attn_dropout(att)

        y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)

        y = y.transpose(0, 2, 1, 3).reshape(
            B, T, C
        )  # re-assemble all head outputs side by side

        # output projection
        y = c_proj(y)  # resid_dropout(c_proj(y))
        return y


class ScanCausalSelfAttention(nn.Module):
    """
    Implement the memory efficient version of attention using scans
    """

    # nr_heads: int = 4
    # hidden_dim: int = 128
    # dropout: float = 0.0
    # att_dropout: float = 0.0
    config: Config
    query_chunk_attention: int = (
        1024  # Sub-sequence on which to perform normal self-attention
    )
    unroll: int = 100
    dtype: jnp.dtype = jnp.float32

    @staticmethod
    def _query_chunk_attention(
        drop_layer: Type["nn.Module"],
        query: jnp.array,
        key: jnp.array,
        value: jnp.array,
        bias: jnp.array,
        key_chunk_size: int = 4096,
        precision: Type["lax.Precision"] = lax.Precision.HIGHEST,
        dtype: Type["jnp.dtype"] = jnp.float32,
    ) -> jnp.array:
        num_kv, B, num_heads, k_features = key.shape
        T = query.shape[0]
        B = query.shape[1]
        v_features = value.shape[-1]
        key_chunk_size = min(key_chunk_size, num_kv)
        query = query / jnp.sqrt(k_features).astype(dtype)

        @partial(jax.checkpoint, prevent_cse=False)
        def summarize_chunk(drop_layer, bias, query, key, value):
            attn_weights = jnp.einsum("qbhd,kbhd->bhqk", query, key).astype(dtype)
            attn_weights = jnp.where(
                bias[None, None, ...] == 0, float("-inf"), attn_weights
            )
            attn_weights = attn_weights.transpose(2, 0, 1, 3)  # bhqk->qbhk
            max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
            max_score = jax.lax.stop_gradient(max_score)
            exp_weights = jnp.exp(attn_weights - max_score)

            exp_weights = jnp.einsum("qbhk,qk->qbhk", exp_weights, bias)
            # dropout applied only on the numerator to simulate dropout after softmax
            exp_weights_drop = drop_layer(exp_weights)
            exp_values = jnp.einsum(
                "vbhf,qbhv->qbhf", value, exp_weights_drop, precision=precision
            ).astype(dtype)
            return (
                exp_values,
                exp_weights.sum(axis=-1),
                max_score.reshape((query.shape[0], B, num_heads)),
            )

        def chunk_scanner(drop_layer, chunk_idx):
            key_chunk = lax.dynamic_slice(
                key,
                (chunk_idx, 0, 0, 0),
                slice_sizes=(key_chunk_size, B, num_heads, k_features),
            )

            value_chunk = lax.dynamic_slice(
                value,
                (chunk_idx, 0, 0, 0),
                slice_sizes=(key_chunk_size, B, num_heads, v_features),
            )

            bias_chunk = lax.dynamic_slice(
                bias, (0, chunk_idx), slice_sizes=(T, key_chunk_size)
            )

            return summarize_chunk(
                drop_layer, bias_chunk, query, key_chunk, value_chunk
            )

        fn = nn.vmap(
            chunk_scanner,
            split_rngs={"params": False, "dropout": True},
        )
        chunk_values, chunk_weights, chunk_max = fn(
            drop_layer, jnp.arange(0, num_kv, key_chunk_size)
        )
        global_max = jnp.max(chunk_max, axis=0, keepdims=True)
        max_diffs = jnp.exp(chunk_max - global_max)
        chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
        chunk_weights *= max_diffs

        all_values = chunk_values.sum(axis=0)
        all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
        return all_values / all_weights

    def mefficient_attention(
        self,
        query: jnp.array,
        key: jnp.array,
        value: jnp.array,
        causal_mask: jnp.array,
        drop_layer: Type["nn.Module"],
        query_chunk_size: int = 1024,
        precision: Type["lax.Precision"] = jax.lax.Precision.HIGHEST,
        dtype: Type["jnp.dtype"] = jnp.float32,
    ):
        num_q, B, num_heads, q_features = query.shape

        def chunk_scanner(drop_layer, chunk_idx, _):
            query_chunk = lax.dynamic_slice(
                query,
                (chunk_idx, 0, 0, 0),
                slice_sizes=(min(query_chunk_size, num_q), B, num_heads, q_features),
            )
            causal_mask_chunk = lax.dynamic_slice(
                causal_mask,
                (chunk_idx, 0),
                slice_sizes=(min(query_chunk_size, num_q), num_q),
            )

            return (
                chunk_idx + query_chunk_size,
                self._query_chunk_attention(
                    drop_layer,
                    query_chunk,
                    key,
                    value,
                    causal_mask_chunk,
                    precision=precision,
                    dtype=dtype,
                ),
            )

        fn = nn.scan(
            chunk_scanner,
            unroll=self.unroll,
            variable_broadcast="params",
            split_rngs={"params": False, "dropout": True},
            length=math.ceil(num_q / query_chunk_size),
        )
        _, res = fn(drop_layer, 0, None)
        return res.reshape(num_q, B, num_heads, value.shape[-1])

    @nn.compact
    def __call__(self, X: jnp.array, train: bool, **kwargs) -> jnp.array:
        """
        Sequential implementation of causal attention
        Args:
            X: jnp.array(BTD) - batch size (B), seq len (T), hidden dim (D)
            train: bool - dropout flag
        Returns:
            y: jnp.array(BTD) - transformed input sequence
        """
        # key, query, value projections for all heads, but in a batch
        c_attn = nn.Dense(
            3 * self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )
        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
                / math.sqrt(2 * self.config.nlayers)
            ),
        )

        # regularization
        attn_dropout = nn.Dropout(rate=self.dropout, deterministic=not train)
        resid_dropout = nn.Dropout(rate=self.dropout, deterministic=not train)
        B, T, C = (
            X.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        bias = jnp.tril(jnp.ones(shape=(T, T)))
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = jnp.split(c_attn(X), 3, axis=-1)
        head_dim = C // self.config.nheads
        k = k.reshape(B, T, self.config.nheads, head_dim).transpose(
            1, 0, 2, 3
        )  # T B H head_dim
        q = q.reshape(B, T, self.config.nheads, head_dim).transpose(
            1, 0, 2, 3
        )  # T B H head_dim
        v = v.reshape(B, T, self.config.nheads, head_dim).transpose(
            1, 0, 2, 3
        )  # T B H head_dim

        y = (
            self.mefficient_attention(
                q,
                k,
                v,
                bias,
                attn_dropout,
                query_chunk_size=self.query_chunk_attention,
                precision=None,
            )
            .transpose(1, 0, 2, 3)
            .reshape(B, T, C)
        )

        return resid_dropout(c_proj(y))
