"""
The new architexure which is more coplex then the vanilla tranformer. Includes convolutions and gating before the mixing block.
"""

from typing import Literal, Type
import functools
import math
import jax
from jax import numpy as jnp
from flax import linen as nn
from flax.linen.dtypes import promote_dtype

from .layers import (
    RMSNorm,
    GatedMLP,
    Conv1D,
    PositionalEncoding,
    RopeEmbeds,
    create_rope,
    Embedder,
)
from .xPos import XPos
from .latte_rel_em import (
    RotCausalScanLatte,
    RotConvQKCausalScanLatte,
    RotConvAllCausalScanLatte,
)
from .latte_mach import CausalRopeLatteMachiattoChunk, CausalRopeLatteMachiattoSliding
from .latte import CausalScanLatte
from .attention import (
    CausalRope,
    CausalSelfAttention,
    ScanCausalSelfAttention,
    BidirectionalAttention,
)
from latte_trans.config import Config, ATT_TYPE


def mixing_layer_factory(config: Config, dtype: jnp.dtype):
    match config.attention_type:
        case "latte_causal":
            return CausalScanLatte(
                config=config,
                unroll=config.unroll,
                dtype=dtype,
            )
        case "latte_convAll_causal":
            if config.embed_type in ["xpos", "rope"]:
                return RotConvAllCausalScanLatte(
                    config=config,
                    unroll=config.unroll,
                    dtype=dtype,
                )
            else:
                raise Exception("Not yet implemented")
        case "latte_convQR_causal":
            if config.embed_type in ["xpos", "rope"]:
                return RotConvQKCausalScanLatte(
                    config=config,
                    unroll=config.unroll,
                    dtype=dtype,
                )
            else:
                raise Exception("Not yet implemented")

        case "latte_mach_simple_causal":
            if config.embed_type in ["xpos", "rope"]:
                return CausalRopeLatteMachiattoChunk(
                    config=config,
                    unroll=config.unroll,
                    dtype=dtype,
                )
            else:
                raise Exception("Not yet implemented")

        case "latte_mach_sliding_causal":
            if config.embed_type in ["xpos", "rope"]:
                return CausalRopeLatteMachiattoSliding(
                    config=config,
                    unroll=config.unroll,
                    dtype=dtype,
                )
            else:
                raise Exception("Not yet implemented")
        case "standard_causal":
            if config.embed_type in ["xpos", "rope"]:
                return CausalRope(
                    config=config,
                    dtype=dtype,
                )
            else:
                return CausalSelfAttention(
                    config=config,
                    dtype=dtype,
                )
        case "standard_bid":
            return BidirectionalAttention(
                config=config,
                dtype=dtype,
            )
        case "scan_standard_causal":
            return ScanCausalSelfAttention(
                config=config,
                unroll=config.unroll,
                query_chunk_attention=1024,
                dtype=dtype,
            )
        case _ as unreachable:
            raise IOError("Type of attention not supported")


class MixingBlock(nn.Module):
    hidden_dim: int  # model dimention
    L: int  # number of latent variables
    nheads: int  # number of heads
    unroll: int  # unrolls used for the scan operation
    dropout: float = 0.0
    att_dropout: float = 0.0
    final_w_init_variance_scale: float = 1
    attention_type: ATT_TYPE = "latte_causal"
    rot_embeds: jnp.array = None
    dtype: jnp.dtype = jnp.float32

    @property
    def kernel_init(self) -> nn.initializers.Initializer:
        """Initialization of the kernel for the linear x and y layers of the block."""
        return nn.initializers.variance_scaling(
            scale=1.0,
            mode="fan_in",
            distribution="normal",
        )

    @property
    def out_kernel_init(self) -> nn.initializers.Initializer:
        """Initialization of the kernel for the last layer of the block."""
        return nn.initializers.variance_scaling(
            scale=self.final_w_init_variance_scale,
            mode="fan_in",
            distribution="normal",
        )

    @nn.compact
    def __call__(self, x, train: bool = False):
        lin1 = nn.Dense(
            features=self.hidden_dim,
            kernel_init=self.kernel_init,
            name="linear_1",
            dtype=self.dtype,
        )
        lin2 = nn.Dense(
            features=self.hidden_dim,
            kernel_init=self.kernel_init,
            name="linear_conv",
            dtype=self.dtype,
        )
        linear_out = nn.Dense(
            features=self.hidden_dim,
            kernel_init=self.out_kernel_init,
            name="linear_out",
            dtype=self.dtype,
        )
        conv = Conv1D(nchannels=self.hidden_dim, kernel_size=3, dtype=self.dtype)
        # TODO: replace with a config?
        mixing_layer = mixing_layer_factory(self)
        att_norm = RMSNorm(width=self.hidden_dim)
        act = jax.nn.gelu

        y = lin1(x)
        y = act(y)
        x = lin2(x)

        x = conv(x)
        x = act(x)
        x = mixing_layer(x, train=train)
        # we use a sort of attention so need to normalisded (norm not guaranteed to be 1)
        # x = att_norm(x)
        # element wise multipliation
        x = y * x
        return linear_out(x)


class ResidualBlock(nn.Module):
    hidden_dim: int  # model dimention
    L: int  # number of latent variables
    nheads: int  # number of heads
    unroll: int  # unrolls used for the scan operation
    dropout: float = 0.0
    att_dropout: float = 0.0
    final_w_init_variance_scale: float = 1
    attention_type: ATT_TYPE = "stable_latte"
    rot_embeds: jnp.array = None
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, x, train: bool = False):
        pre_norm = RMSNorm(width=self.hidden_dim, dtype=self.dtype)
        mix_block = MixingBlock(
            hidden_dim=self.hidden_dim,
            L=self.L,
            nheads=self.nheads,
            unroll=self.unroll,
            dropout=self.dropout,
            att_dropout=self.att_dropout,
            final_w_init_variance_scale=self.final_w_init_variance_scale,
            attention_type=self.attention_type,
            rot_embeds=self.rot_embeds,
            dtype=self.dtype,
        )
        mlp_pre_norm = RMSNorm(width=self.hidden_dim, dtype=self.dtype)

        mlp = GatedMLP(
            hidden_dim=self.hidden_dim,
            final_w_init_variance_scale=self.final_w_init_variance_scale,
            dtype=self.dtype,
        )

        residual = x
        x = pre_norm(x)
        x = mix_block(x, train=train)

        residual = x + residual
        x = mlp_pre_norm(residual)
        x = mlp(x)
        return x + residual


class TransformerBlock(nn.Module):
    config: Config
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, x, train: bool = False):
        pre_norm = RMSNorm(width=self.config.hidden_dim, dtype=self.dtype)
        mix_layer = mixing_layer_factory(self.config, self.dtype)
        mlp_pre_norm = RMSNorm(width=self.config.hidden_dim, dtype=self.dtype)
        mlp = GatedMLP(
            hidden_dim=self.config.hidden_dim,
            intermediate_dim=self.config.intermediate_dim,
            initializer_range=self.config.initializer_range,
            dtype=self.dtype,
        )

        residual = x
        x = pre_norm(x)
        x = mix_layer(x, train=train)

        # jax.debug.print("Mix block nan: {x}", x=jnp.isnan(x).any())
        residual = x + residual
        x = mlp_pre_norm(residual)
        x = mlp(x)
        return x + residual


class DecoderSota(nn.Module):
    """
    Servers as EncoderOnly or as a Decoder only
    depending on the attention_type
    """

    vocab_size: int
    config: Config
    dtype: jnp.dtype = jnp.float32
    sharded: bool = False

    def lm_head(self, x, embeds):
        return x @ embeds.embedding.T

    @nn.compact
    def __call__(
        self,
        X: jnp.array,
        train: bool = False,
        **kwargs,
    ) -> jnp.array:
        """
        Args:
            X: jnp.array(BTD), B = Batch size, T = sequence length, D = embed dimension
            train: bool - used for dropout
        Returns:
            out: jnp.array(BTD) - transformed output sequence
        """
        # drop_embed = nn.Dropout(rate=self.dropout, deterministic=not train)
        embedder = nn.Embed(
            num_embeddings=self.vocab_size,
            features=self.config.hidden_dim,
            dtype=self.dtype,
            embedding_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range
            ),
        )
        # embedder = Embedder(
        #     vocab_size=self.vocab_size,
        #     embed_dim=self.hidden_dim,
        #     scale_by_sqrt_dim=False,#self.config.embeddings_scale_by_sqrt_dim,
        #     dtype=self.dtype,
        # )

        pos_embeds = None
        # relative
        if self.config.embed_type == "absolute":  # absolute
            pos_embeds = PositionalEncoding(
                d_model=self.config.hidden_dim,
                max_len=self.config.pos_embed_max_len,
                dtype=self.dtype,
            )

        block_fn = functools.partial(
            TransformerBlock,  # ResidualBlock, #
            config=self.config,
            dtype=self.dtype,
        )

        final_norm = RMSNorm(width=self.config.hidden_dim, dtype=self.dtype)

        # absolute
        if not pos_embeds is None:
            X = pos_embeds(embedder(X))  # .encode(X))
            # X = drop_embed(X)
        else:  # relative or nope
            X = embedder(X)  # .encode(X)
            # X = drop_embed(X)

        # for l in enc_layers:
        #     X = l(X, train=train, **kwargs)
        # block_fn = nn.remat(block_fn, static_argnums=(1,))
        block = block_fn(name="residual_block")
        if self.sharded:
            X, _ = nn.scan(
                lambda module, carry, _: (module(carry, train=train), None),
                variable_axes={"params": 0, "intermediates": 0},
                split_rngs={"params": True, "dropout": True},
                length=self.config.nlayers,
            )(block, X, ())
        else:
            X, _ = nn.scan(
                lambda module, carry, _: (module(carry, train=train), None),
                variable_axes={"params": 0, "intermediates": 0},
                split_rngs={"params": True, "dropout": True},
                length=self.config.nlayers,
                metadata_params={
                    "partition_name": None
                },  # We do not need to partition over the layer axis.
            )(block, X, ())
        # get logits
        X = final_norm(X)
        logits = self.lm_head(X, embedder)
        return logits
