"""
The new architexure which is more coplex then the vanilla tranformer. Includes convolutions and gating before the mixing block.
"""

from typing import Literal, Type
from functools import partial
import einops
import math
import jax
from jax import numpy as jnp
from flax import linen as nn
from flax.linen.dtypes import promote_dtype
from ..layers import DepthConv1D, S5Layer, Embedder, Conv1D, RMSNorm
from ..layers import SlidingWindowAtt
from latte_trans.config import LMTaskConfig, ATT_TYPE
from recurrentgemma.jax.layers import RGLRU

parallel_scan = jax.lax.associative_scan


class FlaxGemmaRMSNorm(nn.Module):
    config: LMTaskConfig
    width: int = None
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        if self.width is None:
            width = self.config.hidden_dim
        else:
            width = self.width
        self.epsilon = 1e-6  # self.config.rms_norm_eps
        self.weight = self.param("scale", lambda _, shape: jnp.ones(shape), width)

    def __call__(self, hidden_states):
        variance = jnp.asarray(hidden_states, dtype=jnp.float32)
        variance = jnp.power(variance, 2)
        variance = variance.mean(-1, keepdims=True)
        # use `jax.numpy.sqrt` as `jax.lax.rsqrt` does not match `torch.rsqrt`
        hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon)
        out = (1 + self.weight.astype(self.dtype)) * hidden_states.astype(self.dtype)
        return out


class Gemma2RotaryEmbedding(nn.Module):
    dim: int
    max_position_embeddings: int = 2048
    base: int = 10000

    def setup(self):

        self.inv_freq = 1.0 / (
            self.base
            ** (
                jnp.arange(0, self.dim, 2, dtype=jnp.int32).astype(jnp.float32)
                / self.dim
            )
        )

    def __call__(self, x, position_ids, seq_len=None):
        # print(x.shape)
        if position_ids is None:
            position_ids = jnp.arange(0, x.shape[2])[None, ...]

        # x: [bs, num_attention_heads, seq_len, head_size]
        inv_freq_expanded = self.inv_freq[None, :, None].astype(jnp.float32)
        inv_freq_expanded = jnp.broadcast_to(
            inv_freq_expanded, (position_ids.shape[0], inv_freq_expanded.shape[1], 1)
        )
        position_ids_expanded = position_ids[:, None, :].astype(jnp.float32)
        # Force float32 since bfloat16 loses precision on long contexts
        # See https://github.com/huggingface/transformers/pull/29285
        freqs = (
            inv_freq_expanded.astype(jnp.float32)
            @ position_ids_expanded.astype(jnp.float32)
        ).transpose(0, 2, 1)
        emb = jnp.concatenate((freqs, freqs), axis=-1)
        # apply interpolation like in: https://arxiv.org/pdf/2306.15595
        # seq_len = position_ids.shape[1]
        # if seq_len > self.max_position_embeddings:
        #     print("Using interpolation")
        #     emb = emb * (self.max_position_embeddings / seq_len)
        cos = jnp.cos(emb)
        sin = jnp.sin(emb)
        return cos.astype(dtype=x.dtype), sin.astype(dtype=x.dtype)


def rotate_half(x, neg=False):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    if neg:
        return jnp.concatenate((x2, -x1), axis=-1)
    else:
        # Classic rope. -n comes from transpose
        return jnp.concatenate((-x2, x1), axis=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = jnp.expand_dims(cos, unsqueeze_dim)
    sin = jnp.expand_dims(sin, unsqueeze_dim)
    # print(cos.shape, q.shape)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


def apply_vapor(x, cos, sin, neg=False, unsqueeze_dim=1):
    """Applies Vapor with Rotary Position Embedding to the query and key tensors.
    Args:
        x (`torch.Tensor`): The query tensor: jax.Array(T,B,H,D)
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.

        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    # (T,B,H,D) -> (B, H, T, D)
    x = x.transpose(1, 2, 0, 3)
    cos = jnp.expand_dims(cos, unsqueeze_dim)
    sin = jnp.expand_dims(sin, unsqueeze_dim)
    # print(cos.shape, q.shape)
    x_embed = (x * cos) + (rotate_half(x, neg=neg) * sin)
    # (B, H, T, D) -> (T,B,H,D)
    x_embed = x_embed.transpose(2, 0, 1, 3)
    return x_embed


def repeat_kv(hidden_states: jax.Array, n_rep: int) -> jax.Array:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = jnp.broadcast_to(
        hidden_states[:, :, None, :, :],
        shape=(batch, num_key_value_heads, n_rep, slen, head_dim),
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class SlidingWindowAtt:
    """Not a proper flax module. Just apply sliding window attention given Q,K,V
    Sliding window attention in (OT2W). Kernel implementation nedeed for O(TW)
    Pass `exact_windowsize=true` to mask entires on the far left since algorithm uses 2W.
    """

    def __init__(
        self,
        window_size: int,
        exact_windowsize: bool = True,
        causal=True,
        attn_logit_softcapping=None,
    ):
        self.attn_logit_softcapping = attn_logit_softcapping
        self.exact_windowsize = exact_windowsize
        self.causal = causal
        if causal:
            self.window_size = window_size
        else:  # bidirectional case
            self.window_size = window_size // 2

    @staticmethod
    def look_around(x, backward=1, forward=0, pad_value=-1, dim=2):
        t = x.shape[1]
        pad_width = len(x.shape) * [(0, 0)]
        pad_width[1] = (backward, forward)
        padded_x = jnp.pad(x, pad_width=pad_width, constant_values=pad_value)
        tensors = [
            padded_x[:, ind : (ind + t), ...] for ind in range(backward + forward + 1)
        ]
        return jnp.concatenate(tensors, axis=dim)

    @staticmethod
    def pad_to_multiple(x, multiple, dim=-1, value=0):
        seqlen = x.shape[dim]
        m = seqlen / multiple
        if m.is_integer():
            return False, x
        remainder = math.ceil(m) * multiple - seqlen
        # pad_offset = (0,) * (-1 - dim) * 2
        pad_width = len(x.shape) * [(0, 0)]
        pad_width[dim] = (0, remainder)
        return True, jnp.pad(x, pad_width=pad_width, constant_values=value)

    def __call__(self, Q, K, V, input_mask, attn_dropout):
        """
        Args:
            Q,K,V: jax.Array(B,H,T,D)
            input_mask: jax.Array(BT) - useful only for bidirectional case. 1 = attend, 0 ignore
        Returns:
            jax.Array(B,H,T,D)

        """
        B, H, T, D = Q.shape
        pad_value = -1

        if self.causal:
            forward = 0
        else:
            forward = 1

        # merge batch and heads for ease
        (Q, packed_shape), (K, _), (V, _) = map(
            lambda t: einops.pack([t], "* n d"), (Q, K, V)
        )
        # autopadding to make sure seq length divisible by window size - discard before returning
        orig_seq_len = Q.shape[1]
        (needed_pad, Q), (_, K), (_, V) = map(
            lambda t: self.pad_to_multiple(
                t, self.window_size, dim=-2, value=pad_value
            ),
            (Q, K, V),
        )
        B, T, dim_head = Q.shape
        assert (
            T % self.window_size
        ) == 0, f"sequence length {T} must be divisible by window size {self.window_size} for local attention"

        windows = T // self.window_size
        bq, bk, bv = map(
            lambda t: einops.rearrange(t, "b (w n) d -> b w n d", w=windows), (Q, K, V)
        )
        # concatenate one window ahead to make sure first token had w context length
        bk = self.look_around(bk, backward=1, forward=forward, pad_value=pad_value)
        bv = self.look_around(bv, backward=1, forward=forward, pad_value=pad_value)

        # attention scale sqrt(1/dim_head)
        sim = einops.einsum(bq, bk, "b h i e, b h j e -> b h i j") * 256**-0.5

        if self.attn_logit_softcapping is not None:
            sim = sim / self.attn_logit_softcapping
            sim = jax.nn.tanh(sim)
            sim = sim * self.attn_logit_softcapping

        # handle padding
        seq = jnp.arange(T)
        b_t = einops.rearrange(seq, "(w n) -> 1 w n", w=windows, n=self.window_size)
        bq_t = b_t
        bq_k = self.look_around(b_t, backward=1, forward=forward, pad_value=pad_value)

        bq_t = einops.rearrange(bq_t, "... i -> ... i 1")
        bq_k = einops.rearrange(bq_k, "... j -> ... 1 j")
        pad_mask = bq_k == pad_value

        if self.causal:
            causal_mask = bq_t < bq_k
            if self.exact_windowsize:
                causal_mask = causal_mask | (bq_t > (bq_k + self.window_size))
            sim = jnp.where(causal_mask, -9e15, sim)

        # bidirectional case
        if not self.causal and self.exact_windowsize:
            window_mask = ((bq_k - self.window_size) > bq_t) | (
                bq_t > (bq_k + self.window_size)
            )
            sim = jnp.where(window_mask, -9e15, sim)
        # everything has a pad mask
        sim = jnp.where(pad_mask, -9e15, sim)

        if input_mask is not None:
            _, input_mask = self.pad_to_multiple(
                mask, self.window_size, dim=-1, value=False
            )
            input_mask = einops.rearrange(
                input_mask, "... (w n) -> (...) w n", w=windows, n=self.window_size
            )
            mask = self.look_around(mask, backward=1, forward=forward, pad_value=False)
            mask = einops.rearrange(mask, "... j -> ... 1 j")
            sim = jnp.where(input_mask, sim, -9e15)

        attn = jax.nn.softmax(sim, axis=-1)
        attn = attn_dropout(attn)
        out = einops.einsum(attn, bv, "b h i j, b h j e -> b h i e")
        out = einops.rearrange(out, "b w n d -> b (w n) d")
        out = out[:, :orig_seq_len, :]
        out, *_ = einops.unpack(out, packed_shape, "* n d")
        return out


class GemmaSlidingCausalAtt(nn.Module):
    """Only used for testing sliding window.
    Use the CausalRopeLatteMachiattoSliding layer for our Machiatto Model"""

    config: LMTaskConfig
    dtype: jnp.dtype = jnp.float32
    attn_logit_softcapping: float = 50.0

    @nn.compact
    def __call__(self, X: jnp.array, train: bool, **kwargs) -> jnp.array:
        nheads = self.config.nheads
        if self.config.num_key_value_heads:
            num_key_value_heads = self.config.num_key_value_heads
        else:
            num_key_value_heads = self.config.nheads
        if self.config.head_dim:
            head_dim = self.config.head_dim
        else:
            head_dim = self.config.hidden_dim // self.config.nheads

        num_key_value_groups = self.config.nheads // num_key_value_heads

        sliding_att = SlidingWindowAtt(
            window_size=self.config.att_block_len,
            exact_windowsize=True,
            causal=True,
            attn_logit_softcapping=self.attn_logit_softcapping,
        )
        rotary_emb = Gemma2RotaryEmbedding(
            dim=self.config.head_dim,
            max_position_embeddings=self.config.pos_embed_max_len,
            base=10000.0,
        )
        # key, query, value projections for all heads, but in a batch
        q_proj = nn.Dense(
            nheads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="q_proj",
        )

        k_proj = nn.Dense(
            num_key_value_heads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="k_proj",
        )
        v_proj = nn.Dense(
            num_key_value_heads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="v_proj",
        )
        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="out_proj",
        )

        # regularization
        attn_dropout = nn.Dropout(rate=self.config.dropout_att, deterministic=not train)
        B, T, C = (
            X.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)

        # conv = DepthConv1D(
        #     nchannels=self.config.hidden_dim,
        #     out_channels=self.config.hidden_dim,
        #     kernel_size=3,  # self.config.max_seq_len,  # 3,
        #     dtype=self.dtype,
        #     name="latte_conv",
        # )
        # X = conv(X)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        # bias = jnp.tril(jnp.ones(shape=(T, T))).reshape(1, 1, T, T)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = q_proj(X), k_proj(X), v_proj(X)

        k = k.reshape(B, T, num_key_value_heads, -1).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        q = q.reshape(B, T, nheads, -1).transpose(0, 2, 1, 3)  # (B, nh, T, hs)
        v = v.reshape(B, T, num_key_value_heads, -1).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)

        cos, sin = rotary_emb(v, position_ids=None)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)
        k = repeat_kv(k, num_key_value_groups)
        v = repeat_kv(v, num_key_value_groups)

        sliding_att = jax.checkpoint(partial(sliding_att, attn_dropout=attn_dropout))
        y = sliding_att(q, k, v, input_mask=None)
        y = y.transpose(0, 2, 1, 3).reshape(
            B, T, -1
        )  # re-assemble all head outputs side by side

        # output projection
        y = c_proj(y)  # resid_dropout(c_proj(y))
        return y


def accumulate(carry, args):
    nu, alpha, prev_max = carry
    Qs_t, curr_alph, V_t, c_mx = args
    revert_maxi = jnp.exp(-c_mx + prev_max)
    add_maxi = jnp.exp(curr_alph - c_mx)

    alpha = jnp.einsum("BHL,BHL->BHL", alpha, revert_maxi)
    alpha += add_maxi
    nu = jnp.einsum("BHLD,BHL->BHLD", nu, revert_maxi)
    nu += jnp.einsum("BHL,BHD->BHLD", add_maxi, V_t)
    y = jnp.einsum("BHL,BHLD->BHD", Qs_t / alpha, nu)
    return ((nu, alpha, c_mx), y)


class CausalRopeLatteMachiattoSliding(nn.Module):
    config: LMTaskConfig
    attn_logit_softcapping: float = 50.0
    unroll: int = 100
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, X, train=False, **kwargs):

        nheads = self.config.nheads
        if self.config.num_key_value_heads:
            num_key_value_heads = self.config.num_key_value_heads
        else:
            num_key_value_heads = self.config.nheads
        if self.config.head_dim:
            head_dim = self.config.head_dim
        else:
            head_dim = self.config.hidden_dim // self.config.nheads

        num_key_value_groups = self.config.nheads // num_key_value_heads

        rotary_emb = Gemma2RotaryEmbedding(
            dim=head_dim,
            max_position_embeddings=self.config.pos_embed_max_len,
            base=10000.0,
        )

        sliding_att = SlidingWindowAtt(
            window_size=self.config.att_block_len,
            exact_windowsize=True,
            causal=True,
            attn_logit_softcapping=self.attn_logit_softcapping,
        )

        # key, query, value projections for all heads, but in a batch
        q_proj = nn.Dense(
            nheads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="q_proj",
        )

        k_proj = nn.Dense(
            num_key_value_heads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="k_proj",
        )
        v_proj = nn.Dense(
            num_key_value_heads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="v_proj",
        )
        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="out_proj",
        )

        # latte attention
        Wk = nn.Dense(
            self.config.L,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            dtype=self.dtype,
            name="latte_Wk",
        )
        Wq = nn.Dense(
            self.config.L + self.config.nheads,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            dtype=self.dtype,
            name="latte_Wq",
        )

        # latte_conv = DepthConv1D(
        #     nchannels=self.config.hidden_dim,  # self.config.L,  #
        #     out_channels=self.config.hidden_dim,  # self.config.L, #
        #     kernel_size=4,  # self.config.max_seq_len,  # 3,
        #     dtype=self.dtype,
        #     name="latte_conv",
        # )
        # latte_conv = S5Layer(
        #     ssm_size=self.config.L,
        #     hidden_dim=self.config.hidden_dim,
        #     blocks=self.config.nheads,
        #     dtype=self.dtype,
        #     name="S5",
        # )

        lru_in = nn.Dense(
            self.config.L,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            dtype=self.dtype,
            name="latte_lru_in",
        )

        latte_conv = RGLRU(
            width=self.config.L,
            num_heads=self.config.nheads,
            dtype=self.dtype,
            name="latte_rglru",
        )

        # regularization
        attn_dropout = nn.Dropout(rate=self.config.dropout, deterministic=not train)
        Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)
        # batch size, sequence length, embedding dimensionality (n_embd)
        B, T, C = X.shape

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k = q_proj(X), k_proj(X)
        v = v_proj(X)
        k = k.reshape(B, T, num_key_value_heads, -1).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        q = q.reshape(B, T, nheads, -1).transpose(0, 2, 1, 3)  # (B, nh, T, hs)
        v = v.reshape(B, T, num_key_value_heads, -1).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        cos, sin = rotary_emb(v, position_ids=None)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)
        k = repeat_kv(k, num_key_value_groups)
        v = repeat_kv(v, num_key_value_groups)

        # jax.vmap(fun, in_axes=0, out_axes=0,
        # sliding_att = jax.checkpoint(sliding_att)
        p_s_l0 = sliding_att(q, k, v, input_mask=None, attn_dropout=attn_dropout)

        # X = latte_conv(X)
        pos_ids = jnp.repeat(jnp.arange(T)[None], B, axis=0)
        X, _ = latte_conv(lru_in(X), pos_ids, return_cache=False)
        X = FlaxGemmaRMSNorm(
            config=self.config,
            dtype=self.dtype,
            width=self.config.L,
            name="latte_lru_norm",
        )(X)
        # multi head implementation
        q = Wq(X)
        k = Wk(X)
        q = jnp.einsum("BTL->TBL", q).reshape(T, B, self.config.nheads, -1)
        k = jnp.einsum("BTL->TBL", k).reshape(T, B, self.config.nheads, -1)

        # q = jnp.einsum("DL,BTD->TBL", Wq2, X).reshape(T, B, self.config.nheads, -1)
        # k = jnp.einsum("DL,BTD->TBL", Wk, X).reshape(T, B, self.config.nheads, -1)

        # # V = jnp.einsum("DM,BTD->TBM", Wv, X).reshape(T, B, self.config.nheads, -1)
        q = jax.nn.softmax(q, axis=-1)
        q = Q_drop(q)

        p_s_l0 = p_s_l0.reshape(B, self.config.nheads, T, -1)  # BHNLD -> BHTD
        # p(l=0|t) \sum_{s \in[t, t-w]}^t p(s|l,t)v_s
        causal_att = jnp.einsum("TBH,BHTD->BHTD", q[:, :, :, 0], p_s_l0)

        # \sum_{l=1}^Lp(l|t) \sum_{s=0}^t p(s|l,t)v_s
        # v = V
        v = v.transpose(2, 0, 1, 3)  # .reshape(T, B, self.config.nheads, -1)

        # print("Second Shape of input Q,K,V: ", q.shape, k.shape, v.shape)
        latte_att = self.latte_attention4((cos, sin), q[:, :, :, 1:], K=k, V=v)  # BHTD
        v = causal_att + latte_att
        v = v.transpose(0, 2, 1, 3)  # BHTD -> BTHD
        v = v.reshape(B, T, -1)
        return c_proj(v)

    def latte_attention(self, rot_embeds, Qs, K, V):
        """latte_attention in O(TL + LD)
        Args:
            rot_embeds: Tuple[jax.Array, jax.Array]
            Q: jax.Array(T,B,H,L)
            K: jax.Array(T,B,H,L)
            V: jax.Array(T,B,H,D)
        """
        T, B, H, C = V.shape
        L = Qs.shape[-1]
        cos, sin = rot_embeds

        maxi = jax.lax.cummax(K, axis=0)
        # maxi for stability should be trated as a constant - no grad is faster
        maxi = jax.lax.stop_gradient(maxi)

        init_alpha = jnp.zeros(shape=(B, H, L), dtype=self.dtype)
        init_nu = jnp.zeros((B, H, L, C), dtype=self.dtype)
        _, y = jax.lax.scan(
            jax.checkpoint(accumulate),
            unroll=self.unroll,
            init=(
                init_nu,
                init_alpha,
                K[0],
            ),
            xs=[Qs, K, V, maxi],
        )
        # TBHD ->   BHTD
        return y.transpose(1, 2, 0, 3)

    def latte_attention4(self, rot_embeds, Q, K, V):
        """Faster version of latte_attention by applying parallel scanns O(TLD)
        Args:
            rot_embeds: Tuple[jax.Array, jax.Array]
            Q: jax.Array(T,B,H,L)
            K: jax.Array(T,B,H,L)
            V: jax.Array(T,B,H,D)
        """
        T, B, H, C = V.shape
        L = Q.shape[-1]
        cos, sin = rot_embeds

        maxi = jax.lax.cummax(K, axis=0)
        # maxi for stability should be trated as a constant - no grad is faster
        maxi = jax.lax.stop_gradient(maxi)
        # revert maxi
        revert_maxi = jnp.zeros_like(maxi)
        revert_maxi = revert_maxi.at[1:].set(-maxi[1:] + maxi[:-1])
        revert_maxi = jnp.exp(revert_maxi)  # TBHL
        add_maxi = jnp.exp(K - maxi)
        nu = jnp.einsum("TBHL,TBHD->TBHLD", add_maxi, V)

        # print("Test: ", revert_maxi.shape, add_maxi.shape, nu.shape)

        def bin_V(A, B):
            rmA, amA, nuA = A
            rmB, amB, nuB = B
            nu = nuA * rmB[..., None] + nuB
            alpha = amA * rmB + amB
            return (rmA * rmB, alpha, nu)

        _, alpha, y = parallel_scan(bin_V, (revert_maxi, add_maxi, nu))
        y = jnp.einsum("TBHL,TBHLD->TBHD", Q / alpha, y)
        # TBHD ->   BHTD
        return y.transpose(1, 2, 0, 3)


class CausalRopeLatteMachiattoSlidingAblation(nn.Module):
    config: LMTaskConfig
    attn_logit_softcapping: float = 50.0
    unroll: int = 100
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, X, train=False, **kwargs):

        nheads = self.config.nheads
        if self.config.num_key_value_heads:
            num_key_value_heads = self.config.num_key_value_heads
        else:
            num_key_value_heads = self.config.nheads
        if self.config.head_dim:
            head_dim = self.config.head_dim
        else:
            head_dim = self.config.hidden_dim // self.config.nheads

        num_key_value_groups = self.config.nheads // num_key_value_heads

        rotary_emb = Gemma2RotaryEmbedding(
            dim=head_dim,
            max_position_embeddings=self.config.pos_embed_max_len,
            base=10000.0,
        )

        sliding_att = SlidingWindowAtt(
            window_size=self.config.att_block_len,
            exact_windowsize=True,
            causal=True,
            attn_logit_softcapping=self.attn_logit_softcapping,
        )

        # key, query, value projections for all heads, but in a batch
        q_proj = nn.Dense(
            nheads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="q_proj",
        )

        k_proj = nn.Dense(
            num_key_value_heads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="k_proj",
        )
        v_proj = nn.Dense(
            num_key_value_heads * head_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="v_proj",
        )
        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=self.config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            name="out_proj",
        )

        # latte attention
        # Wk = nn.Dense(
        #     self.config.L,
        #     use_bias=False,
        #     kernel_init=jax.nn.initializers.normal(
        #         stddev=self.config.initializer_range
        #     ),
        #     dtype=self.dtype,
        #     name="latte_Wk",
        # )
        # Wq = nn.Dense(
        #     self.config.L + self.config.nheads,
        #     use_bias=False,
        #     kernel_init=jax.nn.initializers.normal(
        #         stddev=self.config.initializer_range
        #     ),
        #     dtype=self.dtype,
        #     name="latte_Wq",
        # )

        # latte_conv = DepthConv1D(
        #     nchannels=self.config.hidden_dim,  # self.config.L,  #
        #     out_channels=self.config.hidden_dim,  # self.config.L, #
        #     kernel_size=4,  # self.config.max_seq_len,  # 3,
        #     dtype=self.dtype,
        #     name="latte_conv",
        # )
        # latte_conv = S5Layer(
        #     ssm_size=self.config.L,
        #     hidden_dim=self.config.hidden_dim,
        #     blocks=self.config.nheads,
        #     dtype=self.dtype,
        #     name="S5",
        # )

        lru_in = nn.Dense(
            self.config.L,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            dtype=self.dtype,
            name="latte_lru_in",
        )

        lru_out = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
            dtype=self.dtype,
            name="latte_lru_out",
        )

        latte_conv = RGLRU(
            width=self.config.L,
            num_heads=self.config.nheads,
            dtype=self.dtype,
            name="latte_rglru",
        )

        # regularization
        attn_dropout = nn.Dropout(rate=self.config.dropout, deterministic=not train)
        Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)
        # batch size, sequence length, embedding dimensionality (n_embd)
        B, T, C = X.shape

        # X = latte_conv(X)
        pos_ids = jnp.repeat(jnp.arange(T)[None], B, axis=0)
        Y, _ = latte_conv(lru_in(X), pos_ids, return_cache=False)
        Y = FlaxGemmaRMSNorm(
            config=self.config,
            dtype=self.dtype,
            width=self.config.L,
            name="latte_lru_norm",
        )(Y)
        Y = lru_out(Y)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k = q_proj(Y), k_proj(Y)
        v = v_proj(X)
        k = k.reshape(B, T, num_key_value_heads, -1).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        q = q.reshape(B, T, nheads, -1).transpose(0, 2, 1, 3)  # (B, nh, T, hs)
        v = v.reshape(B, T, num_key_value_heads, -1).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        cos, sin = rotary_emb(v, position_ids=None)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)
        k = repeat_kv(k, num_key_value_groups)
        v = repeat_kv(v, num_key_value_groups)

        # jax.vmap(fun, in_axes=0, out_axes=0,
        # sliding_att = jax.checkpoint(sliding_att)
        p_s_l0 = sliding_att(q, k, v, input_mask=None, attn_dropout=attn_dropout)

        # # multi head implementation
        # q = Wq(X)
        # k = Wk(X)
        # q = jnp.einsum("BTL->TBL", q).reshape(T, B, self.config.nheads, -1)
        # k = jnp.einsum("BTL->TBL", k).reshape(T, B, self.config.nheads, -1)

        # # q = jnp.einsum("DL,BTD->TBL", Wq2, X).reshape(T, B, self.config.nheads, -1)
        # # k = jnp.einsum("DL,BTD->TBL", Wk, X).reshape(T, B, self.config.nheads, -1)

        # # # V = jnp.einsum("DM,BTD->TBM", Wv, X).reshape(T, B, self.config.nheads, -1)
        # q = jax.nn.softmax(q, axis=-1)
        # q = Q_drop(q)

        # p_s_l0 = p_s_l0.reshape(B, self.config.nheads, T, -1)  # BHNLD -> BHTD
        # # p(l=0|t) \sum_{s \in[t, t-w]}^t p(s|l,t)v_s
        # causal_att = jnp.einsum("TBH,BHTD->BHTD", q[:, :, :, 0], p_s_l0)

        # # \sum_{l=1}^Lp(l|t) \sum_{s=0}^t p(s|l,t)v_s
        # # v = V
        # v = v.transpose(2, 0, 1, 3)  # .reshape(T, B, self.config.nheads, -1)

        # # print("Second Shape of input Q,K,V: ", q.shape, k.shape, v.shape)
        # latte_att = self.latte_attention4((cos, sin), q[:, :, :, 1:], K=k, V=v)  # BHTD
        # v = causal_att + latte_att
        # v = v.transpose(0, 2, 1, 3)  # BHTD -> BTHD
        # v = v.reshape(B, T, -1)
        # return c_proj(v)
        p_s_l0 = p_s_l0.reshape(B, self.config.nheads, T, -1)  # BHNLD -> BHTD
        causal_att = p_s_l0
        v = causal_att
        v = v.transpose(0, 2, 1, 3)  # BHTD -> BTHD
        v = v.reshape(B, T, -1)
        return c_proj(v)

    def latte_attention(self, rot_embeds, Qs, K, V):
        """latte_attention in O(TL + LD)
        Args:
            rot_embeds: Tuple[jax.Array, jax.Array]
            Q: jax.Array(T,B,H,L)
            K: jax.Array(T,B,H,L)
            V: jax.Array(T,B,H,D)
        """
        T, B, H, C = V.shape
        L = Qs.shape[-1]
        cos, sin = rot_embeds

        maxi = jax.lax.cummax(K, axis=0)
        # maxi for stability should be trated as a constant - no grad is faster
        maxi = jax.lax.stop_gradient(maxi)

        init_alpha = jnp.zeros(shape=(B, H, L), dtype=self.dtype)
        init_nu = jnp.zeros((B, H, L, C), dtype=self.dtype)
        _, y = jax.lax.scan(
            jax.checkpoint(accumulate),
            unroll=self.unroll,
            init=(
                init_nu,
                init_alpha,
                K[0],
            ),
            xs=[Qs, K, V, maxi],
        )
        # TBHD ->   BHTD
        return y.transpose(1, 2, 0, 3)

    def latte_attention4(self, rot_embeds, Q, K, V):
        """Faster version of latte_attention by applying parallel scanns O(TLD)
        Args:
            rot_embeds: Tuple[jax.Array, jax.Array]
            Q: jax.Array(T,B,H,L)
            K: jax.Array(T,B,H,L)
            V: jax.Array(T,B,H,D)
        """
        T, B, H, C = V.shape
        L = Q.shape[-1]
        cos, sin = rot_embeds

        maxi = jax.lax.cummax(K, axis=0)
        # maxi for stability should be trated as a constant - no grad is faster
        maxi = jax.lax.stop_gradient(maxi)
        # revert maxi
        revert_maxi = jnp.zeros_like(maxi)
        revert_maxi = revert_maxi.at[1:].set(-maxi[1:] + maxi[:-1])
        revert_maxi = jnp.exp(revert_maxi)  # TBHL
        add_maxi = jnp.exp(K - maxi)
        nu = jnp.einsum("TBHL,TBHD->TBHLD", add_maxi, V)

        # print("Test: ", revert_maxi.shape, add_maxi.shape, nu.shape)

        def bin_V(A, B):
            rmA, amA, nuA = A
            rmB, amB, nuB = B
            nu = nuA * rmB[..., None] + nuB
            alpha = amA * rmB + amB
            return (rmA * rmB, alpha, nu)

        _, alpha, y = parallel_scan(bin_V, (revert_maxi, add_maxi, nu))
        y = jnp.einsum("TBHL,TBHLD->TBHD", Q / alpha, y)
        # TBHD ->   BHTD
        return y.transpose(1, 2, 0, 3)


class GemmaMLP(nn.Module):
    config: LMTaskConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        embed_dim = self.config.hidden_dim
        inner_dim = self.config.intermediate_dim
        kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
        self.activation_fn = partial(
            jax.nn.gelu, approximate=True
        )  # ACT2FN[self.config.hidden_act]
        self.gate_proj = nn.Dense(
            inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init
        )
        self.down_proj = nn.Dense(
            embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init
        )
        self.up_proj = nn.Dense(
            inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init
        )

    def __call__(self, hidden_states):
        up_proj_states = self.up_proj(hidden_states)
        gate_states = jax.nn.gelu(
            self.gate_proj(hidden_states), approximate=True
        )  # self.activation_fn(self.gate_proj(hidden_states))

        hidden_states = self.down_proj(up_proj_states * gate_states)
        return hidden_states


class GemmaDecoderLayer(nn.Module):
    config: LMTaskConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):

        self.self_attn = CausalRopeLatteMachiattoSliding(  # GemmaCausalAtt(  #   # GemmaSlidingCausalAtt(  #
            config=self.config, dtype=self.dtype
        )  # get_decoder_mixer(self.config, self.dtype)
        self.input_layernorm = FlaxGemmaRMSNorm(
            config=self.config, dtype=self.dtype, name="input_layernorm"
        )
        self.post_attention_layernorm = FlaxGemmaRMSNorm(
            config=self.config, dtype=self.dtype, name="post_attention_layernorm"
        )
        self.pre_feedforward_layernorm = FlaxGemmaRMSNorm(
            config=self.config, dtype=self.dtype, name="pre_feedforward_layernorm"
        )
        self.post_feedforward_layernorm = FlaxGemmaRMSNorm(
            config=self.config, dtype=self.dtype, name="post_feedforward_layernorm"
        )
        self.mlp = GemmaMLP(self.config, dtype=self.dtype, name="mlp")

    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        position_ids=None,
        train: bool = False,
    ):
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            train=train,
        )
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.pre_feedforward_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_feedforward_layernorm(hidden_states)
        # residual connection
        hidden_states = residual + hidden_states
        return hidden_states


class GemmaDecoder(nn.Module):
    config: LMTaskConfig
    sharded: bool
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, X, attention_mask, train=False):
        final_norm = FlaxGemmaRMSNorm(config=self.config, dtype=self.dtype, name="norm")
        block_fn = partial(
            GemmaDecoderLayer,
            config=self.config,
            dtype=self.dtype,
            name="residual_block",
        )
        # faster to compile - but a few disadvantages like merging bloks and lossing some default optimisations
        block = block_fn()  # (name="residual_block")
        if self.sharded:
            X, _ = nn.scan(
                lambda module, carry, _: (
                    module(carry, attention_mask),  # train=train
                    None,
                ),
                variable_axes={"params": 0, "intermediates": 0},
                split_rngs={"params": True, "dropout": True},
                length=self.config.nlayers,
            )(block, X, ())
        else:
            X, _ = nn.scan(
                lambda module, carry, _: (
                    module(carry, attention_mask),  # , train=train
                    None,
                ),
                variable_axes={"params": 0, "intermediates": 0},
                split_rngs={"params": True, "dropout": True},
                length=self.config.nlayers,
                metadata_params={
                    "partition_name": None
                },  # We do not need to partition over the layer axis.
            )(block, X, ())
        # for i in range(self.config.nlayers):
        #     X = block_fn(name=f"residual_block_{i}")(
        #         X, attention_mask=attention_mask, train=train
        #     )
        # get logits
        X = final_norm(X)
        return X
