import tensorflow as tf
from tensorflow.keras.losses import categorical_crossentropy as cce
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model
import argparse
import numpy as np
import utils
from scipy import stats

lamda = 2
log_det_lamda = 0.5
det_offset = 1e-6
log_offset = 1e-20
zero = tf.constant(0, dtype=tf.float32)


def gal_loss(y_true, y_pred, num_model=utils.num_ensemble):
    y_p = tf.split(y_pred, num_model, axis=-1)
    y_t = tf.split(y_true, num_model, axis=-1)
    loss_0 = cce(y_t[0], y_p[0])
    loss_1 = cce(y_t[1], y_p[1])
    loss_2 = cce(y_t[2], y_p[2])
    grads_0 = tf.gradients(loss_0, model.layers[0].output)
    g_0 = tf.reshape(grads_0, [-1])
    grads_1 = tf.gradients(loss_1, model.layers[0].output)
    g_1 = tf.reshape(grads_1, [-1])
    grads_2 = tf.gradients(loss_2, model.layers[0].output)
    g_2 = tf.reshape(grads_2, [-1])

    dot_product_1 = tf.tensordot(g_0, g_1, axes=1)
    norms_1 = tf.multiply(tf.norm(g_0), tf.norm(g_1))
    if tf.equal(norms_1, 0) is True:
        norms_1 = tf.math.add(norms_1, tf.constant([1e-9]))
    cs_1 = tf.divide(dot_product_1, norms_1)

    dot_product_2 = tf.tensordot(g_0, g_2, axes=1)
    norms_2 = tf.multiply(tf.norm(g_0), tf.norm(g_2))
    if tf.equal(norms_2, 0) is True:
        norms_2 = tf.math.add(norms_2, tf.constant([1e-9]))
    cs_2 = tf.divide(dot_product_2, norms_2)

    dot_product_3 = tf.tensordot(g_1, g_2, axes=1)
    norms_3 = tf.multiply(tf.norm(g_1), tf.norm(g_2))
    if tf.equal(norms_3, 0) is True:
        norms_3 = tf.math.add(norms_3, tf.constant([1e-9]))
    cs_3 = tf.divide(dot_product_3, norms_3)
    sum = tf.math.log(tf.math.exp(cs_1) + tf.math.exp(cs_2) + tf.math.exp(cs_3))
    return loss_0 + loss_1 + loss_2 + 0.5 * sum


def parl_loss(y_true, y_pred, num_model=utils.num_ensemble):
    y_p = tf.split(y_pred, num_model, axis=-1)
    y_t = tf.split(y_true, num_model, axis=-1)
    loss_1 = cce(y_t[0], y_p[0])
    loss_2 = cce(y_t[1], y_p[1])
    loss_3 = cce(y_t[2], y_p[2])
    layer_sum = 0
    for i in utils.conv_layers[:args.num_layers]:
        g_0 = tf.reshape(tf.gradients(model.layers[i - 2].output, model.layers[0].output), [-1])
        g_1 = tf.reshape(tf.gradients(model.layers[i - 1].output, model.layers[0].output), [-1])
        g_2 = tf.reshape(tf.gradients(model.layers[i].output, model.layers[0].output), [-1])
        dot_product_1 = tf.tensordot(g_0, g_1, axes=1)
        norms_1 = tf.multiply(tf.norm(g_0), tf.norm(g_1))
        if tf.equal(norms_1, 0) is True:
            norms_1 = tf.math.add(norms_1, tf.constant([1e-12]))
        cs_1 = tf.divide(dot_product_1, norms_1)
        dot_product_2 = tf.tensordot(g_0, g_2, axes=1)
        norms_2 = tf.multiply(tf.norm(g_0), tf.norm(g_2))
        if tf.equal(norms_2, 0) is True:
            norms_2 = tf.math.add(norms_2, tf.constant([1e-12]))
        cs_2 = tf.divide(dot_product_2, norms_2)
        dot_product_3 = tf.tensordot(g_1, g_2, axes=1)
        norms_3 = tf.multiply(tf.norm(g_1), tf.norm(g_2))
        if tf.equal(norms_3, 0) is True:
            norms_3 = tf.math.add(norms_3, tf.constant([1e-12]))
        cs_3 = tf.divide(dot_product_3, norms_3)
        layer_sum = layer_sum + cs_1 + cs_2 + cs_3
    return loss_1 + loss_2 + loss_3 + 0.5 * layer_sum


def Loss_withEE_DPP(y_true, y_pred, num_model=3):
    y_true = (num_model * y_true) / tf.reduce_sum(y_true, axis=1, keepdims=True)
    y_p = tf.split(y_pred, num_model, axis=-1)
    y_t = tf.split(y_true, num_model, axis=-1)
    CE_all = 0
    for i in range(num_model):
        CE_all += cce(y_t[i], y_p[i])
    if lamda == 0 and log_det_lamda == 0:
        print('This is original ECE!')
        return CE_all
    else:
        EE = Ensemble_Entropy(y_true, y_pred, num_model)
        log_dets = log_det(y_true, y_pred, num_model)
        return CE_all - lamda * EE - log_det_lamda * log_dets


def Ensemble_Entropy(y_true, y_pred, num_model=3):
    y_p = tf.split(y_pred, num_model, axis=-1)
    y_p_all = 0
    for i in range(num_model):
        y_p_all += y_p[i]
    Ensemble = Entropy(y_p_all / num_model)
    return Ensemble


def Entropy(input):
    return tf.reduce_sum(-tf.multiply(input, tf.math.log(input + log_offset)), axis=-1)


def log_det(y_true, y_pred, num_model=3):
    bool_R_y_true = tf.not_equal(tf.ones_like(y_true) - y_true, zero)
    mask_non_y_pred = tf.boolean_mask(y_pred, bool_R_y_true)
    mask_non_y_pred = tf.reshape(mask_non_y_pred, [-1, num_model, num_classes - 1])
    mask_non_y_pred = mask_non_y_pred / tf.norm(mask_non_y_pred, axis=2, keepdims=True)
    matrix = tf.matmul(mask_non_y_pred, tf.transpose(mask_non_y_pred, perm=[0, 2, 1]))
    all_log_det = tf.linalg.logdet(matrix + det_offset * tf.expand_dims(tf.eye(num_model), 0))
    return all_log_det


def Ensemble_Entropy_metric(y_true, y_pred, num_model=3):
    EE = Ensemble_Entropy(y_true, y_pred, num_model=num_model)
    return K.mean(EE)


def log_det_metric(y_true, y_pred, num_model=3):
    log_dets = log_det(y_true, y_pred, num_model=num_model)
    return K.mean(log_dets)


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', action='store', type=str, required=True)
parser.add_argument('--attack', action='store', type=str, required=True)
parser.add_argument('--random_start', action='store', type=int, required=False)
parser.add_argument('--strength', action='store', type=float, required=True)
parser.add_argument('--model', action='store', type=str, required=True)
parser.add_argument('--num_layers', action='store', type=int, required=False)
args = parser.parse_args()

if args.model == "gal":
    model_name = "cifar10_gal.h5"
    model = load_model(model_name, custom_objects={'gal_loss': gal_loss, 'acc_metric': utils.acc_metric})
elif args.model == "parl":
    model_name = args.dataset + "_" + args.model + "_" + str(args.num_layers) + ".h5"
    model = load_model(model_name, custom_objects={'parl_loss': parl_loss, 'acc_metric': utils.acc_metric})
else:
    if args.model == "adp":
        model_name = args.dataset + "_" + args.model + ".h5"
        model = load_model(model_name, custom_objects={'Loss_withEE_DPP': Loss_withEE_DPP, 'Ensemble_Entropy_metric': Ensemble_Entropy_metric, 'acc_metric': utils.acc_metric, 'log_det_metric': log_det_metric})
    else:
        model_name = args.dataset + "_" + args.model + ".h5"
        model = load_model(model_name, custom_objects={'ens_loss': utils.ens_loss, 'acc_metric': utils.acc_metric})

if args.dataset == "cifar10":
    num_classes = 10
else:
    num_classes = 100
path = "adversarial_examples/" + args.dataset
test_labels = np.load("benign_samples/" + args.dataset + "/test_labels.npy")
eps = args.strength

if args.attack in ['fgsm', 'bim', 'mim']:
    adv_x = np.load(path + "/" + args.attack + "_" + str(eps) + ".npy")
else:
    adv_x = np.load(path + "/" + args.attack + "_" + str(eps) + "_" + str(args.random_start) + ".npy")

predict = model.predict(adv_x)
y_p = np.split(predict, utils.num_ensemble, axis=-1)
if args.model == "parl":
    predictions = []
    for i in range(utils.num_ensemble):
        predictions.append(np.argmax(y_p[i], axis=1))
    mode_ensemble_predictions = stats.mode(predictions, axis=0)[0]
    ensemble_accuracy = np.sum(mode_ensemble_predictions == test_labels) / len(adv_x)
else:
    average_ensemble_probability = np.mean(y_p, axis=0)
    ensemble_prediction_average = np.argmax(average_ensemble_probability, axis=1)
    ensemble_accuracy = np.sum(ensemble_prediction_average == test_labels) / len(adv_x)

print("Ensemble Accuracy: " + str(ensemble_accuracy))
