##############################################################
# Generate adversarial datapoints
import random
import os
import argparse
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from data_utils import GetLoader
from torchvision import transforms
import numpy as np
from pgd_attack import LinfPGDAttack
import imageio
from skimage import img_as_ubyte
from PIL import Image

dataset_name = 'mnist_m'
image_root = os.path.join('dataset', dataset_name)
cudnn.benchmark = True
batch_size = 200
image_size = 28

# manual_seed = random.randint(1, 10000)
manual_seed = 1
random.seed(manual_seed)
torch.manual_seed(manual_seed)

# load data
img_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
])

train_list = os.path.join(image_root, 'mnist_m_train_labels.txt')

train_dataset = GetLoader(
    data_root=os.path.join(image_root, 'mnist_m_train'),
    data_list=train_list,
    transform=img_transform
)

train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8)


test_list = os.path.join(image_root, 'mnist_m_test_labels.txt')

test_dataset = GetLoader(
    data_root=os.path.join(image_root, 'mnist_m_test'),
    data_list=test_list,
    transform=img_transform
)

test_dataloader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=8)

# load model
model_path = os.path.join('saved_models', 'mnistm_model.pt')
model = torch.load(model_path)

# setup optimizer
loss_func = torch.nn.CrossEntropyLoss()

model = model.cuda()
loss_func = loss_func.cuda()

model = model.eval()
model = model.cuda()


def train():
    n_total = 0
    n_correct = 0

    train_adv_data = []
    train_adv_labels = []

    for i, (img, label) in enumerate(train_dataloader):
        batch_size = img.shape[0]
        img = img.cuda()
        label = label.cuda()
        adv_img = attacker.perturb(img, label)
        train_adv_data.extend(adv_img.cpu().numpy())
        train_adv_labels.extend(label.cpu().numpy())

        adv_output = model(input_data=adv_img)
        pred = adv_output.data.max(1, keepdim=True)[1]
        n_correct += pred.eq(label.data.view_as(pred)).cpu().sum()
        n_total += batch_size
        print('Process {}'.format(n_total))

        # per sample checking
        # for idx in range(adv_img.shape[0]):
        #     tosave = adv_img[idx].cpu().numpy()
        #     tosave = np.moveaxis(tosave, 0, -1)
        #     imageio.imwrite(str(idx) + '.png', img_as_ubyte(tosave))

    accu = n_correct.data.numpy() * 1.0 / n_total

    print('Adv acc:', accu)

    return train_adv_data, train_adv_labels


def test():
    n_total = 0
    n_correct = 0

    test_adv_data = []
    test_adv_labels = []

    for i, (img, label) in enumerate(test_dataloader):
        batch_size = img.shape[0]
        img = img.cuda()
        label = label.cuda()
        adv_img = attacker.perturb(img, label)
        test_adv_data.extend(adv_img.cpu().numpy())
        test_adv_labels.extend(label.cpu().numpy())

        adv_output = model(input_data=adv_img)
        pred = adv_output.data.max(1, keepdim=True)[1]
        n_correct += pred.eq(label.data.view_as(pred)).cpu().sum()
        n_total += batch_size
        print('Process {}'.format(n_total))

    accu = n_correct.data.numpy() * 1.0 / n_total

    print('Adv acc:', accu)

    return test_adv_data, test_adv_labels


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
    parser.add_argument('--eps', default=8/255, type=float, help='eps')

    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

    attacker = LinfPGDAttack(model, eps=args.eps, nb_iter=100,
                             eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1.,
                             targeted=False, num_classes=10, elementwise_best=True)

    adv_data_save_path = os.path.join('dataset','adv_mnistm')
    os.makedirs(adv_data_save_path, exist_ok=True)

    ########### generating train
    train_adv_data, train_adv_labels = train()

    np.save(adv_data_save_path + '/train_eps' + str(args.eps), [train_adv_data, train_adv_labels])

    ########### generating test
    test_adv_data, test_adv_labels = test()

    np.save(adv_data_save_path + '/test_eps' + str(args.eps), [test_adv_data, test_adv_labels])
