import tensorflow as tf
import os
import numpy as np
import torch
import argparse

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--iter', default=1000, type=int)
    parser.add_argument('--d', default=3072, type=int)
    return parser.parse_args()

def del_element(arr,index):
    arr1 = arr[0:index]
    arr2= arr[index+1:]
    return np.concatenate((arr1,arr2),axis=0)

def generate_u(s2):
    args = get_args()
    # sample u_i and normlize
    rv = np.random.randn(1,s2)
    u_norm = np.linalg.norm(rv, keepdims=None)
    rv_norm = rv / u_norm
    #extend to size 1*d
    index = np.random.randint(1,args.d,size=s2)
    u = np.zeros(args.d)
    u[index] = rv_norm
    u = np.reshape(u,(1,32,32,3))
    return u

def f(model, X_adv,y):
    output = model.model.predict(X_adv)
    real = output[:,y]
    other_all = output.copy()
    other_all[:,y]=0
    other = np.amax(other_all,axis=1)
    loss = np.maximum(np.log(real.astype(np.float64)+1e-10)-np.log(other.astype(np.float64)+1e-10),-1e-10)
    return loss

def compute_gradient(model,X_adv,y,s2,miu,q):
    args = get_args()
    gradient_i = np.zeros((q,args.d))
    for i in range(q):
        u = generate_u(s2)
        predict_1 = f(model, X_adv+miu*u,y)
        predict_2 = f(model, X_adv,y)
        u = np.reshape(u, (1, args.d))
        gradient_i[i] = args.d/miu*(predict_1-predict_2)*u
    gradient = np.sum(gradient_i,axis=0)/q
    gradient = np.reshape(gradient,(1,32,32,3))
    return gradient


def generate_data(data, id, target_label):
    inputs = []
    target_vec = []

    inputs.append(data.test_data[id])
    target_vec.append(np.eye(data.test_labels.shape[1])[target_label])

    inputs = np.array(inputs)
    target_vec = np.array(target_vec)

    return inputs, target_vec

def model_prediction(model, inputs):
    prob = model.model.predict(inputs)
    predicted_class = np.argmax(prob)
    prob_str = np.array2string(prob).replace('\n','')
    return prob, predicted_class, prob_str