
import copy
import torch.nn as nn

from models.layer.residual_connection_layer import ResidualConnectionLayer


class EncoderBlock(nn.Module):

    def __init__(self, self_attention,  norm, dr_rate=0, residual_block=True):
        super(EncoderBlock, self).__init__()
        self.self_attention = self_attention
        self.residual1 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate)
        self.residual_block = residual_block


    def forward(self, src, tar, src_mask):
        out = src
        if self.residual_block:
           out = self.residual1(out, lambda out: self.self_attention(query=tar, key=out, value=out, mask=src_mask))
        else:
            out = self.self_attention(query=tar, key=out, value=out, mask=src_mask)

        return out
