"""
The new architexure which is more coplex then the vanilla tranformer. Includes convolutions and gating before the mixing block.
"""

from typing import Literal, Type
import functools
import math
import jax
from jax import numpy as jnp
from flax import linen as nn
from flax.linen.dtypes import promote_dtype

from ..layers import (
    RMSNorm,
    GatedMLP,
    Conv1D,
    PositionalEncoding,
    RopeEmbeds,
    create_rope,
    Embedder,
)
from ..xPos import XPos


from .attention import (
    RopeCausal,
    CausalSelfAttention,
    BidirectionalAttention,
)
from latte_trans.config import Config, ATT_TYPE


class TransformerBlock(nn.Module):
    config: Config
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, x, train: bool = False):
        pre_norm = RMSNorm(width=self.config.hidden_dim, dtype=self.dtype)
        mix_layer = mixing_layer_factory(self.config, self.dtype)
        mlp_pre_norm = RMSNorm(width=self.config.hidden_dim, dtype=self.dtype)
        mlp = GatedMLP(
            hidden_dim=self.config.hidden_dim,
            initializer_range=self.config.initializer_range,
            dtype=self.dtype,
        )

        residual = x
        x = pre_norm(x)
        x = mix_layer(x, train=train)

        # jax.debug.print("Mix block nan: {x}", x=jnp.isnan(x).any())
        residual = x + residual
        x = mlp_pre_norm(residual)
        x = mlp(x)
        return x + residual


class DecoderSota(nn.Module):
    """
    Servers as EncoderOnly or as a Decoder only
    depending on the attention_type
    """

    vocab_size: int
    config: Config
    dtype: jnp.dtype = jnp.float32
    sharded: bool = False

    @nn.compact
    def __call__(
        self,
        X: jnp.array,
        train: bool = False,
        **kwargs,
    ) -> jnp.array:
        """
        Args:
            X: jnp.array(BTD), B = Batch size, T = sequence length, D = embed dimension
            train: bool - used for dropout
        Returns:
            out: jnp.array(BTD) - transformed output sequence
        """
        # drop_embed = nn.Dropout(rate=self.dropout, deterministic=not train)
        embedder = nn.Embed(
            num_embeddings=self.vocab_size,
            features=self.config.hidden_dim,
            dtype=self.dtype,
            embedding_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )
        # embedder = Embedder(
        #     vocab_size=self.vocab_size,
        #     embed_dim=self.hidden_dim,
        #     scale_by_sqrt_dim=False,#self.config.embeddings_scale_by_sqrt_dim,
        #     dtype=self.dtype,
        # )

        pos_embeds = None
        # relative
        if self.config.embed_type == "absolute":  # absolute
            pos_embeds = PositionalEncoding(
                d_model=self.config.hidden_dim,
                max_len=self.config.pos_embed_max_len,
                dtype=self.dtype,
            )

        block_fn = functools.partial(
            TransformerBlock,  # ResidualBlock, #
            config=self.config,
            dtype=self.dtype,
        )

        final_norm = RMSNorm(width=self.config.hidden_dim, dtype=self.dtype)

        # absolute
        if not pos_embeds is None:
            X = pos_embeds(embedder(X))  # .encode(X))
            # X = drop_embed(X)
        else:  # relative or nope
            X = embedder(X)  # .encode(X)
            # X = drop_embed(X)

        # for l in enc_layers:
        #     X = l(X, train=train, **kwargs)
        # block_fn = nn.remat(block_fn, static_argnums=(1,))
        block = block_fn(name="residual_block")
        if self.sharded:
            X, _ = nn.scan(
                lambda module, carry, _: (module(carry, train=train), None),
                variable_axes={"params": 0, "intermediates": 0},
                split_rngs={"params": True, "dropout": True},
                length=self.config.nlayers,
            )(block, X, ())
        else:
            X, _ = nn.scan(
                lambda module, carry, _: (module(carry, train=train), None),
                variable_axes={"params": 0, "intermediates": 0},
                split_rngs={"params": True, "dropout": True},
                length=self.config.nlayers,
                metadata_params={
                    "partition_name": None
                },  # We do not need to partition over the layer axis.
            )(block, X, ())
        # get logits
        X = final_norm(X)

        return X
