import os
import numpy as np
from PIL import Image
from torchvision import transforms
import torch
from tqdm import tqdm
import torch.nn.functional as F
import random


for mod in range(4):
    image_dir = f'../{mod}'  # Replace with your directory path
    image_filenames = []
    for root, _, files in os.walk(image_dir):
        for file in files:
            if file.endswith('.png'):
                image_filenames.append(os.path.join(root, file))
    print(f"-- ---- {mod}")
    device = torch.device('cuda')
    model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14').to(device)

    preprocess = transforms.Compose([
        transforms.Resize(
                        (224, 224), interpolation=transforms.InterpolationMode.BICUBIC
                    ),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
                    )
    ])

    batch_size = 64
    all_features = []

    for i in tqdm(range(0, len(image_filenames), batch_size), desc="Extracting Features"):
        batch_images = []
        for j in range(i, min(i + batch_size, len(image_filenames))):
            image_path = image_filenames[j]
            image = Image.open(image_path).convert('RGB')
            image = preprocess(image)
            batch_images.append(image)
        
        batch_images = torch.stack(batch_images).to(device)
        
        model.eval()
        with torch.no_grad():
            batch_features = model(batch_images)
        
        all_features.append(batch_features.cpu().numpy())

    all_features = np.concatenate(all_features, axis=0)

    np.savez(f'../{mod}.npz', dino_features=all_features)

    print(f"Extracted and normalized features shape: {all_features.shape}")