import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from models.positional_encodings import sin_cos_positional_encoding


class ToyTransformer(nn.Module):
    """A simplified transformer-like model.

    Variations:
    - add_embedding: Use nn.Embedding and a final linear projection.
    - add_qk_attention: Use standard QK attention instead of a fixed 'a' vector.

    Forward pass (default): sum_{t' < t} softmax(a)_{t'} * (W @ x_{t'})
    where 'a' is a fixed, non-trainable vector initialized to zeros,
    and W is a linear layer.
    """

    def __init__(
        self,
        vocab_size,
        output_size,
        block_size,
        add_qk_attention=False,
        add_positional_encoding=False,
        cat_positional_encoding=True,
        head_dim=64,
        embedding_dim=64,
        **kwargs,
    ):
        super().__init__()
        self.input_dim = vocab_size
        self.output_size = output_size
        self.block_size = block_size
        self.add_qk_attention = add_qk_attention
        self.add_positional_encoding = add_positional_encoding
        self.cat_positional_encoding = cat_positional_encoding
        self.embedding_dim = embedding_dim
        self.head_dim = head_dim

        self.W = nn.Linear(vocab_size, output_size, bias=False)
        current_dim = output_size

        if self.add_positional_encoding:
            self.positional_encoding = sin_cos_positional_encoding(
                max_len=block_size,
                d_model=256,
                concatenate=self.cat_positional_encoding,
            )
            current_dim += 256 if self.cat_positional_encoding else 0
            self.W = nn.Linear(self.input_dim + 256, output_size, bias=False)

        if self.add_qk_attention:
            self.W_Q = nn.Linear(current_dim, self.head_dim, bias=True)
            self.W_K = nn.Linear(current_dim, self.head_dim, bias=True)
        else:
            self.a = nn.Parameter(torch.zeros(block_size))

        # Causal mask to ensure attention is only over past tokens (t' <= t)
        mask = torch.tril(torch.ones(self.block_size, self.block_size))
        mask = mask.masked_fill(mask == 0, float("-inf"))  # Fill upper triangle with -inf
        self.register_buffer("causal_mask", mask, persistent=False)

        print(
            f"\nToyTransformer Initialized:\n"
            f"Input Dimension (Vocab Size): {vocab_size}\n"
            f"Output Size: {output_size}\n"
            f"Block Size (context length + 1): {self.block_size}\n"
            f"Add QK Attention: {self.add_qk_attention}\n"
            + (f"Head Dimension: {self.head_dim}\n" if self.add_qk_attention else "")
        )

    def forward(self, x, return_attention=False):
        """
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, input_dim).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, seq_len, output_size).
        """
        B, T, C = x.shape
        assert (
            T <= self.block_size
        ), f"Input sequence length ({T}) exceeds model block size ({self.block_size})"

        if self.add_positional_encoding:
            print(x.shape)
            x = self.positional_encoding(x)
            print(x.shape)

        # 1. Calculate attention weights
        if self.add_qk_attention:
            q = self.W_Q(x)  # (B, T, head_dim)
            k = self.W_K(x)  # (B, T, head_dim)
            attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
            print(attention_scores.shape)
        else:
            attention_scores = (
                self.a.unsqueeze(0).unsqueeze(1).expand(B, T, self.block_size)[:, :, :T]
            )  # (B, T, T)

        # Apply causal mask
        current_mask = self.causal_mask[:T, :T].unsqueeze(0)  # (1, T, T)
        attention_scores = attention_scores + current_mask  # Add mask, broadcasts over batch

        # attention_weights: (B, T, T)
        attention_weights = F.softmax(attention_scores, dim=-1)

        # 2. Apply attention weights
        # (B, T, T) @ (B, T, dim) -> (B, T, dim)
        output = torch.bmm(attention_weights, self.W(x))

        return (output, [[attention_weights]]) if return_attention else output
