from typing import Any, Literal, Type
import functools
import math
import jax
import numpy as np
from flax import linen as nn
from jax import numpy as jnp

from .latte_rel_em import (
    RotCausalScanLatte,
    RotConvQKCausalScanLatte,
    RotConvAllCausalScanLatte,
)
from .latte_mach import CausalRopeLatteMachiattoChunk, CausalRopeLatteMachiattoSliding
from .latte import CausalScanLatte
from .layers import PositionalEncoding, RopeEmbeds, create_rope, Embedder
from .xPos import XPos

from .attention import (
    CausalSelfAttention,
    ScanCausalSelfAttention,
    CausalRope,
    BidirectionalAttention,
)
from .init_jax import dense_init
from latte_trans.config import LMTaskConfig, ATT_TYPE


def mixing_layer_factory(config: LMTaskConfig, dtype: jnp.dtype):
    match config.attention_type:
        case "latte_causal":
            if config.embed_type in ["nope", "absolute"]:
                return CausalScanLatte(
                    config=config,
                    unroll=config.unroll,
                    dtype=dtype,
                )
            elif config.embed_type in ["xpos", "rope"]:
                return RotCausalScanLatte(
                    config=config,
                    unroll=config.unroll,
                    dtype=dtype,
                )
            else:
                raise Exception("Not yet implemented")
        case "latte_convAll_causal":
            if config.embed_type in ["xpos", "rope"]:
                return RotConvAllCausalScanLatte(
                    config=config,
                    unroll=config.unroll,
                    dtype=dtype,
                )
            else:
                raise Exception("Not yet implemented")
        case "latte_convQR_causal":
            if config.embed_type in ["xpos", "rope"]:
                return RotConvQKCausalScanLatte(
                    config=config,
                    unroll=config.unroll,
                    dtype=dtype,
                )
            else:
                raise Exception("Not yet implemented")

        case "latte_mach_simple_causal":
            if config.embed_type in ["xpos", "rope"]:
                return CausalRopeLatteMachiattoChunk(
                    config=config,
                    unroll=config.unroll,
                    dtype=dtype,
                )
            else:
                raise Exception("Not yet implemented")

        case "latte_mach_sliding_causal":
            if config.embed_type in ["xpos", "rope"]:
                return CausalRopeLatteMachiattoSliding(
                    config=config,
                    unroll=config.unroll,
                    dtype=dtype,
                )
            else:
                raise Exception("Not yet implemented")
        case "standard_causal":
            if config.embed_type in ["xpos", "rope"]:
                return CausalRope(
                    config=config,
                    dtype=dtype,
                )
            else:
                return CausalSelfAttention(
                    config=config,
                    dtype=dtype,
                )
        case "standard_bid":
            return BidirectionalAttention(
                config=config,
                dtype=dtype,
            )
        case "scan_standard_causal":
            return ScanCausalSelfAttention(
                config=config,
                unroll=config.unroll,
                query_chunk_attention=1024,
                dtype=dtype,
            )
        case _ as unreachable:
            raise IOError("Type of attention not supported")


class TransBlock(nn.Module):
    """
    Implements a standard transformer block where the attention layer is replaced with mine
    """

    config: LMTaskConfig
    rot_embeds: jnp.array = None
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(
        self,
        X: jnp.array,
        train: bool = False,
        do_inference=False,
        cache=None,
        **kwargs,
    ) -> jnp.array:
        """
        Args:
            X: jnp.array(BTD), B = Batch size, T = sequence length, D = embed dimension
            train: bool - used for dropout
        Returns:

            out: jnp.array(BTD) - transformed output sequence
        """
        if self.config.batchnorm:
            norm1 = nn.BatchNorm(
                use_running_average=not train, momentum=0.9, dtype=self.dtype
            )
            norm2 = nn.BatchNorm(
                use_running_average=not train, momentum=0.9, dtype=self.dtype
            )
        else:
            norm1 = nn.LayerNorm(dtype=self.dtype)
            norm2 = nn.LayerNorm(dtype=self.dtype)

        lru = mixing_layer_factory(self.config, dtype=self.dtype)

        # Two - layer MLP
        mlp = [
            nn.Dense(
                self.config.intermediate_dim,
                dtype=self.dtype,
                kernel_init=jax.nn.initializers.normal(
                    stddev=self.config.initializer_range
                ),
            ),
            functools.partial(nn.gelu, approximate=False),
            nn.Dense(
                self.config.hidden_dim,
                dtype=self.dtype,
                kernel_init=jax.nn.initializers.normal(
                    stddev=self.config.initializer_range
                ),
            ),
            nn.Dropout(self.config.dropout, deterministic=not train),
        ]

        drop = nn.Dropout(self.config.dropout, deterministic=not train)

        if do_inference:
            pass

        skip = X
        if self.config.prenorm:
            X = norm1(X)
        X = lru(X, train=train, **kwargs)  # apply a mixing layer, like attention
        X = skip + drop(X)
        if not self.config.prenorm:
            X = norm1(X)
        # MLP part
        skip = X
        if self.config.prenorm:
            X = norm2(X)
        for l in mlp:
            X = l(X)
        X = skip + X
        if not self.config.prenorm:
            X = norm2(X)
        return X


class Decoder(nn.Module):
    """
    Servers as EncoderOnly or as a Decoder only
    depending on the attention_type
    """

    vocab_size: int
    config: LMTaskConfig
    dtype: jnp.dtype = jnp.float32
    sharded: bool = False

    @nn.compact
    def __call__(
        self,
        X: jnp.array,
        train: bool = False,
        do_inference=False,
        cache=None,
        **kwargs,
    ) -> jnp.array:
        """
        Args:
            X: jnp.array(BTD), B = Batch size, T = sequence length, D = embed dimension
            train: bool - used for dropout
        Returns:
            out: jnp.array(BTD) - transformed output sequence
        """
        ln = nn.LayerNorm(dtype=self.dtype)
        drop_embed = nn.Dropout(rate=self.config.dropout, deterministic=not train)
        # embed = nn.Embed(
        #     num_embeddings=self.vocab_size,
        #     features=self.config.hidden_dim,
        #     dtype=self.dtype,
        #     embedding_init=jax.nn.initializers.normal(
        #         stddev=self.config.initializer_range
        #     ),
        # )
        text_embed = Embedder(
            vocab_size=self.vocab_size,
            embed_dim=self.config.hidden_dim,
            scale_by_sqrt_dim=False,
            dtype=self.dtype,
            name="model_embed",
        )

        pos_embeds = None
        if self.config.embed_type == "absolute":  # absolute
            pos_embeds = PositionalEncoding(
                d_model=self.config.hidden_dim,
                max_len=self.config.pos_embed_max_len,
                dtype=self.dtype,
            )
        block_fn = functools.partial(
            TransBlock,
            config=self.config,
            dtype=self.dtype,
        )

        if do_inference:
            pass

        # absolute
        if not pos_embeds is None:
            X = pos_embeds(text_embed.encode(X))
            X = drop_embed(X)
        else:  # relative or nope
            X = text_embed.encode(X)
            X = drop_embed(X)

        if not self.config.prenorm:
            X = ln(X)

        # for l in enc_layers:
        #     X = l(X, train=train, **kwargs)
        block = block_fn(name="transformer_block")
        if self.sharded:
            X, _ = nn.scan(
                lambda module, carry, _: (module(carry, train=train, **kwargs), None),
                variable_axes={"params": 0, "intermediates": 0},
                split_rngs={"params": True, "dropout": True},
                length=self.config.nlayers,
            )(block, X, ())
        else:
            X, _ = nn.scan(
                lambda module, carry, _: (module(carry, train=train, **kwargs), None),
                variable_axes={"params": 0, "intermediates": 0},
                split_rngs={"params": True, "dropout": True},
                length=self.config.nlayers,
                metadata_params={
                    "partition_name": None
                },  # We do not need to partition over the layer axis.
            )(block, X, ())

        if self.config.prenorm:
            if self.config.batchnorm:
                X = nn.BatchNorm(use_running_average=not train, momentum=0.9)(X)
            else:
                X = nn.LayerNorm()(X)

        return text_embed.decode(X)
