"""
Notation:
    B = batch size, T = sequene length, D = embed/hidden dimension, H = number heads
"""

import warnings
from typing import Any, Dict, Tuple
import jax
from flax import linen as nn
from jax import numpy as jnp
import math
from flax.linen.dtypes import promote_dtype
from latte_trans.config import Config
from latte_trans.models.modules.layers import (
    S5Layer,
    DepthConv1D,
    RopeEmbeds,
    XPos,
    SlidingWindowAtt,
    RMSNorm,
)
from recurrentgemma.jax.layers import RGLRU

PRECISION = jax.lax.Precision.DEFAULT  # "HIGEST" #
parallel_scan = jax.lax.associative_scan


def attention_product(q, k, v, mask=None) -> Tuple[jnp.array, jnp.array]:
    """
    q,k,v: jnp.array(BHTD) - Query, Key, Value
    mask: causal or mask for bidirectional
    """
    d_d = q.shape[-1]
    att_logits = jnp.matmul(q, jnp.swapaxes(k, -2, -1))
    attn_logits = att_logits / math.sqrt(d_d)
    if mask is not None:
        attn_logits = jnp.where(mask == 0, -9e15, attn_logits)
    attention = nn.softmax(attn_logits, axis=-1)
    values = jnp.matmul(attention, v)
    return values, attention


class CausalScanLatte(nn.Module):
    """
    Numerically stable causal latent attention.
    """

    config: Config
    unroll: int = 100
    dtype: jnp.dtype = jnp.float32

    @staticmethod
    def accumulate(carry, args):
        nu, alpha, prev_max = carry
        Qs_t, curr_alph, V_t, c_mx = args
        revert_maxi = jnp.exp(-c_mx + prev_max)
        add_maxi = jnp.exp(curr_alph - c_mx)
        alpha = jnp.einsum("BHL,BHL->BHL", alpha, revert_maxi, precision=PRECISION)
        alpha += add_maxi
        nu = jnp.einsum("BHLD,BHL->BHLD", nu, revert_maxi, precision=PRECISION)
        nu += jnp.einsum("BHL,BHD->BHLD", add_maxi, V_t, precision=PRECISION)

        y = jnp.einsum("BHL,BHLD->BHD", Qs_t / alpha, nu, precision=PRECISION)
        return ((nu, alpha, c_mx), y)

    @staticmethod
    def accumulate3(carry, args):
        """Optimized version of accumulate where
        we compute as much as possible outside the sequential operation"""
        nu = carry
        Qs_t, V_t, revert_maxi, add_maxi = args
        nu = jnp.einsum("BHLD,BHL->BHLD", nu, revert_maxi)
        nu += jnp.einsum("BHL,BHD->BHLD", add_maxi, V_t)
        y = jnp.einsum("BHL,BHLD->BHD", Qs_t, nu)
        return nu, y

    def inference(self, X, cache, Wq, Wk, Wv, o_proj, Q_drop):
        """
        Recurrent version of forward, expects one token at a time
        Args:
            X: jnp.array(BD)
            cache: Dict[str, Any] - previous recursive state
        """
        B, D = X.shape[0], X.shape[-1] // self.config.nheads
        H, L = self.config.nheads, self.config.L // self.config.nheads
        if cache is None:
            # initialise hidden state
            cache = dict()
            cache["alpha"] = jnp.zeros(shape=(B, H, L), dtype=jnp.float32)
            cache["nu"] = jnp.zeros((B, H, L, D), dtype=jnp.float32)
            cache["prev_max"] = None

        Q = jnp.einsum("DL,BD->BL", Wq, X, precision=PRECISION).reshape(B, H, -1)
        Qs_t = jax.nn.softmax(Q, axis=-1)
        Qs_t = Q_drop(Qs_t)
        # print("QInf: ", Qs_t[0, 0, :10])
        V_t = jnp.einsum("DM,BD->BM", Wv, X, precision=PRECISION).reshape(
            B, H, -1
        )  # for nu
        K_t = jnp.einsum("DL,BTD->TBL", Wk, X[None, ...], precision=PRECISION).reshape(
            B, H, -1
        )  # for alpha

        alpha = cache["alpha"]
        nu = cache["nu"]
        prev_max = cache["prev_max"]
        # 0, -inf does not work, first val is ok
        if prev_max is None:
            prev_max = K_t
        c_max = jnp.maximum(prev_max, K_t)
        revert_maxi = jnp.exp(-c_max + prev_max)
        add_maxi = jnp.exp(K_t - c_max)

        alpha = jnp.einsum("BHL,BHL->BHL", alpha, revert_maxi, precision=PRECISION)
        alpha += add_maxi
        nu = jnp.einsum("BHLD,BHL->BHLD", nu, revert_maxi, precision=PRECISION)
        nu += jnp.einsum("BHL,BHD->BHLD", add_maxi, V_t, precision=PRECISION)

        y = jnp.einsum("BHL,BHLD->BHD", Qs_t / alpha, nu, precision=PRECISION)
        y = y.reshape(B, -1)
        y = y @ o_proj
        # reset last hidden state
        cache["alpha"] = alpha
        cache["nu"] = nu
        cache["prev_max"] = c_max

        return cache, y

    def mix_sequence(self, Q, K, V, Q_drop):
        T, B, H, C = V.shape
        L = Q.shape[-1]

        maxi = jax.lax.cummax(K, axis=0)
        # maxi for stability should be trated as a constant - no grad is faster
        maxi = jax.lax.stop_gradient(maxi)

        init_alpha = jnp.zeros(shape=(B, H, L), dtype=self.dtype)
        init_nu = jnp.zeros((B, H, L, C), dtype=self.dtype)
        Qs = jax.nn.softmax(Q, axis=-1)
        Qs = Q_drop(Qs)

        _, y = jax.lax.scan(
            self.accumulate,
            unroll=self.unroll,
            init=(
                init_nu,
                init_alpha,
                K[0],
            ),
            xs=[Qs, K, V, maxi],
        )
        # TBHD -> BTHD
        y = y.transpose(1, 0, 2, 3)
        y = y.reshape(B, T, -1)
        return y

    def mix_sequence4(self, Q, K, V, Q_drop):
        """Fastest but with O(TLD)
        Args:
            Q: jax.Array(T,B,H,L)
            K: jax.Array(T,B,H,L)
            V: jax.Array(T,B,H,D)
        """
        T, B, H, C = V.shape
        L = Q.shape[-1]
        # calc R^{-s}x_s
        # V = V_drop(V)
        Qs = jax.nn.softmax(Q, axis=-1)
        Qs = Q_drop(Qs)

        maxi = jax.lax.cummax(K, axis=0)
        # maxi for stability should be trated as a constant - no grad is faster
        maxi = jax.lax.stop_gradient(maxi)
        # revert maxi
        revert_maxi = jnp.zeros_like(maxi)
        revert_maxi = revert_maxi.at[1:].set(-maxi[1:] + maxi[:-1])
        revert_maxi = jnp.exp(revert_maxi)  # TBHL
        add_maxi = jnp.exp(K - maxi)
        nu = jnp.einsum("TBHL,TBHD->TBHLD", add_maxi, V)

        def bin_V(A, B):
            rmA, amA, nuA = A
            rmB, amB, nuB = B
            nu = nuA * rmB[..., None] + nuB
            alpha = amA * rmB + amB
            return (rmA * rmB, alpha, nu)

        _, alpha, y = parallel_scan(bin_V, (revert_maxi, add_maxi, nu))
        y = jnp.einsum("TBHL,TBHLD->BTHD", Qs / alpha, y)
        y = y.reshape(B, T, -1)
        return y

    @nn.compact
    def __call__(
        self,
        X: jnp.array,
        train: bool = False,
        cache: Dict[str, Any] = None,
        do_inference: bool = False,
        **kwargs,
    ) -> jnp.array:
        """
        B: batch size H: nr heads, T: seq_len, D: hidden_dim. L: latent dimension
        Args:
            X: jnp.array(BTD)
            train: bool - Constant used for dropout
        Returns:
            y: jnp.array(BTD) - transformed output sequence
        """
        Wk = self.param(
            "Wk",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        Wq = self.param(
            "Wq",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        Wv = self.param(
            "Wv",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.hidden_dim),
        )
        o_proj = self.param(
            "o_proj",
            jax.nn.initializers.normal(
                stddev=self.config.initializer_range
                / math.sqrt(2 * self.config.nlayers)
            ),
            (self.config.hidden_dim, self.config.hidden_dim),
        )
        Wk, Wq, Wv, o_proj = promote_dtype(Wk, Wq, Wv, o_proj, dtype=self.dtype)
        Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)
        # resid_drop = nn.Dropout(self.config.dropout, deterministic=not train)
        if do_inference:
            return self.inference(
                X,
                cache=cache,
                Wq=Wq,
                Wk=Wk,
                Wv=Wv,
                o_proj=o_proj,
                Q_drop=Q_drop,
                # resid_drop=resid_drop,
            )
        B, T, _ = X.shape
        H, L = self.config.nheads, self.config.L // self.config.nheads
        # multi head implementation
        V = jnp.einsum("DM,BTD->TBM", Wv, X).reshape(T, B, H, -1)
        Q = jnp.einsum("DL,BTD->TBL", Wq, X).reshape(T, B, H, -1)
        K = jnp.einsum("DL,BTD->TBL", Wk, X).reshape(T, B, H, -1)
        # y = self.mix_sequence(Q=Q, K=K, V=V, Q_drop=Q_drop)
        y = self.mix_sequence4(Q=Q, K=K, V=V, Q_drop=Q_drop)
        return y @ o_proj


class BidLatte(nn.Module):
    """
    Bidirectional version in which we sum to "T" instead of "t".
    No sequential implementation required.
    """

    config: Config
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(
        self,
        X: jnp.array,
        attention_mask: jnp.array = None,
        train: bool = False,
        **kwargs,
    ) -> jnp.array:
        """
        B: batch size H: nr heads, T: seq_len, D: hidden_dim. L: latent dimension
        Args:
            X: jnp.array(BTD)
            attention_mask: jnp.array(BTD) - attnention used to ignore pads
                Only used in bidirectional since we sum up to T, and pad needed for batching
            train: bool - Just to respect the interface of trainer.
        Returns:
            y: jnp.array(BTD) - transformed output sequence
        """
        Wk = self.param(
            "Wk",
            jax.nn.initializers.lecun_normal(),
            (self.config.hidden_dim, self.config.L),
        )
        Wq = self.param(
            "Wq",
            jax.nn.initializers.lecun_normal(),
            (self.config.hidden_dim, self.config.L),
        )
        Wv = self.param(
            "Wv",
            jax.nn.initializers.lecun_normal(),
            (self.config.hidden_dim, self.config.hidden_dim),
        )
        o_proj = self.param(
            "o_proj",
            jax.nn.initializers.lecun_normal(),
            (self.config.hidden_dim, self.config.hidden_dim),
        )
        Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)
        resid_drop = nn.Dropout(self.config.dropout, deterministic=not train)

        B, T, D = X.shape
        L, H = self.config.L, self.config.nheads
        # multi head implementation
        V = jnp.einsum("DM,BTD->TBM", Wv, X, precision=PRECISION).reshape(T, B, H, -1)
        Q = jnp.einsum("DL,BTD->TBL", Wq, X, precision=PRECISION).reshape(T, B, H, -1)
        K = jnp.einsum("DL,BTD->LBT", Wk, X, precision=PRECISION)

        K = jnp.where(attention_mask, K, -9e15).transpose(2, 1, 0).reshape(T, B, H, -1)
        Qs = jax.nn.softmax(Q, axis=-1)  # T B H L
        Qs = Q_drop(Qs)
        maxi = jnp.max(K, axis=0, keepdims=True)
        K = jnp.exp(K - maxi)

        Kv = jnp.einsum("TBHL,TBHD->BHLD", K, V, precision=PRECISION)
        # normalize
        K = K.sum(axis=0)  # BLH
        Kv = jnp.einsum("BHL,BHLD->BHLD", 1 / K, Kv, precision=PRECISION)
        y = jnp.einsum("TBHL,BHLD->BTHD", Qs, Kv, precision=PRECISION)
        y = y.reshape(B, T, -1)
        return resid_drop(y @ o_proj)


class BidLatteRopeConv(nn.Module):
    """
    Bidirectional version in which we sum to "T" instead of "t".
    No sequential implementation required.
    """

    config: Config
    dtype: jnp.dtype

    def latte_attention(self, rot_embeds, Qs, K, V, attention_mask):
        """
        Qs: TBL
        K: LBT
        V: DBT
        """
        _, B, T = K.shape
        if attention_mask is not None:
            K = jnp.where(attention_mask, K, -9e15)
        K = K.transpose(2, 1, 0).reshape(T, B, self.config.nheads, -1)
        V = V.transpose(2, 1, 0).reshape(T, B, self.config.nheads, -1)

        # calc R^{-s}x_s
        if isinstance(rot_embeds, XPos):
            # T, B, self.config.nheads, -1 -> BHTD
            V = rot_embeds(V.transpose(1, 2, 0, 3), offset=0, downscale=True).transpose(
                2, 0, 1, 3
            )
        else:
            V = rot_embeds.apply_vapor(mat=V, neg=True)

        maxi = jnp.max(K, axis=0, keepdims=True)
        K = jnp.exp(K - maxi)

        Kv = jnp.einsum("TBHL,TBHD->BHLD", K, V)
        # normalize
        K = K.sum(axis=0)  # BLH
        Kv = jnp.einsum("BHL,BHLD->BHLD", 1 / K, Kv)
        y = jnp.einsum("TBHL,BHLD->TBHD", Qs, Kv)

        # calc R^t \sum_l ...
        if isinstance(rot_embeds, XPos):
            y = y.transpose(1, 2, 0, 3)
            y = rot_embeds(y, offset=0, downscale=False)
            return y
        else:
            # TBHD -> BHTD
            y = rot_embeds.apply_vapor(mat=y, neg=False)
        # TBHD ->   BHTD
        return y.transpose(1, 2, 0, 3)

    @nn.compact
    def __call__(
        self,
        X: jnp.array,
        attention_mask: jnp.array = None,
        train: bool = False,
        **kwargs,
    ) -> jnp.array:
        """
        B: batch size H: nr heads, T: seq_len, D: hidden_dim. L: latent dimension
        Args:
            X: jnp.array(BTD)
            attention_mask: jnp.array(BTD) - attnention used to ignore pads
                Only used in bidirectional since we sum up to T, and pad needed for batching
            train: bool - Just to respect the interface of trainer.
        Returns:
            y: jnp.array(BTD) - transformed output sequence
        """
        if self.config.embed_type == "rope":
            rot_embeds = RopeEmbeds(
                n_pos=self.config.pos_embed_max_len,
                d_model=self.config.hidden_dim // self.config.nheads,
            )
        elif self.config.embed_type == "xpos":
            rot_embeds = XPos(
                head_dim=self.config.hidden_dim // self.config.nheads,
                scale_base=self.config.max_seq_len,
            )

        # latte attention
        Wk = self.param(
            "Wk",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        Wq = self.param(
            "Wq",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        Wv = self.param(
            "Wv",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.hidden_dim),
        )
        Wk, Wq, Wv = promote_dtype(Wk, Wq, Wv, dtype=self.dtype)
        conv = DepthConv1D(
            nchannels=self.config.hidden_dim,
            out_channels=self.config.hidden_dim,
            kernel_size=9,
            dtype=self.dtype,
        )

        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )
        Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)

        B, T, C = (
            X.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)

        Y = conv(X)
        # multi head implementation
        Q = jnp.einsum("DL,BTD->TBL", Wq, Y).reshape(T, B, self.config.nheads, -1)
        K = jnp.einsum("DL,BTD->LBT", Wk, Y)
        V = jnp.einsum("DM,BTD->MBT", Wv, X)
        Qs = jax.nn.softmax(Q, axis=-1)
        Qs = Q_drop(Qs)

        # \sum_{l=1}^Lp(l|t) \sum_{s=0}^T p(s|l)v_s
        y = self.latte_attention(
            rot_embeds, Qs, K=K, V=V, attention_mask=attention_mask
        )  # BHTD
        y = y.transpose(0, 2, 1, 3)  # BHTD -> BTHD
        y = y.reshape(B, T, -1)
        return c_proj(y)


class BidLatteConv(nn.Module):
    """
    Bidirectional version in which we sum to "T" instead of "t".
    No sequential implementation required.
    """

    config: Config
    dtype: jnp.dtype

    def latte_attention(self, Qs, K, V, attention_mask):
        """
        Qs: TBL
        K: LBT
        V: DBT
        """
        _, B, T = K.shape
        if attention_mask is not None:
            K = jnp.where(attention_mask, K, -9e15)
        K = K.transpose(2, 1, 0).reshape(T, B, self.config.nheads, -1)
        V = V.transpose(2, 1, 0).reshape(T, B, self.config.nheads, -1)

        maxi = jnp.max(K, axis=0, keepdims=True)
        K = jnp.exp(K - maxi)

        Kv = jnp.einsum("TBHL,TBHD->BHLD", K, V)
        # normalize
        K = K.sum(axis=0)  # BLH
        Kv = jnp.einsum("BHL,BHLD->BHLD", 1 / K, Kv)
        y = jnp.einsum("TBHL,BHLD->BHTD", Qs, Kv)
        return y

    @nn.compact
    def __call__(
        self,
        X: jnp.array,
        attention_mask: jnp.array = None,
        train: bool = False,
        **kwargs,
    ) -> jnp.array:
        """
        B: batch size H: nr heads, T: seq_len, D: hidden_dim. L: latent dimension
        Args:
            X: jnp.array(BTD)
            attention_mask: jnp.array(BTD) - attnention used to ignore pads
                Only used in bidirectional since we sum up to T, and pad needed for batching
            train: bool - Just to respect the interface of trainer.
        Returns:
            y: jnp.array(BTD) - transformed output sequence
        """
        B, T, _ = X.shape
        # latte attention
        Wk = self.param(
            "Wk",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            # (self.config.hidden_dim, self.config.L),
            (self.config.L, self.config.L),
        )
        Wq = self.param(
            "Wq",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            # (self.config.hidden_dim, self.config.L),
            (self.config.L, self.config.L),
        )
        Wv = self.param(
            "Wv",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.hidden_dim),
        )
        Wk, Wq, Wv = promote_dtype(Wk, Wq, Wv, dtype=self.dtype)
        # conv = DepthConv1D(
        #     nchannels=self.config.hidden_dim,
        #     out_channels=self.config.hidden_dim,
        #     kernel_size=self.config.max_seq_len,  # 9,
        #     dtype=self.dtype,
        # )

        # conv = S5Layer(
        #     ssm_size=self.config.L,
        #     hidden_dim=self.config.hidden_dim,
        #     blocks=self.config.blocks,
        #     dtype=self.dtype,
        # )
        # conv_out = nn.Dense(
        #     self.config.L, use_bias=False, dtype=self.dtype, name="lin_out"
        # )
        lru_in = self.param(
            "lru_in",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )

        conv = RGLRU(
            width=self.config.L,
            num_heads=self.config.nheads,
            dtype=self.dtype,
        )

        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )
        Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)

        B, T, C = (
            X.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)

        # Y = conv_out(conv(X))
        pos_ids = jnp.repeat(jnp.arange(T)[None], B, axis=0)
        Y, _ = conv(jnp.einsum("DL,BTD->BTL", lru_in, X), pos_ids, return_cache=False)

        # multi head implementation
        Q = jnp.einsum("DL,BTD->TBL", Wq, Y).reshape(T, B, self.config.nheads, -1)
        K = jnp.einsum("DL,BTD->LBT", Wk, Y)
        V = jnp.einsum("DM,BTD->MBT", Wv, X)
        Qs = jax.nn.softmax(Q, axis=-1)
        Qs = Q_drop(Qs)

        # \sum_{l=1}^Lp(l|t) \sum_{s=0}^T p(s|l)v_s
        y = self.latte_attention(Qs, K=K, V=V, attention_mask=attention_mask)  # BHTD
        y = y.transpose(0, 2, 1, 3)  # BHTD -> BTHD
        y = y.reshape(B, T, -1)
        return c_proj(y)


class RotaryEmbedding(nn.Module):
    dim: int
    max_position_embeddings: int = 2048
    base: int = 10000

    def setup(self):

        self.inv_freq = 1.0 / (
            self.base
            ** (
                jnp.arange(0, self.dim, 2, dtype=jnp.int32).astype(jnp.float32)
                / self.dim
            )
        )

    def __call__(self, x, position_ids, seq_len=None):
        # print(x.shape)
        if position_ids is None:
            position_ids = jnp.arange(0, x.shape[2])[None, ...]

        # x: [bs, num_attention_heads, seq_len, head_size]
        inv_freq_expanded = self.inv_freq[None, :, None].astype(jnp.float32)
        inv_freq_expanded = jnp.broadcast_to(
            inv_freq_expanded, (position_ids.shape[0], inv_freq_expanded.shape[1], 1)
        )
        position_ids_expanded = position_ids[:, None, :].astype(jnp.float32)
        # Force float32 since bfloat16 loses precision on long contexts
        # See https://github.com/huggingface/transformers/pull/29285
        freqs = (
            inv_freq_expanded.astype(jnp.float32)
            @ position_ids_expanded.astype(jnp.float32)
        ).transpose(0, 2, 1)
        emb = jnp.concatenate((freqs, freqs), axis=-1)
        # apply interpolation like in: https://arxiv.org/pdf/2306.15595
        # seq_len = position_ids.shape[1]
        # if seq_len > self.max_position_embeddings:
        #     print("Using interpolation")
        #     emb = emb * (self.max_position_embeddings / seq_len)
        cos = jnp.cos(emb)
        sin = jnp.sin(emb)
        return cos.astype(dtype=x.dtype), sin.astype(dtype=x.dtype)


def rotate_half(x, neg=False):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    if neg:
        return jnp.concatenate((x2, -x1), axis=-1)
    else:
        # Classic rope. -n comes from transpose
        return jnp.concatenate((-x2, x1), axis=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = jnp.expand_dims(cos, unsqueeze_dim)
    sin = jnp.expand_dims(sin, unsqueeze_dim)
    # print(cos.shape, q.shape)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class RopeBidLatteMachSliding(nn.Module):
    """
    Bidiractional latte macch with bidirectional sliding window attention.
    We sum to "T" instead of "t".
    No sequential implementation required.
    """

    config: Config
    dtype: jnp.dtype

    def latte_attention(self, Qs, K, V, attention_mask):
        """
        Qs: TBL
        K: LBT
        V: TBHD
        """
        _, B, T = K.shape
        if attention_mask is not None:
            K = jnp.where(attention_mask, K, -9e15)
        K = K.transpose(2, 1, 0).reshape(T, B, self.config.nheads, -1)

        maxi = jnp.max(K, axis=0, keepdims=True)
        K = jnp.exp(K - maxi)

        Kv = jnp.einsum("TBHL,TBHD->BHLD", K, V)
        # normalize
        K = K.sum(axis=0)  # BLH
        Kv = jnp.einsum("BHL,BHLD->BHLD", 1 / K, Kv)
        y = jnp.einsum("TBHL,BHLD->BHTD", Qs, Kv)
        return y

    @nn.compact
    def __call__(
        self,
        X: jnp.array,
        attention_mask: jnp.array = None,
        train: bool = False,
        **kwargs,
    ) -> jnp.array:
        """
        B: batch size H: nr heads, T: seq_len, D: hidden_dim. L: latent dimension
        Args:
            X: jnp.array(BTD)
            attention_mask: jnp.array(BTD) - attnention used to ignore pads
                Only used in bidirectional since we sum up to T, and pad needed for batching
            train: bool - Just to respect the interface of trainer.
        Returns:
            y: jnp.array(BTD) - transformed output sequence
        """
        B, T, _ = X.shape

        if self.config.embed_type == "rope":
            rotary_emb = RotaryEmbedding(
                dim=self.config.hidden_dim // self.config.nheads,
                max_position_embeddings=self.config.pos_embed_max_len,
                base=10000.0,
            )

        # self attention
        q_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="q_proj",
        )

        k_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="k_proj",
        )
        v_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="v_proj",
        )

        # latte attention
        Wk = self.param(
            "Wk",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            # (self.config.hidden_dim, self.config.L),
            (self.config.L, self.config.L),
        )
        Wq = self.param(
            "Wq",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            # (self.config.hidden_dim, self.config.L),
            (self.config.L, self.config.L + self.config.nheads),
        )
        Wk, Wq = promote_dtype(Wk, Wq, dtype=self.dtype)

        lru_in = nn.Dense(
            self.config.L,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            dtype=self.dtype,
            name="latte_lru_in",
        )

        latte_conv = RGLRU(
            width=self.config.L,
            num_heads=self.config.nheads,
            dtype=self.dtype,
            name="latte_rglru",
        )
        latte_lru_norm = RMSNorm(
            dtype=self.dtype,
            width=self.config.L,
            name="latte_lru_norm",
        )
        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )
        sliding_att = SlidingWindowAtt(
            window_size=self.config.att_block_len,
            exact_windowsize=True,
            causal=False,
        )

        # regularization
        attn_dropout = nn.Dropout(rate=self.config.dropout, deterministic=not train)
        Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)

        B, T, C = (
            X.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)
        q, k = q_proj(X), k_proj(X)
        v = v_proj(X)
        k = k.reshape(B, T, self.config.nheads, -1).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        q = q.reshape(B, T, self.config.nheads, -1).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        v = v.reshape(B, T, self.config.nheads, -1).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)

        if self.config.embed_type == "rope":
            cos, sin = rotary_emb(v, position_ids=None)
            q, k = apply_rotary_pos_emb(q, k, cos, sin)

        p_s_l0 = sliding_att(q, k, v, input_mask=None, attn_dropout=attn_dropout)
        p_s_l0 = p_s_l0.reshape(B, self.config.nheads, T, -1)  # BHNLD -> BHTD

        pos_ids = jnp.repeat(jnp.arange(T)[None], B, axis=0)
        Y, _ = latte_conv(lru_in(X), pos_ids, return_cache=False)
        Y = latte_lru_norm(Y)

        # multi head implementation
        Q = jnp.einsum("DL,BTD->TBL", Wq, Y).reshape(T, B, self.config.nheads, -1)
        K = jnp.einsum("DL,BTD->LBT", Wk, Y)
        Qs = jax.nn.softmax(Q, axis=-1)
        Qs = Q_drop(Qs)

        # p(l=0|t) \sum_{s \in[t, t-w]}^t p(s|l,t)v_s
        causal_att = jnp.einsum("TBH,BHTD->BHTD", Qs[:, :, :, 0], p_s_l0)

        # \sum_{l=1}^Lp(l|t) \sum_{s=0}^t p(s|l,t)v_s
        # (B, H, T, D) -> T B H D
        v = v.transpose(2, 0, 1, 3)  # .reshape(T, B, self.config.nheads, -1)

        # \sum_{l=1}^Lp(l|t) \sum_{s=0}^T p(s|l)v_s
        latte_att = self.latte_attention(
            Qs[:, :, :, 1:], K=K, V=v, attention_mask=attention_mask
        )  # BHTD

        v = causal_att + latte_att
        v = v.transpose(0, 2, 1, 3)  # BHTD -> BTHD
        v = v.reshape(B, T, -1)
        return c_proj(v)
