from transformers import AutoImageProcessor, AutoModel
import torch
import numpy as np
from torchvision.datasets import CIFAR10
import os
import pickle
from tqdm import tqdm


device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
model = AutoModel.from_pretrained('facebook/dinov2-base').to(device)

root = r'cifar10_dataset'    

cifar10_test = CIFAR10(root=root, download=False, train=False)


features_file = r"../cifar10_test_features.pkl"  
if os.path.exists(features_file):
    print("You have already extracted features locally.")
else:
    train_test_features = {class_name: [] for class_name in range(10)}
    batch_size = 64
    pbar = tqdm(total=len(cifar10_test))
    with torch.no_grad():
        for i in range(0, len(cifar10_test), batch_size):
            batch_images = processor(images=[cifar10_test[j][0] for j in range(i, min(i + batch_size, len(cifar10_test)))],
                                     return_tensors="pt").to(device)
            outputs1 = model(**batch_images)
            batch_features = outputs1.last_hidden_state
            batch_features = batch_features.mean(dim=1)
            for j, index in enumerate(range(i, min(i + batch_size, len(cifar10_test)))):
                class_name = cifar10_test[index][1]
                train_test_features[class_name].append(batch_features[j].cpu().detach().numpy())
                pbar.update(1)
    pbar.close()
    for key in train_test_features:
        train_test_features[key] = np.array(train_test_features[key])
    with open(features_file, "wb") as f:
        pickle.dump(train_test_features, f)