from typing import Any, Literal, Type
import functools
import math
import torch
import torch.nn as nn
from torch.nn import functional as F

from .layers_pt import MLP, PositionalEncoding, RopeEmbeds, RMSNorm, GatedMLP
from .xPos_pt import XPos
from ..modules.init_pt import _init_weights

from .attention_pt import CausalSelfAttention, CausalRope
from latte_trans.config import ATT_TYPE, LMTaskConfig


def mixing_layer_factory(config: LMTaskConfig):
    match config.attention_type:
        case "standard_causal":
            if config.embed_type in ["nope", "absolute"]:
                return CausalSelfAttention(config=config)
            elif config.embed_type in ["xpos", "rope"]:
                return CausalRope(config=config)
        case _ as unreachable:
            raise IOError("Type of attention not supported")


class TransBlock(nn.Module):
    """
    Implements a standard transformer block where the attention layer is replaced with mine
    """

    def __init__(
        self,
        config: LMTaskConfig,
    ):
        super().__init__()
        self.config = config
        self.hidden_dim = config.hidden_dim
        self.max_seq_len = config.max_seq_len
        self.nheads = config.nheads
        self.prenorm = config.prenorm
        self.dropout = config.dropout
        self.att_dropout = config.dropout_att
        self.batchnorm = config.batchnorm
        self.attention_type = config.attention_type

        self.norm1 = RMSNorm(self.hidden_dim)
        self.norm2 = RMSNorm(self.hidden_dim)
        self.mlp = GatedMLP(self.config.hidden_dim)
        self.lru = mixing_layer_factory(config)

        # self.drop = nn.Dropout(self.dropout)

    def forward(
        self,
        X: torch.tensor,
        **kwargs,
    ) -> torch.tensor:
        """
        Args:
            X: jnp.array(BTD), B = Batch size, T = sequence length, D = embed dimension
            train: bool - used for dropout
        Returns:

            out: jnp.array(BTD) - transformed output sequence
        """

        # Two - layer MLP
        skip = X
        X = self.norm1(X)
        X = self.lru(X, **kwargs)  # apply a mixing layer, like attention
        X = skip + X
        # MLP part
        skip = X
        X = self.norm2(X)
        X = self.mlp(X)
        X = skip + X

        return X


class Decoder(nn.Module):
    """
    Servers as EncoderOnly or as a Decoder only
    depending on the attention_type
    """

    def __init__(
        self,
        vocab_size: int,
        pad_id: int,
        config: LMTaskConfig,
        rot_embeds: torch.tensor = None,
    ):
        super().__init__()
        self.config = config
        self.hidden_dim = config.hidden_dim
        self.max_seq_len = config.max_seq_len
        self.nheads = config.nheads
        self.prenorm = config.prenorm
        self.dropout = config.dropout
        self.att_dropout = config.dropout_att
        self.batchnorm = config.batchnorm
        self.attention_type = config.attention_type
        self.rot_embeds = rot_embeds

        self.post_ln = RMSNorm(self.hidden_dim)
        # self.drop_embed = nn.Dropout(self.dropout)
        self.embed = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=self.hidden_dim,
            padding_idx=pad_id,
        )

        self.pos_embeds = None
        if self.config.embed_type == "absolute":
            self.pos_embeds = PositionalEncoding(
                d_model=self.hidden_dim,
                max_len=self.config.pos_embed_max_len,
            )

        self.enc_layers = nn.ModuleList(
            [TransBlock(config) for _ in range(config.nlayers)]
        )

        self.apply(functools.partial(_init_weights, self.config))

    def forward(
        self,
        X: torch.tensor,
        **kwargs,
    ) -> torch.tensor:
        """
        Args:
            X: jnp.array(BTD), B = Batch size, T = sequence length, D = embed dimension
            train: bool - used for dropout
        Returns:
            out: jnp.array(BTD) - transformed output sequence
        """
        # absolute
        if not self.pos_embeds is None:
            X = self.pos_embeds(self.embed(X))
            # X = self.drop_embed(X)
        else:  # relative or nope
            X = self.embed(X)
            # X = self.drop_embed(X)

        for l in self.enc_layers:
            X = l(X, **kwargs)

        X = self.post_ln(X)
        return X
