import torch
import os


selected_datasets = ['birdsnap', 'caltech101', 'eurosat',
                     'fgvcaircraft', 'flowers102', 'food101',
                     'gtsrb', 'imagenet', 'oxfordpets', 'resisc45',
                     'stanfordcars', 'pascalvoc2007', 'sun397', 'ucf101']
print(len(selected_datasets))

root_path = '/notebooks/classwise_collection/class_embeddings/imgprotos/'
save_dict = {}
t = []
for s in selected_datasets:
    a = torch.load(os.path.join(root_path, f'{s}_128_shot.pt'))
    t.extend(list(a.values()))
    save_dict.update(a)


print(len(t), len(save_dict))


torch.save(save_dict, os.path.join(root_path, 'combined_14datasets_imgprotos_128_shot.pt'))