from vit_pytorch import ViT as ViT_pytorch
from vit_pytorch import SimpleViT as SimpleViT_pytorch

def ViT(num_classes: int = 1000, image_size: int = 32):
    return ViT_pytorch(image_size=image_size, patch_size=4, num_classes=num_classes, dim=192, depth=6, heads=6, mlp_dim=384)

def SimpleViT(num_classes: int = 1000, image_size: int = 32):
    return SimpleViT_pytorch(image_size=image_size, patch_size=4, num_classes=num_classes, dim=192, depth=6, heads=6, mlp_dim=384)
