# head.py
import torch
import torch.nn as nn

class MLPHead(nn.Module):
    def __init__(self, in_dim, out_dim, hidden=256, dropout=0.1, use_batch_norm=False, use_layer_norm=True):
        super().__init__()
        layers = [nn.Linear(in_dim, hidden)]
        

        if use_batch_norm:
            layers.append(nn.BatchNorm1d(hidden))
        elif use_layer_norm:
            layers.append(nn.LayerNorm(hidden))
        
        layers.extend([nn.GELU(), nn.Dropout(dropout)])
        layers.append(nn.Linear(hidden, out_dim))
        
        self.mlp = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.mlp(x)