from PIL import Image
import glob
import numpy as np


def get_tiny_imagenet_data():
    
    all_images = []
    
    folders = glob.glob('./data/tiny-imagenet/train/*')
    for folder in folders:
        files = glob.glob(folder + '/images/*.JPEG')
        for file in files:
            train_image = Image.open(file)
            train_image = train_image.convert('RGB')
            all_images.append(np.asarray(train_image))
            train_image.close()
            
    print('tinyimagenet dataset size: {}'.format(len(all_images)))
            
    return all_images


def get_tiny_imagenet_test_data():
    
    with open('./data/tiny-imagenet/val/val_annotations.txt') as f:
        test_data_annotations = f.readlines()
    
    all_test_images = []
    folders = glob.glob('./data/tiny-imagenet/train/*')
    
    for folder in folders:
        category_name = folder.split('/')[-1]
        for test_data_anno in test_data_annotations:
            test_data_anno = test_data_anno.split('\t')
            if test_data_anno[1] == category_name:
                image = Image.open('./data/tiny-imagenet/val/images/' + test_data_anno[0])
                image = image.convert('RGB')
                all_test_images.append(np.asarray(image))
    
    return all_test_images