import math

import torch
import torch.nn as nn
import torch.nn.functional as F


class Attention(nn.Module):
    def __init__(self, config, is_causal=False):
        super().__init__()
        # output projection
        self.is_causal = is_causal
        self.c_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        # https://arxiv.org/abs/2205.14135
        self.flash = True # hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer(
                "bias",
                torch.tril(torch.ones(config.block_size, config.block_size)).view(
                    1,
                    1,
                    config.block_size,
                    config.block_size
                )
            )

    def split_heads(self, q, k, v):
        # Heads splitting is not necessary for MP_attention
        # q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
        # k = k.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
        # v = v.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
        return q, k, v

    def dot_product_attention(self, k, q, v):
        attn = None
        if self.flash:
            y = F.scaled_dot_product_attention(
                q,
                k,
                v,
                attn_mask=None,
                dropout_p=self.config.dropout if self.training else 0.0,
                is_causal=self.is_causal
            )
        else:
            attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))

            if self.is_causal:
                attn = attn.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))

            attn = F.softmax(attn, dim=-1)
            attn = self.attn_dropout(attn)
            # (B, num_heads, T, T) x (B, num_heads, T, head_size) -> (B, num_heads, T, head_size)
            y = attn @ v
        y = y.contiguous()
        return attn, y

    def attention(self, q, k, v):
        q, k, v = self.split_heads(q, k, v)
        attn, y = self.dot_product_attention(k, q, v)
        # output projection
        y = self.resid_dropout(self.c_proj(y))
        if self.config.return_attention:
            return y, attn
        else:
            return y


class CrossAttention(Attention):
    def __init__(self, config, is_causal=False):
        super().__init__(config, is_causal)
        self.config = config
        self.w_q = nn.Linear(config.d_model, config.d_model, config.bias)
        self.w_k = nn.Linear(config.d_model, config.d_model, config.bias)
        self.w_v = nn.Linear(config.d_model, config.d_model, config.bias)

    def forward(self, q, k, v):
        q = self.w_q(q)
        k = self.w_k(k)
        v = self.w_v(v)
        return self.attention(q, k, v)


class SelfAttention(Attention):
    def __init__(self, config, is_causal=False):
        super().__init__(config, is_causal)
        self.config = config
        self.c_attn = nn.Linear(config.d_model, config.d_model * 3, config.bias)

    def forward(self, x):
        # project all the queries, keys and values
        qkv = self.c_attn(x)
        # config.d_model = C
        q, k, v = qkv.split(self.config.d_model, dim=-1)
        return self.attention(q, k, v)
