import os
import pickle

data = {}
cls2lbl = {}
counter = 0
root = os.getcwd() + '/tiny-imagenet-200'
for cls in sorted(os.listdir(root + '/train')):
    cls2lbl[cls] = counter
    counter += 1

data['train'] = {}
data['train']['imgs'] = []
data['train']['labels'] = []
for cls in sorted(os.listdir(root + '/train')):
    for imgs in sorted(os.listdir(root + '/train/' + cls + '/images')):
        src = root + '/train/' + cls + '/images/' + imgs
        data['train']['imgs'].append(src)
        data['train']['labels'].append(cls2lbl[cls])

data['val'] = {}
data['val']['imgs'] = []
data['val']['labels'] = []
f = open(root + '/val/val_annotations.txt')
context = f.readlines()
f.close()
valimg2lbl = {}
for i in context:
    i_split = i.rstrip().split('\t')
    valimg2lbl[i_split[0]] = cls2lbl[i_split[1]]
for imgs in sorted(os.listdir(root + '/val/images')):
    src = root + '/val/images/' + imgs
    data['val']['imgs'].append(src)
    data['val']['labels'].append(valimg2lbl[imgs])

data['test'] = {}
data['test']['imgs'] = []
data['test']['labels'] = [0]*10000  # padding
for imgs in sorted(os.listdir(root + '/test/images')):
    src = root + '/test/images/' + imgs
    data['test']['imgs'].append(src)

print("Train images: {}".format(len(data['train']['labels'])))
print("Val images: {}".format(len(data['val']['labels'])))
print("Test images: {}".format(len(data['test']['imgs'])))
with open('tiny_imagenet.pickle', 'wb') as handle:
    pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)