import torch
import torch.nn as nn
import math

class Embedding(nn.Embedding):
    def __init__(self, num_embeddings, embedding_dim, one_hot=False, freeze=False, *args, **kwargs):
        super().__init__(num_embeddings, embedding_dim, *args, **kwargs)
        self.one_hot = one_hot
        self.freeze = freeze

        if self.one_hot:
            # When one-hot is true, there are no trainable weights
            self.weight = None
        elif self.freeze:
            # Freeze the embedding weights if required
            self.weight.requires_grad = False

    def forward(self, input):
        if self.one_hot:
            # Manually create one-hot using scatter_ instead of F.one_hot
            n_elems = input.numel()

            # The output shape is input.shape + [num_embeddings]
            out_shape = list(input.shape) + [self.num_embeddings]
            result = input.new_zeros(out_shape, dtype=torch.float32)

            # Use reshape (or contiguous().view) to handle non-contiguous cases
            input_flat = input.reshape(n_elems, 1)
            result_flat = result.view(n_elems, self.num_embeddings)

            # Scatter 1.0 at the index from input
            result_flat.scatter_(1, input_flat, 1.0)

            return result
        else:
            embeddings = super().forward(input)
            return embeddings * math.sqrt(self.embedding_dim)
