import foolbox as fb
import tensorflow as tf

import numpy as np

from time import time

import art


def CW(model,
       x,
       y_sparse,
       epsilon=0.141,
       confidence=1.0,
       stepsize=1.e-3,
       steps=100,
       clamp=[0, 1],
       batch_size=32,
       num_interpolation=10,
       return_probits=False):
    """y_sparse means you pass the integer labels instead of the one-hot encoding to this function"""

    boundaries = get_boundary_points(
        model,
        x,
        y_sparse,
        batch_size=batch_size,
        pipeline='cw',
        search_range=['local', 'l2', epsilon, stepsize, steps],
        clamp=clamp,
        backend='tf.keras',
        confidence=confidence)

    boundaries = np.minimum(boundaries, np.ones_like(boundaries) * clamp[1])
    boundaries = np.maximum(boundaries, np.ones_like(boundaries) * clamp[0])

    y_pred_adv_prob = model.predict(boundaries, batch_size=batch_size)
    y_pred_adv = np.argmax(y_pred_adv_prob, -1)
    is_adv = y_pred_adv == y_sparse
    is_adv = [not b for b in is_adv]
    is_adv = np.array(is_adv)

    if return_probits:
        y_pred_adv = y_pred_adv_prob

    return boundaries, y_pred_adv, is_adv


def PGDs(model,
         x,
         y_sparse,
         epsilon=0.141,
         confidence=1.0,
         stepsize=None,
         steps=100,
         clamp=[0, 1],
         batch_size=32,
         num_interpolation=10,
         return_probits=False):
    """y_sparse means you pass the integer labels instead of the one-hot encoding to this function"""

    if num_interpolation < 1:
        num_interpolation = 1

    boundaries = get_boundary_points(
        model,
        x,
        y_sparse,
        batch_size=batch_size,
        pipeline='pgd',
        search_range=[
            'local', 'l2',
            [
                epsilon * (i / num_interpolation)
                for i in range(1, num_interpolation + 1)
            ], stepsize, steps
        ],
        clamp=clamp,
        backend='tf.keras')

    y_pred_adv_prob = model.predict(boundaries, batch_size=batch_size)
    y_pred_adv = np.argmax(y_pred_adv_prob, -1)
    is_adv = y_pred_adv == y_sparse
    is_adv = [not b for b in is_adv]
    is_adv = np.array(is_adv)

    if return_probits:
        y_pred_adv = y_pred_adv_prob

    return boundaries, y_pred_adv, is_adv


def IGD_L1(model,
           x,
           y_sparse,
           epsilon=0.141,
           num_classes=2,
           confidence=0.5,
           stepsize=1.e-3,
           steps=100,
           clamp=[0, 1],
           batch_size=32,
           num_interpolation=10,
           return_probits=False):

    y_orig_onehot = tf.keras.utils.to_categorical(y_sparse,
                                                  num_classes=num_classes)

    return iterative_attack(model,
                            x,
                            y_orig_onehot,
                            1.0,
                            epsilon_iter=stepsize,
                            max_iter=steps,
                            clip_min=clamp[0],
                            clip_max=clamp[1],
                            decision_rule='L1',
                            batch_size=batch_size,
                            confidence=confidence,
                            return_probits=return_probits,
                            num_classes=num_classes)


def IGD_L2(model,
           x,
           y_sparse,
           epsilon=0.141,
           num_classes=2,
           confidence=0.5,
           stepsize=1.e-3,
           steps=100,
           clamp=[0, 1],
           batch_size=32,
           num_interpolation=10,
           return_probits=False):

    y_orig_onehot = tf.keras.utils.to_categorical(y_sparse,
                                                  num_classes=num_classes,
                                                  dtype='float32')

    return iterative_attack(model,
                            x,
                            y_orig_onehot,
                            0.0,
                            epsilon_iter=stepsize,
                            max_iter=steps,
                            clip_min=clamp[0],
                            clip_max=clamp[1],
                            decision_rule='L2',
                            batch_size=batch_size,
                            confidence=confidence,
                            return_probits=return_probits,
                            num_classes=num_classes)


def iterative_attack(model,
                     x,
                     y_orig_onehot,
                     beta,
                     epsilon_iter=1e-1,
                     max_iter=100,
                     clip_min=-1.,
                     clip_max=1.,
                     decision_rule='EN',
                     batch_size=128,
                     confidence=0.5,
                     return_probits=True, 
                     num_classes=2):

    classifier = art.estimators.classification.tensorflow.TensorFlowV2Classifier(
        model=model, clip_values=[np.min(clip_min),
                                  np.max(clip_max)], nb_classes=num_classes, input_shape=x.shape[1:])

    attack = art.attacks.evasion.ElasticNet(
        classifier=classifier,
        confidence=confidence,
        learning_rate=epsilon_iter,
        beta=beta,
        max_iter=max_iter,
        batch_size=batch_size,
        decision_rule=decision_rule,
        verbose=False)

    boundaries = attack.generate(x, y=y_orig_onehot)
    y_sparse = np.argmax(y_orig_onehot, axis=-1)

    y_pred_adv_prob = model.predict(boundaries, batch_size=batch_size)
    y_pred_adv = np.argmax(y_pred_adv_prob, -1)
    is_adv = y_pred_adv == y_sparse
    is_adv = [not b for b in is_adv]
    is_adv = np.array(is_adv)

    if return_probits:
        y_pred_adv = y_pred_adv_prob

    return boundaries, y_pred_adv, is_adv


def batch_flatten(x):
    return np.reshape(x, (x.shape[0], -1))


def IntegratedGradient(model, x, y, res=50, baseline=None, num_class=2):

    if baseline is None:
        baseline = tf.zeros_like(x)
    elif isinstance(baseline, float) or isinstance(baseline, int):
        baseline = tf.zeros_like(x) + baseline
    elif len(baseline.shape) == len(x.shape) - 1:
        baseline = tf.ones_like(x) * baseline[None, :]

    assert baseline.shape == x.shape

    grad = tf.zeros_like(x)
    qois = []
    grads = []
    grads_t = []

    for i in range(1, res + 1):
        j = tf.ones((x.shape[0],1), dtype=np.float32) * i
        with tf.GradientTape(persistent=True) as tape:
            tape.watch(j)
            x_in = baseline + (x - baseline) * j / res
            out = model(x_in, training=False)
            if out.shape[1] > 1:
                per_qoi = model(x_in, training=False) * tf.one_hot(
                    tf.cast(y, tf.int32), num_class) 
            else:
                per_qoi = out * tf.cast(y * 2 - 1, tf.float32)

            qoi = tf.reduce_sum(per_qoi, axis=-1)
            sum_qoi = tf.reduce_sum(qoi)
            qois.append(tf.expand_dims(qoi, 0))
        input_grad = tape.gradient(sum_qoi, x_in)
        # qoi = tf.expand_dims(qoi, 1)
        # j = tf.expand_dims(j, 1)
        input_grad_t = tape.gradient(sum_qoi, j)
        grad += input_grad / res
        grads.append(input_grad.numpy())
        grads_t.append(input_grad_t)
    attr = grad * x
    qois = tf.transpose(tf.concat(qois, axis=0))
    if len(grads) > 0:
        grads = np.vstack(grads)
    grads_t = tf.concat(grads_t, axis=1)

    return attr, qois, grads_t


def Saliency(model, x, y, num_class=2):
    x = tf.constant(x)
    with tf.GradientTape() as tape:
        tape.watch(x)
        qoi = model(x, training=False) * tf.one_hot(tf.cast(y, tf.int32),
                                                    num_class)
        qoi = tf.reduce_sum(qoi, axis=-1)
    input_grad = tape.gradient(qoi, x)
    return input_grad


def normalize_attr(attr):
    if len(attr.shape) == 4:
        attr = tf.reduce_mean(attr, axis=-1)
        attr = tf.abs(attr)
        attr /= tf.reduce_max(attr, axis=(1, 2), keepdims=True)
    return attr


def get_pdf(cdf):
    pdf = cdf[:, 1:] - cdf[:, :-1]
    pdf = tf.abs(pdf)
    pdf = pdf / tf.reduce_sum(pdf, -1, keepdims=True)
    return pdf


def unit_vector(x):
    x_flatten = tf.keras.backend.batch_flatten(x)
    return x / l2_norm(x_flatten)


def update_delta(x, delta, grad, step_size, epsilon, p, clamp):
    if p == np.inf:
        delta = delta + step_size * tf.sign(grad)
    elif p == 2:
        delta = delta + step_size * unit_vector(grad)
    delta = delta - x
    delta_flatten = tf.keras.backend.batch_flatten(delta)
    if p == 2:
        norm = l2_norm(delta_flatten)
    elif p == np.inf:
        norm = tf.reduce_max(tf.abs(delta_flatten), axis=-1, keepdims=True)
    elif p == 1:
        norm = tf.reduce_sum(tf.abs(delta_flatten), axis=-1, keepdims=True)
    coeff = tf.clip_by_value(epsilon / norm, 0.0, 1.0)
    delta = tf.reshape(delta_flatten * coeff, [-1] + delta.shape.as_list()[1:])
    delta = delta + x
    delta = clip_with_clamp(delta, clamp)

    return delta


def clip_with_clamp(delta, clamp):
    if isinstance(clamp[0], float) and isinstance(clamp[1], float):
        delta = tf.clip_by_value(delta, clamp[0], clamp[1])
    else:
        delta = tf.minimum(delta, tf.ones_like(delta) * clamp[1])
        delta = tf.maximum(delta, tf.ones_like(delta) * clamp[0])

    return delta


def l2_norm(x):
    return (tf.sqrt(tf.reduce_sum(x**2., axis=-1, keepdims=True)) +
            tf.keras.backend.epsilon())


def batch_path_attack(model,
                      x,
                      y,
                      output_layer=-1,
                      a = 0.5,
                      res=10,
                      baseline=0,
                      max_steps=100,
                      adv_step_size=1e-3,
                      clamp=[0, 1],
                      adv_epsilon=0.1,
                      p=2,
                      num_class=1000,
                      alpha=1,
                      target=None,
                      stddev=1e-3,
                      distance_fn=None,
                      use_kl=True,
                      return_probits=False):

    if output_layer != -1:
        logit_model = tf.keras.models.Model(model.input, model.layers[output_layer].output)
        logit_model.compile()
        model = logit_model

    x = tf.constant(x)
    delta = tf.identity(x)

    best_diff = [0] * x.shape[0]
    optimal_adv = x.numpy()

    if distance_fn is None:

        def distance_fn(x, y):
            return tf.norm(tf.keras.backend.batch_flatten(x) -
                           tf.keras.backend.batch_flatten(y),
                           axis=-1)

    original_attr, original_qoi, _ = IntegratedGradient(model,
                                                        delta,
                                                        y,
                                                        res=res,
                                                        baseline=baseline,
                                                        num_class=num_class)
    original_attr = normalize_attr(original_attr)
    pdf = get_pdf(original_qoi)

    kl = tf.keras.losses.KLDivergence()
    delta += tf.random.normal(delta.shape, stddev=stddev, dtype=delta.dtype)
    delta = tf.Variable(delta)

    for _ in range(max_steps):
        with tf.GradientTape() as tape:
            tape.watch(delta)
            loss = 0

            _, qois, grads_t = IntegratedGradient(model,
                                               delta,
                                               y,
                                               res=res,
                                               baseline=baseline,
                                               num_class=num_class)
            # adv_pdf = get_pdf(qois)
            adv_pdf = grads_t

            if use_kl:
                if target is not None:
                    pdf = target
                    loss -= tf.reduce_mean(kl(target, adv_pdf)) * alpha
                else:
                    loss += tf.reduce_mean(kl(pdf, adv_pdf)) * alpha

            loss += tf.reduce_mean(tf.reduce_sum(qois, -1))
            # loss -= tf.reduce_max(tf.abs(adv_pdf[:, int(a * res):]))

        grad = tape.gradient(loss, delta)
        delta = update_delta(x, delta, grad, adv_step_size, adv_epsilon, p,
                             clamp)

        keep_atk = tf.constant(y) == tf.argmax(model(delta), -1)
        # attr = normalize_attr(attr)
        # diff = distance_fn(original_attr, attr)
        for i in range(len(optimal_adv)):
            # If the label is unchanged and the difference is better than the previous adv example
            if keep_atk[i]:
                # best_diff[i] = diff[i]
                optimal_adv[i] = delta[i]

        # print(l2_norm(x-delta))

    y_pred = model.predict(optimal_adv)

    if y_pred.shape[1] == 1:
        y_pred = np.concatenate([y_pred, -y_pred], axis=1)

    if not return_probits:
        y_pred = np.argmax(y_pred, axis=-1)
    return optimal_adv, y_pred, best_diff


def RNS(model, X, Y, batch_size=256, print_grad_norm=True, **kwargs):

    print('Robust Neighbor Searching in Progress ...')

    pb = tf.keras.utils.Progbar(target=X.shape[0])

    optimal_adv, y_pred, best_diff = [], [], []
    for i in range(0, X.shape[0], batch_size):
        x = X[i:i + batch_size]
        y = Y[i:i + batch_size]
        opt, y_p, diff = batch_path_attack(model, x, y, **kwargs)
        optimal_adv.append(opt)
        y_pred.append(y_p)
        best_diff.append(diff)

        pb.add(x.shape[0])

    optimal_adv = np.concatenate(optimal_adv, 0)
    y_pred = np.concatenate(y_pred, 0)
    best_diff = np.concatenate(best_diff, 0)

    if print_grad_norm:
        grad = Saliency(model, X, Y, num_class=kwargs['num_class'])
        opt_grad = Saliency(model,
                            optimal_adv,
                            Y,
                            num_class=kwargs['num_class'])
        avg_norm = tf.reduce_mean(l2_norm(
            tf.keras.backend.batch_flatten(grad)))
        opt_avg_norm = tf.reduce_mean(
            l2_norm(tf.keras.backend.batch_flatten(opt_grad)))
        print(f"|| df(x) / dx ||_2 before RNS: {avg_norm}")
        print(f"|| df(x) / dx ||_2 after RNS: {opt_avg_norm}")

    return optimal_adv, y_pred, best_diff


def Simple(model, X, Y, batch_size=256, print_grad_norm=True, **kwargs):
    optimal_adv, y_pred, best_diff = [], [], []

    print('Simple Searching in Progress ...')

    pb = tf.keras.utils.Progbar(target=X.shape[0])
    for i in range(0, X.shape[0], batch_size):
        x = X[i:i + batch_size]
        y = Y[i:i + batch_size]

        opt, y_p, diff = simple_attack(model, x, y, **kwargs)
        optimal_adv.append(opt)
        y_pred.append(y_p)
        best_diff.append(diff)

        pb.add(x.shape[0])

    optimal_adv = np.concatenate(optimal_adv, 0)
    y_pred = np.concatenate(y_pred, 0)
    best_diff = np.concatenate(best_diff, 0)
    return optimal_adv, y_pred, best_diff


def simple_attack(model,
                  x,
                  y,
                  res=10,
                  baseline=0,
                  max_steps=100,
                  adv_step_size=1e-3,
                  clamp=[0, 1],
                  adv_epsilon=0.1,
                  p=2,
                  num_class=1000,
                  alpha=1,
                  target=None,
                  stddev=1e-3,
                  distance_fn=None,
                  use_kl=True,
                  return_probits=False):

    x = tf.constant(x)
    delta = tf.identity(x)

    best_diff = [0] * x.shape[0]
    optimal_adv = x.numpy()

    for _ in range(max_steps):
        grad = Saliency(model, delta, y, num_class=num_class)
        delta = update_delta(x, delta, grad, adv_step_size, adv_epsilon, p,
                             clamp)

        keep_atk = tf.constant(y) == tf.argmax(model(delta), -1)
        # attr = normalize_attr(attr)
        # diff = distance_fn(original_attr, attr)
        for i in range(len(optimal_adv)):
            # If the label is unchanged and the difference is better than the previous adv example
            if keep_atk[i]:
                # best_diff[i] = diff[i]
                optimal_adv[i] = delta[i]
            else:
                print("predictions changed, will stop early")

    y_pred = model.predict(optimal_adv)
    if not return_probits:
        y_pred = np.argmax(y_pred, axis=-1)
    return optimal_adv, y_pred, best_diff


def get_boundary_points(model,
                        x,
                        y_sparse,
                        batch_size=64,
                        pipeline=['pgd'],
                        search_range=['local', 'l2', 0.3, None, 100],
                        clamp=[0, 1],
                        backend='pytorch',
                        confidence=1.0,
                        **kwargs):
    """Find nearby boundary points by running adversarial attacks

    Reference: Boundary Attributions for Normal (Vector) Explanations https://arxiv.org/pdf/2103.11257.pdf
    
    https://github.com/zifanw/boundary

    Args:
        model (tf.models.Model or torch.nn.Module): tf.keras model or pytorch model
        x (np.ndarray): Benigh inputs
        y_onehot (np.ndarray): One-hot labels for the benign inputs
        batch_size (int, optional): Batch size. Defaults to 64.
        pipeline (list, optional): A list of adversarial attacks used to find nearby boundaries. Defaults to ['pgd'].
        search_range (list, optional): Parameters shared by all adversarial attacks. Defaults to ['local', 'l2', 0.3, None, 100].
        clamp (list, optional): Data range. Defaults to [0, 1].
        backend (str, optional): Deep learning frame work. It is either 'tf.keras' or 'pytorch'. Defaults to 'pytorch'.
        device (str, optional): GPU device to run the attack. This only matters if the backend is 'pytorch'. Defaults to 'cuda:0'.
    Returns:
        (np.ndarray, np.ndarray): Points on the closest boundary and distances
    """

    if not isinstance(clamp[0], float):
        clamp = [np.min(clamp[0]), np.max(clamp[1])]

    if 'pgd' == pipeline:
        print(">>> Start PGD Attack <<<", end='\n', flush=True)
        if backend == 'tf.keras':
            fmodel = fb.TensorFlowModel(model, bounds=(clamp[0], clamp[1]))
            x = tf.constant(x, dtype=tf.float32)
            y_sparse = tf.constant(y_sparse, dtype=tf.int32)
            if isinstance(search_range[2], float):
                if search_range[1] == 'l2':
                    attack = fb.attacks.L2PGD(
                        rel_stepsize=search_range[3] if search_range[3]
                        is not None else 2 * search_range[2] / search_range[4],
                        steps=search_range[4])
                else:
                    attack = fb.attacks.LinfPGD(
                        rel_stepsize=search_range[3] if search_range[3]
                        is not None else 2 * search_range[2] / search_range[4],
                        steps=search_range[4])

                boundary_points = []
                success = 0
                for i in range(0, x.shape[0], batch_size):
                    batch_x = x[i:i + batch_size]
                    batch_y = y_sparse[i:i + batch_size]

                    _, batch_boundary_points, batch_success = attack(
                        fmodel, batch_x, batch_y, epsilons=[search_range[2]])

                    boundary_points.append(
                        batch_boundary_points[0].unsqueeze(0))
                    success += np.sum(batch_success)

                boundary_points = tf.concat(boundary_points, axis=0)
                success /= x.shape[0]

                print(
                    f">>> Attacking with EPS={search_range[2]} (norm={search_range[1]}), Success Rate={success} <<<"
                )

            elif isinstance(search_range[2], (list, np.ndarray)):
                boundary_points = []
                success = 0.
                for i in range(0, x.shape[0], batch_size):

                    batch_x = x[i:i + batch_size]
                    batch_y = y_sparse[i:i + batch_size]

                    batch_boundary_points = None
                    batch_success = None

                    for eps in search_range[2]:
                        if search_range[1] == 'l2':
                            attack = fb.attacks.L2PGD(
                                rel_stepsize=search_range[3] if search_range[3]
                                is not None else 2 * eps / search_range[4],
                                steps=search_range[4])
                        else:
                            attack = fb.attacks.LinfPGD(
                                rel_stepsize=search_range[3] if search_range[3]
                                is not None else 2 * eps / search_range[4],
                                steps=search_range[4])

                        _, c_boundary_points, c_success = attack(
                            fmodel, batch_x, batch_y, epsilons=[eps])
                        c_boundary_points = c_boundary_points[0].numpy()
                        c_success = tf.cast(c_success[0], tf.int32).numpy()

                        print(
                            f">>> Attacking with EPS={eps} (norm={search_range[1]}), Success Rate={tf.reduce_mean(tf.cast(c_success, tf.float32))} <<<"
                        )

                        if batch_boundary_points is None:
                            batch_boundary_points = c_boundary_points
                            batch_success = c_success
                        else:
                            for i in range(batch_boundary_points.shape[0]):
                                if batch_success[i] == 0 and c_success[i] == 1:
                                    batch_boundary_points[
                                        i] = c_boundary_points[i]
                                    batch_success[i] = c_success[i]

                    boundary_points.append(batch_boundary_points)
                    success += np.sum(batch_success)

                boundary_points = tf.concat(boundary_points, axis=0)
                success /= x.shape[0]

            else:
                raise TypeError(
                    f"Expecting eps as float or list, but got {type(search_range[3])}"
                )

            # y_pred = np.argmax(
            #     model.predict(boundary_points, batch_size=batch_size), -1)

            x = x.numpy()
            y_sparse = y_sparse.numpy()
            boundary_points = boundary_points.numpy()

        # bd, dis2cls_bd = take_closer_bd(x, np.argmax(y_sparse, -1), bd,
        #                                 dis2cls_bd, boundary_points,
        #                                 np.argmax(y_pred, -1))

    if 'cw' == pipeline:
        print(">>> Start CW Attack <<<", end='\n', flush=True)

        if backend == 'tf.keras':
            fmodel = fb.TensorFlowModel(model, bounds=(clamp[0], clamp[1]))
            x = tf.constant(x, dtype=tf.float32)
            y_sparse = tf.constant(y_sparse, dtype=tf.int32)

            attack = fb.attacks.L2CarliniWagnerAttack(
                stepsize=search_range[3] if search_range[3] is not None else
                2 * search_range[2] / search_range[4],
                steps=search_range[4],
                confidence=confidence)

            boundary_points = []
            success = 0.
            for i in range(0, x.shape[0], batch_size):
                batch_x = x[i:i + batch_size]
                batch_y = y_sparse[i:i + batch_size]

                _, batch_boundary_points, batch_success = attack(
                    fmodel, batch_x, batch_y, epsilons=[search_range[2]])
                boundary_points.append(batch_boundary_points[0])
                success += tf.reduce_sum(tf.cast(batch_success, tf.float32))

            boundary_points = tf.concat(boundary_points, axis=0)
            success /= x.shape[0]

            print(
                f">>> Attacking with EPS={search_range[2]} (norm={search_range[1]}), Success Rate={success} <<<"
            )

            # y_pred = np.argmax(
            #     model.predict(boundary_points, batch_size=batch_size), -1)

            x = x.numpy()
            y_sparse = y_sparse.numpy()
            boundary_points = boundary_points.numpy()

        # bd, dis2cls_bd = take_closer_bd(x, np.argmax(y_sparse, -1), bd,
        #                                 dis2cls_bd, boundary_points,
        #                                 np.argmax(y_pred, -1))

    return convert_to_numpy(boundary_points)


def convert_to_numpy(x):
    """[summary]
    Reference: Boundary Attributions for Normal (Vector) Explanations https://arxiv.org/pdf/2103.11257.pdf
    
    https://github.com/zifanw/boundary
    """

    if not isinstance(x, np.ndarray):
        return x.numpy()
    else:
        return x
