"""Row-wise feedforward block."""

import torch.nn as nn


class rFF(nn.Module):
    """Row-wise 2-layer feedforward (standard Transformer FFN)."""

    def __init__(self, dim: int, expansion: int = 1, dropout: float = 0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * expansion),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * expansion, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

