"""
Notation:
    B = batch size, T = sequene length, D = embed/hidden dimension, H = number heads
"""

import warnings
from typing import Any, Dict, Tuple
import jax
from flax import linen as nn
from jax import numpy as jnp
import math
from flax.linen.dtypes import promote_dtype
from latte_trans.config import Config
from latte_trans.models.modules.layers import S5Layer, DepthConv1D, RopeEmbeds, XPos
from recurrentgemma.jax.layers import RGLRU

PRECISION = jax.lax.Precision.DEFAULT  # "HIGEST" #
parallel_scan = jax.lax.associative_scan


def attention_product(q, k, v, mask=None) -> Tuple[jnp.array, jnp.array]:
    """
    q,k,v: jnp.array(BHTD) - Query, Key, Value
    mask: causal or mask for bidirectional
    """
    d_d = q.shape[-1]
    att_logits = jnp.matmul(q, jnp.swapaxes(k, -2, -1))
    attn_logits = att_logits / math.sqrt(d_d)
    if mask is not None:
        attn_logits = jnp.where(mask == 0, -9e15, attn_logits)
    attention = nn.softmax(attn_logits, axis=-1)
    values = jnp.matmul(attention, v)
    return values, attention


class CausalScanLatte(nn.Module):
    """
    Numerically stable causal latent attention.
    """

    config: Config
    unroll: int = 100
    dtype: jnp.dtype = jnp.float32

    @staticmethod
    def accumulate(carry, args):
        nu, alpha, prev_max = carry
        Qs_t, curr_alph, V_t, c_mx = args
        revert_maxi = jnp.exp(-c_mx + prev_max)
        add_maxi = jnp.exp(curr_alph - c_mx)
        alpha = jnp.einsum("BHL,BHL->BHL", alpha, revert_maxi, precision=PRECISION)
        alpha += add_maxi
        nu = jnp.einsum("BHLD,BHL->BHLD", nu, revert_maxi, precision=PRECISION)
        nu += jnp.einsum("BHL,BHD->BHLD", add_maxi, V_t, precision=PRECISION)

        y = jnp.einsum("BHL,BHLD->BHD", Qs_t / alpha, nu, precision=PRECISION)
        return ((nu, alpha, c_mx), y)

    @staticmethod
    def accumulate3(carry, args):
        """Optimized version of accumulate where
        we compute as much as possible outside the sequential operation"""
        nu = carry
        Qs_t, V_t, revert_maxi, add_maxi = args
        nu = jnp.einsum("BHLD,BHL->BHLD", nu, revert_maxi)
        nu += jnp.einsum("BHL,BHD->BHLD", add_maxi, V_t)
        y = jnp.einsum("BHL,BHLD->BHD", Qs_t, nu)
        return nu, y

    def inference(self, X, cache, Wq, Wk, Wv, o_proj, Q_drop):
        """
        Recurrent version of forward, expects one token at a time
        Args:
            X: jnp.array(BD)
            cache: Dict[str, Any] - previous recursive state
        """
        B, D = X.shape[0], X.shape[-1] // self.config.nheads
        H, L = self.config.nheads, self.config.L // self.config.nheads
        if cache is None:
            # initialise hidden state
            cache = dict()
            cache["alpha"] = jnp.zeros(shape=(B, H, L), dtype=jnp.float32)
            cache["nu"] = jnp.zeros((B, H, L, D), dtype=jnp.float32)
            cache["prev_max"] = None

        Q = jnp.einsum("DL,BD->BL", Wq, X, precision=PRECISION).reshape(B, H, -1)
        Qs_t = jax.nn.softmax(Q, axis=-1)
        Qs_t = Q_drop(Qs_t)
        # print("QInf: ", Qs_t[0, 0, :10])
        V_t = jnp.einsum("DM,BD->BM", Wv, X, precision=PRECISION).reshape(
            B, H, -1
        )  # for nu
        K_t = jnp.einsum("DL,BTD->TBL", Wk, X[None, ...], precision=PRECISION).reshape(
            B, H, -1
        )  # for alpha

        alpha = cache["alpha"]
        nu = cache["nu"]
        prev_max = cache["prev_max"]
        # 0, -inf does not work, first val is ok
        if prev_max is None:
            prev_max = K_t
        c_max = jnp.maximum(prev_max, K_t)
        revert_maxi = jnp.exp(-c_max + prev_max)
        add_maxi = jnp.exp(K_t - c_max)

        alpha = jnp.einsum("BHL,BHL->BHL", alpha, revert_maxi, precision=PRECISION)
        alpha += add_maxi
        nu = jnp.einsum("BHLD,BHL->BHLD", nu, revert_maxi, precision=PRECISION)
        nu += jnp.einsum("BHL,BHD->BHLD", add_maxi, V_t, precision=PRECISION)

        y = jnp.einsum("BHL,BHLD->BHD", Qs_t / alpha, nu, precision=PRECISION)
        y = y.reshape(B, -1)
        y = y @ o_proj
        # reset last hidden state
        cache["alpha"] = alpha
        cache["nu"] = nu
        cache["prev_max"] = c_max

        return cache, y

    def mix_sequence(self, Q, K, V, Q_drop):
        T, B, H, C = V.shape
        L = Q.shape[-1]

        maxi = jax.lax.cummax(K, axis=0)
        # maxi for stability should be trated as a constant - no grad is faster
        maxi = jax.lax.stop_gradient(maxi)

        init_alpha = jnp.zeros(shape=(B, H, L), dtype=self.dtype)
        init_nu = jnp.zeros((B, H, L, C), dtype=self.dtype)
        Qs = jax.nn.softmax(Q, axis=-1)
        Qs = Q_drop(Qs)

        _, y = jax.lax.scan(
            self.accumulate,
            unroll=self.unroll,
            init=(
                init_nu,
                init_alpha,
                K[0],
            ),
            xs=[Qs, K, V, maxi],
        )
        # TBHD -> BTHD
        y = y.transpose(1, 0, 2, 3)
        y = y.reshape(B, T, -1)
        return y

    def mix_sequence4(self, Q, K, V, Q_drop):
        """Fastest but with O(TLD)
        Args:
            Q: jax.Array(T,B,H,L)
            K: jax.Array(T,B,H,L)
            V: jax.Array(T,B,H,D)
        """
        T, B, H, C = V.shape
        L = Q.shape[-1]
        # calc R^{-s}x_s
        # V = V_drop(V)
        Qs = jax.nn.softmax(Q, axis=-1)
        Qs = Q_drop(Qs)

        maxi = jax.lax.cummax(K, axis=0)
        # maxi for stability should be trated as a constant - no grad is faster
        maxi = jax.lax.stop_gradient(maxi)
        # revert maxi
        revert_maxi = jnp.zeros_like(maxi)
        revert_maxi = revert_maxi.at[1:].set(-maxi[1:] + maxi[:-1])
        revert_maxi = jnp.exp(revert_maxi)  # TBHL
        add_maxi = jnp.exp(K - maxi)
        nu = jnp.einsum("TBHL,TBHD->TBHLD", add_maxi, V)

        def bin_V(A, B):
            rmA, amA, nuA = A
            rmB, amB, nuB = B
            nu = nuA * rmB[..., None] + nuB
            alpha = amA * rmB + amB
            return (rmA * rmB, alpha, nu)

        _, alpha, y = parallel_scan(bin_V, (revert_maxi, add_maxi, nu))
        y = jnp.einsum("TBHL,TBHLD->BTHD", Qs / alpha, y)
        y = y.reshape(B, T, -1)
        return y

    @nn.compact
    def __call__(
        self,
        X: jnp.array,
        train: bool = False,
        cache: Dict[str, Any] = None,
        do_inference: bool = False,
        **kwargs,
    ) -> jnp.array:
        """
        B: batch size H: nr heads, T: seq_len, D: hidden_dim. L: latent dimension
        Args:
            X: jnp.array(BTD)
            train: bool - Constant used for dropout
        Returns:
            y: jnp.array(BTD) - transformed output sequence
        """
        Wk = self.param(
            "Wk",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        Wq = self.param(
            "Wq",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        Wv = self.param(
            "Wv",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.hidden_dim),
        )
        o_proj = self.param(
            "o_proj",
            jax.nn.initializers.normal(
                stddev=self.config.initializer_range
                / math.sqrt(2 * self.config.nlayers)
            ),
            (self.config.hidden_dim, self.config.hidden_dim),
        )
        Wk, Wq, Wv, o_proj = promote_dtype(Wk, Wq, Wv, o_proj, dtype=self.dtype)
        Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)
        # resid_drop = nn.Dropout(self.config.dropout, deterministic=not train)
        if do_inference:
            return self.inference(
                X,
                cache=cache,
                Wq=Wq,
                Wk=Wk,
                Wv=Wv,
                o_proj=o_proj,
                Q_drop=Q_drop,
                # resid_drop=resid_drop,
            )
        B, T, _ = X.shape
        H, L = self.config.nheads, self.config.L // self.config.nheads
        # multi head implementation
        V = jnp.einsum("DM,BTD->TBM", Wv, X).reshape(T, B, H, -1)
        Q = jnp.einsum("DL,BTD->TBL", Wq, X).reshape(T, B, H, -1)
        K = jnp.einsum("DL,BTD->TBL", Wk, X).reshape(T, B, H, -1)
        # y = self.mix_sequence(Q=Q, K=K, V=V, Q_drop=Q_drop)
        y = self.mix_sequence4(Q=Q, K=K, V=V, Q_drop=Q_drop)
        return y @ o_proj


class BidLatte(nn.Module):
    """
    Bidirectional version in which we sum to "T" instead of "t".
    No sequential implementation required.
    """

    config: Config
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(
        self,
        X: jnp.array,
        attention_mask: jnp.array = None,
        train: bool = False,
        **kwargs,
    ) -> jnp.array:
        """
        B: batch size H: nr heads, T: seq_len, D: hidden_dim. L: latent dimension
        Args:
            X: jnp.array(BTD)
            attention_mask: jnp.array(BTD) - attnention used to ignore pads
                Only used in bidirectional since we sum up to T, and pad needed for batching
            train: bool - Just to respect the interface of trainer.
        Returns:
            y: jnp.array(BTD) - transformed output sequence
        """
        Wk = self.param(
            "Wk",
            jax.nn.initializers.lecun_normal(),
            (self.config.hidden_dim, self.config.L),
        )
        Wq = self.param(
            "Wq",
            jax.nn.initializers.lecun_normal(),
            (self.config.hidden_dim, self.config.L),
        )
        Wv = self.param(
            "Wv",
            jax.nn.initializers.lecun_normal(),
            (self.config.hidden_dim, self.config.hidden_dim),
        )
        o_proj = self.param(
            "o_proj",
            jax.nn.initializers.lecun_normal(),
            (self.config.hidden_dim, self.config.hidden_dim),
        )
        Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)
        resid_drop = nn.Dropout(self.config.dropout, deterministic=not train)

        B, T, D = X.shape
        L, H = self.config.L, self.config.nheads
        # multi head implementation
        V = jnp.einsum("DM,BTD->TBM", Wv, X, precision=PRECISION).reshape(T, B, H, -1)
        Q = jnp.einsum("DL,BTD->TBL", Wq, X, precision=PRECISION).reshape(T, B, H, -1)
        K = jnp.einsum("DL,BTD->LBT", Wk, X, precision=PRECISION)

        K = jnp.where(attention_mask, K, -9e15).transpose(2, 1, 0).reshape(T, B, H, -1)
        Qs = jax.nn.softmax(Q, axis=-1)  # T B H L
        Qs = Q_drop(Qs)
        maxi = jnp.max(K, axis=0, keepdims=True)
        K = jnp.exp(K - maxi)

        Kv = jnp.einsum("TBHL,TBHD->BHLD", K, V, precision=PRECISION)
        # normalize
        K = K.sum(axis=0)  # BLH
        Kv = jnp.einsum("BHL,BHLD->BHLD", 1 / K, Kv, precision=PRECISION)
        y = jnp.einsum("TBHL,BHLD->BTHD", Qs, Kv, precision=PRECISION)
        y = y.reshape(B, T, -1)
        return resid_drop(y @ o_proj)


class BidLatteRopeConv(nn.Module):
    """
    Bidirectional version in which we sum to "T" instead of "t".
    No sequential implementation required.
    """

    config: Config
    dtype: jnp.dtype

    def latte_attention(self, rot_embeds, Qs, K, V, attention_mask):
        """
        Qs: TBL
        K: LBT
        V: DBT
        """
        _, B, T = K.shape
        if attention_mask is not None:
            K = jnp.where(attention_mask, K, -9e15)
        K = K.transpose(2, 1, 0).reshape(T, B, self.config.nheads, -1)
        V = V.transpose(2, 1, 0).reshape(T, B, self.config.nheads, -1)

        # calc R^{-s}x_s
        if isinstance(rot_embeds, XPos):
            # T, B, self.config.nheads, -1 -> BHTD
            V = rot_embeds(V.transpose(1, 2, 0, 3), offset=0, downscale=True).transpose(
                2, 0, 1, 3
            )
        else:
            V = rot_embeds.apply_vapor(mat=V, neg=True)

        maxi = jnp.max(K, axis=0, keepdims=True)
        K = jnp.exp(K - maxi)

        Kv = jnp.einsum("TBHL,TBHD->BHLD", K, V)
        # normalize
        K = K.sum(axis=0)  # BLH
        Kv = jnp.einsum("BHL,BHLD->BHLD", 1 / K, Kv)
        y = jnp.einsum("TBHL,BHLD->TBHD", Qs, Kv)

        # calc R^t \sum_l ...
        if isinstance(rot_embeds, XPos):
            y = y.transpose(1, 2, 0, 3)
            y = rot_embeds(y, offset=0, downscale=False)
            return y
        else:
            # TBHD -> BHTD
            y = rot_embeds.apply_vapor(mat=y, neg=False)
        # TBHD ->   BHTD
        return y.transpose(1, 2, 0, 3)

    @nn.compact
    def __call__(
        self,
        X: jnp.array,
        attention_mask: jnp.array = None,
        train: bool = False,
        **kwargs,
    ) -> jnp.array:
        """
        B: batch size H: nr heads, T: seq_len, D: hidden_dim. L: latent dimension
        Args:
            X: jnp.array(BTD)
            attention_mask: jnp.array(BTD) - attnention used to ignore pads
                Only used in bidirectional since we sum up to T, and pad needed for batching
            train: bool - Just to respect the interface of trainer.
        Returns:
            y: jnp.array(BTD) - transformed output sequence
        """
        if self.config.embed_type == "rope":
            rot_embeds = RopeEmbeds(
                n_pos=self.config.pos_embed_max_len,
                d_model=self.config.hidden_dim // self.config.nheads,
            )
        elif self.config.embed_type == "xpos":
            rot_embeds = XPos(
                head_dim=self.config.hidden_dim // self.config.nheads,
                scale_base=self.config.max_seq_len,
            )

        # latte attention
        Wk = self.param(
            "Wk",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        Wq = self.param(
            "Wq",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        Wv = self.param(
            "Wv",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.hidden_dim),
        )
        Wk, Wq, Wv = promote_dtype(Wk, Wq, Wv, dtype=self.dtype)
        conv = DepthConv1D(
            nchannels=self.config.hidden_dim,
            out_channels=self.config.hidden_dim,
            kernel_size=9,
            dtype=self.dtype,
        )

        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )
        Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)

        B, T, C = (
            X.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)

        Y = conv(X)
        # multi head implementation
        Q = jnp.einsum("DL,BTD->TBL", Wq, Y).reshape(T, B, self.config.nheads, -1)
        K = jnp.einsum("DL,BTD->LBT", Wk, Y)
        V = jnp.einsum("DM,BTD->MBT", Wv, X)
        Qs = jax.nn.softmax(Q, axis=-1)
        Qs = Q_drop(Qs)

        # \sum_{l=1}^Lp(l|t) \sum_{s=0}^T p(s|l)v_s
        y = self.latte_attention(
            rot_embeds, Qs, K=K, V=V, attention_mask=attention_mask
        )  # BHTD
        y = y.transpose(0, 2, 1, 3)  # BHTD -> BTHD
        y = y.reshape(B, T, -1)
        return c_proj(y)


class BidLatteConv(nn.Module):
    """
    Bidirectional version in which we sum to "T" instead of "t".
    No sequential implementation required.
    """

    config: Config
    dtype: jnp.dtype

    def latte_attention(self, Qs, K, V, attention_mask):
        """
        Qs: TBL
        K: LBT
        V: DBT
        """
        _, B, T = K.shape
        if attention_mask is not None:
            K = jnp.where(attention_mask, K, -9e15)
        K = K.transpose(2, 1, 0).reshape(T, B, self.config.nheads, -1)
        V = V.transpose(2, 1, 0).reshape(T, B, self.config.nheads, -1)

        maxi = jnp.max(K, axis=0, keepdims=True)
        K = jnp.exp(K - maxi)

        Kv = jnp.einsum("TBHL,TBHD->BHLD", K, V)
        # normalize
        K = K.sum(axis=0)  # BLH
        Kv = jnp.einsum("BHL,BHLD->BHLD", 1 / K, Kv)
        y = jnp.einsum("TBHL,BHLD->BHTD", Qs, Kv)
        return y

    @nn.compact
    def __call__(
        self,
        X: jnp.array,
        attention_mask: jnp.array = None,
        train: bool = False,
        **kwargs,
    ) -> jnp.array:
        """
        B: batch size H: nr heads, T: seq_len, D: hidden_dim. L: latent dimension
        Args:
            X: jnp.array(BTD)
            attention_mask: jnp.array(BTD) - attnention used to ignore pads
                Only used in bidirectional since we sum up to T, and pad needed for batching
            train: bool - Just to respect the interface of trainer.
        Returns:
            y: jnp.array(BTD) - transformed output sequence
        """
        B, T, _ = X.shape
        # latte attention
        Wk = self.param(
            "Wk",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            # (self.config.hidden_dim, self.config.L),
            (self.config.L, self.config.L),
        )
        Wq = self.param(
            "Wq",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            # (self.config.hidden_dim, self.config.L),
            (self.config.L, self.config.L),
        )
        Wv = self.param(
            "Wv",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.hidden_dim),
        )
        Wk, Wq, Wv = promote_dtype(Wk, Wq, Wv, dtype=self.dtype)
        # conv = DepthConv1D(
        #     nchannels=self.config.hidden_dim,
        #     out_channels=self.config.hidden_dim,
        #     kernel_size=self.config.max_seq_len,  # 9,
        #     dtype=self.dtype,
        # )

        # conv = S5Layer(
        #     ssm_size=self.config.state_dim,
        #     hidden_dim=self.config.hidden_dim,
        #     blocks=16,
        #     dtype=self.dtype,
        # )

        lru_in = self.param(
            "lru_in",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        conv = RGLRU(
            width=self.config.L,
            num_heads=self.config.nheads,
            dtype=self.dtype,
        )

        # output projection
        c_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )
        Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)

        B, T, C = (
            X.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)

        # Y = conv(X)
        pos_ids = jnp.repeat(jnp.arange(T)[None], B, axis=0)
        Y, _ = conv(jnp.einsum("DL,BTD->BTL", lru_in, X), pos_ids, return_cache=False)

        # multi head implementation
        Q = jnp.einsum("DL,BTD->TBL", Wq, Y).reshape(T, B, self.config.nheads, -1)
        K = jnp.einsum("DL,BTD->LBT", Wk, Y)
        V = jnp.einsum("DM,BTD->MBT", Wv, X)
        Qs = jax.nn.softmax(Q, axis=-1)
        Qs = Q_drop(Qs)

        # \sum_{l=1}^Lp(l|t) \sum_{s=0}^T p(s|l)v_s
        y = self.latte_attention(Qs, K=K, V=V, attention_mask=attention_mask)  # BHTD
        y = y.transpose(0, 2, 1, 3)  # BHTD -> BTHD
        y = y.reshape(B, T, -1)
        return c_proj(y)
