import numpy as np
import tensorflow as tf
import os
import gzip
from utils.download_tools import down_from_url


MNIST_DIR = "resources/mnist/"
MNIST_FILES = {
    "train_imgs": "train-images-idx3-ubyte.gz",
    "train_labels": "train-labels-idx1-ubyte.gz",
    "test_imgs": "t10k-images-idx3-ubyte.gz",
    "test_labels": "t10k-labels-idx1-ubyte.gz"
}
MNIST_URLS = {
    "train_imgs": "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
    "train_labels": "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
    "test_imgs": "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
    "test_labels": "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"
}
IMAGE_SIZE = 28
NUM_TRAIN_IMAGES = 60000
NUM_TEST_IMAGES = 10000


def create_mnist():
    # Download the MNIST files to "resources/mnist/"
    os.makedirs(MNIST_DIR, exist_ok=True)
    for key in MNIST_FILES.keys():
        down_from_url(MNIST_URLS[key], MNIST_DIR + MNIST_FILES[key])

    # A function for reading MNIST images from gzip files
    def read_mnist_images(path, num_images):
        with gzip.open(path, 'r') as f:
            f.read(16)
            buf = f.read(num_images * IMAGE_SIZE * IMAGE_SIZE)
        data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
        data = data.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE, 1)
        return data

    train_imgs = read_mnist_images("resources/mnist/train-images-idx3-ubyte.gz", NUM_TRAIN_IMAGES) / 255.
    test_imgs = read_mnist_images("resources/mnist/t10k-images-idx3-ubyte.gz", NUM_TEST_IMAGES) / 255.

    # A function for reading MNIST labels from gzip files
    def read_mnist_labels(path):
        with gzip.open(path, 'r') as f:
            f.read(8)
            buf = f.read()
        labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
        return labels

    train_labels = read_mnist_labels("resources/mnist/train-labels-idx1-ubyte.gz")
    test_labels = read_mnist_labels("resources/mnist/t10k-labels-idx1-ubyte.gz")
    train_labels = tf.one_hot(train_labels, depth=10).numpy()
    test_labels = tf.one_hot(test_labels, depth=10).numpy()

    return {
        "train_imgs": train_imgs,
        "test_imgs": test_imgs,
        "train_labels": train_labels,
        "test_labels": test_labels
    }
