import torch
import torch.nn as nn
from torch.nn import init

class Embedding(nn.Module):
    def __init__ (self, args, device):
        super(Embedding, self).__init__()
        self.device = device
        self.tgt_emb = nn.Embedding(args.vocab_size, args.d_model)


    def forward(self, X_input):
        tgt_emb = self.tgt_emb(X_input)

        return tgt_emb

class toymodel(nn.Module):
    def __init__(self, args, device):
        super(toymodel, self).__init__()
        self.device = device
        self.embedding = Embedding(args, device)
        self.fc1 = nn.Linear(args.d_model, args.vocab_size)

    def forward(self, X_input):
        X_input = self.embedding(X_input)
        X_input = X_input.sum(dim = 1)
        X_input = nn.GELU()(X_input)
        X_input = self.fc1(X_input)

        return X_input