import math
import torch
import torch.nn as nn
from torch.nn import functional as F

# https://github.com/karpathy/nanoGPT/blob/master/model.py

class AttentionModule(nn.Module):

    def __init__(self, params, embed_dim, n_heads, qk_dim, bias=True, dropout=0.0):
        super().__init__()
        assert embed_dim % n_heads == 0
        # key, query, value projections for all heads, but in a batch
        self.q_proj = nn.Linear(embed_dim, qk_dim, bias=bias)
        self.k_proj = nn.Linear(embed_dim, qk_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)    
        
        # output projection
        self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        # regularization
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        self.n_heads = n_heads
        self.n_embd = embed_dim
        self.dropout = dropout
        self.flash = params.use_flashattention
        print(f'Using {"flash" if self.flash else "manual"} attention')
    
    
    def forward(self, q, k, v):
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.q_proj(q), self.k_proj(k), self.v_proj(v)
        
        B, T_q, D_qk = q.size() # batch size, sequence length, embedding dimensionality (n_embd)
        B, T_kv, D_qk = k.size() # batch size, sequence length, embedding dimensionality (n_embd)
        B, T_kv, D_v = v.size() # batch size, sequence length, embedding dimensionality (n_embd)
        
        q = q.view(B, T_q, self.n_heads, D_qk // self.n_heads).transpose(1, 2) # (B, nh, T, hs)
        k = k.view(B, T_kv, self.n_heads, D_qk // self.n_heads).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T_kv, self.n_heads, D_v // self.n_heads).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T_q, D_v) # re-assemble all head outputs side by side

        return y