
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler, Subset
import numpy as np
from tqdm import tqdm
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize

from transformers import CLIPModel

split = 'train'

batch_size = 1024

# Define the image preprocessing pipeline
transform = Compose([
    Resize(224),
    CenterCrop(224),
    ToTensor(),
    Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load the ImageNet dataset using ImageFolder
dataset_path = f"/local/tlong/data/ImageNet_{split}"
dataset = ImageFolder(dataset_path, transform=transform)

# classes = dataset.classes

if __name__ == '__main__':
    # Load the pre-trained CLIP model
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    model = model.cuda()

    # create a dataloader for dataset
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=64)

    # Extract features for ave to one pt file
    features_list, label_list = [], []
    with torch.no_grad():
        for batch_imgs, batch_labels in tqdm(loader):

            batch_imgs = batch_imgs.cuda()

            inputs = {'pixel_values':batch_imgs}
            features = model.get_image_features(**inputs)

            features_list.append(features.detach().cpu())
            label_list.append(batch_labels.detach().cpu())

        features = torch.cat(features_list, dim=0)
        labels = torch.cat(label_list, dim=0)

    save_dict = {'features': features, 'labels': labels}
    torch.save(save_dict, f"feat/imagenet_{split}_clip_features.pt")


