import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import warnings
from models.positional_encodings import get_pos_encoder_embedding


def expand_or_assert(param, n_layers, name, expected_type):
    if isinstance(param, expected_type):
        return [param] * n_layers
    assert len(param) == n_layers, f"The length of {name} must be equal to the number of layers"
    return param


class TransformerLayer(nn.Module):
    def __init__(
        self,
        qk_dim,
        model_dim,
        ff_dim,
        n_heads,
        masks,
        enable_Wout=False,
        enable_skip=True,
        enable_norm=True,
        enable_mlp=True,
        enable_value=True,
        scale_attention=False,
        qk_param=True,
    ):
        super(TransformerLayer, self).__init__()
        self.enable_skip = enable_skip
        self.enable_norm = enable_norm
        self.enable_mlp = enable_mlp
        self.enable_value = enable_value
        self.enable_Wout = enable_Wout if enable_Wout else False
        self.qk_param = qk_param
        self.n_heads = n_heads
        self.model_dim = model_dim
        self.qk_dim = qk_dim
        self.scale_attention = scale_attention

        # Self-attention parameters
        if qk_param:
            self.Q_list = nn.ModuleList(
                [nn.Linear(model_dim, qk_dim, bias=False) for _ in range(n_heads)]
            )
            self.K_list = nn.ModuleList(
                [nn.Linear(model_dim, qk_dim, bias=False) for _ in range(n_heads)]
            )
        else:
            self.A_list = nn.ModuleList(
                [nn.Linear(model_dim, model_dim, bias=False) for _ in range(n_heads)]
            )
        if self.enable_value:
            self.V_list = nn.ModuleList(
                [nn.Linear(model_dim, model_dim, bias=False) for _ in range(n_heads)]
            )

        # Output Linear Layer, for single or concatenated heads
        if self.enable_Wout:
            # assert if
            self.Wout = nn.Linear(model_dim * n_heads, model_dim, bias=True)

        # Feed-forward network
        if enable_mlp:
            self.W_F1 = nn.Linear(model_dim, ff_dim, bias=True)
            self.W_F2 = nn.Linear(ff_dim, model_dim, bias=True)

        # Layer normalization
        if self.enable_norm:
            self.norm1 = nn.LayerNorm(model_dim)
        if self.enable_mlp and self.enable_norm:
            self.norm2 = nn.LayerNorm(model_dim)

        for i, mask in enumerate(masks):
            self.register_buffer(f"mask_{i}", mask)

    def forward(self, X, masks, return_attention=False, tresh_att=None, temp=None):
        """
        Compute multi-head self-attention.
        - If enable_Wout=True: we concatenate head outputs and apply Wout.
        - If enable_Wout=False: we sum over all head outputs directly.

        Args:
            X : [batch_size, seq_len, model_dim]
            masks : either None or list of masks (one per head).
            return_attention : if True, return list of attention matrices.
            tresh_att : optional thresholding for attention matrix.
            temp : optional temperature for attention scaling.
        """

        att_list = []

        # Either accumulate in a Python list (for concatenation)
        # or keep a running sum (for direct summation).
        if self.enable_Wout:
            head_outputs = []  # we will concatenate
        else:
            accum_output = None

        for i in range(self.n_heads):
            # ---- 1) Q, K, A, or V  ----
            if self.qk_param:
                Q = self.Q_list[i](X)
                K = self.K_list[i](X)
                attention_scores = torch.matmul(Q, K.transpose(-2, -1))
            else:
                A = self.A_list[i](X)
                attention_scores = torch.matmul(A, X.transpose(-2, -1))

            # ---- 2) Handle Value vector  ----
            if self.enable_value:
                V = self.V_list[i](X)
            else:
                V = X

            # ---- 3) Optional scaling  ----
            if self.scale_attention:
                # fall back to sqrt(qk_dim) or sqrt(model_dim)
                t = (
                    temp
                    if temp is not None
                    else (self.qk_dim**0.5 if self.qk_param else self.model_dim**0.5)
                )
                attention_scores = attention_scores / t

            # ---- 4) Retrieve or use passed-in mask  ----
            if masks is None:
                mask = self.get_buffer(f"mask_{i}")
            else:
                mask = masks[i]

            attention_scores = attention_scores + mask

            # ---- 5) Softmax + optional thresholding ----
            att = F.softmax(attention_scores, dim=-1)
            if tresh_att is not None:
                att = torch.where(att > tresh_att, att, torch.zeros_like(att))

            att_list.append(att)

            # ---- 6) Multiply by V -> head output ----
            attention_output = torch.matmul(att, V)

            # ---- 7) Collect head outputs  ----
            if self.enable_Wout:
                # We'll concatenate later
                head_outputs.append(attention_output)
            else:
                # Accumulate running sum directly
                if accum_output is None:
                    # the first head output
                    accum_output = attention_output
                else:
                    accum_output += attention_output

        # ---- 8) Merge heads differently based on enable_Wout ----
        if self.enable_Wout:
            # Concatenate all heads: shape => [B, T, (n_heads*model_dim)]
            multihead_out = torch.cat(head_outputs, dim=-1)
            # Project back to model_dim
            attention_output = self.Wout(multihead_out)
        else:
            # Summed heads: shape => [B, T, model_dim]
            attention_output = accum_output

        # Skip connection and normalization for self-attention
        if self.enable_skip:
            Z = attention_output + X
        else:
            Z = attention_output
        if self.enable_norm:
            Z = self.norm1(Z)

        # Feed-forward network
        if self.enable_mlp:
            H = self.W_F2(F.relu(self.W_F1(Z)))
        else:
            H = Z

        # Skip connection and normalization for feed-forward network
        if self.enable_skip and self.enable_mlp:
            X_prime = H + Z
        else:
            X_prime = H
        if self.enable_norm and self.enable_mlp:
            X_prime = self.norm2(X_prime)

        if return_attention:
            return X_prime, att_list
        else:
            return X_prime


class MultiHeadSelfAttention(nn.Module):
    """
    Demonstration of a multi-head self-attention module that:
      - Uses Q/K of dimension qk_dim IF qk_param == True
      - Otherwise uses a single A (of dimension model_dim per head)
      - V can be model_dim per head
      - Optionally applies:
           * W_out (projection after concat)
           * skip connection
           * layernorm (separate norms for attention & MLP)
           * MLP (2-layer feed-forward)
    """

    def __init__(
        self,
        model_dim,
        n_heads,
        qk_dim,
        qk_param=True,  # If True => separate Q,K; else => single A
        enable_value=True,
        enable_Wout=True,
        scale_attention=True,
        enable_skip=True,
        enable_norm=True,
        enable_mlp=True,
        masks=None,
    ):
        super().__init__()
        self.model_dim = (model_dim // n_heads) * n_heads
        self.n_heads = n_heads
        self.qk_dim = qk_dim
        self.qk_param = qk_param
        self.enable_value = enable_value
        self.enable_Wout = enable_Wout
        self.scale_attention = scale_attention
        self.enable_skip = enable_skip
        self.enable_norm = enable_norm
        self.enable_mlp = enable_mlp

        # ---------------------------------------------------------------------
        # Q/K or A
        # ---------------------------------------------------------------------
        if self.qk_param:
            # Q, K each produce n_heads * qk_dim
            self.W_Q = nn.Linear(model_dim, n_heads * qk_dim, bias=True)
            self.W_K = nn.Linear(model_dim, n_heads * qk_dim, bias=True)
            self.W_A = None
        else:
            # Single A produces model_dim
            self.W_Q = None
            self.W_K = None
            self.W_A = nn.Linear(model_dim, model_dim, bias=False)

        # ---------------------------------------------------------------------
        # V
        # ---------------------------------------------------------------------
        if self.enable_value:
            # V => model_dim
            self.W_V = nn.Linear(model_dim, model_dim, bias=True)
        else:
            self.W_V = None

        # ---------------------------------------------------------------------
        # Optional W_out (if we concat heads, final dimension = model_dim)
        # ---------------------------------------------------------------------
        if self.enable_Wout:
            self.W_out = nn.Linear(model_dim, model_dim, bias=True)
        else:
            self.W_out = None

        # ---------------------------------------------------------------------
        # (Optional) Norm layers
        # We typically have a norm after attention and another after MLP
        # ---------------------------------------------------------------------
        if self.enable_norm:
            self.norm1 = nn.LayerNorm(model_dim)
            if self.enable_mlp:
                self.norm2 = nn.LayerNorm(model_dim)

        # ---------------------------------------------------------------------
        # (Optional) 2-layer MLP
        # Example sizes: model_dim -> 4*model_dim -> model_dim
        # or any other feed-forward dimension
        # ---------------------------------------------------------------------
        if self.enable_mlp:
            hidden_dim = 4 * model_dim  # or whatever dimension you prefer
            self.W_F1 = nn.Linear(model_dim, hidden_dim)
            self.W_F2 = nn.Linear(hidden_dim, model_dim)

        # ---------------------------------------------------------------------
        # (Optional) Mask
        # ---------------------------------------------------------------------
        for i, mask in enumerate(masks):
            self.register_buffer(f"mask_{i}", mask)

    def forward(
        self,
        X,  # [B, T, model_dim]
        mask=None,  # [B, n_heads, T, T] or broadcastable shape
        return_attention=False,
        thresh_att=None,
        temp=None,
    ):
        B, T, _ = X.shape

        # ---------------------------------------------------------------------
        # 1) Compute attention scores
        # ---------------------------------------------------------------------
        if self.qk_param:
            # Q, K => [B, T, n_heads*qk_dim]
            Q = self.W_Q(X)
            K = self.W_K(X)

            # Reshape => [B, n_heads, T, qk_dim]
            Q = Q.view(B, T, self.n_heads, self.qk_dim).transpose(1, 2)
            K = K.view(B, T, self.n_heads, self.qk_dim).transpose(1, 2)

            # => [B, n_heads, T, T]
            attention_scores = torch.matmul(Q, K.transpose(-2, -1))
        else:
            # Single A => [B, T, n_heads*model_dim]
            A = self.W_A(X)
            # => [B, n_heads, T, model_dim]
            A = A.view(B, T, self.n_heads, self.model_dim).transpose(1, 2)

            # We'll do A @ X^T in parallel
            A_ = A.reshape(B * self.n_heads, T, self.model_dim)
            X_ = X.unsqueeze(1).expand(B, self.n_heads, T, self.model_dim)
            X_ = X_.reshape(B * self.n_heads, T, self.model_dim)

            attention_scores = torch.bmm(A_, X_.transpose(1, 2))  # => [B*n_heads, T, T]
            attention_scores = attention_scores.view(B, self.n_heads, T, T)

        # (Optional) scaling
        if self.scale_attention:
            # default if not specified:
            #   - sqrt(qk_dim) if qk_param == True
            #   - sqrt(model_dim) if qk_param == False
            if temp is None:
                default_scale = self.qk_dim if self.qk_param else self.model_dim
                temp = default_scale**0.5

            attention_scores = attention_scores / temp

        # Add mask for each head
        if mask is None:
            # Retrieve the registered mask for each head and stack them
            masks = [self.get_buffer(f"mask_{i}") for i in range(self.n_heads)]
            mask = torch.stack(masks, dim=0)
        else:
            # Assume mask is a list of masks (one per head) and stack them
            mask = torch.stack(mask, dim=0)

        attention_scores = attention_scores + mask

        # Softmax
        att = F.softmax(attention_scores, dim=-1)
        # Optional thresholding
        if thresh_att is not None:
            att = torch.where(att > thresh_att, att, torch.zeros_like(att))
            # If truly ignoring those entries, you might want to re-normalize

        # ---------------------------------------------------------------------
        # 2) Compute V, then multiply by attention
        # ---------------------------------------------------------------------
        if self.enable_value:
            # => [B, T, model_dim]
            V = self.W_V(X)
            # => [B, n_heads, T, model_dim]
            V = V.view(B, T, self.n_heads, self.model_dim // self.n_heads).transpose(1, 2)
        else:
            # fallback: spread X to all heads
            V = X.view(B, T, self.n_heads, self.model_dim // self.n_heads).transpose(1, 2)

        # Multiply => [B, n_heads, T, model_dim]
        attention_output = torch.matmul(att, V)

        # ---------------------------------------------------------------------
        # 3) Combine heads
        # ---------------------------------------------------------------------
        if self.enable_Wout:
            # Concat => [B, T, model_dim]
            attention_output = attention_output.transpose(1, 2).contiguous()
            attention_output = attention_output.view(B, T, self.model_dim)
            # Project => [B, T, model_dim]
            attention_output = self.W_out(attention_output)
        else:
            # Concatenate heads => [B, T, model_dim] (since each head has model_dim)
            attention_output = attention_output.view(B, T, self.model_dim)

        # ---------------------------------------------------------------------
        # 4) Optional skip & norm (for attention)
        # ---------------------------------------------------------------------
        if self.enable_skip:
            # skip-conn from X
            Z = attention_output + X
        else:
            Z = attention_output

        if self.enable_norm:
            Z = self.norm1(Z)

        # ---------------------------------------------------------------------
        # 5) (Optional) MLP + skip & norm
        # ---------------------------------------------------------------------
        if self.enable_mlp:
            # MLP
            H = self.W_F2(F.relu(self.W_F1(Z)))
        else:
            H = Z

        if self.enable_skip and self.enable_mlp:
            X_prime = H + Z
        else:
            X_prime = H

        if self.enable_norm and self.enable_mlp:
            X_prime = self.norm2(X_prime)

        # ---------------------------------------------------------------------
        # Return final outputs
        # ---------------------------------------------------------------------
        if return_attention:
            att_list = [att[:, i] for i in range(self.n_heads)]
            return X_prime, att_list
        else:
            return X_prime


class Transformer(nn.Module):
    def __init__(
        self,
        qk_dim,
        embedding_dim,
        pos_dim,
        vocab_size,
        n_layers,
        n_heads,
        block_size,
        enable_skip=False,
        enable_norm=False,
        enable_mlp=True,
        enable_value=True,
        enable_Wout=False,
        pos_enc="none",
        scale_attention=False,
        cat_pos=False,
        scale_init=None,
        init_att=None,
        init_mlp=None,
        init_output=None,
        freeze_emb=False,
        one_hot_emb=False,
        skip_embedding=False,
        embedding_type="embedding",
        qk_param=True,
        masks=None,
        temp=None,
        output_size=None,
    ):
        """
        Initialize the Transformer model.

        Args:
            qk_dim (int): The dimension of the query and key vectors.
            embedding_dim (int): The dimension of the input word embeddings.
            pos_dim (int): The dimension of the positional encoding.
            ff_dim (int): The dimension of the feed-forward layer.
            vocab_size (int): The size of the vocabulary.
            n_layers (int): The number of transformer layers.
            block_size (int): The size of the transformer block.
            enable_skip (bool, optional): Whether to enable skip connections. Defaults to False.
            enable_norm (bool, optional): Whether to enable layer normalization. Defaults to False.
            pos_enc (str, optional): The type of positional encoding. Defaults to 'none'.
            scale_attention (bool, optional): Whether to scale the attention weights.
                Defaults to False.
            cat_pos (bool, optional): Whether to concatenate positional encoding with word
                embeddings. Defaults to False.
            scale_init (float, optional): The scale factor for weight initialization.
                Defaults to None.
            init_att (float, optional): The initial value for attention weights. Defaults to None.
            init_mlp (float, optional): The initial value for MLP weights. Defaults to None.
            init_output (float, optional): The initial value for output layer weights.
                Defaults to None.
            freeze_emb (bool, optional): Whether to freeze the input embeddings. Defaults to False.
            one_hot_emb (bool, optional): Whether to use one-hot embeddings. Defaults to False.
            skip_embedding (bool, optional): Whether to skip the embedding layer. Defaults to False.
            embedding_type (str, optional): The type of embedding to use. Defaults to 'embedding'.
            qk_param (bool, optional): Whether to use separate query and key parameters.
                Defaults to True.
            masks (list, optional): List of attention masks for each layer. Defaults to None.
            temp (float, optional): Temperature for scaling attention weights. Defaults to None.
            output_size (int, optional): The size of the output layer. Defaults to
                None (uses vocab_size).

        """
        super(Transformer, self).__init__()
        self.pos_enc = pos_enc
        self.vocab_size = vocab_size
        self.block_size = block_size - 1
        self.n_layers = n_layers
        # checkf if enable_mlp is a list
        enable_mlp = expand_or_assert(enable_mlp, n_layers, "enable_mlp", bool)
        self.n_heads = expand_or_assert(n_heads, n_layers, "n_heads", int)
        # att_mask = nn.Transformer.generate_square_subsequent_mask(self.block_size)
        masks = self.create_masks(masks)
        self.pos_encoder, self.pos_dim, self.embedding_dim, _, self.model_dim, self.embedding = (
            get_pos_encoder_embedding(
                pos_dim,
                embedding_dim,
                cat_pos,
                self.block_size,
                pos_enc,
                None,
                vocab_size,
                one_hot_emb,
                skip_embedding,
                freeze_emb,
                embedding_type=embedding_type,
            )
        )
        print("model_dim", self.model_dim)
        # self.register_buffer('att_mask', att_mask)
        self.qk_dim = qk_dim if qk_dim != 0 else self.model_dim
        self.qk_param = qk_param
        self.temp = [temp] * n_layers if isinstance(temp, (int, float)) else temp

        self.layers = nn.ModuleList(
            [
                MultiHeadSelfAttention(
                    qk_dim=self.qk_dim,
                    model_dim=self.model_dim,
                    n_heads=self.n_heads[i],
                    masks=masks[i],
                    enable_Wout=enable_Wout,
                    enable_skip=enable_skip,
                    enable_norm=enable_norm,
                    enable_mlp=enable_mlp[i],
                    enable_value=enable_value,
                    scale_attention=scale_attention,
                    qk_param=qk_param,
                )
                for i in range(n_layers)
            ]
        )
        self.output_layer = nn.Linear(
            self.model_dim, output_size if output_size is not None else vocab_size, bias=False
        )

        self.initialize_weights(scale_init, init_att, init_mlp, init_output)
        # print a summary of the model saying what is teh embedding dim , the pos dim and the model dim
        print(
            f"\nModel Summary: \
                \nEmbedding Dimension: {self.embedding_dim}\
                \nPositional Encoding Dimension: {self.pos_dim}\
                \nModel Dimension: {self.model_dim}\
                \nQK Dimension: {self.qk_dim}\
                \nNumber of Layers: {n_layers}\
                \nNumber of heads: {n_heads}\
                \nEnable Skip: {enable_skip}\
                \nEnable Norm: {enable_norm}\
                \nEnable MLP: {enable_mlp}\
                \nScale Attention: {scale_attention}\
                \nPositional Encoding: {pos_enc}\
                \nConcatenate Positional Encoding: {cat_pos}\
                \nBlock Size: {self.block_size}\
                \nOne hot embedding: {one_hot_emb}\
                \nFreeze Embedding: {freeze_emb}\
                \nQK Parametrization: {qk_param}\
                \nOutput Size: {output_size if output_size is not None else vocab_size}\
                \n"
        )
        print(
            f"\nThe model has {sum(p.numel() for p in self.parameters() if p.requires_grad):,} trainable parameters \n"
        )
        print(f"\nModel Architecture: \n{self}")

    def create_masks(self, masks):
        if masks is None:
            masks = [
                [
                    nn.Transformer.generate_square_subsequent_mask(self.block_size)
                    for _ in range(self.n_heads[i])
                ]
                for i in range(self.n_layers)
            ]
        else:
            assert (
                len(masks) == self.n_layers
            ), "Number of masks should be equal to the number of layers"
            assert all(
                len(masks[i]) == self.n_heads[i] for i in range(self.n_layers)
            ), "Number of masks should be equal to the number of heads"
        return masks

    def initialize_weights(
        self, scale_init=None, init_att="uniform", init_mlp="normal", init_output="xavier_uniform"
    ):
        """Initialize weights for Transformer layers including the output layer.

        Args:
            scale_init: Scaling factor for initialization where applicable.
            init_type_attn: Type of initialization for attention layers (Q, K, V).
            init_type_mlp: Type of initialization for MLP layers (W_F1, W_F2).
            init_type_output: Type of initialization for the output layer.
        """
        scale_label = scale_init if scale_init is not None else "dynamic"
        scale = lambda x: 1 / math.sqrt(x.size(1)) if scale_init is None else scale_init

        # Define initialization methods using lambda functions to apply dynamic scaling
        init_methods = {
            "uniform": lambda x: nn.init.uniform_(x, a=-scale(x), b=scale(x)),
            "normal": lambda x: nn.init.normal_(x, mean=0.0, std=scale(x)),
            "xavier_uniform": lambda x: nn.init.xavier_uniform_(x),
            "constant": lambda x: nn.init.constant_(x, scale_init),
        }

        if init_att is not None:
            for layer in self.layers:
                # Check for TransformerLayer attributes
                if hasattr(layer, "Q_list"):
                    for name in ["Q_list", "K_list", "A_list", "V_list"]:
                        module_list = getattr(layer, name, None)
                        if module_list is not None:
                            for module in module_list:
                                init_methods[init_att](module.weight)
                else:
                    # For MultiHeadSelfAttention layers
                    if layer.qk_param:
                        for module in [layer.W_Q, layer.W_K]:
                            init_methods[init_att](module.weight)
                    else:
                        init_methods[init_att](layer.W_A.weight)
                    if layer.enable_value and layer.W_V is not None:
                        init_methods[init_att](layer.W_V.weight)

        # Initialize MLP layers W_F1, W_F2
        if init_mlp is not None:
            for module in [layer.W_F1, layer.W_F2]:
                init_methods[init_mlp](module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

        # Initialize output layer
        if init_output is not None:
            init_methods[init_output](self.output_layer.weight)
            if self.output_layer.bias is not None:
                nn.init.constant_(self.output_layer.bias, 0)  # Initialize biases to zero

        # Print summary of the initializations used
        print(f"\nInitialization summary:")
        print(f"Attention layers (Q, K, V): {init_att}, scale: {scale_label}")
        print(f"MLP layers (W_F1, W_F2): {init_mlp}, scale: {scale_label}")
        print(f"Output layer: {init_output}, scale: {scale_label}")

    def forward(self, X, masks=None, return_attention=False, tresh_att=None):
        """
        Forward pass of the transformer model.

        Args:
            X (torch.Tensor): Input tensor of shape (batch_size, sequence_length).
            att_mask (torch.Tensor, optional): Attention mask tensor of shape (batch_size, sequence_length, sequence_length).
            return_attention (bool, optional): Whether to return the attention weights.

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, sequence_length, vocabulary_size).
            list of torch.Tensor: List of attention weights for each layer, if return_attention is True.
        """
        X = self.embedding(X)
        if self.pos_encoder is not None:
            X = self.pos_encoder(X)
        dtype = next(self.parameters()).dtype
        X = X.type(dtype)

        if masks is None:
            masks = [None] * self.n_layers
        else:
            assert (
                len(masks) == self.n_layers
            ), "Number of masks should be equal to the number of layers"
            assert all(
                len(masks[i]) == self.n_heads[i] for i in range(self.n_layers)
            ), "Number of masks should be equal to the number of heads"

        # check if tresh_att is not a list but only a scalar then copy it n_layers times
        if not isinstance(tresh_att, list):
            tresh_att = [tresh_att] * len(self.layers)

        if return_attention:
            att_l = []
            for i, layer in enumerate(self.layers):
                X, att = layer(X, masks[i], return_attention, tresh_att[i])
                att_l.append(att)
            X = self.output_layer(X)
            return X, att_l
        else:
            for i, layer in enumerate(self.layers):
                X = layer(X, masks[i])
            X = self.output_layer(X)
            return X

    def generate(self, max_length, start_token=None, temperature=1.0, top_k=None):
        self.eval()  # Set the model to evaluation mode
        device = next(self.parameters()).device

        # Sample a random starting token (assuming your vocab indices start from 0)
        if start_token is None:
            start_token = np.random.randint(0, self.vocab_size)

        generated_sequence = [start_token]

        for _ in range(max_length - 1):
            current_seq = torch.tensor([generated_sequence], dtype=torch.long).to(
                device
            )  # Move tensor to the model's device
            att_mask = nn.Transformer.generate_square_subsequent_mask(current_seq.size(1)).to(
                device
            )

            with torch.no_grad():
                logits = self(current_seq, att_mask)

            # Apply temperature scaling if temperature is specified differently
            if temperature != 1.0:
                logits = logits[:, -1, :] / temperature
            else:
                logits = logits[:, -1, :]

            probs = F.softmax(logits, dim=-1)

            # Apply top-k filtering if top_k is specified
            if top_k is not None:
                top_probs, top_ix = torch.topk(probs, k=top_k)
                sampled_ix = torch.multinomial(top_probs, num_samples=1)
                next_token = top_ix.gather(1, sampled_ix).item()
            else:
                next_token = torch.multinomial(probs, num_samples=1).item()

            generated_sequence.append(next_token)

            # Optionally, break if an end token is generated
            # if next_token == end_token_id:  # Define end_token_id as per your setup
            #     break
        return generated_sequence


# qk_dim, embedding_dim, pos_dim, ff_dim, vocab_size, n_layers, block_size, enable_skip=False, enable_norm=False, enable_mlp=True, pos_enc='none', scale_attention=False, cat_pos=False, scale_init=None, init_att=None, init_mlp = None, init_output=None, freeze_emb = False, one_hot_emb = False):
if __name__ == "__main__":

    def test_transformer():
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        qk_dim = 64
        embedding_dim = [32]
        pos_dim = 16
        ff_dim = 64
        vocab_size = 100
        n_layers = 2
        block_size = 54
        batch_size = 10
        enable_skip = True
        enable_norm = True
        enable_mlp = [False, True]
        pos_enc = ["sin_cos", "one_hot", "learned"]
        scale_attention = True
        cat_pos = [True, False]
        freeze_emb = [True, False]
        one_hot_emb = [True, False]
        qk_params = [True, False]
        inpt = torch.randint(0, vocab_size, (batch_size, block_size - 1))

        for enc in pos_enc:
            for emb_dim in embedding_dim:
                for cp in cat_pos:
                    for freeze in freeze_emb:
                        for ohe in one_hot_emb:
                            for qk in qk_params:
                                print(
                                    f"Positional Encoding: {enc}, Embedding Dimension: {emb_dim}, Positional dimension {pos_dim}, Concatenate Positional Encoding: {cp}",
                                    f"Freeze Embedding: {freeze}, One hot embedding: {ohe} \n",
                                )
                                try:
                                    model = Transformer(
                                        qk_dim,
                                        emb_dim,
                                        pos_dim,
                                        ff_dim,
                                        vocab_size,
                                        n_layers,
                                        block_size,
                                        enable_skip=enable_skip,
                                        enable_norm=enable_norm,
                                        enable_mlp=enable_mlp,
                                        enable_value=False,
                                        pos_enc=enc,
                                        scale_attention=scale_attention,
                                        cat_pos=cp,
                                        scale_init=None,
                                        init_att=None,
                                        init_mlp=None,
                                        init_output=None,
                                        freeze_emb=freeze,
                                        one_hot_emb=ohe,
                                        qk_param=qk,
                                    ).to(device)
                                    output = model(inpt)
                                    print(f"Output shape: {output.shape}")

                                except Exception as e:
                                    print(
                                        f"!!!!!!!!!!!!!!!! Error occurred !!!!!!!!!!!!!!!!!!!: {str(e)}"
                                    )
                                print("\n")

    def test_multi_head_transformer():
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        qk_dim = 64
        embedding_dim = 128
        pos_dim = 16
        ff_dim = 64
        vocab_size = 100
        n_layers = 2
        n_heads = [1, 2, 4]
        enable_Wout = False
        block_size = 54
        batch_size = 10
        enable_skip = True
        enable_norm = True
        enable_mlp = [True, True]
        pos_enc = "learned"
        scale_attention = True
        cat_pos = False
        freeze_emb = False
        one_hot_emb = False
        qk_params = [True, False]
        init_att = "constant"
        inpt = torch.randint(0, vocab_size, (batch_size, block_size - 1)).to(device)

        for n_head in n_heads:
            for qk in qk_params:
                print(f"Number of Heads: {n_head}, QK Parametrization: {qk}")
                # try:
                model = Transformer(
                    qk_dim,
                    embedding_dim,
                    pos_dim,
                    ff_dim,
                    vocab_size,
                    n_layers,
                    n_head,
                    block_size,
                    enable_skip=enable_skip,
                    enable_norm=enable_norm,
                    enable_mlp=enable_mlp,
                    enable_value=False,
                    enable_Wout=enable_Wout,
                    pos_enc=pos_enc,
                    scale_attention=scale_attention,
                    cat_pos=cat_pos,
                    scale_init=None,
                    init_att=init_att,
                    init_mlp=None,
                    init_output=None,
                    freeze_emb=freeze_emb,
                    one_hot_emb=one_hot_emb,
                    qk_param=qk,
                ).to(device)

                # Compile the model using torch.compile
                model = torch.compile(model)
                output = model(inpt)

                print(f"Output shape: {output.shape}")
                # Check that each transformer layer contains the correct number of registered masks.
                for idx, layer in enumerate(model.layers):
                    # Count the number of buffers whose keys start with "mask_"
                    num_layer_masks = sum(1 for key in layer._buffers if key.startswith("mask_"))
                    expected_masks = model.n_heads[idx]
                    assert (
                        num_layer_masks == expected_masks
                    ), f"Layer {idx} has {num_layer_masks} masks; expected {expected_masks}."
                print("All layers have the correct number of masks.")

    test_multi_head_transformer()
