import torch.nn as nn

# Base embedding classes for encoding models
class BaseEmbedding(nn.Module):
    """Base class for all embedding layers"""
    def __init__(self, config):
        super().__init__()
        self.config = config

    def forward(self, input_ids=None, input_embeds=None, token_types=None):
        raise NotImplementedError

    def decode(self, hidden_states):
        """Inverse transform of embeddings (if needed)"""
        return hidden_states

class StandardEmbedding(BaseEmbedding):
    """Standard embedding using nn.Embedding"""
    def __init__(self, config):
        super().__init__(config)
        self.embedding = nn.Embedding(config.vocab_size, config.d_model)

    def forward(self, input_ids=None, input_embeds=None, token_types=None):
        if input_embeds is not None:
            return input_embeds
        return self.embedding(input_ids)