import numpy as np

import tensorflow.examples.tutorials.mnist.input_data as mnist_input
from tensorflow.keras.datasets import cifar100, cifar10


def download_raw_dataset(raw_data_name):
    pass

def load_mnist(data_path, one_hot=True, reshape=False):
    mnist = mnist_input.read_data_sets(train_dir=data_path,
                                       one_hot=one_hot,
                                       reshape=reshape)
    train_data = (mnist.train.images, mnist.train.labels)
    test_data = (mnist.test.images, mnist.test.labels)
    return train_data, test_data

def load_cifar10(data_path, one_hot=True, reshape=False):
    train_data, test_data = cifar10.load_data()
    if one_hot:
        x_train, y_train = train_data
        x_test, y_test = test_data

        # y_train_one_hot = np.zeros((y_train.size, y_train.max()+1), dtype=int)
        # y_train_one_hot[np.arange(y_train.size), y_train] = 1
        # y_test_one_hot = np.zeros((y_test.size, y_test.max()+1), dtype=int)
        # y_test_one_hot[np.arange(y_test.size), y_test] = 1
        y_train_one_hot = get_one_hot(y_train, 10)
        y_test_one_hot = get_one_hot(y_test, 10)

        train_data = x_train, y_train_one_hot
        test_data = x_test, y_test_one_hot

    return train_data, test_data


def load_cifar100(data_path, one_hot=True, reshape=False):
    train_data, test_data = cifar100.load_data(label_mode='fine')
    if one_hot:
        x_train, y_train = train_data
        x_test, y_test = test_data

        # y_train_one_hot = np.zeros((y_train.size, y_train.max()+1), dtype=int)
        # y_train_one_hot[np.arange(y_train.size), y_train] = 1
        # y_test_one_hot = np.zeros((y_test.size, y_test.max()+1), dtype=int)
        # y_test_one_hot[np.arange(y_test.size), y_test] = 1
        y_train_one_hot = get_one_hot(y_train, 100)
        y_test_one_hot = get_one_hot(y_test, 100)
        print(y_test_one_hot.shape)

        train_data = x_train, y_train_one_hot
        test_data = x_test, y_test_one_hot
    return train_data, test_data

def get_one_hot(targets, nb_classes):
    res = np.eye(nb_classes)[np.array(targets).reshape(-1)]
    return np.squeeze(res.reshape(list(targets.shape)+[nb_classes]))