import torch
import torchvision
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from torch.utils.data import DataLoader
from tqdm import tqdm
from pathlib import Path
import glob, os
import numpy as np

# from transformers import CLIPModel

def extract(train_dataset, train_loader, model):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)


    with torch.no_grad():
        feature_list, indices_list, label_list = [], [], []

        for idx, (inputs, targets) in tqdm(enumerate(train_loader)):
            inputs = inputs.to(device)

            # Extract the features from the inputs
            inputs = {'pixel_values':inputs}
            features = model.get_image_features(**inputs)

            # Save the features, filenames, and labels
            feature_list.append(features)
            indices_list.append(torch.range(idx * train_loader.batch_size, (idx+1) * train_loader.batch_size))
            label_list.append(targets)

    features = torch.cat(feature_list, dim = 0)
    image_idices = torch.cat(indices_list, 0)
    labels = torch.cat(label_list, 0)

    return features,image_idices,labels

data_root = '/ssd2/quickdraw/'

files = glob.glob(os.path.join(data_root, '*.npy'))

data_list = []

for file in files:
    
    data = np.load(file)
    data_list.append(data)

data = np.concatenate(data_list, axis=0)

np.save(os.path.join('/ssd2', 'quickdraw_all_data.npy'), data)