import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


class ReGLU(nn.Module):
    """
    References:
        https://github.com/pfnet-research/deep-table/blob/237c8be8a405349ce6ab78075234c60d9bfe60b7/deep_table/nn/layers/activation.py

        Shazeer et al., "GLU Variants Improve Transformer," 2020.
        https://arxiv.org/abs/2002.05202
    """

    def reglu(self, x: Tensor) -> Tensor:
        assert x.shape[-1] % 2 == 0
        x, gates = x.chunk(2, dim=-1)
        return x * F.relu(gates)

    def forward(self, x: Tensor) -> Tensor:
        return self.reglu(x)


class GEGLU(nn.Module):
    """
    References:
        https://github.com/pfnet-research/deep-table/blob/237c8be8a405349ce6ab78075234c60d9bfe60b7/deep_table/nn/layers/activation.py

        Shazeer et al., "GLU Variants Improve Transformer," 2020.
        https://arxiv.org/abs/2002.05202
    """

    def geglu(self, x: Tensor) -> Tensor:
        assert x.shape[-1] % 2 == 0
        x, gates = x.chunk(2, dim=-1)
        return x * F.gelu(gates)

    def forward(self, x: Tensor) -> Tensor:
        return self.geglu(x)
