import gzip
import hashlib
import os
import pickle
import struct
import tarfile
import urllib.request
from typing import Literal

import numpy
import scipy.io as sio
from tqdm import tqdm


def verify_file(file_path, digest):
    if os.path.isfile(file_path):
        with open(file_path, 'rb') as file:
            hash = hashlib.sha256()
            for chunk in iter(lambda: file.read(4096), b''):
                hash.update(chunk)
        if (hash.hexdigest() == digest):
            return True
        else:
            return False
    else:
        return False


def download_file(file_path, url):
    with urllib.request.urlopen(url) as response:
        print(f"Downloading {os.path.basename(file_path)} from the following URL: {url}.")
        with tqdm(total=int(response.headers.get("Content-Length", 0)), unit='B', unit_scale=True) as progress_bar:
            with open(file_path, 'wb') as file:
                while True:
                    data = response.read(1024)
                    if not data:
                        break
                    file.write(data)
                    progress_bar.update(len(data))


def MNIST(path: str):

    if (not os.path.exists(path)):
        os.makedirs(path)

    if (not verify_file(os.path.join(path, "train-images-idx3-ubyte.gz"), "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609")):
        download_file(os.path.join(path, "train-images-idx3-ubyte.gz"), "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")
    if (not verify_file(os.path.join(path, "train-labels-idx1-ubyte.gz"), "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c")):
        download_file(os.path.join(path, "train-labels-idx1-ubyte.gz"), "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")
    if (not verify_file(os.path.join(path, "t10k-images-idx3-ubyte.gz"), "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6")):
        download_file(os.path.join(path, "t10k-images-idx3-ubyte.gz"), "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz")
    if (not verify_file(os.path.join(path, "t10k-labels-idx1-ubyte.gz"), "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6")):
        download_file(os.path.join(path, "t10k-labels-idx1-ubyte.gz"), "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz")

    with gzip.open(os.path.join(path, "train-images-idx3-ubyte.gz")) as data_file:
        magic, number, row, column = struct.unpack(">iiii", data_file.read(16))
        data_train = numpy.frombuffer(data_file.read(), numpy.uint8).reshape(number, 1, row, column)

    with gzip.open(os.path.join(path, "train-labels-idx1-ubyte.gz")) as data_file:
        magic, number = struct.unpack(">ii", data_file.read(8))
        label_train = numpy.frombuffer(data_file.read(), numpy.uint8)

    with gzip.open(os.path.join(path, "t10k-images-idx3-ubyte.gz")) as data_file:
        magic, number, row, column = struct.unpack(">iiii", data_file.read(16))
        data_test = numpy.frombuffer(data_file.read(), numpy.uint8).reshape(number, 1, row, column)

    with gzip.open(os.path.join(path, "t10k-labels-idx1-ubyte.gz")) as data_file:
        magic, number = struct.unpack(">ii", data_file.read(8))
        label_test = numpy.frombuffer(data_file.read(), numpy.uint8)

    return data_train, label_train, data_test, label_test


def CIFAR(path: str, classes: Literal[10, 100]):

    if (not os.path.exists(path)):
        os.makedirs(path)

    if (classes == 10):

        foldername = "cifar-10-batches-py/"
        filename_train = ["data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4", "data_batch_5"]
        filename_test = "test_batch"
        key = b"labels"
        files_exist = True

        for filename in [*filename_train, filename_test]:
            if not os.path.exists(os.path.join(path, foldername, filename)):
                files_exist = False

        if (not files_exist):
            if (not verify_file(os.path.join(path, "cifar-10-python.tar.gz"), "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce")):
                download_file(os.path.join(path, "cifar-10-python.tar.gz"), "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz")
            with tarfile.open(os.path.join(path, "cifar-10-python.tar.gz"), 'r:gz') as tar:
                tar.extractall(os.path.join(path))

    else:

        foldername = "cifar-100-python/"
        filename_train = ["train"]
        filename_test = "test"
        key = b"fine_labels"
        files_exist = True

        for filename in [*filename_train, filename_test]:
            if not os.path.exists(os.path.join(path, foldername, filename)):
                files_exist = False

        if (not files_exist):
            if (not verify_file(os.path.join(path, "cifar-100-python.tar.gz"), "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7")):
                download_file(os.path.join(path, "cifar-100-python.tar.gz"), "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz")
            with tarfile.open(os.path.join(path, "cifar-100-python.tar.gz"), 'r:gz') as tar:
                tar.extractall(os.path.join(path))

    data = []
    labels = []
    for filename in filename_train:
        with open(os.path.join(path, foldername, filename), "rb") as data_file:
            dict = pickle.load(data_file, encoding="bytes")
            data.append(dict[b"data"].reshape((-1, 3, 32, 32)))
            labels.append(dict[key])
    data_train = numpy.vstack(data)
    label_train = numpy.concatenate(labels)

    with open(os.path.join(path, foldername, filename_test), "rb") as data_file:
        dict = pickle.load(data_file, encoding="bytes")
        data_test = dict[b"data"].reshape((-1, 3, 32, 32))
        label_test = numpy.array(dict[key])

    return data_train, label_train, data_test, label_test


def SVHN(path: str, extra: bool = False):

    if (not os.path.exists(path)):
        os.makedirs(path)

    if (not verify_file(os.path.join(path, "train_32x32.mat"), "435e94d69a87fde4fd4d7f3dd208dfc32cb6ae8af2240d066de1df7508d083b8")):
        download_file(os.path.join(path, "train_32x32.mat"), "http://ufldl.stanford.edu/housenumbers/train_32x32.mat")
    if (not verify_file(os.path.join(path, "test_32x32.mat"), "cdce80dfb2a2c4c6160906d0bd7c68ec5a99d7ca4831afa54f09182025b6a75b")):
        download_file(os.path.join(path, "test_32x32.mat"), "http://ufldl.stanford.edu/housenumbers/test_32x32.mat")

    dict = sio.loadmat(os.path.join(path, 'train_32x32.mat'))
    data_train = numpy.transpose(dict['X'], (3, 2, 0, 1))
    label_train = dict['y'].squeeze()

    dict = sio.loadmat(os.path.join(path, 'test_32x32.mat'))
    data_test = numpy.transpose(dict['X'], (3, 2, 0, 1))
    label_test = dict['y'].squeeze()

    if (extra == True):
        if (not verify_file(os.path.join(path, "extra_32x32.mat"), "a133a4beb38a00fcdda90c9489e0c04f900b660ce8a316a5e854838379a71eb3")):
            download_file(os.path.join(path, "extra_32x32.mat"), "http://ufldl.stanford.edu/housenumbers/extra_32x32.mat")
        dict = sio.loadmat(os.path.join(path, 'extra_32x32.mat'))
        data_train = numpy.concatenate((data_train, numpy.transpose(dict['X'], (3, 2, 0, 1))), axis=0)
        label_train = numpy.concatenate((label_train, dict['y'].squeeze()), axis=0)

    return data_train, label_train, data_test, label_test