#!/usr/bin/env python
# coding: utf-8

# In[25]:


import tensorflow
import tensorflow.compat.v1 as tf
import numpy as np
from setup_cifar import CIFAR, CIFARModel
from PIL import Image
import Utils_CIFAR as util
import numpy as np
import matplotlib.pyplot as plt
import random
import sys
from settings import imgs_idx

# Default value for eta
eta = 'Error you should specify eta'

# Check if '--eta' is provided in the command line arguments
if '--eta' in sys.argv:
    # Get the index of '--eta' in the argument list
    index_eta = sys.argv.index('--eta')

    # Try to get the value following '--eta'
    try:
        eta = float(sys.argv[index_eta + 1])
    except (IndexError, ValueError):
        print("Invalid value for --eta. Using default.")


def compute_batch_gradient(delta, imgs_idx, model, data, q, d, miu, s2):
    gradient_total = np.zeros(d)
    for image_id in imgs_idx:
        _, orig_class, _ = util.model_prediction(model, np.expand_dims(data.test_data[image_id],axis=0))
        target_label = orig_class
        orig_img, y = util.generate_data(data,image_id,target_label)
        true_label_list = np.argmax(data.test_labels, axis=1)
        true_label = true_label_list[image_id]

        gradient_i = np.zeros((q,d))
        for i in range(q):
            X_adv = orig_img + delta
            u = util.generate_u(s2, d)
            predict_1 = util.f(model, X_adv+miu*u,true_label)
            predict_2 = util.f(model, X_adv,true_label)
            u = np.reshape(u, (1, d))
            gradient_i[i] = d/miu*(predict_1-predict_2)*u
        gradient = np.sum(gradient_i,axis=0)/q
        gradient = np.reshape(gradient,(1,32,32,3))
        gradient_total += gradient
    return gradient_total / len(imgs_idx)
    
def compute_batch_loss(delta, imgs_idx, model, data):
    total_loss = 0
    for image_id in imgs_idx:
        _, orig_class, _ = util.model_prediction(model, np.expand_dims(data.test_data[image_id],axis=0))
        target_label = orig_class
        orig_img, y = util.generate_data(data,image_id,target_label)
        true_label_list = np.argmax(data.test_labels, axis=1)
        true_label = true_label_list[image_id]
        X_adv = orig_img + delta
        adv_image = np.clip(X_adv,-0.5,0.5)
        total_loss += util.f(model, adv_image,true_label)
    return total_loss / len(imgs_idx)


# In[27]:


import warnings
warnings.filterwarnings("ignore")
from collections import namedtuple
Args = namedtuple('Args', ['iter', 'miu', 'd', 'k', 'eta', 'q', 's2', 'lname'])
args = Args(iter=100, miu=0.001, q=10, s2=3072, eta=eta, k=60, d=3072, lname='./CIFAR/result_saga.txt')
# args = Args(iter=100, miu=0.001, q=10, s2=3072, eta=0.01, k=60, d=3072, lname='./CIFAR/result_saga.txt')
zomax = 600
np.random.seed(42)
random.seed(42)
tf.set_random_seed(42)
RS = np.random.RandomState(42)

with tf.Session() as sess:
    data, model = CIFAR(), CIFARModel('./models/cifar',sess,True)
    
    use_log = True

    succ_count, ii, iii = 0, 0, 0

    image_number = len(imgs_idx)
    l2_distortion_collect = np.zeros(image_number)
    attack_succ_count = np.zeros(image_number)
    cc = 0
    cc2 = 0

    delta = np.zeros((1,32,32,3))
    hist_loss = []
    hist_zo = []
    hist_nht = []
    nizo = 0
    nht = 0

    total_loss = compute_batch_loss(delta, imgs_idx, model, data)
    print(f"total loss: {total_loss}")
    hist_loss.append(total_loss.item())
    hist_zo.append(nizo)
    hist_nht.append(nht)
    firstit = True
    min_loss = np.inf
    while nizo < zomax:


        if nizo == 0:
            # If it is the first iteration, we need to initialize the table first for SAGA
            table_grad = np.zeros((len(imgs_idx), *delta.shape))
            print(table_grad.shape)
            for idxx in range(len(imgs_idx)):
                attack_flag = False
                image_id = imgs_idx[idxx]

                orig_prob, orig_class, orig_prob_str = util.model_prediction(model,
                                                                                np.expand_dims(data.test_data[image_id],axis=0))
                target_label = orig_class
                orig_img, target = util.generate_data(data,image_id,target_label)
                true_label_list = np.argmax(data.test_labels, axis=1)
                true_label = true_label_list[image_id]
                adv_image = orig_img + delta

                current_gradient = util.compute_gradient(model, adv_image, true_label,args.s2,args.miu,args.q, args.d)

                table_grad[idxx] = current_gradient
                nizo += args.q + 1

            grad_avg = np.mean(table_grad, axis=0)
            gradient = grad_avg +0.


        else:

            attack_flag = False
            image_idxx = RS.choice(range(len(imgs_idx)))
            image_id = imgs_idx[image_idxx]

            orig_prob, orig_class, orig_prob_str = util.model_prediction(model,
                                                                            np.expand_dims(data.test_data[image_id],axis=0))
            target_label = orig_class
            orig_img, target = util.generate_data(data,image_id,target_label)
            true_label_list = np.argmax(data.test_labels, axis=1)
            true_label = true_label_list[image_id]
            with open(args.lname,'a+') as f:
                f.write("\n Image ID:{}, infer label:{}, true label:{} \n".format(image_id, orig_class, true_label))
            print("Image ID:{}, infer label:{}, true label:{}".format(image_id, orig_class, true_label))
            if true_label != orig_class:
                raise("True Label is different from the original prediction, pass!")
            adv_image = orig_img + delta



            old_gradient = table_grad[image_idxx]
            current_gradient = util.compute_gradient(model, adv_image, true_label,args.s2,args.miu,args.q, args.d)
            gradient = current_gradient - old_gradient + grad_avg
            table_grad[image_idxx] = current_gradient
            grad_avg = grad_avg + 1/len(imgs_idx) * (current_gradient - old_gradient)
            nizo += args.q + 1


        delta_tmp = delta + 0.
        delta_tmp = delta_tmp - args.eta * gradient
        delta_tmp = np.reshape(delta_tmp, (args.d))
        top_k_idx = np.argsort(-np.abs(delta_tmp))[0:args.k]
        delta = np.zeros_like(delta_tmp)

        delta[top_k_idx] = delta_tmp[top_k_idx]
        nht += 1
        l2_dist = np.linalg.norm(delta, ord=2, keepdims=False)
        l0_num = 0
        for dim in range(args.d):
            if delta[dim] != 0:
                l0_num = l0_num + 1
        l0_dist = l0_num / args.d

        delta = np.reshape(delta, (1, 32,32,3))
        adv_image = np.clip(orig_img + delta,-0.5,0.5)
        attack_prob, attack_predict_class,_ = util.model_prediction(model, adv_image)
        if (nizo + 1) % 1 == 0:
            if true_label != attack_predict_class:
                with open(args.lname, 'a+') as f:
                    f.write("Iter %d (Succ): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d \n" % (
                            nht + 1,image_id, l0_dist, l2_dist, true_label, attack_predict_class))
                print("Iter %d (Succ): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d" % (
                    nht + 1,image_id, l0_dist, l2_dist, true_label, attack_predict_class))
                attack_flag = True
            else:
                with open(args.lname, 'a+') as f:
                    f.write("Iter %d (Fail): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d \n" % (
                            nht + 1,image_id, l0_dist, l2_dist, true_label, attack_predict_class))
                print("Iter %d (Fail): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d" % (
                    nht + 1, image_id, l0_dist, l2_dist, true_label, attack_predict_class))
            total_loss = compute_batch_loss(delta, imgs_idx, model, data)
            if total_loss < min_loss:
                min_loss = total_loss
                min_delta = delta
            print(f"total loss: {total_loss}")
            hist_loss.append(total_loss.item())
            hist_zo.append(nizo)
            hist_nht.append(nht)


# # Inference

# In[28]:


delta = min_delta
# at the end of training, we save the images in our folder
label_list = {0: 'airplane',
              1: 'automobile',
              2: 'bird',
              3: 'cat',
              4: 'deer',
              5: 'dog',
              6: 'frog',
              7: 'horse',
              8: 'ship',
              9: 'truck'}
with tf.Session() as sess:
    data, model = CIFAR(), CIFARModel('./models/cifar',sess,True)
    


    for i in imgs_idx:
        orig_img, target = util.generate_data(data,i,target_label)
        adv_image = np.clip(orig_img + delta,-0.5,0.5)
        resh = np.reshape(adv_image, (32, 32, 3))



        im = Image.fromarray(((resh + 0.5)*255).astype(np.uint8), "RGB")
        _, orig_class, _ = util.model_prediction(model, np.expand_dims(orig_img[0],axis=0))
        print(f'----- sample {i} --------')
        print(f"original class: {label_list[orig_class]}")
        _, new_class, _ = util.model_prediction(model, np.expand_dims(adv_image[0],axis=0))
        print(f"predicted class: {label_list[new_class]}")
        print(f'------------------------')

        plt.figure()
        plt.imshow(im)
        plt.axis('off')
        plt.savefig(f'./saga/{eta}_attacked_{i}_{label_list[new_class]}.jpg', bbox_inches='tight')


        adv_image = np.clip(orig_img,-0.5,0.5)
        resh = np.reshape(adv_image, (32, 32, 3))
        im = Image.fromarray(((resh + 0.5)*255).astype(np.uint8), "RGB")


        plt.figure()
        plt.imshow(im)
        plt.axis('off')
        plt.savefig(f'./saga/{eta}_original_{i}.jpg', bbox_inches='tight')
        print(f"l2 distortion: {np.linalg.norm(delta)} in input space")


print(delta)
resh = np.reshape(delta, (32, 32, 3))
im = Image.fromarray(((resh)*255).astype(np.uint8), "RGB")
print(np.count_nonzero(delta))

plt.figure()
plt.imshow(im)
plt.axis('off')
plt.savefig(f'./saga/{eta}_perturb.jpg', bbox_inches='tight')


# In[29]:


plt.figure()
plt.plot(hist_zo, hist_loss, linestyle='-', marker='^', markersize=5, label=f"lr: {args.eta}")
plt.legend()
plt.ylabel('F(w)')
plt.xlabel('# IZO')
plt.show()


# In[30]:


plt.figure()
plt.plot(hist_nht, hist_loss, linestyle='-', marker='^', markersize=5, label=f"lr: {args.eta}")
plt.legend()
plt.ylabel('F(w)')
plt.xlabel('# NHT')
plt.show()


# In[31]:


import pickle

results = {'nizo': hist_zo, 
           'nht': hist_nht,
           'hist': hist_loss}
with open(f'./saga/{eta}_curves.pickle', 'wb') as file:
    pickle.dump(results, file)


# In[ ]:





# In[ ]:





# In[ ]:





# In[ ]:




