import torch.nn as nn


class NoPosEmbedding(nn.Module):
    def __init__(self, d_model, max_len=None):
        super().__init__()

    def forward(self, x):
        return x


class SinusoidalPosEmbedding(nn.Module):
    def __init__(self, d_model, max_len=None):
        super().__init__()

    def forward(self, x):
        raise NotImplementedError


def get_pos_embedding_cls(pos_embedding):
    if pos_embedding is None:
        return NoPosEmbedding
    elif pos_embedding == 'sinusoidal':
        return SinusoidalPosEmbedding