from torchvision.models import vit_b_16, vit_l_16, vit_h_14
import torch

save_path = "./ckpt/vit_base.pth"
model = vit_b_16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), save_path)

# save_path = "./ckpt/vit_large.pth"
# model = vit_l_16(weights='IMAGENET1K_V1')
# torch.save(model.state_dict(), save_path)

# save_path = "./ckpt/vit_huge.pth"
# model = vit_h_14(weights='IMAGENET1K_V1')
# torch.save(model.state_dict(), save_path)
