from torch import nn

from model.pooling import Pooling


class LabelEncoder(nn.Module):
    def __init__(self, config):
        super(LabelEncoder, self).__init__()
        self.pooling = Pooling(config.pooling)

    def forward(self, embeddings, word_mask):
        if embeddings.shape[1] != word_mask.shape[1]:
            word_mask = word_mask[:, :embeddings.shape[1]]
        output = self.pooling(embeddings, word_mask)
        return output
