import torch
import torch.nn as nn
import torch.nn.functional as F

class Vec_Classifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(Vec_Classifier, self).__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        vectors = torch.eye(num_classes, input_dim)
        self.class_vectors = nn.Parameter(vectors)
    
    
if __name__ == "__main__":
    classifier = Vec_Classifier(20, 10)
    print(classifier.class_vectors.shape)