
import torch.nn as nn
from .sublayers import PositionwiseFeedForward,  MultiHeadAttention

class DecoderLayer(nn.Module):

    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(self, dec_input, enc_output):
        dec_output, dec_enc_attn = self.enc_attn(
            dec_input, enc_output, enc_output)
        #print(dec_enc_attn)
        dec_output = self.pos_ffn(dec_output)
        return dec_output, dec_enc_attn


