import torch
import torch.nn as nn

class DescEncoder(nn.Module):
    def __init__(self, text_encoder, description_dict, device):
        super(DescEncoder, self).__init__()
        self.text_encoder = text_encoder
        self.description_dict = description_dict
        self.label_dim = text_encoder.hidden_dim * 2
        self.device = device

    def forward(self, label_idx):
        desc_tensor = self.text_encoder(self.description_dict)
        return desc_tensor[label_idx, :]

