import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttention(nn.Module):
    def __init__(self, hidden_dim: int, policy_hidden_dim: int, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.policy_hidden_dim = policy_hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(policy_hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(policy_hidden_dim, hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(hidden_dim)
        
    def forward(self, query: torch.Tensor, key_value: torch.Tensor) -> torch.Tensor:
        batch_size = query.size(0)
        
        query = query.unsqueeze(1)
        
        q = self.q_proj(query)  # [B, 1, D]
        k = self.k_proj(key_value)  # [B, S, D]
        v = self.v_proj(key_value)  # [B, S, D]
        
        q = q.view(batch_size, 1, self.num_heads, self.head_dim).transpose(1, 2)  # [B, H, 1, D/H]
        k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # [B, H, S, D/H]
        v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # [B, H, S, D/H]
        
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale  # [B, H, 1, S]

        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        context = torch.matmul(attn_weights, v)  # [B, H, 1, D/H]
        
        context = context.transpose(1, 2).contiguous().view(batch_size, 1, self.hidden_dim)
        
        output = self.out_proj(context).squeeze(1)  # [B, D]
        
        output = query.squeeze(1) + output 
        output = self.layer_norm(output)   
        
        return output
