"""
The new architexure which is more coplex then the vanilla tranformer. Includes convolutions and gating before the mixing block.
"""

from typing import Literal, Type, List
from functools import partial
import einops
import math
import jax
from jax import numpy as jnp
from flax import linen as nn
from flax.linen.dtypes import promote_dtype
from ..layers import DepthConv1D, S5Layer, Embedder, Conv1D
from ..layers import SlidingWindowAtt
from latte_trans.config import LMTaskConfig
from .types import GemmaMachiattoCache, LatteCache
from recurrentgemma.jax.layers import RGLRU

parallel_scan = jax.lax.associative_scan


class FlaxGemmaRMSNorm(nn.Module):
    config: LMTaskConfig
    width: int = None
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        if self.width is None:
            width = self.config.hidden_dim
        else:
            width = self.width
        self.epsilon = 1e-6  # self.config.rms_norm_eps
        self.weight = self.param("scale", lambda _, shape: jnp.ones(shape), width)

    def __call__(self, hidden_states):
        variance = jnp.asarray(hidden_states, dtype=jnp.float32)
        variance = jnp.power(variance, 2)
        variance = variance.mean(-1, keepdims=True)
        # use `jax.numpy.sqrt` as `jax.lax.rsqrt` does not match `torch.rsqrt`
        hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon)
        out = (1 + self.weight.astype(self.dtype)) * hidden_states.astype(self.dtype)
        return out


class Gemma2RotaryEmbedding(nn.Module):
    dim: int
    max_position_embeddings: int = 2048
    base: int = 10000

    def setup(self):

        self.inv_freq = 1.0 / (
            self.base
            ** (
                jnp.arange(0, self.dim, 2, dtype=jnp.int32).astype(jnp.float32)
                / self.dim
            )
        )

    def __call__(self, x, position_ids, seq_len=None):
        # print(x.shape)
        if position_ids is None:
            position_ids = jnp.arange(0, x.shape[2])[None, ...]

        # x: [bs, num_attention_heads, seq_len, head_size]
        inv_freq_expanded = self.inv_freq[None, :, None].astype(jnp.float32)
        inv_freq_expanded = jnp.broadcast_to(
            inv_freq_expanded, (position_ids.shape[0], inv_freq_expanded.shape[1], 1)
        )
        position_ids_expanded = position_ids[:, None, :].astype(jnp.float32)
        # Force float32 since bfloat16 loses precision on long contexts
        # See https://github.com/huggingface/transformers/pull/29285
        freqs = (
            inv_freq_expanded.astype(jnp.float32)
            @ position_ids_expanded.astype(jnp.float32)
        ).transpose(0, 2, 1)
        emb = jnp.concatenate((freqs, freqs), axis=-1)
        # apply interpolation like in: https://arxiv.org/pdf/2306.15595
        # seq_len = position_ids.shape[1]
        # if seq_len > self.max_position_embeddings:
        #     print("Using interpolation")
        #     emb = emb * (self.max_position_embeddings / seq_len)
        cos = jnp.cos(emb)
        sin = jnp.sin(emb)
        return cos.astype(dtype=x.dtype), sin.astype(dtype=x.dtype)


def rotate_half(x, neg=False):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    if neg:
        return jnp.concatenate((x2, -x1), axis=-1)
    else:
        # Classic rope. -n comes from transpose
        return jnp.concatenate((-x2, x1), axis=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = jnp.expand_dims(cos, unsqueeze_dim)
    sin = jnp.expand_dims(sin, unsqueeze_dim)
    # print(cos.shape, q.shape)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


def apply_vapor(x, cos, sin, neg=False, unsqueeze_dim=1):
    """Applies Vapor with Rotary Position Embedding to the query and key tensors.
    Args:
        x (`torch.Tensor`): The query tensor: jax.Array(T,B,H,D)
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.

        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    # (T,B,H,D) -> (B, H, T, D)
    x = x.transpose(1, 2, 0, 3)
    cos = jnp.expand_dims(cos, unsqueeze_dim)
    sin = jnp.expand_dims(sin, unsqueeze_dim)
    # print(cos.shape, q.shape)
    x_embed = (x * cos) + (rotate_half(x, neg=neg) * sin)
    # (B, H, T, D) -> (T,B,H,D)
    x_embed = x_embed.transpose(2, 0, 1, 3)
    return x_embed


def repeat_kv(hidden_states: jax.Array, n_rep: int) -> jax.Array:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = jnp.broadcast_to(
        hidden_states[:, :, None, :, :],
        shape=(batch, num_key_value_heads, n_rep, slen, head_dim),
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class SlidingWindowAtt:
    """Not a proper flax module. Just apply sliding window attention given Q,K,V
    Sliding window attention in O(T2W). Kernel implementation nedeed for O(TW)
    Pass `exact_windowsize=true` to mask entries on the far left since algorithm uses 2W.
    """

    def __init__(
        self,
        window_size: int,
        exact_windowsize: bool = True,
        causal=True,
        attn_logit_softcapping=None,
    ):
        self.attn_logit_softcapping = attn_logit_softcapping
        self.exact_windowsize = exact_windowsize
        self.causal = causal
        if causal:
            self.window_size = window_size
        else:  # bidirectional case
            self.window_size = window_size // 2

    @staticmethod
    def look_around(x, backward=1, forward=0, pad_value=-1, dim=2):
        t = x.shape[1]
        pad_width = len(x.shape) * [(0, 0)]
        pad_width[1] = (backward, forward)
        padded_x = jnp.pad(x, pad_width=pad_width, constant_values=pad_value)
        tensors = [
            padded_x[:, ind : (ind + t), ...] for ind in range(backward + forward + 1)
        ]
        return jnp.concatenate(tensors, axis=dim)

    @staticmethod
    def pad_to_multiple(x, multiple, dim=-1, value=0):
        seqlen = x.shape[dim]
        m = seqlen / multiple
        if m.is_integer():
            return False, x
        remainder = math.ceil(m) * multiple - seqlen
        # pad_offset = (0,) * (-1 - dim) * 2
        pad_width = len(x.shape) * [(0, 0)]
        pad_width[dim] = (0, remainder)
        return True, jnp.pad(x, pad_width=pad_width, constant_values=value)

    def __call__(self, Q, K, V, input_mask, attn_dropout):
        """
        Args:
            Q,K,V: jax.Array(B,H,T,D)
            input_mask: jax.Array(BT) - useful only for bidirectional case. 1 = attend, 0 ignore
        Returns:
            jax.Array(B,H,T,D)

        """
        B, H, T, D = Q.shape
        pad_value = -1

        if self.causal:
            forward = 0
        else:
            forward = 1

        # merge batch and heads for ease
        (Q, packed_shape), (K, _), (V, _) = map(
            lambda t: einops.pack([t], "* n d"), (Q, K, V)
        )
        # autopadding to make sure seq length divisible by window size - discard before returning
        orig_seq_len = Q.shape[1]
        (needed_pad, Q), (_, K), (_, V) = map(
            lambda t: self.pad_to_multiple(
                t, self.window_size, dim=-2, value=pad_value
            ),
            (Q, K, V),
        )
        B, T, dim_head = Q.shape
        assert (
            T % self.window_size
        ) == 0, f"sequence length {T} must be divisible by window size {self.window_size} for local attention"

        windows = T // self.window_size
        bq, bk, bv = map(
            lambda t: einops.rearrange(t, "b (w n) d -> b w n d", w=windows), (Q, K, V)
        )
        # concatenate one window ahead to make sure first token had w context length
        bk = self.look_around(bk, backward=1, forward=forward, pad_value=pad_value)
        bv = self.look_around(bv, backward=1, forward=forward, pad_value=pad_value)

        # attention scale sqrt(1/dim_head)
        sim = einops.einsum(bq, bk, "b h i e, b h j e -> b h i j") * 256**-0.5

        if self.attn_logit_softcapping is not None:
            sim = sim / self.attn_logit_softcapping
            sim = jax.nn.tanh(sim)
            sim = sim * self.attn_logit_softcapping

        # handle padding
        seq = jnp.arange(T)
        b_t = einops.rearrange(seq, "(w n) -> 1 w n", w=windows, n=self.window_size)
        bq_t = b_t
        bq_k = self.look_around(b_t, backward=1, forward=forward, pad_value=pad_value)

        bq_t = einops.rearrange(bq_t, "... i -> ... i 1")
        bq_k = einops.rearrange(bq_k, "... j -> ... 1 j")
        pad_mask = bq_k == pad_value

        if self.causal:
            causal_mask = bq_t < bq_k
            if self.exact_windowsize:
                causal_mask = causal_mask | (bq_t > (bq_k + self.window_size))
            sim = jnp.where(causal_mask, -9e15, sim)

        # bidirectional case
        if not self.causal and self.exact_windowsize:
            window_mask = ((bq_k - self.window_size) > bq_t) | (
                bq_t > (bq_k + self.window_size)
            )
            sim = jnp.where(window_mask, -9e15, sim)
        # everything has a pad mask
        sim = jnp.where(pad_mask, -9e15, sim)

        if input_mask is not None:
            _, input_mask = self.pad_to_multiple(
                mask, self.window_size, dim=-1, value=False
            )
            input_mask = einops.rearrange(
                input_mask, "... (w n) -> (...) w n", w=windows, n=self.window_size
            )
            mask = self.look_around(mask, backward=1, forward=forward, pad_value=False)
            mask = einops.rearrange(mask, "... j -> ... 1 j")
            sim = jnp.where(input_mask, sim, -9e15)

        attn = jax.nn.softmax(sim, axis=-1)
        attn = attn_dropout(attn)
        out = einops.einsum(attn, bv, "b h i j, b h j e -> b h i e")
        out = einops.rearrange(out, "b w n d -> b (w n) d")
        out = out[:, :orig_seq_len, :]
        out, *_ = einops.unpack(out, packed_shape, "* n d")
        return out


def accumulate(carry, args):
    nu, alpha, prev_max = carry
    Qs_t, curr_alph, V_t, c_mx = args
    revert_maxi = jnp.exp(-c_mx + prev_max)
    add_maxi = jnp.exp(curr_alph - c_mx)

    alpha = jnp.einsum("BHL,BHL->BHL", alpha, revert_maxi)
    alpha += add_maxi
    nu = jnp.einsum("BHLD,BHL->BHLD", nu, revert_maxi)
    nu += jnp.einsum("BHL,BHD->BHLD", add_maxi, V_t)
    y = jnp.einsum("BHL,BHLD->BHD", Qs_t / alpha, nu)
    return ((nu, alpha, c_mx), y)


class CausalRopeLatteMachiattoSliding(nn.Module):
    config: LMTaskConfig
    attn_logit_softcapping: float = 50.0
    unroll: int = 100
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(
        self,
        X,
        train=False,
        cache: GemmaMachiattoCache = None,
        do_inference=False,
        **kwargs,
    ):
        nheads = self.config.nheads
        if self.config.num_key_value_heads:
            num_key_value_heads = self.config.num_key_value_heads
        else:
            num_key_value_heads = self.config.nheads
        if self.config.head_dim:
            head_dim = self.config.head_dim
        else:
            head_dim = self.config.hidden_dim // self.config.nheads

        num_key_value_groups = self.config.nheads // num_key_value_heads

        rotary_emb = Gemma2RotaryEmbedding(
            dim=head_dim,
            max_position_embeddings=self.config.pos_embed_max_len,
            base=10000.0,
        )

        sliding_att = SlidingWindowAtt(
            window_size=self.config.att_block_len,
            exact_windowsize=True,
            causal=True,
            attn_logit_softcapping=self.attn_logit_softcapping,
        )

        # key, query, value projections for all heads, but in a batch
        q_proj = nn.Dense(
            nheads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="q_proj",
        )

        k_proj = nn.Dense(
            num_key_value_heads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="k_proj",
        )
        v_proj = nn.Dense(
            num_key_value_heads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="v_proj",
        )
        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="out_proj",
        )

        Wk = nn.Dense(
            self.config.L,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            dtype=self.dtype,
            name="latte_Wk",
        )
        Wq = nn.Dense(
            self.config.L + self.config.nheads,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            dtype=self.dtype,
            name="latte_Wq",
        )

        lru_in = nn.Dense(
            self.config.L,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            dtype=self.dtype,
            name="latte_lru_in",
        )

        latte_conv = RGLRU(
            width=self.config.L,
            num_heads=self.config.nheads,
            dtype=self.dtype,
            name="latte_rglru",
        )
        latte_lru_norm = FlaxGemmaRMSNorm(
            config=self.config,
            dtype=self.dtype,
            width=self.config.L,
            name="latte_lru_norm",
        )
        # regularization
        attn_dropout = nn.Dropout(rate=self.config.dropout, deterministic=not train)
        Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)
        # batch size, sequence length, embedding dimensionality (n_embd)
        if do_inference:
            X = X[:, -1:, :]
            if cache is None:
                cache = self.init_mach_cache(X)
                window_att = X
            else:
                # jax.debug.print("Cache is: {x}", x=cache)
                window_att = cache.window_att
                # Small Caveat: window size means w prev tokens (total = w + 1)
                if window_att.shape[1] > self.config.att_block_len:
                    window_att = window_att[:, 1:, :]
                window_att = jnp.concat([window_att, X], axis=1)
                X = window_att

        ### Sliding Attention ###
        B, T, C = X.shape
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k = q_proj(X), k_proj(X)
        v = v_proj(X)
        k = k.reshape(B, T, num_key_value_heads, -1).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        q = q.reshape(B, T, nheads, -1).transpose(0, 2, 1, 3)  # (B, nh, T, hs)
        v = v.reshape(B, T, num_key_value_heads, -1).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)

        cos, sin = rotary_emb(v, position_ids=None)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)
        k = repeat_kv(k, num_key_value_groups)
        v = repeat_kv(v, num_key_value_groups)
        p_s_l0 = sliding_att(q, k, v, input_mask=None, attn_dropout=attn_dropout)
        p_s_l0 = p_s_l0.reshape(B, self.config.nheads, T, -1)  # BHNLD -> BHTD

        #### Latte ###
        if do_inference:
            # eliminate data from local window
            p_s_l0 = p_s_l0[:, :, -1:, :]
            X = X[:, -1:, :]
            T = 1
            pos_ids = cache.positions
            X, new_conv_cache = latte_conv(
                lru_in(X), pos_ids, cache=cache.conv, return_cache=True
            )
        else:
            pos_ids = jnp.repeat(jnp.arange(T)[None], B, axis=0)
            X, _ = latte_conv(lru_in(X), pos_ids, return_cache=False)

        X = latte_lru_norm(X)
        # multi head implementation
        q = Wq(X)
        k = Wk(X)
        q = jnp.einsum("BTL->TBL", q).reshape(T, B, self.config.nheads, -1)
        k = jnp.einsum("BTL->TBL", k).reshape(T, B, self.config.nheads, -1)

        q = jax.nn.softmax(q, axis=-1)
        q = Q_drop(q)

        # p(l=0|t) \sum_{s \in[t, t-w]}^t p(s|l,t)v_s
        causal_att = jnp.einsum("TBH,BHTD->BHTD", q[:, :, :, 0], p_s_l0)

        # \sum_{l=1}^Lp(l|t) \sum_{s=0}^t p(s|l,t)v_s
        # (B, H, T, D) -> T B H D
        v = v.transpose(2, 0, 1, 3)  # .reshape(T, B, self.config.nheads, -1)
        if do_inference:
            # eliminate data from local window
            v = v[-1:]
            # jax.debug.print("K: {}", k[:, 0], ordered=True)
            latte_att, new_latte_cache = self.latte_inference(
                cache=cache.latte, Q=q[:, :, :, 1:], K=k, V=v
            )
        else:
            # BHTD
            latte_att = self.latte_scan(q[:, :, :, 1:], K=k, V=v)
        v = causal_att + latte_att
        v = v.transpose(0, 2, 1, 3)  # BHTD -> BTHD
        v = v.reshape(B, T, -1)
        v = c_proj(v)
        if do_inference:
            new_cache = GemmaMachiattoCache(
                positions=pos_ids + 1,
                conv=new_conv_cache,
                latte=new_latte_cache,
                window_att=window_att,
            )
            return v, new_cache

        return v, None

    def latte_scan(self, Qs, K, V):
        """latte_attention in O(TL + LD)
        Args:
            Q: jax.Array(T,B,H,L)
            K: jax.Array(T,B,H,L)
            V: jax.Array(T,B,H,D)
        """
        T, B, H, C = V.shape
        L = Qs.shape[-1]

        maxi = jax.lax.cummax(K, axis=0)
        # maxi for stability should be trated as a constant - no grad is faster
        maxi = jax.lax.stop_gradient(maxi)

        init_alpha = jnp.zeros(shape=(B, H, L), dtype=self.dtype)
        init_nu = jnp.zeros((B, H, L, C), dtype=self.dtype)
        _, y = jax.lax.scan(
            jax.checkpoint(accumulate),
            unroll=self.unroll,
            init=(
                init_nu,
                init_alpha,
                K[0],
            ),
            xs=[Qs, K, V, maxi],
        )
        # TBHD ->   BHTD
        return y.transpose(1, 2, 0, 3)

    def latte_parallel(self, Q, K, V):
        """Faster version of latte_attention by applying parallel scanns O(TLD)
        Args:
            Q: jax.Array(T,B,H,L)
            K: jax.Array(T,B,H,L)
            V: jax.Array(T,B,H,D)
        """
        T, B, H, C = V.shape
        L = Q.shape[-1]

        maxi = jax.lax.cummax(K, axis=0)
        # maxi for stability should be trated as a constant - no grad is faster
        maxi = jax.lax.stop_gradient(maxi)
        # revert maxi
        revert_maxi = jnp.zeros_like(maxi)
        revert_maxi = revert_maxi.at[1:].set(-maxi[1:] + maxi[:-1])
        revert_maxi = jnp.exp(revert_maxi)  # TBHL
        add_maxi = jnp.exp(K - maxi)
        nu = jnp.einsum("TBHL,TBHD->TBHLD", add_maxi, V)

        def bin_V(A, B):
            rmA, amA, nuA = A
            rmB, amB, nuB = B
            nu = nuA * rmB[..., None] + nuB
            alpha = amA * rmB + amB
            return (rmA * rmB, alpha, nu)

        _, alpha, y = parallel_scan(bin_V, (revert_maxi, add_maxi, nu))
        y = jnp.einsum("TBHL,TBHLD->TBHD", Q / alpha, y)
        # TBHD ->   BHTD
        return y.transpose(1, 2, 0, 3)

    def latte_inference(
        self, cache: LatteCache, Q: jax.Array, K: jax.Array, V: jax.Array
    ):
        """
        Recurrent version of forward, expects one token at a time
        Args:
            cache: Dict[str, Any] - previous recursive state
            Q: jax.Array(1,B,H,L)
            K: jax.Array(1,B,H,L)
            V: jax.Array(1,B,H,D)
        """
        # Q = jnp.einsum("DL,BD->BL", Wq, X).reshape(B, H, -1)
        # Qs_t = jax.nn.softmax(Q, axis=-1)
        # Qs_t = Q_drop(Qs_t)
        # V_t = jnp.einsum("DM,BD->BM", Wv, X).reshape(B, H, -1)  # for nu
        # K_t = jnp.einsum("DL,BD->BL", Wk, X).reshape(B, H, -1)  # for alpha

        K_t, Qs_t, V_t = K[0], Q[0], V[0]
        if cache is None:
            # initialise hidden state
            cache = self.init_latte_cache(K_t, D=V.shape[-1])

        alpha = cache.alpha
        nu = cache.nu
        prev_max = cache.prev_max
        # jax.debug.print("prev_max is: {x}", x=prev_max, ordered=True)

        c_max = jnp.maximum(prev_max, K_t)
        revert_maxi = jnp.exp(-c_max + prev_max)
        add_maxi = jnp.exp(K_t - c_max)

        alpha = jnp.einsum("BHL,BHL->BHL", alpha, revert_maxi)
        alpha += add_maxi
        nu = jnp.einsum("BHLD,BHL->BHLD", nu, revert_maxi)
        nu += jnp.einsum("BHL,BHD->BHLD", add_maxi, V_t)

        y = jnp.einsum("BHL,BHLD->BHD", Qs_t / alpha, nu)
        # BHD -> BHTD
        y = y[:, :, None, :]
        # reset last hidden state
        new_cache = LatteCache(alpha=alpha, nu=nu, prev_max=c_max)
        return y, new_cache

    def init_mach_cache(self, X):
        B = X.shape[0]
        positions = jnp.zeros(shape=(B, 1), dtype=jnp.int32)
        window_att = None  # jnp.empty_like(X)  # None
        conv = None  # jnp.empty_like(X)  # None
        latte = None  # jnp.empty_like(X)  # None
        return GemmaMachiattoCache(
            positions=positions, conv=conv, latte=latte, window_att=window_att
        )

    def init_latte_cache(self, Kt, D):
        B = Kt.shape[0]
        H, L = self.config.nheads, self.config.L // self.config.nheads
        alpha = jnp.zeros(shape=(B, H, L))
        nu = jnp.zeros((B, H, L, D))
        prev_max = Kt
        return LatteCache(alpha=alpha, nu=nu, prev_max=prev_max)


class GemmaMLP(nn.Module):
    config: LMTaskConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        embed_dim = self.config.hidden_dim
        inner_dim = self.config.intermediate_dim
        kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
        self.activation_fn = partial(
            jax.nn.gelu, approximate=True
        )  # ACT2FN[self.config.hidden_act]
        self.gate_proj = nn.Dense(
            inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init
        )
        self.down_proj = nn.Dense(
            embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init
        )
        self.up_proj = nn.Dense(
            inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init
        )

    def __call__(self, hidden_states):
        up_proj_states = self.up_proj(hidden_states)
        gate_states = jax.nn.gelu(
            self.gate_proj(hidden_states), approximate=True
        )  # self.activation_fn(self.gate_proj(hidden_states))

        hidden_states = self.down_proj(up_proj_states * gate_states)
        return hidden_states


class GemmaDecoderLayer(nn.Module):
    config: LMTaskConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):

        self.self_attn = CausalRopeLatteMachiattoSliding(  # GemmaCausalAtt(  #   # GemmaSlidingCausalAtt(  #
            config=self.config, dtype=self.dtype
        )  # get_decoder_mixer(self.config, self.dtype)
        self.input_layernorm = FlaxGemmaRMSNorm(
            config=self.config, dtype=self.dtype, name="input_layernorm"
        )
        self.post_attention_layernorm = FlaxGemmaRMSNorm(
            config=self.config, dtype=self.dtype, name="post_attention_layernorm"
        )
        self.pre_feedforward_layernorm = FlaxGemmaRMSNorm(
            config=self.config, dtype=self.dtype, name="pre_feedforward_layernorm"
        )
        self.post_feedforward_layernorm = FlaxGemmaRMSNorm(
            config=self.config, dtype=self.dtype, name="post_feedforward_layernorm"
        )
        self.mlp = GemmaMLP(self.config, dtype=self.dtype, name="mlp")

    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        train: bool = False,
        cache: GemmaMachiattoCache = None,
        do_inference=False,
    ):
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states, cache = self.self_attn(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=None,
            train=train,
            cache=cache,
            do_inference=do_inference,
        )
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.pre_feedforward_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_feedforward_layernorm(hidden_states)
        # residual connection
        hidden_states = residual + hidden_states
        return hidden_states, cache


class GemmaDecoder(nn.Module):
    config: LMTaskConfig
    sharded: bool
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(
        self,
        X,
        attention_mask,
        train=False,
        cache: List[GemmaMachiattoCache] = None,
        do_inference: bool = False,
    ):
        final_norm = FlaxGemmaRMSNorm(config=self.config, dtype=self.dtype, name="norm")
        block_fn = partial(
            GemmaDecoderLayer,
            config=self.config,
            dtype=self.dtype,
            name="residual_block",
        )
        # faster to compile - but a few disadvantages like merging bloks and lossing some default optimisations\
        block = block_fn()

        def module_apply(module, carry, layer_cache):
            res, cache = module(
                carry,
                attention_mask,
                train=train,
                cache=layer_cache,
                do_inference=do_inference,
            )
            # carry, ys
            return (res, cache)

        if self.sharded:
            X, cache = nn.scan(
                module_apply,
                variable_axes={"params": 0, "intermediates": 0},
                split_rngs={"params": True, "dropout": True},
                length=self.config.nlayers,
            )(block, X, cache)
        else:
            X, cache = nn.scan(
                module_apply,
                variable_axes={"params": 0, "intermediates": 0},
                split_rngs={"params": True, "dropout": True},
                length=self.config.nlayers,
                metadata_params={
                    "partition_name": None
                },  # We do not need to partition over the layer axis.
            )(block, X, cache)
        # get logits
        X = final_norm(X)
        return X, cache
