import tensorflow as tf
from tensorflow.keras.losses import categorical_crossentropy as cce
from tensorflow.keras.metrics import categorical_accuracy as ca

num_ensemble = 3
conv_layers = [3, 12, 21, 36, 45, 60, 69, 84, 93, 111, 120, 135, 144, 159, 168, 186, 195, 210, 219, 240]


def lr_schedule(epoch):
    learning_rate = 1e-3
    if epoch > 90:
        learning_rate *= 1e-2
    elif epoch > 75:
        learning_rate *= 1e-1
    print('Learning rate: ', learning_rate)
    return learning_rate


def ens_loss(y_true, y_pred, num_model=num_ensemble):
    y_p = tf.split(y_pred, num_model, axis=-1)
    y_t = tf.split(y_true, num_model, axis=-1)
    loss_1 = cce(y_t[0], y_p[0])
    loss_2 = cce(y_t[1], y_p[1])
    loss_3 = cce(y_t[2], y_p[2])
    return loss_1 + loss_2 + loss_3


def acc_metric(y_true, y_pred, num_model=num_ensemble):
    y_p = tf.split(y_pred, num_model, axis=-1)
    y_t = tf.split(y_true, num_model, axis=-1)
    ens_p = tf.reduce_mean(y_p, axis=0)
    return ca(y_t[0], ens_p)


@tf.function
def compute_gradient(model_fn, loss_fn, x, y):
    with tf.GradientTape() as g:
        g.watch(x)
        output = model_fn(x)
        split_output = tf.split(output, num_ensemble, axis=-1)
        logits = tf.reduce_mean(split_output, axis=0)
        loss = loss_fn(y, logits)
    grad = g.gradient(loss, x)
    return grad


def optimize_linear(grad, eps):
    optimal_perturbation = tf.sign(grad)
    scaled_perturbation = tf.multiply(eps, optimal_perturbation)
    return scaled_perturbation


def clip_eta(eta, eps):
    eta = tf.clip_by_value(eta, -eps, eps)
    return eta


def random_lp_vector(shape, eps, dtype=tf.float32, seed=None):
    r = tf.random.uniform(shape, -eps, eps, dtype=dtype, seed=seed)
    return r
