from spaghettini import quick_register

from torchvision import transforms
import torch


@quick_register
def one_hot_encode_transform(num_classes=10):
    return transforms.Compose([transforms.Lambda(one_hot_encoder(num_classes))])


def one_hot_encoder(num_classes):
    def one_hot_encode_with_fixed_classes(single_label):
        one_hot_label = torch.zeros((num_classes,)).float()
        one_hot_label[single_label] = 1.0

        return one_hot_label

    return one_hot_encode_with_fixed_classes
