#!/usr/bin/env python
# coding: utf-8

# In[1]:


import tensorflow
import tensorflow.compat.v1 as tf
import numpy as np
from setup_cifar import CIFAR, CIFARModel
from PIL import Image
import Utils_CIFAR as util
import numpy as np
import matplotlib.pyplot as plt
import random

import os
import sys
from setting import idx

beta = 'Error you should specify eta'

if '--beta' in sys.argv:
    index_beta = sys.argv.index('--beta')
    try:
        beta = float(sys.argv[index_beta + 1])
    except (IndexError, ValueError):
        print("Invalid value for --eta. Using default.")

# beta = 1
# idx = [65, 67, 70, 75, 86, 113, 123, 129, 138, 149] # bird
idx = idx
path = r'./vrzht/'
pic_path = path + str(round(beta, 3)) + r'/'
result_path = path + 'results/'

paths = [path, pic_path, result_path]
for p in paths:
    if not os.path.exists(p):
        os.makedirs(p)


# In[2]:


def compute_batch_gradient(delta, imgs_idx, model, data, q, d, miu, s2):
    gradient_total = np.zeros_like(delta)
    for image_id in imgs_idx:
        _, orig_class, _ = util.model_prediction(model, np.expand_dims(data.test_data[image_id],axis=0))
        target_label = orig_class
        orig_img, y = util.generate_data(data,image_id,target_label)
        true_label_list = np.argmax(data.test_labels, axis=1)
        true_label = true_label_list[image_id]

        gradient_i = np.zeros((q,d))
        for i in range(q):
            X_adv = orig_img + delta
            u = util.generate_u(s2, d)
            predict_1 = util.f(model, X_adv+miu*u,true_label)
            predict_2 = util.f(model, X_adv,true_label)
            u = np.reshape(u, (1, d))
            gradient_i[i] = d/miu*(predict_1-predict_2)*u
        gradient = np.sum(gradient_i,axis=0)/q
        gradient = np.reshape(gradient,(1,32,32,3))
        gradient_total += gradient
    return gradient_total / len(imgs_idx)
    
def compute_batch_loss(delta, imgs_idx, model, data):
    total_loss = 0
    for image_id in imgs_idx:
        _, orig_class, _ = util.model_prediction(model, np.expand_dims(data.test_data[image_id],axis=0))
        target_label = orig_class
        orig_img, y = util.generate_data(data,image_id,target_label)
        true_label_list = np.argmax(data.test_labels, axis=1)
        true_label = true_label_list[image_id]
        X_adv = orig_img + delta
        adv_image = np.clip(X_adv,-0.5,0.5)
        total_loss += util.f(model, adv_image,true_label)
    return total_loss / len(imgs_idx)


# In[3]:


import warnings
warnings.filterwarnings("ignore")
from collections import namedtuple
Args = namedtuple('Args', ['iter', 'miu', 'd', 'k', 'eta', 'q', 's2', 'lname'])

args = Args(iter=100, miu=0.001, q=10, s2=3072, eta=0.01, k=60, d=3072, lname='./CIFAR/result.txt')
zomax = 600
m = 10
np.random.seed(42)
random.seed(42)
tf.set_random_seed(42)
RS = np.random.RandomState(42)

with tf.Session() as sess:
    data, model = CIFAR(), CIFARModel('./models/cifar',sess,True)

    # take all the images of dogs: 
    imgs_idx = idx
    use_log = True

    succ_count, ii, iii = 0, 0, 0

    image_number = len(imgs_idx)
    l2_distortion_collect = np.zeros(image_number)
    attack_succ_count = np.zeros(image_number)
    cc = 0
    cc2 = 0

    delta = np.zeros((1,32,32,3))
    hist_loss = []
    hist_zo = []
    hist_nht = []
    nizo = 0
    nht = 0


    total_loss = compute_batch_loss(delta, imgs_idx, model, data)
    print(f"total loss: {total_loss}")
    hist_loss.append(total_loss.item())
    hist_zo.append(nizo)
    hist_nht.append(nht)


    while nizo < zomax:

        inner_its = 0
        anchor = delta + 0.

        full_grad = compute_batch_gradient(anchor, imgs_idx, model, data, args.q, args.d, args.miu, args.s2)
        nizo += len(imgs_idx) * (args.q + 1) 


        while (inner_its < m) and (nizo < zomax):
            attack_flag = False
            image_id = RS.choice(imgs_idx)
            orig_prob, orig_class, orig_prob_str = util.model_prediction(model,
                                                                            np.expand_dims(data.test_data[image_id],axis=0))
            target_label = orig_class
            orig_img, target = util.generate_data(data,image_id,target_label)
            true_label_list = np.argmax(data.test_labels, axis=1)
            true_label = true_label_list[image_id]
            with open(args.lname,'a+') as f:
                f.write("\n Image ID:{}, infer label:{}, true label:{} \n".format(image_id, orig_class, true_label))
            print("Image ID:{}, infer label:{}, true label:{}".format(image_id, orig_class, true_label))
            if true_label != orig_class:
                raise("True Label is different from the original prediction, pass!")

            count = 0

            adv_image = orig_img + delta
            current_gradient = util.compute_gradient(model, adv_image, true_label,args.s2,args.miu,args.q, args.d)
            nizo += args.q + 1

            anchor_gradient = util.compute_gradient(model, orig_img + anchor, true_label,args.s2,args.miu,args.q, args.d)
            nizo += args.q + 1

            delta_tmp = delta + 0.
            delta_tmp = delta_tmp - args.eta * (beta*(current_gradient - anchor_gradient) + full_grad)

            delta_tmp = np.reshape(delta_tmp, (args.d))
            top_k_idx = np.argsort(-np.abs(delta_tmp))[0:args.k]
            delta = np.zeros_like(delta_tmp)



            delta[top_k_idx] = delta_tmp[top_k_idx]


            inner_its += 1


            nht += 1
            l2_dist = np.linalg.norm(delta, ord=2, keepdims=False)
            l0_num = 0
            for dim in range(args.d):
                if delta[dim] != 0:
                    l0_num = l0_num + 1
            l0_dist = l0_num / args.d

            delta = np.reshape(delta, (1, 32,32,3))
            adv_image = np.clip(orig_img + delta,-0.5,0.5)
            attack_prob, attack_predict_class,_ = util.model_prediction(model, adv_image)
            if (nizo + 1) % 1 == 0:
                if true_label != attack_predict_class:
                    with open(args.lname, 'a+') as f:
                        f.write("Iter %d (Succ): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d \n" % (
                                nht + 1,image_id, l0_dist, l2_dist, true_label, attack_predict_class))
                    print("Iter %d (Succ): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d" % (
                        nht + 1,image_id, l0_dist, l2_dist, true_label, attack_predict_class))
                    attack_flag = True
                    count = count + 1
                else:
                    with open(args.lname, 'a+') as f:
                        f.write("Iter %d (Fail): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d \n" % (
                                nht + 1,image_id, l0_dist, l2_dist, true_label, attack_predict_class))
                    print("Iter %d (Fail): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d" % (
                        nht + 1, image_id, l0_dist, l2_dist, true_label, attack_predict_class))
                total_loss = compute_batch_loss(delta, imgs_idx, model, data)
                print(f"total loss: {total_loss}")
                hist_loss.append(total_loss.item())
                hist_zo.append(nizo)
                hist_nht.append(nht)



# # Inference

# In[4]:


# at the end of training, we save the images in our folder
label_list = {0: 'airplane',
              1: 'automobile',
              2: 'bird',
              3: 'cat',
              4: 'deer',
              5: 'dog',
              6: 'frog',
              7: 'horse',
              8: 'ship',
              9: 'truck'}
with tf.Session() as sess:
    data, model = CIFAR(), CIFARModel('./models/cifar',sess,True)

    for i in imgs_idx:
        orig_img, target = util.generate_data(data,i,target_label)
        adv_image = np.clip(orig_img + delta,-0.5,0.5)
        resh = np.reshape(adv_image, (32, 32, 3))



        im = Image.fromarray(((resh + 0.5)*255).astype(np.uint8), "RGB")

        _, orig_class, _ = util.model_prediction(model, np.expand_dims(orig_img[0],axis=0))
        print(f'----- sample {i} --------')
        print(f"original class: {label_list[orig_class]}")
        _, new_class, _ = util.model_prediction(model, np.expand_dims(adv_image[0],axis=0))
        print(f"predicted class: {label_list[new_class]}")
        print(f'------------------------')

        plt.figure()
        plt.imshow(im)
        plt.axis('off')
        plt.savefig(pic_path + f'attacked_{i}.jpg', bbox_inches='tight')


        adv_image = np.clip(orig_img,-0.5,0.5)
        resh = np.reshape(adv_image, (32, 32, 3))
        im = Image.fromarray(((resh + 0.5)*255).astype(np.uint8), "RGB")


        plt.figure()
        plt.imshow(im)
        plt.axis('off')
        plt.savefig(pic_path + f'original_{i}.jpg', bbox_inches='tight')

        print(f"l2 distortion: {np.linalg.norm(delta)} in input space")


print(delta)
resh = np.reshape(delta, (32, 32, 3))
im = Image.fromarray(((resh)*255).astype(np.uint8), "RGB")
print(np.count_nonzero(delta))


plt.figure()
plt.imshow(im)
plt.axis('off')
plt.savefig(pic_path + 'perturb.jpg', bbox_inches='tight')


# In[5]:


# pickle the results for further loading
import pickle

results = {'nizo': hist_zo, 
           'nht': hist_nht,
           'hist': hist_loss}
print(hist_loss)
with open(result_path + f'{str(round(beta, 3))}_vrzht_curves.pickle', 'wb') as file:
    pickle.dump(results, file)


# In[6]:


# here we plot the curves IZO and NHT: 
plt.figure()
plt.plot(hist_zo, hist_loss, linestyle='-', marker='^', markersize=5, label=f"lr: {args.eta}")
plt.legend()
plt.ylabel('F(w)')
plt.xlabel('# IZO')
# plt.savefig(result_path + str(round(beta, 3)) + "_IZO.png")
plt.show()


# In[7]:


plt.figure()
plt.plot(hist_nht, hist_loss, linestyle='-', marker='^', markersize=5, label=f"lr: {args.eta}")
plt.legend()
plt.ylabel('F(w)')
plt.xlabel('# NHT')
# plt.savefig(result_path + str(round(beta, 3)) + "_NHT.png")
plt.show()


# In[ ]:




