from tensorflow.keras.models import load_model
from tensorflow.keras.losses import categorical_crossentropy as cce
import tensorflow as tf
import argparse
import utils
import numpy as np
import os

if not os.path.exists("adversarial_examples"):
    os.mkdir("adversarial_examples")

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', action='store', type=str, required=True)
parser.add_argument('--attack', action='store', type=str, required=True)
parser.add_argument('--strength', action='store', type=float, required=True)

args = parser.parse_args()
if args.dataset == "cifar10":
    num_classes = 10
else:
    num_classes = 100

path = "adversarial_examples/" + args.dataset
if not os.path.exists(path):
    os.mkdir(path)

surrogate_model = load_model(args.dataset + "_baseline_surrogate.h5", custom_objects={'ens_loss': utils.ens_loss, 'acc_metric': utils.acc_metric})
nb_iter = 50
decay_factor = 0.01

test_samples = np.load("benign_samples/" + args.dataset + "/test_samples.npy")
test_labels = np.load("benign_samples/" + args.dataset + "/test_labels.npy")
x = tf.cast(test_samples, tf.float32)
y_ = surrogate_model(x)
split_y_ = tf.split(y_, utils.num_ensemble, axis=-1)
y = tf.one_hot(tf.argmax(tf.reduce_mean(split_y_, axis=0), 1), num_classes)
grad = utils.compute_gradient(surrogate_model, cce, x, y)
eps = args.strength

if args.attack == "fgsm":
    print("FGSM Sample Generation")
    optimal_perturbation = utils.optimize_linear(grad, eps)
    adv_x = x + optimal_perturbation
    np.save(path + "/" + args.attack + "_" + str(eps) + ".npy", adv_x)

if args.attack == "bim":
    print("BIM Sample Generation")
    eta = tf.zeros_like(x)
    eta = utils.clip_eta(eta, eps)
    adv_x = x + eta
    eps_iter = eps / 5
    i = 0
    while i < nb_iter:
        grad = utils.compute_gradient(surrogate_model, cce, adv_x, y)
        optimal_perturbation = utils.optimize_linear(grad, eps_iter)
        adv_x = adv_x + optimal_perturbation
        adv_x = x + utils.clip_eta(adv_x - x, eps)
        i += 1
    np.save(path + "/" + args.attack + "_" + str(eps) + ".npy", adv_x)

if args.attack == "mim":
    print("MIM Sample Generation")
    momentum = tf.zeros_like(x)
    adv_x = x
    eps_iter = eps / 5
    i = 0
    while i < nb_iter:
        grad = utils.compute_gradient(surrogate_model, cce, adv_x, y)
        red_ind = list(range(1, len(grad.shape)))
        avoid_zero_div = tf.cast(1e-12, grad.dtype)
        grad = grad / tf.math.maximum(avoid_zero_div, tf.math.reduce_mean(tf.math.abs(grad), red_ind, keepdims=True),)
        momentum = decay_factor * momentum + grad
        optimal_perturbation = utils.optimize_linear(momentum, eps_iter)
        adv_x = adv_x + optimal_perturbation
        adv_x = x + utils.clip_eta(adv_x - x, eps)
        i += 1
    np.save(path + "/" + args.attack + "_" + str(eps) + ".npy", adv_x)

if args.attack == "pgd":
    print("PGD Sample Generation")
    for it in range(10):
        print("RANDOM RESTART: " + str(it + 1))
        rand_minmax = eps
        eta = utils.random_lp_vector(tf.shape(x), tf.cast(rand_minmax, x.dtype), dtype=x.dtype)
        eta = utils.clip_eta(eta, eps)
        adv_x = x + eta
        eps_iter = eps / 5
        i = 0
        while i < nb_iter:
            grad = utils.compute_gradient(surrogate_model, cce, adv_x, y)
            optimal_perturbation = utils.optimize_linear(grad, eps_iter)
            adv_x = adv_x + optimal_perturbation
            adv_x = x + utils.clip_eta(adv_x - x, eps)
            i += 1
        np.save(path + "/" + args.attack + "_" + str(eps) + "_" + str(it + 1) + ".npy", adv_x)
