import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F

from models.config import MiMoEConfig


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    
    def forward(self, x, **kwargs) -> Tensor:
        return self.fn(self.norm(x), **kwargs)
    
    
class Attention(nn.Module):
    def __init__(
        self,
        config: MiMoEConfig,
    ):
        super().__init__()
        self.hidden_dim = config.hidden_dim
        self.n_head = config.num_heads
        self.head_dim = self.hidden_dim // self.n_head
        self.scale = self.head_dim ** -0.5
        self.dropout_rate = config.dropout_rate
        
        self.to_qkv = nn.Linear(self.hidden_dim, self.hidden_dim * 3, bias=False)
    
    def forward(
        self, 
        x: Tensor
    ) -> Tensor:
        B, T, D = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.reshape(B, T, self.n_head, self.head_dim).transpose(1, 2), qkv) # (B, H, T, D_head)
        
        dots = (q @ k.transpose(-1, -2)) * self.scale  # (B, H, T, T)
        attn = F.softmax(dots, dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, T, D)
        return out


class RoPEAttention(Attention):
    def __init__(
        self, 
        config: MiMoEConfig
    ):
        super().__init__(config)
        self.max_seq_len = config.max_seq_len
        freqs_cis = self.build_freqs_cis()
        self.register_buffer("freqs_cis", freqs_cis, persistent=False)

    def forward(
        self, 
        x: Tensor
    ) -> Tensor:
        B, T, D = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.reshape(B, T, self.n_head, self.head_dim).transpose(1, 2), qkv) # (B, H, T, D_head)
        q = self.apply_rotary_emb(q, self.freqs_cis[:T]) # (B, H, T, D_head)
        k = self.apply_rotary_emb(k, self.freqs_cis[:T]) # (B, H, T, D_head)
        
        attn = (q @ k.transpose(-1, -2)) * self.scale # (B, H, T, T)
        attn = F.softmax(attn, dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, T, D)
        return out
    
    def build_freqs_cis(
        self,
        theta: float = 10000.0
    ):
        head_dim = self.hidden_dim // self.n_head
        half_dim = head_dim // 2
        freq_seq = torch.arange(0, half_dim, dtype=torch.float32)
        inv_freq = 1.0 / (theta ** (freq_seq / half_dim))
        
        t = torch.arange(self.max_seq_len + 1, dtype=torch.float32)
        freqs = torch.einsum("i,j->ij", t, inv_freq)
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis
    
    def apply_rotary_emb(
        self, 
        x: Tensor, 
        freqs_cis: Tensor
    ) -> Tensor:
        x_ = x.float().reshape(*x.shape[:-1], -1, 2) # (B, H, T, D_head // 2, 2)
        x_ = torch.view_as_complex(x_) # (B, H, T, D_head // 2)
        
        freqs_cis = freqs_cis[None, None, :, :] # (1, 1, T, D_head // 2)
        x_out = torch.view_as_real(x_ * freqs_cis).flatten(3)
        return x_out.type_as(x)