"""
The new architexure which is more coplex then the vanilla tranformer. Includes convolutions and gating before the mixing block.
"""

from typing import Literal, Type
import functools
import math
import jax
from jax import numpy as jnp
from flax import linen as nn
from flax.linen.dtypes import promote_dtype

from .layers import (
    RMSNorm,
    GatedMLP,
    Conv1D,
    PositionalEncoding,
    RopeEmbeds,
    create_rope,
    Embedder,
)
from .xPos import XPos
from .latte_rel_em import (
    RotCausalScanLatte,
    RotConvQKCausalScanLatte,
    RotConvAllCausalScanLatte,
)
from .latte_mach import CausalRopeLatteMachiattoChunk
from .latte import CausalScanLatte
from .attention import (
    CausalRope,
    CausalSelfAttention,
    ScanCausalSelfAttention,
    BidirectionalAttention,
)
from latte_trans.models.modules.layers import SlidingWindowAtt
from latte_trans.config import Config

parallel_scan = jax.lax.associative_scan


def mixing_layer_factory(config: Config, dtype: jnp.dtype):
    match config.attention_type:
        case "latte_mach_sliding_causal":
            if config.embed_type in ["xpos", "rope"]:
                return CausalRopeLatteMachiattoSliding(
                    config=config,
                    unroll=config.unroll,
                    dtype=dtype,
                )
            else:
                raise Exception("Not yet implemented")

        case _ as unreachable:
            raise IOError("Type of attention not supported")


class CausalRopeLatteMachiattoSliding(nn.Module):
    config: Config
    unroll: int = 100
    dtype: jnp.dtype = jnp.float32

    def latte_attention4(self, rot_embeds, Q, K, V):
        """Faster version of latte_attention by applying parallel scanns to normalisation as well
        Still in O(TL + LD)
        Args:
            Q: jax.Array(T,B,H,L)
            K: jax.Array(T,B,H,L)
            V: jax.Array(T,B,H,D)
        """
        T, B, H, C = V.shape
        L = Q.shape[-1]
        # calc R^{-s}x_s
        if isinstance(rot_embeds, XPos):
            # T, B, self.config.nheads, -1 -> BHTD
            V = rot_embeds(V.transpose(1, 2, 0, 3), offset=0, downscale=True).transpose(
                2, 0, 1, 3
            )
        else:
            V = rot_embeds.apply_vapor(mat=V, neg=True)
        # V = V_drop(V)
        Qs = jax.nn.softmax(Q, axis=-1)

        maxi = jax.lax.cummax(K, axis=0)
        # maxi for stability should be trated as a constant - no grad is faster
        maxi = jax.lax.stop_gradient(maxi)
        # revert maxi
        revert_maxi = jnp.zeros_like(maxi)
        revert_maxi = revert_maxi.at[1:].set(-maxi[1:] + maxi[:-1])
        revert_maxi = jnp.exp(revert_maxi)  # TBHL
        add_maxi = jnp.exp(K - maxi)
        nu = jnp.einsum("TBHL,TBHD->TBHLD", add_maxi, V)

        def bin_V(A, B):
            rmA, amA, nuA = A
            rmB, amB, nuB = B
            nu = nuA * rmB[..., None] + nuB
            alpha = amA * rmB + amB
            return (rmA * rmB, alpha, nu)

        _, alpha, y = parallel_scan(bin_V, (revert_maxi, add_maxi, nu))
        y = jnp.einsum("TBHL,TBHLD->TBHD", Qs / alpha, y)
        # calc R^t \sum_l ...
        if isinstance(rot_embeds, XPos):
            y = y.transpose(1, 2, 0, 3)
            y = rot_embeds(y, offset=0, downscale=False)
            return y
        else:
            # TBHD -> BHTD
            y = rot_embeds.apply_vapor(mat=y, neg=False)
        # TBHD ->   BHTD
        return y.transpose(1, 2, 0, 3)

    @staticmethod
    def scale_offset(x):
        gamma = jnp.var(x.shape[-1:])
        beta = jnp.var(x.shape[-1:])
        return x * gamma + beta

    @nn.compact
    def __call__(self, X, train=False):
        if self.config.embed_type == "rope":
            rot_embeds = RopeEmbeds(
                n_pos=self.config.pos_embed_max_len,
                d_model=self.config.hidden_dim // self.config.nheads,
            )
        elif self.config.embed_type == "xpos":
            rot_embeds = XPos(
                head_dim=self.config.hidden_dim // self.config.nheads,
                scale_base=self.config.max_seq_len,
            )
        # self attention
        v_attn = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )
        qk_attn = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )

        # latte attention
        Wk = self.param(
            "Wk",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L),
        )
        Wq = self.param(
            "Wq",
            jax.nn.initializers.normal(stddev=self.config.initializer_range),
            (self.config.hidden_dim, self.config.L + self.config.nheads),
        )
        Wk, Wq = promote_dtype(Wk, Wq, dtype=self.dtype)
        conv = Conv1D(
            nchannels=self.config.hidden_dim,
            out_channels=self.config.hidden_dim,
            kernel_size=3,
            dtype=self.dtype,
        )

        # regularization
        attn_dropout = nn.Dropout(rate=self.config.dropout, deterministic=not train)
        Q_drop = nn.Dropout(self.config.dropout_att, deterministic=not train)
        B, T, C = (
            X.shape
        )  # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q = self.scale_offset(qk_attn(X))
        k = q
        v = v_attn(X)
        att_dim = C // self.config.nheads
        k = k.reshape(B, T, self.config.nheads, att_dim).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        q = q.reshape(B, T, self.config.nheads, att_dim).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)
        v = v.reshape(B, T, self.config.nheads, att_dim).transpose(
            0, 2, 1, 3
        )  # (B, nh, T, hs)

        sliding_attention = SlidingWindowAtt(
            window_size=self.config.att_block_len, exact_windowsize=True, causal=True
        )
        Y = conv(X)
        # multi head implementation
        Q = jnp.einsum("DL,BTD->TBL", Wq, Y).reshape(T, B, self.config.nheads, -1)
        Q = self.scale_offset(Q)
        # remove the extra latent l = 0 used for combining attentions
        K = Q[:, :, :, 1:]
        # K = jnp.einsum("DL,BTD->TBL", Wk, Y).reshape(T, B, self.config.nheads, -1)
        Qs = jax.nn.softmax(Q, axis=-1)
        Qs = Q_drop(Qs)

        # jax.vmap(fun, in_axes=0, out_axes=0,
        p_s_l0 = sliding_attention(
            q,
            k,
            v,
            input_mask=None,
            attn_dropout=attn_dropout,
            rot_embeds=rot_embeds,
        )
        p_s_l0 = p_s_l0.reshape(B, self.config.nheads, T, att_dim)  # BHNLD -> BHTD
        # p(l=0|t) \sum_{s \in[t, t-w]}^t p(s|l,t)v_s
        causal_att = jnp.einsum("TBH,BHTD->BHTD", Q[:, :, :, 0], p_s_l0)

        # \sum_{l=1}^Lp(l|t) \sum_{s=0}^t p(s|l,t)v_s # T,B,H,D
        latte_att = self.latte_attention4(
            rot_embeds, Qs[:, :, :, 1:], K=K, V=v.transpose(2, 0, 1, 3)
        )  # BHTD
        y = causal_att + latte_att
        y = y.transpose(0, 2, 1, 3)  # BHTD -> BTHD
        y = y.reshape(B, T, -1)
        return y


class TransformerBlock(nn.Module):
    config: Config
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, x, train: bool = False):
        pre_norm = RMSNorm(width=self.config.hidden_dim, dtype=self.dtype)
        mix_layer = CausalRopeLatteMachiattoSliding(
            config=self.config,
            unroll=self.config.unroll,
            dtype=self.dtype,
        )  # mixing_layer_factory(self.config, self.dtype)
        # mlp_pre_norm = RMSNorm(width=self.config.hidden_dim, dtype=self.dtype)

        gate_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )

        o_proj = nn.Dense(
            self.config.hidden_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )

        residual = x
        x = pre_norm(x)
        x = mix_layer(x, train=train)
        u = gate_proj(x)
        x = u * x
        return o_proj(x) + residual


class QualDecoder(nn.Module):
    """
    Servers as EncoderOnly or as a Decoder only
    depending on the attention_type
    """

    vocab_size: int
    config: Config
    dtype: jnp.dtype = jnp.float32
    sharded: bool = False

    @nn.compact
    def __call__(
        self,
        X: jnp.array,
        train: bool = False,
        **kwargs,
    ) -> jnp.array:
        """
        Args:
            X: jnp.array(BTD), B = Batch size, T = sequence length, D = embed dimension
            train: bool - used for dropout
        Returns:
            out: jnp.array(BTD) - transformed output sequence
        """
        # drop_embed = nn.Dropout(rate=self.dropout, deterministic=not train)
        embedder = nn.Embed(
            num_embeddings=self.vocab_size,
            features=self.config.hidden_dim,
            dtype=self.dtype,
            embedding_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )
        # embedder = Embedder(
        #     vocab_size=self.vocab_size,
        #     embed_dim=self.hidden_dim,
        #     scale_by_sqrt_dim=False,#self.config.embeddings_scale_by_sqrt_dim,
        #     dtype=self.dtype,
        # )

        pos_embeds = None
        # relative
        if self.config.embed_type == "absolute":  # absolute
            pos_embeds = PositionalEncoding(
                d_model=self.config.hidden_dim,
                max_len=self.config.pos_embed_max_len,
                dtype=self.dtype,
            )

        block_fn = functools.partial(
            TransformerBlock,  # ResidualBlock, #
            config=self.config,
            dtype=self.dtype,
        )

        final_norm = RMSNorm(width=self.config.hidden_dim, dtype=self.dtype)

        # absolute
        if not pos_embeds is None:
            X = pos_embeds(embedder(X))  # .encode(X))
            # X = drop_embed(X)
        else:  # relative or nope
            X = embedder(X)  # .encode(X)
            # X = drop_embed(X)

        # for l in enc_layers:
        #     X = l(X, train=train, **kwargs)
        block = block_fn(name="residual_block")
        if self.sharded:
            X, _ = nn.scan(
                lambda module, carry, _: (module(carry, train=train), None),
                variable_axes={"params": 0, "intermediates": 0},
                split_rngs={"params": True, "dropout": True},
                length=self.config.nlayers,
            )(block, X, ())
        else:
            X, _ = nn.scan(
                lambda module, carry, _: (module(carry, train=train), None),
                variable_axes={"params": 0, "intermediates": 0},
                split_rngs={"params": True, "dropout": True},
                length=self.config.nlayers,
                metadata_params={
                    "partition_name": None
                },  # We do not need to partition over the layer axis.
            )(block, X, ())
        # get logits
        X = final_norm(X)

        return X
