import torch
from torch.utils.data import DataLoader
from tinyModels import ResNetExtractor
import torchvision.transforms as transforms
from datasets.folder import ImageFolder
from tqdm import tqdm
import pickle
import torchvision.transforms as transforms

TRAIN_MEAN = [0.4802, 0.4481, 0.3975]
TRAIN_STD = [0.2302, 0.2265, 0.2262]

def dump_features(device, path):
    model = ResNetExtractor()
    model.load_state_dict(torch.load(path, map_location="cpu"))
    model.to(device)
    
    transform = transforms.Compose([
        transforms.Resize(int(64/0.875)),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        transforms.Normalize(TRAIN_MEAN, TRAIN_STD),
    ])
    data_train = ImageFolder(root="train", transform=transform)
    trainloader = DataLoader(data_train, batch_size=1, num_workers=5, pin_memory=True)
    
    targets, features = [], []
    for idx, img, target in tqdm(trainloader):
        targets.extend(target.numpy().tolist())
        
        img = img.to(device)
        feature = model(img).detach().cpu().numpy()
        features.extend([feature[i] for i in range(feature.shape[0])])
        
    save_dict = {"features": features, "label": targets, "path": path}
    file = open("features.bin", "wb")
    pickle.dump(save_dict, file)


if __name__ == "__main__":
    device = "cuda:7"
    path = ""
    dump_features(device, path)