from typing import Dict, Any
from functools import partial
import math
import jax
from jax import numpy as jnp
from flax import linen as nn
from flax.linen.dtypes import promote_dtype
from .layers import Conv1D, DepthConv1D, S5Layer, RopeEmbeds, RMSNorm
from .xPos import XPos
from latte_trans.config import Config
from .layers import SlidingWindowAtt
from recurrentgemma.jax.modules import RecurrentBlock
from recurrentgemma.jax.layers import RGLRU

parallel_scan = jax.lax.associative_scan


def apply_rope(rel_pos, mat):
    """
    Implement rotation where rel_pos is already A^t.
    Uses fast implementation of the parse matrix
    Args:
        rel_pos: jnp.array(B, H, T, D] -> TBHD
            Half sin & second half cos
        mat: jnp.array(B,H,T,D) -> TD
            input matrix
        neg: bool
            Denotes weather we need to calculate R^{-s}
    """
    sin, cos = jnp.split(rel_pos, indices_or_sections=2, axis=-1)
    sin = jnp.tile(sin, reps=(1, 2))
    cos = jnp.tile(cos, reps=(1, 2))
    # jax.debug.print("sin shape: {x}", x=sin.shape)
    rotate_half_mat = jnp.concatenate([-mat[..., 1::2], mat[..., 0::2]], axis=-1)
    # print(rotate_half_mat.shape, mat.shape, sin.shape, cos.shape)
    # rotated = cos * mat + sin * rotate_half_mat
    rotated = jnp.einsum("BHTD,TD->BHTD", mat, cos) + jnp.einsum(
        "BHTD,TD->BHTD", rotate_half_mat, sin
    )

    return rotated


def apply_vapor(rel_pos, mat, neg=False):
    """
    Implement rotation where rel_pos is already A^t.
    Uses fast implementation of the parse matrix
    Args:
        rel_pos: jnp.array(B, H, T, D] -> TBHD
            Half sin & second half cos
        mat: jnp.array(B,H,T,D) -> TD
            input matrix
        neg: bool
            Denotes weather we need to calculate R^{-s}
    """
    sin, cos = jnp.split(rel_pos, indices_or_sections=2, axis=-1)
    sin = jnp.tile(sin, reps=(1, 2))
    cos = jnp.tile(cos, reps=(1, 2))
    # jax.debug.print("sin shape: {x}", x=sin.shape)
    if neg:
        rotate_half_mat = jnp.concatenate([mat[..., 1::2], -mat[..., 0::2]], axis=-1)
    else:
        rotate_half_mat = jnp.concatenate([-mat[..., 1::2], mat[..., 0::2]], axis=-1)
    # print(rotate_half_mat.shape, mat.shape, sin.shape, cos.shape)
    # rotated = cos * mat + sin * rotate_half_mat
    rotated = jnp.einsum("TBHD,TD->TBHD", mat, cos) + jnp.einsum(
        "TBHD,TD->TBHD", rotate_half_mat, sin
    )

    return rotated.astype(mat.dtype)


def accumulate(carry, args):
    nu, alpha, prev_max = carry
    Qs_t, curr_alph, V_t, c_mx = args
    revert_maxi = jnp.exp(-c_mx + prev_max)
    add_maxi = jnp.exp(curr_alph - c_mx)

    alpha = jnp.einsum("BHL,BHL->BHL", alpha, revert_maxi)
    alpha += add_maxi
    nu = jnp.einsum("BHLD,BHL->BHLD", nu, revert_maxi)
    nu += jnp.einsum("BHL,BHD->BHLD", add_maxi, V_t)
    y = jnp.einsum("BHL,BHLD->BHD", Qs_t / alpha, nu)
    return ((nu, alpha, c_mx), y)


def accumulate3(carry, args):
    """Optimized version of accumulate where
    we compute as much as possible outside the sequential operation"""
    nu = carry
    Qs_t, V_t, revert_maxi, add_maxi = args
    nu = jnp.einsum("BHLD,BHL->BHLD", nu, revert_maxi)
    nu += jnp.einsum("BHL,BHD->BHLD", add_maxi, V_t)
    y = jnp.einsum("BHL,BHLD->BHD", Qs_t, nu)
    return nu, y


class CausalRopeLatteMachiattoChunk(nn.Module):
    config: Config
    unroll: int = 100
    dtype: jnp.dtype = jnp.float32

    @staticmethod
    def attention(rot_embeds, attn_dropout, q, k, v):
        """
        Args:
            q,k,v: Tensor(BHTD)
        """
        T = q.shape[2]
        # causal mask to ensure that attention is only applied to the left in the input sequence
        mask = jnp.tril(jnp.ones(shape=(T, T))).reshape(1, 1, T, T)

        if isinstance(rot_embeds, XPos):
            k = rot_embeds(k, offset=0, downscale=True)
            q = rot_embeds(q, offset=0, downscale=False)
        else:
            # q = apply_rope(rot_embeds[: q.shape[2], :], q)
            # k = apply_rope(rot_embeds[: k.shape[2], :], k)
            k = rot_embeds(k)
            q = rot_embeds(q)

        # manual implementation of attention
        att = (q @ jnp.swapaxes(k, -2, -1)) * (1.0 / math.sqrt(k.shape[-1]))

        # att = att.masked_fill(bias[:,:,:T,:T] == 0, float('-inf'))
        att = jnp.where(mask == 0, -9e15, att)
        att = jax.nn.softmax(att, axis=-1)

        att = attn_dropout(att)
        y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        return y

    def latte_attention(self, rot_embeds, Qs, K, V):
        T, B, H, C = V.shape
        L = Qs.shape[-1]
        # sin_pos = self.rot_embeds[:T, :]  # T D

        maxi = jax.lax.cummax(K, axis=0)
        # maxi for stability should be trated as a constant - no grad is faster
        maxi = jax.lax.stop_gradient(maxi)

        init_alpha = jnp.zeros(shape=(B, H, L), dtype=self.dtype)
        init_nu = jnp.zeros((B, H, L, C), dtype=self.dtype)
        # calc R^{-s}x_s
        if isinstance(rot_embeds, XPos):
            # T, B, self.config.nheads, -1 -> BHTD
            V = rot_embeds(V.transpose(1, 2, 0, 3), offset=0, downscale=True).transpose(
                2, 0, 1, 3
            )
        else:
            # V = apply_vapor(rel_pos=self.rot_embeds[: V.shape[0]], mat=V, neg=True)
            V = rot_embeds.apply_vapor(mat=V, neg=True)
        # V = V_drop(V)

        _, y = jax.lax.scan(
            accumulate,
            unroll=self.unroll,
            init=(
                init_nu,
                init_alpha,
                K[0],
            ),
            xs=[Qs, K, V, maxi],
        )
        # calc R^t \sum_l ...
        if isinstance(rot_embeds, XPos):
            y = y.transpose(1, 2, 0, 3)
            y = rot_embeds(y, offset=0, downscale=False)
            return y
        else:
            # y = apply_vapor(rel_pos=self.rot_embeds[: y.shape[0]], mat=y, neg=False)
            # TBHD -> BHTD
            # y = y.transpose(1, 2, 0, 3)
            y = rot_embeds.apply_vapor(mat=y, neg=False)
        # TBHD ->   BHTD
        return y.transpose(1, 2, 0, 3)

    @nn.compact
    def __call__(self, X, train=False):
        if self.config.embed_type == "rope":
            rot_embeds = RopeEmbeds(
                n_pos=self.config.pos_embed_max_len,
                d_model=self.config.hidden_dim // self.config.nheads,
            )
        elif self.config.embed_type == "xpos":
            rot_embeds = XPos(
                head_dim=self.config.hidden_dim // self.config.nheads,
                scale_base=self.config.max_seq_len,
            )

        # self attention
        c_attn = nn.Dense(
            3 * self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )
        # latte attention
        Wk = self.param(
            "Wk",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        Wq = self.param(
            "Wq",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L + self.config.nheads),
        )
        Wv = self.param(
            "Wv",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.hidden_dim),
        )
        Wk, Wq, Wv = promote_dtype(Wk, Wq, Wv, dtype=self.dtype)
        conv = Conv1D(
            nchannels=self.config.hidden_dim,
            out_channels=self.config.hidden_dim,
            kernel_size=3,
            dtype=self.dtype,
        )

        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            # kernel_init=jax.nn.initializers.normal(
            #     stddev=self.config.initializer_range
            #     / math.sqrt(2 * self.config.nlayers)
            # ),
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )

        # regularization
        attn_dropout = nn.Dropout(rate=self.config.dropout, deterministic=not train)
        Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)
        B, T, C = (
            X.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = jnp.split(c_attn(X), 3, axis=2)
        att_dim = C // self.config.nheads
        k = k.reshape(B, T, self.config.nheads, att_dim).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        q = q.reshape(B, T, self.config.nheads, att_dim).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        v = v.reshape(B, T, self.config.nheads, att_dim).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)

        # split in blocks and vmap attention
        k = k.reshape(B, self.config.nheads, -1, self.config.att_block_len, att_dim)
        q = q.reshape(B, self.config.nheads, -1, self.config.att_block_len, att_dim)
        v = v.reshape(B, self.config.nheads, -1, self.config.att_block_len, att_dim)

        Y = conv(X)
        Q = jnp.einsum("DL,BTD->TBL", Wq, Y).reshape(T, B, self.config.nheads, -1)
        # multi head implementation
        V = jnp.einsum("DM,BTD->TBM", Wv, X).reshape(T, B, self.config.nheads, -1)
        K = jnp.einsum("DL,BTD->TBL", Wk, Y).reshape(T, B, self.config.nheads, -1)

        Qs = jax.nn.softmax(Q, axis=-1)
        Qs = Q_drop(Qs)

        # jax.vmap(fun, in_axes=0, out_axes=0,
        fn = partial(self.attention, rot_embeds, attn_dropout)
        p_s_l0 = jax.vmap(fn, in_axes=2, out_axes=2)(q, k, v)
        p_s_l0 = p_s_l0.reshape(B, self.config.nheads, T, att_dim)  # BHNLD -> BHTD
        # p(l=0|t) \sum_{s \in[t, t-w]}^t p(s|l,t)v_s
        causal_att = jnp.einsum("TBH,BHTD->BHTD", Q[:, :, :, 0], p_s_l0)

        # \sum_{l=1}^Lp(l|t) \sum_{s=0}^t p(s|l,t)v_s
        latte_att = self.latte_attention4(rot_embeds, Qs[:, :, :, 1:], K=K, V=V)  # BHTD
        y = causal_att + latte_att
        y = y.transpose(0, 2, 1, 3)  # BHTD -> BTHD
        y = y.reshape(B, T, -1)
        return c_proj(y)


class CausalRopeLatteMachiattoSliding(nn.Module):
    config: Config
    unroll: int = 100
    dtype: jnp.dtype = jnp.float32

    def latte_attention4(self, rot_embeds, Q, K, V):
        """Faster version of latte_attention by applying parallel scanns to normalisation as well
        Still in O(TL + LD)
        Args:
            Q: jax.Array(T,B,H,L)
            K: jax.Array(T,B,H,L)
            V: jax.Array(T,B,H,D)
        """
        T, B, H, C = V.shape
        L = Q.shape[-1]

        maxi = jax.lax.cummax(K, axis=0)
        # maxi for stability should be trated as a constant - no grad is faster
        maxi = jax.lax.stop_gradient(maxi)
        # revert maxi
        revert_maxi = jnp.zeros_like(maxi)
        revert_maxi = revert_maxi.at[1:].set(-maxi[1:] + maxi[:-1])
        revert_maxi = jnp.exp(revert_maxi)  # TBHL
        add_maxi = jnp.exp(K - maxi)
        nu = jnp.einsum("TBHL,TBHD->TBHLD", add_maxi, V)

        def bin_V(A, B):
            rmA, amA, nuA = A
            rmB, amB, nuB = B
            nu = nuA * rmB[..., None] + nuB
            alpha = amA * rmB + amB
            return (rmA * rmB, alpha, nu)

        _, alpha, y = parallel_scan(bin_V, (revert_maxi, add_maxi, nu))
        y = jnp.einsum("TBHL,TBHLD->TBHD", Q / alpha, y)

        return y.transpose(1, 2, 0, 3)

    def latte_attention(self, rot_embeds, Qs, K, V):
        T, B, H, C = V.shape
        L = Qs.shape[-1]
        # sin_pos = self.rot_embeds[:T, :]  # T D

        maxi = jax.lax.cummax(K, axis=0)
        # maxi for stability should be trated as a constant - no grad is faster
        maxi = jax.lax.stop_gradient(maxi)

        init_alpha = jnp.zeros(shape=(B, H, L), dtype=self.dtype)
        init_nu = jnp.zeros((B, H, L, C), dtype=self.dtype)

        _, y = jax.lax.scan(
            accumulate,
            unroll=self.unroll,
            init=(
                init_nu,
                init_alpha,
                K[0],
            ),
            xs=[Qs, K, V, maxi],
        )
        # TBHD ->   BHTD
        return y.transpose(1, 2, 0, 3)

    @nn.compact
    def __call__(self, X, train=False):
        if self.config.embed_type == "rope":
            rot_embeds = RopeEmbeds(
                n_pos=self.config.pos_embed_max_len,
                d_model=self.config.hidden_dim // self.config.nheads,
            )
        elif self.config.embed_type == "xpos":
            rot_embeds = XPos(
                head_dim=self.config.hidden_dim // self.config.nheads,
                scale_base=512,  # self.config.max_seq_len,
            )
        # self attention
        q_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="q_proj",
        )

        k_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="k_proj",
        )
        v_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="v_proj",
        )

        # latte attention
        Wk = self.param(
            "Wk",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.L, self.config.L),
            # (self.config.hidden_dim, self.config.L),
        )
        Wq = self.param(
            "Wq",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (
                self.config.L,
                self.config.L + self.config.nheads,
            ),
            # (self.config.hidden_dim, self.config.L + self.config.nheads),
        )
        # Wv = self.param(
        #     "Wv",
        #     jax.nn.initializers.normal(stddev=self.config.initializer_range),
        #     (self.config.hidden_dim, self.config.hidden_dim),
        # )
        # Wk, Wq, Wv = promote_dtype(Wk, Wq, Wv, dtype=self.dtype)
        Wk, Wq = promote_dtype(Wk, Wq, dtype=self.dtype)
        # conv = DepthConv1D(
        #     nchannels=self.config.hidden_dim,
        #     out_channels=self.config.hidden_dim,
        #     kernel_size=4,  # self.config.max_seq_len,
        #     dtype=self.dtype,
        # )

        lru_in = self.param(
            "lru_in",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        conv = RGLRU(
            width=self.config.L,
            num_heads=self.config.nheads,
            dtype=self.dtype,
        )

        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )

        # regularization
        Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)
        attn_dropout = nn.Dropout(self.config.dropout_att, deterministic=not train)
        B, T, C = (
            X.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        Q, K = q_proj(X), k_proj(X)
        v = v_proj(X)
        att_dim = C // self.config.nheads
        K = K.reshape(B, T, self.config.nheads, att_dim).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        Q = Q.reshape(B, T, self.config.nheads, att_dim).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        v = v.reshape(B, T, self.config.nheads, att_dim).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)

        sliding_attention = SlidingWindowAtt(
            window_size=self.config.att_block_len, exact_windowsize=True, causal=True
        )

        # jax.vmap(fun, in_axes=0, out_axes=0,
        p_s_l0 = sliding_attention(
            Q,
            K,
            v,
            input_mask=None,
            attn_dropout=attn_dropout,
            rot_embeds=rot_embeds,
        )

        # Y = conv(X)
        pos_ids = jnp.repeat(
            jnp.arange(T)[None], B, axis=0
        )  # jnp.broadcast_to(jnp.arange(0, T)[None, :], (B, 1))
        X, _ = conv(jnp.einsum("DL,BTD->BTL", lru_in, X), pos_ids, return_cache=False)
        # X = RMSNorm(self.config.L, dtype=self.dtype)(X)
        # multi head implementation
        Q = jnp.einsum("DL,BTD->TBL", Wq, X).reshape(T, B, self.config.nheads, -1)
        K = jnp.einsum("DL,BTD->TBL", Wk, X).reshape(T, B, self.config.nheads, -1)
        # V = jnp.einsum("DM,BTD->TBM", Wv, X).reshape(T, B, self.config.nheads, -1)
        Q = jax.nn.softmax(Q, axis=-1)
        Q = Q_drop(Q)

        p_s_l0 = p_s_l0.reshape(B, self.config.nheads, T, att_dim)  # BHNLD -> BHTD
        # p(l=0|t) \sum_{s \in[t, t-w]}^t p(s|l,t)v_s
        causal_att = jnp.einsum("TBH,BHTD->BHTD", Q[:, :, :, 0], p_s_l0)

        # \sum_{l=1}^Lp(l|t) \sum_{s=0}^t p(s|l,t)v_s
        # v = V
        v = v.transpose(2, 0, 1, 3)  # .reshape(T, B, self.config.nheads, -1)
        # Q[:, :, :, 1:]
        latte_att = self.latte_attention4(rot_embeds, Q[:, :, :, 1:], K=K, V=v)  # BHTD
        v = causal_att + latte_att
        v = v.transpose(0, 2, 1, 3)  # BHTD -> BTHD
        v = v.reshape(B, T, -1)

        self.sow("intermediates", "Qs", Q)
        return c_proj(v)


class CausalRopeLatteMachiattoSlidingSWARGLRU(nn.Module):
    """Ablation by removing latte"""

    config: Config
    unroll: int = 100
    dtype: jnp.dtype = jnp.float32

    def latte_attention4(self, rot_embeds, Q, K, V):
        """Faster version of latte_attention by applying parallel scanns to normalisation as well
        Still in O(TL + LD)
        Args:
            Q: jax.Array(T,B,H,L)
            K: jax.Array(T,B,H,L)
            V: jax.Array(T,B,H,D)
        """
        T, B, H, C = V.shape
        L = Q.shape[-1]

        maxi = jax.lax.cummax(K, axis=0)
        # maxi for stability should be trated as a constant - no grad is faster
        maxi = jax.lax.stop_gradient(maxi)
        # revert maxi
        revert_maxi = jnp.zeros_like(maxi)
        revert_maxi = revert_maxi.at[1:].set(-maxi[1:] + maxi[:-1])
        revert_maxi = jnp.exp(revert_maxi)  # TBHL
        add_maxi = jnp.exp(K - maxi)
        nu = jnp.einsum("TBHL,TBHD->TBHLD", add_maxi, V)

        def bin_V(A, B):
            rmA, amA, nuA = A
            rmB, amB, nuB = B
            nu = nuA * rmB[..., None] + nuB
            alpha = amA * rmB + amB
            return (rmA * rmB, alpha, nu)

        _, alpha, y = parallel_scan(bin_V, (revert_maxi, add_maxi, nu))
        y = jnp.einsum("TBHL,TBHLD->TBHD", Q / alpha, y)

        return y.transpose(1, 2, 0, 3)

    def latte_attention(self, rot_embeds, Qs, K, V):
        T, B, H, C = V.shape
        L = Qs.shape[-1]
        # sin_pos = self.rot_embeds[:T, :]  # T D

        maxi = jax.lax.cummax(K, axis=0)
        # maxi for stability should be trated as a constant - no grad is faster
        maxi = jax.lax.stop_gradient(maxi)

        init_alpha = jnp.zeros(shape=(B, H, L), dtype=self.dtype)
        init_nu = jnp.zeros((B, H, L, C), dtype=self.dtype)

        _, y = jax.lax.scan(
            accumulate,
            unroll=self.unroll,
            init=(
                init_nu,
                init_alpha,
                K[0],
            ),
            xs=[Qs, K, V, maxi],
        )
        # TBHD ->   BHTD
        return y.transpose(1, 2, 0, 3)

    @nn.compact
    def __call__(self, X, train=False):
        if self.config.embed_type == "rope":
            rot_embeds = RopeEmbeds(
                n_pos=self.config.pos_embed_max_len,
                d_model=self.config.hidden_dim // self.config.nheads,
            )
        elif self.config.embed_type == "xpos":
            rot_embeds = XPos(
                head_dim=self.config.hidden_dim // self.config.nheads,
                scale_base=512,  # self.config.max_seq_len,
            )
        # self attention
        q_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="q_proj",
        )

        k_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="k_proj",
        )
        v_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="v_proj",
        )

        # latte attention
        # Wk = self.param(
        #     "Wk",
        #     jax.nn.initializers.normal(stddev=self.config.initializer_range),
        #     (self.config.L, self.config.L),
        #     # (self.config.hidden_dim, self.config.L),
        # )
        # Wq = self.param(
        #     "Wq",
        #     jax.nn.initializers.normal(stddev=self.config.initializer_range),
        #     (
        #         self.config.L,
        #         self.config.L + self.config.nheads,
        #     ),
        #     # (self.config.hidden_dim, self.config.L + self.config.nheads),
        # )
        # Wv = self.param(
        #     "Wv",
        #     jax.nn.initializers.normal(stddev=self.config.initializer_range),
        #     (self.config.hidden_dim, self.config.hidden_dim),
        # )
        # Wk, Wq, Wv = promote_dtype(Wk, Wq, Wv, dtype=self.dtype)
        #  Wk, Wq = promote_dtype(Wk, Wq, dtype=self.dtype)
        # conv = DepthConv1D(
        #     nchannels=self.config.hidden_dim,
        #     out_channels=self.config.hidden_dim,
        #     kernel_size=4,  # self.config.max_seq_len,
        #     dtype=self.dtype,
        # )

        lru_in = self.param(
            "lru_in",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        conv = RGLRU(
            width=self.config.L,
            num_heads=self.config.nheads,
            dtype=self.dtype,
        )

        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )

        # regularization
        Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)
        attn_dropout = nn.Dropout(self.config.dropout_att, deterministic=not train)
        B, T, C = (
            X.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)
        # Y = conv(X)
        pos_ids = jnp.repeat(
            jnp.arange(T)[None], B, axis=0
        )  # jnp.broadcast_to(jnp.arange(0, T)[None, :], (B, 1))

        Y, _ = conv(jnp.einsum("DL,BTD->BTL", lru_in, X), pos_ids, return_cache=False)
        Y = RMSNorm(self.config.L, dtype=self.dtype)(Y)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        Q, K = q_proj(Y), k_proj(Y)
        v = v_proj(X)
        att_dim = C // self.config.nheads
        K = K.reshape(B, T, self.config.nheads, att_dim).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        Q = Q.reshape(B, T, self.config.nheads, att_dim).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        v = v.reshape(B, T, self.config.nheads, att_dim).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)

        sliding_attention = SlidingWindowAtt(
            window_size=self.config.att_block_len, exact_windowsize=True, causal=True
        )

        # jax.vmap(fun, in_axes=0, out_axes=0,
        p_s_l0 = sliding_attention(
            Q,
            K,
            v,
            input_mask=None,
            attn_dropout=attn_dropout,
            rot_embeds=rot_embeds,
        )

        # multi head implementation
        # Q = jnp.einsum("DL,BTD->TBL", Wq, X).reshape(T, B, self.config.nheads, -1)
        # K = jnp.einsum("DL,BTD->TBL", Wk, X).reshape(T, B, self.config.nheads, -1)
        # # V = jnp.einsum("DM,BTD->TBM", Wv, X).reshape(T, B, self.config.nheads, -1)
        # Q = jax.nn.softmax(Q, axis=-1)
        # Q = Q_drop(Q)

        # p_s_l0 = p_s_l0.reshape(B, self.config.nheads, T, att_dim)  # BHNLD -> BHTD
        # # p(l=0|t) \sum_{s \in[t, t-w]}^t p(s|l,t)v_s
        # causal_att = jnp.einsum("TBH,BHTD->BHTD", Q[:, :, :, 0], p_s_l0)

        p_s_l0 = p_s_l0.reshape(B, self.config.nheads, T, att_dim)  # BHNLD -> BHTD
        # \sum_{l=1}^Lp(l|t) \sum_{s=0}^t p(s|l,t)v_s
        # v = V
        causal_att = p_s_l0
        v = causal_att
        v = v.transpose(0, 2, 1, 3)  # BHTD -> BTHD
        v = v.reshape(B, T, -1)

        self.sow("intermediates", "Qs", Q)
        return c_proj(v)


class CausalRopeLatteMachiattoSlidingOld(CausalRopeLatteMachiattoChunk):
    config: Config
    unroll: int = 100
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, X, train=False):
        if self.config.embed_type == "rope":
            rot_embeds = RopeEmbeds(
                n_pos=self.config.pos_embed_max_len,
                d_model=self.config.hidden_dim // self.config.nheads,
            )
        elif self.config.embed_type == "xpos":
            rot_embeds = XPos(
                head_dim=self.config.hidden_dim // self.config.nheads,
                scale_base=512,  # self.config.max_seq_len,
            )
        # self attention
        q_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="q_proj",
        )

        k_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="k_proj",
        )
        v_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="v_proj",
        )

        # latte attention
        Wk = self.param(
            "Wk",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        Wq = self.param(
            "Wq",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L + self.config.nheads),
        )
        # Wv = self.param(
        #     "Wv",
        #     jax.nn.initializers.normal(stddev=self.config.initializer_range),
        #     (self.config.hidden_dim, self.config.hidden_dim),
        # )
        # Wk, Wq, Wv = promote_dtype(Wk, Wq, Wv, dtype=self.dtype)
        Wk, Wq = promote_dtype(Wk, Wq, dtype=self.dtype)
        # conv = DepthConv1D(
        #     nchannels=self.config.hidden_dim,
        #     out_channels=self.config.hidden_dim,
        #     kernel_size=3,  # self.config.max_seq_len,
        #     dtype=self.dtype,
        # )

        # conv = S5Layer(
        #     ssm_size=128,
        #     hidden_dim=self.config.hidden_dim,
        #     blocks=1,
        #     dtype=self.dtype,
        #     name="S5",
        # )

        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            # kernel_init=jax.nn.initializers.normal(
            #     stddev=self.config.initializer_range
            #     / math.sqrt(2 * self.config.nlayers)
            # ),
        )

        # regularization
        attn_dropout = nn.Dropout(rate=self.config.dropout, deterministic=not train)
        Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)
        B, T, C = (
            X.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)
        # Y = conv(X)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = q_proj(X), k_proj(X), v_proj(X)
        att_dim = C // self.config.nheads
        k = k.reshape(B, T, self.config.nheads, att_dim).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        q = q.reshape(B, T, self.config.nheads, att_dim).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        v = v.reshape(B, T, self.config.nheads, att_dim).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)

        sliding_attention = SlidingWindowAtt(
            window_size=self.config.att_block_len, exact_windowsize=True, causal=True
        )

        # multi head implementation
        Q = jnp.einsum("DL,BTD->TBL", Wq, Y).reshape(T, B, self.config.nheads, -1)
        K = jnp.einsum("DL,BTD->TBL", Wk, Y).reshape(T, B, self.config.nheads, -1)
        # V = jnp.einsum("DM,BTD->TBM", Wv, X).reshape(T, B, self.config.nheads, -1)
        Qs = jax.nn.softmax(Q, axis=-1)
        Qs = Q_drop(Qs)

        # jax.vmap(fun, in_axes=0, out_axes=0,
        p_s_l0 = sliding_attention(
            q,
            k,
            v,
            input_mask=None,
            attn_dropout=attn_dropout,
            rot_embeds=rot_embeds,
        )
        p_s_l0 = p_s_l0.reshape(B, self.config.nheads, T, att_dim)  # BHNLD -> BHTD
        # p(l=0|t) \sum_{s \in[t, t-w]}^t p(s|l,t)v_s
        causal_att = jnp.einsum("TBH,BHTD->BHTD", Q[:, :, :, 0], p_s_l0)

        # \sum_{l=1}^Lp(l|t) \sum_{s=0}^t p(s|l,t)v_s
        # v = V
        v = v.transpose(2, 0, 1, 3)  # .reshape(T, B, self.config.nheads, -1)
        latte_att = self.latte_attention4(rot_embeds, Qs[:, :, :, 1:], K=K, V=v)  # BHTD
        y = causal_att + latte_att
        y = y.transpose(0, 2, 1, 3)  # BHTD -> BTHD
        y = y.reshape(B, T, -1)
        return c_proj(y)


# class CausalRopeLatteMachiattoSlidingOld(CausalRopeLatteMachiattoChunk):
#     config: Config
#     unroll: int = 100
#     dtype: jnp.dtype = jnp.float32

#     @nn.compact
#     def __call__(self, X, train=False):
#         if self.config.embed_type == "rope":
#             rot_embeds = RopeEmbeds(
#                 n_pos=self.config.pos_embed_max_len,
#                 d_model=self.config.hidden_dim // self.config.nheads,
#             )
#         elif self.config.embed_type == "xpos":
#             rot_embeds = XPos(
#                 head_dim=self.config.hidden_dim // self.config.nheads,
#                 scale_base=self.config.max_seq_len,
#             )
#         # self attention
#         c_attn = nn.Dense(
#             3 * self.config.hidden_dim,
#             use_bias=False,
#             dtype=self.dtype,
#             kernel_init=jax.nn.initializers.normal(
#                 stddev=self.config.initializer_range
#             ),
#         )
#         # latte attention
#         Wk = self.param(
#             "Wk",
#             jax.nn.initializers.normal(stddev=self.config.initializer_range),
#             (self.config.hidden_dim, self.config.L),
#         )
#         Wq = self.param(
#             "Wq",
#             jax.nn.initializers.normal(stddev=self.config.initializer_range),
#             (self.config.hidden_dim, self.config.L + self.config.nheads),
#         )
#         # Wv = self.param(
#         #     "Wv",
#         #     jax.nn.initializers.normal(stddev=self.config.initializer_range),
#         #     (self.config.hidden_dim, self.config.hidden_dim),
#         # )
#         # Wk, Wq, Wv = promote_dtype(Wk, Wq, Wv, dtype=self.dtype)
#         Wk, Wq = promote_dtype(Wk, Wq, dtype=self.dtype)
#         conv = Conv1D(
#             nchannels=self.config.hidden_dim,
#             out_channels=self.config.hidden_dim,
#             kernel_size=3,
#             dtype=self.dtype,
#         )

#         # output projection
#         c_proj = nn.Dense(
#             self.config.hidden_dim,
#             use_bias=False,
#             dtype=self.dtype,
#             kernel_init=jax.nn.initializers.normal(
#                 stddev=self.config.initializer_range
#             ),
#             # kernel_init=jax.nn.initializers.normal(
#             #     stddev=self.config.initializer_range
#             #     / math.sqrt(2 * self.config.nlayers)
#             # ),
#         )

#         # regularization
#         attn_dropout = nn.Dropout(rate=self.config.dropout, deterministic=not train)
#         Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)
#         B, T, C = (
#             X.shape
#         )  # batch size, sequence length, embedding dimensionality (n_embd)

#         # calculate query, key, values for all heads in batch and move head forward to be the batch dim
#         q, k, v = jnp.split(c_attn(X), 3, axis=2)
#         att_dim = C // self.config.nheads
#         k = k.reshape(B, T, self.config.nheads, att_dim).transpose(
#             0, 2, 1, 3
#         )  # (B, nh, T, hs)
#         q = q.reshape(B, T, self.config.nheads, att_dim).transpose(
#             0, 2, 1, 3
#         )  # (B, nh, T, hs)
#         v = v.reshape(B, T, self.config.nheads, att_dim).transpose(
#             0, 2, 1, 3
#         )  # (B, nh, T, hs)

#         sliding_attention = SlidingWindowAtt(
#             window_size=self.config.att_block_len, exact_windowsize=True, causal=True
#         )
#         Y = conv(X)
#         # multi head implementation
#         Q = jnp.einsum("DL,BTD->TBL", Wq, Y).reshape(T, B, self.config.nheads, -1)
#         K = jnp.einsum("DL,BTD->TBL", Wk, Y).reshape(T, B, self.config.nheads, -1)
#         # V = jnp.einsum("DM,BTD->TBM", Wv, X).reshape(T, B, self.config.nheads, -1)
#         Qs = jax.nn.softmax(Q, axis=-1)
#         Qs = Q_drop(Qs)

#         # jax.vmap(fun, in_axes=0, out_axes=0,
#         p_s_l0 = sliding_attention(
#             q,
#             k,
#             v,
#             input_mask=None,
#             attn_dropout=attn_dropout,
#             rot_embeds=rot_embeds,
#         )
#         p_s_l0 = p_s_l0.reshape(B, self.config.nheads, T, att_dim)  # BHNLD -> BHTD
#         # p(l=0|t) \sum_{s \in[t, t-w]}^t p(s|l,t)v_s
#         causal_att = jnp.einsum("TBH,BHTD->BHTD", Q[:, :, :, 0], p_s_l0)

#         # \sum_{l=1}^Lp(l|t) \sum_{s=0}^t p(s|l,t)v_s
#         # v = V
#         v = v.transpose(2, 0, 1, 3)  # .reshape(T, B, self.config.nheads, -1)
#         latte_att = self.latte_attention4(rot_embeds, Qs[:, :, :, 1:], K=K, V=v)  # BHTD
#         y = causal_att + latte_att
#         y = y.transpose(0, 2, 1, 3)  # BHTD -> BTHD
#         y = y.reshape(B, T, -1)
#         return c_proj(y)
