
import os
import torch
import numpy as np
from PIL import Image
import jax.numpy as jnp
from transformers import  FlaxViTForImageClassification
# Load and save model
save_dirs = {"flax": "./weights/imagenet/vit-base-patch16-224",}
flax_model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
flax_model.save_pretrained(save_dirs["flax"])
