import torch
import clip
import os
import random
from torchvision.datasets import ImageNet
from torchvision.transforms import transforms
from tqdm import tqdm
from PIL import Image
import numpy as np
from torch.utils.data import Dataset


def get_imagenet_dataset(input_path,args,preprocess):
    preprocess_transform = transforms.Compose([
        transforms.Resize(224), 
        transforms.CenterCrop(224),  
        transforms.ToTensor(),  
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  
    ])
    if args.VLM_Base == "RN50":
        imagenet_val_dataset = ImageNet(root=input_path, split='val', transform=preprocess)
        imagenet_classes = imagenet_val_dataset.classes
    if args.VLM_Base == "ViT-B/32":
        imagenet_val_dataset = ImageNet(root=input_path, split='val', transform=preprocess_transform)
        imagenet_classes = imagenet_val_dataset.classes

    return imagenet_val_dataset, imagenet_classes

