import torch
import torchvision
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from torch.utils.data import DataLoader
from tqdm import tqdm

from transformers import CLIPModel

def extract(train_dataset, train_loader, model):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)


    with torch.no_grad():
        feature_list, indices_list, label_list = [], [], []

        for idx, (inputs, targets) in tqdm(enumerate(train_loader)):
            inputs = inputs.to(device)

            # Extract the features from the inputs
            inputs = {'pixel_values':inputs}
            features = model.get_image_features(**inputs)

            # Save the features, filenames, and labels
            feature_list.append(features)
            indices_list.append(torch.range(idx * train_loader.batch_size, (idx+1) * train_loader.batch_size))
            label_list.append(targets)

    features = torch.cat(feature_list, dim = 0)
    image_idices = torch.cat(indices_list, 0)
    labels = torch.cat(label_list, 0)

    return features,image_idices,labels



# Step 1: Load and preprocess the CIFAR-100 dataset
CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

transform = Compose([
    Resize(224),
    # CenterCrop(224),
    ToTensor(),
    Normalize(mean=CIFAR100_TRAIN_MEAN, std=CIFAR100_TRAIN_STD)
])

train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=False, num_workers=4)

test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=4)

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
model = model.cuda()

train_features, train_image_idices, train_labels = extract(train_dataset, train_loader, model)
test_features, test_image_idices, test_labels = extract(test_dataset, test_loader, model)

print("Feature extraction completed.")

# use torch to save the extracted features, filenames, and labels to disk

save_dict = {'train_features': train_features, 'train_image_idices': train_image_idices, 'train_labels': train_labels, 'test_features': test_features, 'test_image_idices': test_image_idices, 'test_labels': test_labels}
torch.save(save_dict, "cifar100-clip-features.pt")

# Step 5: Load the extracted features, filenames, and labels from disk
save_dict = torch.load("cifar100-clip-features.pt")
train_features = save_dict['train_features']
train_image_idices = save_dict['train_image_idices']
train_labels = save_dict['train_labels']
test_features = save_dict['test_features']
test_image_idices = save_dict['test_image_idices']
test_labels = save_dict['test_labels']