#!/usr/bin/env python
# coding: utf-8

# In[8]:


import tensorflow
import tensorflow.compat.v1 as tf
import numpy as np
from setup_cifar import CIFAR, CIFARModel
from PIL import Image
import Utils_CIFAR as util
import numpy as np
import matplotlib.pyplot as plt
import random
import sys
from settings import imgs_idx


# In[2]:




# Default value for eta
eta = 'Error you should specify eta'

# Check if '--eta' is provided in the command line arguments
if '--eta' in sys.argv:
    # Get the index of '--eta' in the argument list
    index_eta = sys.argv.index('--eta')

    # Try to get the value following '--eta'
    try:
        eta = float(sys.argv[index_eta + 1])
    except (IndexError, ValueError):
        print("Invalid value for --eta. Using default.")


def compute_batch_gradient(delta, imgs_idx, model, data, q, d, miu, s2):
    gradient_total = np.zeros_like(delta)
    for image_id in imgs_idx:
        _, orig_class, _ = util.model_prediction(model, np.expand_dims(data.test_data[image_id],axis=0))
        target_label = orig_class
        orig_img, y = util.generate_data(data,image_id,target_label)
        true_label_list = np.argmax(data.test_labels, axis=1)
        true_label = true_label_list[image_id]

        gradient_i = np.zeros((q,d))
        for i in range(q):
            X_adv = orig_img + delta
            u = util.generate_u(s2, d)
            predict_1 = util.f(model, X_adv+miu*u,true_label)
            predict_2 = util.f(model, X_adv,true_label)
            u = np.reshape(u, (1, d))
            gradient_i[i] = d/miu*(predict_1-predict_2)*u
        gradient = np.sum(gradient_i,axis=0)/q
        gradient = np.reshape(gradient,(1,32,32,3))
        gradient_total += gradient
    return gradient_total / len(imgs_idx)
    
def compute_batch_loss(delta, imgs_idx, model, data):
    total_loss = 0
    for image_id in imgs_idx:
        _, orig_class, _ = util.model_prediction(model, np.expand_dims(data.test_data[image_id],axis=0))
        target_label = orig_class
        orig_img, y = util.generate_data(data,image_id,target_label)
        true_label_list = np.argmax(data.test_labels, axis=1)
        true_label = true_label_list[image_id]
        X_adv = orig_img + delta
        adv_image = np.clip(X_adv,-0.5,0.5)
        total_loss += util.f(model, adv_image,true_label)
    return total_loss / len(imgs_idx)


# In[15]:


import warnings
warnings.filterwarnings("ignore")
from collections import namedtuple
Args = namedtuple('Args', ['iter', 'miu', 'd', 'k', 'eta', 'q', 's2', 'lname'])

args = Args(iter=100, miu=0.001, q=10, s2=3072, eta=eta, k=60, d=3072, lname=f'./CIFAR/result_sarah.txt')
# args = Args(iter=100, miu=0.001, q=10, s2=3072, eta=0.005, k=60, d=3072, lname='./CIFAR/result.txt')
zomax = 600
m = 10
np.random.seed(42)
random.seed(42)
tf.set_random_seed(42)
RS = np.random.RandomState(42)

with tf.Session() as sess:
    data, model = CIFAR(), CIFARModel('./models/cifar',sess,True)

    # take all the images of dogs: 
    use_log = True

    succ_count, ii, iii = 0, 0, 0

    image_number = len(imgs_idx)
    l2_distortion_collect = np.zeros(image_number)
    attack_succ_count = np.zeros(image_number)
    cc = 0
    cc2 = 0

    delta = np.zeros((1,32,32,3))
    hist_loss = []
    hist_zo = []
    hist_nht = []
    nizo = 0
    nht = 0


    total_loss = compute_batch_loss(delta, imgs_idx, model, data)
    print(f"total loss: {total_loss}")
    hist_loss.append(total_loss.item())
    hist_zo.append(nizo)
    hist_nht.append(nht)
    min_loss = np.inf
    while nizo < zomax:

        inner_its = 0
        anchor = delta + 0.

        full_grad = compute_batch_gradient(anchor, imgs_idx, model, data, args.q, args.d, args.miu, args.s2)
        nizo += len(imgs_idx) * (args.q + 1) 


        while (inner_its < m) and (nizo < zomax):

            if inner_its == 0:
                # If it is the 0-th iteration, we will actually do a true gradient step
                pass

            else:

                image_id = RS.choice(imgs_idx)
                orig_prob, orig_class, orig_prob_str = util.model_prediction(model,
                                                                                np.expand_dims(data.test_data[image_id],axis=0))
                target_label = orig_class
                orig_img, target = util.generate_data(data,image_id,target_label)
                true_label_list = np.argmax(data.test_labels, axis=1)
                true_label = true_label_list[image_id]
                with open(args.lname,'a+') as f:
                    f.write("\n Image ID:{}, infer label:{}, true label:{} \n".format(image_id, orig_class, true_label))
                print("Image ID:{}, infer label:{}, true label:{}".format(image_id, orig_class, true_label))
                if true_label != orig_class:
                    raise("True Label is different from the original prediction, pass!")

                count = 0

                adv_image = orig_img + delta
                current_gradient = util.compute_gradient(model, adv_image, true_label,args.s2,args.miu,args.q, args.d)
                nizo += args.q + 1

                anchor_gradient = util.compute_gradient(model, orig_img + anchor, true_label,args.s2,args.miu,args.q, args.d)
                nizo += args.q + 1

                full_grad = current_gradient - anchor_gradient + full_grad

            anchor = delta + 0.

            delta_tmp = delta + 0. 
            delta_tmp = delta_tmp - args.eta * full_grad

            delta_tmp = np.reshape(delta_tmp, (args.d))
            top_k_idx = np.argsort(-np.abs(delta_tmp))[0:args.k]
            delta = np.zeros_like(delta_tmp)

            delta[top_k_idx] = delta_tmp[top_k_idx]

            inner_its += 1

            nht += 1
            l2_dist = np.linalg.norm(delta, ord=2, keepdims=False)
            l0_num = 0
            for dim in range(args.d):
                if delta[dim] != 0:
                    l0_num = l0_num + 1
            l0_dist = l0_num / args.d

            delta = np.reshape(delta, (1, 32,32,3))
            if (nizo + 1) % 1 == 0:
                total_loss = compute_batch_loss(delta, imgs_idx, model, data)
                if total_loss < min_loss:
                    min_loss = total_loss
                    min_delta = delta
                print(f"total loss: {total_loss}")
                hist_loss.append(total_loss.item())
                hist_zo.append(nizo)
                hist_nht.append(nht)



# # Inference

# In[16]:


delta = min_delta
# at the end of training, we save the images in our folder
label_list = {0: 'airplane',
              1: 'automobile',
              2: 'bird',
              3: 'cat',
              4: 'deer',
              5: 'dog',
              6: 'frog',
              7: 'horse',
              8: 'ship',
              9: 'truck'}
with tf.Session() as sess:
    data, model = CIFAR(), CIFARModel('./models/cifar',sess,True)

    for i in imgs_idx:
        orig_img, target = util.generate_data(data,i,target_label)
        adv_image = np.clip(orig_img + delta,-0.5,0.5)
        resh = np.reshape(adv_image, (32, 32, 3))



        im = Image.fromarray(((resh + 0.5)*255).astype(np.uint8), "RGB")

        _, orig_class, _ = util.model_prediction(model, np.expand_dims(orig_img[0],axis=0))
        print(f'----- sample {i} --------')
        print(f"original class: {label_list[orig_class]}")
        _, new_class, _ = util.model_prediction(model, np.expand_dims(adv_image[0],axis=0))
        print(f"predicted class: {label_list[new_class]}")
        print(f'------------------------')

        plt.figure()
        plt.imshow(im)
        plt.axis('off')
        plt.savefig(f'./sarah/{eta}_attacked_{i}_{label_list[new_class]}.jpg', bbox_inches='tight')


        adv_image = np.clip(orig_img,-0.5,0.5)
        resh = np.reshape(adv_image, (32, 32, 3))
        im = Image.fromarray(((resh + 0.5)*255).astype(np.uint8), "RGB")


        plt.figure()
        plt.imshow(im)
        plt.axis('off')
        plt.savefig(f'./sarah/{eta}_original_{i}.jpg', bbox_inches='tight')

        print(f"l2 distortion: {np.linalg.norm(delta)} in input space")


print(delta)
resh = np.reshape(delta, (32, 32, 3))
im = Image.fromarray(((resh)*255).astype(np.uint8), "RGB")
print(np.count_nonzero(delta))


plt.figure()
plt.imshow(im)
plt.axis('off')
plt.savefig(f'./sarah/{eta}_perturb.jpg', bbox_inches='tight')


# In[19]:


plt.figure()
plt.plot(hist_zo, hist_loss, linestyle='-', marker='^', markersize=5, label=f"lr: {args.eta}")
plt.legend()
plt.ylabel('F(w)')
plt.xlabel('# IZO')
plt.show()


# In[20]:


plt.figure()
plt.plot(hist_nht, hist_loss, linestyle='-', marker='^', markersize=5, label=f"lr: {args.eta}")
plt.legend()
plt.ylabel('F(w)')
plt.xlabel('# NHT')
plt.show()


# In[7]:


import pickle

results = {'nizo': hist_zo, 
           'nht': hist_nht,
           'hist': hist_loss}
with open(f'./sarah/{eta}_curves.pickle', 'wb') as file:
    pickle.dump(results, file)


# In[ ]:




