import copy
import torch.nn as nn


class Encoder(nn.Module):

    def __init__(self, encoder_block, n_layer, norm):
        super(Encoder, self).__init__()
        self.n_layer = n_layer
        self.layers = nn.ModuleList([copy.deepcopy(encoder_block) for _ in range(self.n_layer)])
        self.norm = norm


    def forward(self, src,tar, src_mask):

        if tar is not None:
            out = tar       #tar is query
            for layer in self.layers:
                out = layer( src, out, src_mask) #cross attention , src is context
        else:
            out = src       #src is query
            for layer in self.layers:
                out = layer(out, out, src_mask)  #self attention

        out = self.norm(out)
        return out
