import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class _MHAttention(nn.Module): 
    def __init__(self, dim: int, num_heads: int = 4, dropout: float = 0.1, hard = True):
        super().__init__() 
        assert dim % num_heads == 0 
        self.dim = dim 
        self.h = num_heads 
        self.dh = dim // num_heads 
        self.scale = 1.0 / math.sqrt(self.dh) 
        self.dropout = float(dropout) 
        self.hard = hard 
    def _split(self, x: torch.Tensor):
        L = x.size(0) 
        return x.view(L, self.h, self.dh).transpose(0, 1).contiguous()
    def _merge(self, x: torch.Tensor):
        return x.transpose(0, 1).contiguous().view(-1, self.dim)
    def attn(self, Q, K, V):
        qk = torch.einsum("hqd,hkd->hqk", Q, K) * self.scale 
        A = F.softmax(qk, dim=-1)
        A = F.dropout(A, p=self.dropout, training=self.training)
        if self.hard:
            index = A.argmax(dim=-1) 
            A_hard = torch.zeros_like(A).scatter_(-1, index.unsqueeze(-1), 1.0)
            A = (A_hard - A).detach() + A
        out = torch.einsum("hqk,hkd->hqd", A, V)
        return out, A

class Node2BlockCrossAttn(nn.Module):
    def __init__(self, dim: int, num_heads: int = 4,
                 dropout: float = 0.0, use_ln: bool = True, hard = True):
        super().__init__()
        self.ln_x = nn.LayerNorm(dim) if use_ln else nn.Identity()
        self.ln_b = nn.LayerNorm(dim) if use_ln else nn.Identity()
        self.q = nn.Linear(dim, dim, bias=True)
        self.k = nn.Linear(dim, dim, bias=True)
        self.v = nn.Linear(dim, dim, bias=True)
        self.o = nn.Linear(dim, dim, bias=True)
        self.mha = _MHAttention(dim, num_heads=num_heads, dropout=dropout, hard = hard)
    def forward(self, x, B, return_attn=False):
        x_ = self.ln_x(x) 
        B_ = self.ln_b(B) 
        Q = self.mha._split(self.q(B_)) 
        K = self.mha._split(self.k(x_)) 
        V = self.mha._split(self.v(x_))   
        out, A = self.mha.attn(Q, K, V)  
        out = self.mha._merge(out)  
        out = self.o(out)  
        if return_attn: 
            return out, A 
        return out 

class Block2NodeCrossAttn(nn.Module):
    def __init__(self, dim: int, num_heads: int = 4,
                 dropout: float = 0.0, use_ln: bool = True, hard = True):
        super().__init__()
        self.ln_x = nn.LayerNorm(dim) if use_ln else nn.Identity()
        self.ln_z = nn.LayerNorm(dim) if use_ln else nn.Identity()
        self.q = nn.Linear(dim, dim, bias=True)
        self.k = nn.Linear(dim, dim, bias=True)
        self.v = nn.Linear(dim, dim, bias=True)
        self.o = nn.Linear(dim, dim, bias=True)
        self.mha = _MHAttention(dim, num_heads=num_heads, dropout=dropout, hard = hard)
    def forward(self, x, B, return_attn=False):
        x_ = self.ln_x(x)
        Z_ = self.ln_z(B)
        Q = self.mha._split(self.q(x_))
        K = self.mha._split(self.k(Z_))
        V = self.mha._split(self.v(Z_))
        out, A = self.mha.attn(Q, K, V)
        out = self.mha._merge(out)
        out = self.o(out)
        if return_attn:
            return out, A
        return out

