from torchvision import datasets
from sklearn.preprocessing import LabelBinarizer
from skimage import color
import numpy as np
import pickle

download_svhn = True


def normalize(x, max_value):
    """ If x takes its values between 0 and max_value, normalize it between -1 and 1"""
    return (x / float(max_value)) * 2 - 1


def transform_svhn(X):
    X = np.transpose(X, (0, 2, 3, 1))
    X = np.array([color.rgb2gray(im) for im in X])
    X = normalize(X, 1)
    X = X.reshape(len(X), 32, 32, 1)
    return X


def read_svhn(root):
    svhn_train = datasets.SVHN(root=root+'svhn/', download=download_svhn, split="extra")
    svhn_test = datasets.SVHN(root=root+'svhn/', download=download_svhn, split="test")
    X_svhn_train = transform_svhn(svhn_train.data)
    X_svhn_test = transform_svhn(svhn_test.data)

    lb_svhn = LabelBinarizer()
    Y_svhn_train = lb_svhn.fit_transform(svhn_train.labels.flatten() % 10)
    Y_svhn_test = lb_svhn.fit_transform(svhn_test.labels.flatten() % 10)

    return X_svhn_train, Y_svhn_train, X_svhn_test, Y_svhn_test
