from torch import nn
import torch
from einops import rearrange

class GaussianAttention(nn.Module):
    def __init__(self, num_gaussians, hidden_dim, heads=8, pos_dim=2, dropout=0.0, *args, **kwargs):
        super().__init__()
        dim = hidden_dim // heads
        assert dim * heads == hidden_dim, "hidden_dim must be divisible by heads"
        
        self.num_gaussians = num_gaussians
        self.in_dim = 2 * pos_dim + 2
        self.hidden_dim = hidden_dim
        self.heads = heads
        self.dim = dim
        self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5)
        self.scale = dim ** -0.5
        self.proj_in = nn.Linear(self.in_dim * num_gaussians, hidden_dim)
        self.proj_gaussian = nn.Linear(dim, num_gaussians)
        self.proj_z = nn.Linear(num_gaussians, hidden_dim)
        
        self.to_q = nn.Linear(dim, dim)
        self.to_k = nn.Linear(dim, dim)
        self.to_v = nn.Linear(dim, dim)
        self.proj_out = nn.Linear(hidden_dim, num_gaussians)
        
        self.dropout = nn.Dropout(dropout)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, z, mu, sigma, weight):
        B, N, G = z.shape
        
        gaussian_in = torch.cat([z.unsqueeze(-1), weight.unsqueeze(-1), mu, sigma], dim=-1).reshape(B, N, -1) # (B, N, G * in_dim)
        gaussian_in = self.proj_in(gaussian_in).reshape(B, N, self.heads, self.dim).permute(0, 2, 1, 3).contiguous()  # (B, H, N, D)
        gaussian_weight = torch.softmax(self.proj_gaussian(gaussian_in).contiguous() / (self.temperature + 1e-5), dim=-1)  # (B, H, N, G)
        gaussian_norm = gaussian_weight.sum(dim=-2).clamp_min(1e-5)  # (B, H, G)
        
        z = self.proj_z(z).view(B, N, self.heads, self.dim).permute(0, 2, 1, 3).contiguous()  # (B, H, N, D)
        gaussian_token = torch.einsum("bhnd,bhng->bhgd", z, gaussian_weight)  # (B, H, G, D)
        gaussian_token = gaussian_token / gaussian_norm.unsqueeze(-1) # (B, H, G, D)

        q_gaussian = self.to_q(gaussian_token)  # (B, H, G, D)
        k_gaussian = self.to_k(gaussian_token)  # (B, H, G, D)
        v_gaussian = self.to_v(gaussian_token)  # (B, H, G, D)
        dots = torch.matmul(q_gaussian, k_gaussian.transpose(-1, -2)) * self.scale  # (B, H, G, G)
        attn = self.softmax(dots)
        attn = self.dropout(attn)
        out_gaussian = torch.matmul(attn, v_gaussian)  # (B, H, G, D)

        out_z = torch.einsum("bhgd,bhng->bhnd", out_gaussian, gaussian_weight)  # (B, H, N, D)
        out_z = out_z.permute(0, 2, 1, 3).contiguous().view(B, N, -1)  # (B, N, D*H)
        out_z = self.proj_out(out_z).squeeze(-1)  # (B, N, G)
        
        return out_z
