"""Implementations of GeGLU"""
import torch.nn as nn
import torch.nn.functional as F

class GEGLU(nn.Module):


    def geglu(self, x):
        assert x.shape[-1] % 2 == 0
        a, b = x.chunk(2, dim=-1)
        return a * F.gelu(b)

    def forward(self, x):
        return self.geglu(x)
    

ACT2FN = {
    "relu": nn.ReLU,
    "gelu": nn.GELU,
    "tanh": nn.Tanh,
    "sigmoid": nn.Sigmoid,
    "geglu": GEGLU,
}