from torchvision import datasets
from vendi_score import vendi, image_utils, data_utils
import vendi_score

mnist = datasets.CIFAR10("../data", train=True, download=False)
digits = [[x for x, y in mnist if y == c] for c in range(10)]

# pixel_vs = [image_utils.pixel_vendi_score(imgs) for imgs in digits]
# # The default embeddings are from the pool-2048 layer of the torchvision
# # Inception v3 model.
# inception_vs = [image_utils.embedding_vendi_score(imgs, device="cuda") for imgs in digits]
# for y, (pvs, ivs) in enumerate(zip(pixel_vs, inception_vs)): print(f"{y}\t{pvs:.02f}\t{ivs:.03f}")

all_digits = [x for x, y in mnist]
dataset_vendi = image_utils.embedding_vendi_score(all_digits, device="cuda")
print(f"dataset vendi score: {dataset_vendi:.3f}")