
from typing import Literal

import torch
import torch.nn as nn
import torch.nn.functional as F


class SwiGLU(nn.Module):
    def __init__(self, embed_dim: int, ffn_dim_multiplier: int = 4, use_bias: bool = False):
        super().__init__()
        hidden_dim = embed_dim * ffn_dim_multiplier
        self.linear1 = nn.Linear(embed_dim, hidden_dim, bias=use_bias)
        self.linear2 = nn.Linear(hidden_dim, embed_dim, bias=use_bias)
        self.linear3 = nn.Linear(embed_dim, hidden_dim, bias=use_bias)

    def forward(self, x):
        return self.linear2(F.silu(self.linear1(x)) * self.linear3(x))


class FeedForward(nn.Module):
    def __init__(self, embed_dim: int, ffn_dim_multiplier: int = 4, use_bias: bool = False):
        super().__init__()
        hidden_dim = embed_dim * ffn_dim_multiplier
        self.linear1 = nn.Linear(embed_dim, hidden_dim, bias=use_bias)
        self.linear2 = nn.Linear(hidden_dim, embed_dim, bias=use_bias)

    def forward(self, x):
        return self.linear2(F.gelu(self.linear1(x)))
