# coding=utf-8
import os
import sys
import tarfile
import pickle
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image as processor


URLS = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']

DATASET_FILE = 'cifar-10-python.tar.gz'
FOLDER_NAME = 'cifar-10-batches-py'
TRAINING_SET_DATA_FILES = ["data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4", "data_batch_5"]
TESTING_SET_DATA_FILES = ["test_batch"]


# define image objects
class Image:
    def __init__(self, data, label, file_name, batch_file):
        self.data = data
        self.label = label
        self.file_name = file_name
        self.batch_file = batch_file

    def get_data(self):
        return self.data

    def get_label(self):
        return self.label

    def get_file_name(self):
        return self.file_name

    def get_batch_file(self):
        return self.batch_file

    def print_details(self):
        data = self.get_data()
        print("Image:", self.get_file_name())
        print("Class Label:", self.get_label())
        print("From batch file:", self.get_batch_file())
        print("Shape:", np.shape(data))
        print("Size:", np.size(data))
        return

    def show_image(self):
        plt.figure(self.get_file_name())
        plt.imshow(self.get_data())
        plt.axis('off')
        plt.title("%s - %s" %(self.get_file_name(), self.get_label()))
        plt.show()
        return

    def save_image(self, save_dir):
        processor.fromarray(self.get_data()).save(os.path.join(save_dir, self.get_file_name()))


# download all the dataset files to a specific directory
def download(dir):
    try:
        import wget
        for index in range(len(URLS)):
            print('\nDownloading Dataset File %d / %d...' %(index, len(URLS)))
            wget.download(URLS[index], out=dir)
    except ImportError:
        import urllib.request
        for index in range(len(URLS)):
            print('\nDownloading Dataset Files %d / %d...' % (index, len(URLS)))
            urllib.request.urlretrieve(URLS[index], os.path.join(dir, URLS[index].split('/')[-1]))
    print('\nAll Downloaded.')


# load data batch file and store the image information as a list of image objects
def load_data_file(file, filename):
    print('  %s' %filename)
    # read and unpickle data batch file
    data_dict = pickle.load(file, encoding='bytes')
    data = data_dict[b'data']
    labels = data_dict[b'labels']
    images_names = data_dict[b'filenames']

    images_list = []
    for row in range(np.size(data, 0)):
        reshaped_data = data[row].reshape(3, 32, 32).transpose([1, 2, 0])
        image = Image(reshaped_data, labels[row], images_names[row].decode(), filename)
        #image.print_details()
        #print("------------------------------------------------------------------------------------------------------")
        #print("\rNumber of image objects: %d" %(row + 1), end="", flush=True)
        images_list.append(image)

    return images_list


# load all the data batch files in the given list and return a list of all the image objects form these files
def load_data(tar_file, file_list):
    images_list = []
    for file_name in file_list:
        file = tar_file.extractfile(tar_file.getmember(file_name))
        images_list += load_data_file(file, file_name)
    print('...OK!\n')
    return images_list


# check whether the directories already exist,
# if yes, return True; if not, create the directories and return False
def check_path(path):
    if os.path.exists(path):
        return True
    else:
        os.mkdir(path)
        return False

# save all the images to the corresponding directories
def save_images(dir, images):
    total_images = len(images)
    for i in range(total_images):
        if sys.version_info.major >= 3:
            print("\rSaving image %d / %d" % (i + 1, total_images), end="", flush=True)
        image = images[i]
        label = str(image.get_label())
        save_dir = os.path.join(dir, label)
        check_path(save_dir)
        image.save_image(save_dir)
    print('...OK!')


# main body
def main(origin_dir, target_dir):
    if check_path(origin_dir) == False:
        download(origin_dir)

    check_path(target_dir)
    train_dir = os.path.join(target_dir, 'train')
    check_path(train_dir)
    test_dir = os.path.join(target_dir, 'test')
    check_path(test_dir)

    # decompress the .tar.gz dataset file
    tar = tarfile.open(os.path.join(origin_dir, DATASET_FILE), encoding='utf-8')

    # load all training data and store them as a list of image objects
    print('\nLoading Training Data:')
    training_data = load_data(tar, [FOLDER_NAME + '/' + i for i in TRAINING_SET_DATA_FILES])
    save_images(train_dir, training_data)

    # load all testing data and store them as a list of image objects
    print('\nLoading Testing Data:')
    testing_data = load_data(tar, [FOLDER_NAME + '/' + i for i in TESTING_SET_DATA_FILES])
    save_images(test_dir, testing_data)

    print('\nFinished!')

########################################################################################################################


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='Extract all the data from CIFAR-10 dataset into image folders')
    parser.add_argument('--origin_dir', type=str, default='./_DOWNLOADS/CIFAR-10',
                        help='Directory storing / to store the original dataset files')
    parser.add_argument('--target_dir', type=str, default='./CIFAR-10_Dataset',
                        help='Directory to store the output image data')
    args = parser.parse_args()

    main(args.origin_dir, args.target_dir)