import torch
import torch.nn as nn
from torch.nn import init

class AvgLabel(nn.Module):
    def __init__(self, features, device, options):
        super(AvgLabel, self).__init__()

        self.features = features
        self.label_dim = 300
        self.device = device

    def forward(self, label_idx):
        label_idx = torch.tensor(label_idx).to(self.device)
        return self.features(label_idx)