import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from model import ResNet50Extractor
from datasets.dataset import CIFAR100
import torchvision.transforms as transforms
from tqdm import tqdm
import pickle
import torchvision.transforms as transforms

def dump_features(device, path):
    model = ResNet50Extractor()
    model.load_state_dict(torch.load(path, map_location="cpu"))
    model = model.to(device)
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5070751592371323, 0.48654887331495095, 0.4409178433670343), 
                                 (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)),
    ])
    data_train = CIFAR100(root="./data", train=True, transform=transform)
    
    trainloader = DataLoader(data_train, batch_size=64, num_workers=5, pin_memory=True)
    
    targets, features = [], []
    for idx, img, target in tqdm(trainloader):
        targets.extend(target.numpy().tolist())
        
        img = img.to(device)
        feature = model(img).detach().cpu().numpy()
        features.extend([feature[i] for i in range(feature.shape[0])])
    
    save_dict = {"features": features, "label": targets, "path": path}
    file = open("features.bin", "wb")
    pickle.dump(save_dict, file)

if __name__ == "__main__":
    device = "cuda:0"
    path = ""
    dump_features(device, path)