from torch.utils.data import DataLoader
from skimage import transform
from adv_attack import *
from sklearn.preprocessing import LabelBinarizer
import pickle
from torchvision import datasets, transforms
import numpy as np


def normalize(x, max_value):
    """ If x takes its values between 0 and max_value, normalize it between -1 and 1"""
    return (x / float(max_value)) * 2 - 1


def transform_mnist(X):
    X = X.reshape(len(X), 28, 28)
    X = np.array([transform.resize(im, [32, 32]) for im in X])
    X = normalize(X, 1)
    X = X.reshape(len(X), 32, 32, 1)
    return X


def read_mnist(root, scale_32=True, eps=0.2):
    if scale_32:
        trans = transforms.Compose([transforms.ToTensor(), ])
    else:
        trans = transforms.ToTensor()
    mnist_train = datasets.MNIST(root=root+'mnist/', train=True, download=True, transform=trans)
    mnist_test = datasets.MNIST(root=root+'mnist/', train=False, download=True, transform=trans)

    '''Adversarial attack'''
    pretrained_model = "BaseAttackModel/lenet_mnist_model.pth"
    use_cuda = True
    device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")
    model = Net().to(device)
    model.load_state_dict(torch.load(pretrained_model, map_location='cpu'))
    model.eval()

    train_loader = DataLoader(mnist_train, batch_size=1, shuffle=False, drop_last=False)
    train_data, train_labels = [], []
    for data, target in train_loader:
        output_data = attack(model, data, target, device, epsilon=eps)
        train_data.append(output_data)
        train_labels.append(target.item())
    train_data = transform_mnist(np.array(train_data))

    test_loader = DataLoader(mnist_test, batch_size=1, shuffle=False, drop_last=False)
    test_data, test_labels = [], []
    for data, target in test_loader:
        output_data = attack(model, data, target, device, epsilon=eps)
        test_data.append(output_data)
        test_labels.append(target.item())
    test_data = transform_mnist(np.array(test_data))

    lb_mnist = LabelBinarizer()
    Y_mnist_train = lb_mnist.fit_transform(train_labels)
    Y_mnist_test = lb_mnist.fit_transform(test_labels)

    all_data = {}
    all_data['X_train'] = train_data
    all_data['Y_train'] = Y_mnist_train
    all_data['X_test'] = test_data
    all_data['Y_test'] = Y_mnist_test

    with open("data/MNIST_target_{}.pkl".format(eps), "wb") as pkl_file:
        pickle.dump(all_data, pkl_file)
    return train_data, Y_mnist_train, test_data, Y_mnist_test
