import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor

from .attention import dot_product_attention
from .rope import RoPE2D


class FFN(nn.Module):
    def __init__(self, hidden_dim: int, ff_dim: int):
        super().__init__()

        self.linear_1 = nn.Linear(hidden_dim, ff_dim)
        self.linear_2 = nn.Linear(ff_dim, hidden_dim)

    @torch.compile
    def forward(self, x: Tensor) -> Tensor:
        x = self.linear_1(x)
        x = F.relu(x)
        x = self.linear_2(x)
        return x


class MHA(nn.Module):
    def __init__(self, hidden_dim: int, n_heads: int):
        super().__init__()

        head_dim = hidden_dim // n_heads
        assert hidden_dim % n_heads == 0
        assert head_dim & (head_dim - 1) == 0, (
            "The head dimension must be a power of 2 for flex-attention"
        )

        self.n_heads = n_heads
        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.key = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, hidden_dim)
        self.rope = RoPE2D(head_dim)

    def forward(
        self, x: Tensor, m: Tensor, s: Tensor | None, d: Tensor | None, p: Tensor | None
    ) -> Tensor:
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        q = rearrange(q, "b l (h e) -> b h l e", h=self.n_heads)
        k = rearrange(k, "b s (h e) -> b h s e", h=self.n_heads)
        v = rearrange(v, "b s (h e) -> b h s e", h=self.n_heads)

        if p is not None:
            q = self.rope(q, p)
            k = self.rope(k, p)

        x = dot_product_attention(q, k, v, m, ssmax=s, bias=d)
        x = rearrange(x, "b h l e -> b l (h e)")
        return x


class EncoderLayer(nn.Module):
    def __init__(self, hidden_dim: int, ff_dim: int, n_heads: int):
        super().__init__()

        self.mha = MHA(hidden_dim, n_heads)
        self.ffn = FFN(hidden_dim, ff_dim)
        self.norm_1 = nn.RMSNorm(hidden_dim)
        self.norm_2 = nn.RMSNorm(hidden_dim)

    def forward(
        self, x: Tensor, m: Tensor, s: Tensor | None, d: Tensor | None, p: Tensor | None
    ) -> Tensor:
        x_ = self.norm_1(x)
        x_ = self.mha(x_, m, s, d, p)
        x = x + x_

        x_ = self.norm_2(x)
        x_ = self.ffn(x_)
        x = x + x_

        return x
