import noise_embed.torch_utils as torch_utils
import torch.nn as nn
import torch.nn.functional as F
from noise_embed.torch_utils import Tensor

def noisy_forward(self, input: Tensor) -> Tensor:
    embeddings = F.embedding(
        input,
        self.weight,
        self.padding_idx,
        self.max_norm,
        self.norm_type,
        self.scale_grad_by_freq,
        self.sparse,
    )

    noisy_ids = {
        128000,
        128001,
        128008,
        128009,
        151645,
        151643
    }
    noise_std = 0.001

    if noisy_ids:
        mask = torch_utils.isin(input, torch_utils.tensor(list(noisy_ids), device=input.device))
        if mask.any():
            noise = torch_utils.randn_like(embeddings) * noise_std
            embeddings = torch_utils.where(mask.unsqueeze(-1), embeddings + noise, embeddings)

    return embeddings

nn.Embedding.forward = noisy_forward