# models/region_head.py

import torch.nn as nn

class RegionHead(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(in_dim, num_heads=num_heads, batch_first=True)
        self.norm = nn.LayerNorm(in_dim)
        self.proj = nn.Linear(in_dim, out_dim)

    def forward(self, roi_feats):
        x = self.norm(roi_feats)
        x, _ = self.attn(x, x, x)
        return self.proj(x)