import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
    "--output-prefix",
    type=str,
    required=True,
    help="Prefix for saving your DINO feature .npy files"
)

args = parser.parse_args()
PREFIX = args.output_prefix

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load DINOv2 ViT-L/14 model
print("Loading DINOv2 ViT-L/14 model...")
dino_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14').to(device)
dino_model.eval()

# Define preprocessing transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load CIFAR-10 dataset
print("Loading CIFAR-10 dataset...")
cifar10 = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)

# Select 1000 images per class
selected_indices = []
class_counts = {i: 0 for i in range(10)}

for idx, (_, label) in enumerate(cifar10):
    if class_counts[label] < 1000:
        selected_indices.append(idx)
        class_counts[label] += 1
    if all(count >= 1000 for count in class_counts.values()):
        break

# Create a subset of CIFAR-10
subset_cifar10 = Subset(cifar10, selected_indices)
data_loader = DataLoader(subset_cifar10, batch_size=32, shuffle=False)

# Extract features using DINOv2
features, labels = [], []
print("Extracting DINOv2 features...")
with torch.no_grad():
    for images, targets in tqdm(data_loader, desc="Feature Extraction"):
        images = images.to(device)
        feats = dino_model(images).cpu().numpy()
        features.append(feats)
        labels.append(targets.numpy())

# Convert to numpy arrays
features = np.concatenate(features, axis=0)
labels = np.concatenate(labels, axis=0)

# Save to disk
print("Saving extracted features and labels...")
np.save(f"{PREFIX}_train_features.npy", features)
np.save(f"{PREFIX}_train_labels.npy", labels)

print("Feature extraction complete! Files saved:")
print("- cifar10_train_features.npy (Shape: {})".format(features.shape))
print("- cifar10_train_labels.npy (Shape: {})".format(labels.shape))
