import json
import torch
import torch.utils.data as Data
from torch import nn, optim
import numpy as np
import time
from tqdm import tqdm

class FeedForwardNet(nn.Module):
    def __init__(self, args):
        super(FeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(args.d_model, 4 * args.d_model),
            nn.ReLU(),
            nn.Linear(4 * args.d_model, args.d_model)
        )

    def forward(self, inputs):
        '''
        inputs: [batch_size, seq_len * d_model]
        '''
        output = self.fc(inputs)
        return output



class myDNN_averaged(nn.Module):
    def __init__(self, args, device):
        super(myDNN_averaged, self).__init__()

        self.device = device

        self.tgt_emb = nn.Embedding(args.vocab_size, args.d_model)

        self.W_1 = nn.Linear(args.d_model, args.d_model, bias=False)
        self.W_2 = nn.Linear(args.d_model, args.d_model, bias=False)

        self.layers = nn.ModuleList([FeedForwardNet(args) for _ in range(args.n_layers)])

        self.projection = nn.Linear(args.d_model, args.vocab_size)

        self.d_model = args.d_model
    
    def forward(self, dec_inputs):
        """
        dec_inputs: [batch_size, seq_len]
        """
        batch_size = dec_inputs.size(0)

        hidden_state = self.tgt_emb(dec_inputs)  # [batch_size, seq_len, d_model]

        # [X, a1, a2, a3]，hidden_state X*W_1 + 1/n (a1+a2+...+an)*W_2
        X_info = self.W_1(hidden_state[:, 0, :])
        anchor_info = torch.mean(self.W_2(hidden_state[:, 1:, :]), dim=1)
        hidden_state = X_info + anchor_info
        hidden_state = hidden_state.view(batch_size, self.d_model)

        # 4，ReLU
        for layer in self.layers:
            hidden_state = layer(hidden_state)

        # 
        prob = self.projection(hidden_state)

        return prob, None
    