import argparse
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument(
    "--output_prefix",
    type=str,
    required=True,
    help="Prefix for saving the DINO feature .npy files"
)
args = parser.parse_args()
PREFIX = args.output_prefix

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Loading DINOv2 ViT-L/14 model...")
dino_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14').to(device)
dino_model.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

print("Loading CIFAR-10 training set...")
cifar10 = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
data_loader = DataLoader(cifar10, batch_size=32, shuffle=False, num_workers=4)

features, labels = [], []
print("Extracting DINOv2 features from 50,000 CIFAR-10 images...")
with torch.no_grad():
    for images, targets in tqdm(data_loader, desc="Feature Extraction"):
        images = images.to(device)
        feats = dino_model(images).cpu().numpy()
        features.append(feats)
        labels.append(targets.numpy())

features = np.concatenate(features, axis=0)
labels = np.concatenate(labels, axis=0)

print("Saving extracted features and labels...")
np.save(f"{PREFIX}_train_features.npy", features)
np.save(f"{PREFIX}_train_labels.npy", labels)

print("Feature extraction complete! Files saved:")
print(f"- {PREFIX}_train_features.npy  (shape: {features.shape})")
print(f"- {PREFIX}_train_labels.npy    (shape: {labels.shape})")
