import torch
from torch import nn
from ..dataset import load_embeddings


class Embedding(nn.Module):
    def __init__(self, config):
        super(Embedding, self).__init__()

        self.num_words = config.num_words
        self.embedding_size = config.embedding_size
        self.padding_idx = config.padding_idx
        self.dropout_prob = config.dropout_prob

        self.word_embedding = nn.Embedding(self.num_words, self.embedding_size, padding_idx=self.padding_idx)
        self.dropout = nn.Dropout(self.dropout_prob)

        if config.use_init_embeddings:
            print('use pretrained embeddings')
            weights = torch.Tensor(load_embeddings())
            assert self.num_words == weights.shape[0]
            assert self.embedding_size == weights.shape[1]
            self.word_embedding.weight.data = weights.clone()

    def forward(self, input_ids, use_dropout=True):
        embeddings = self.word_embedding(input_ids)
        if use_dropout:
            embeddings = self.dropout(embeddings)
        return embeddings
