#!/usr/bin/env python
# coding: utf-8

# In[25]:


import tensorflow
import tensorflow.compat.v1 as tf
import numpy as np
from setup_cifar import CIFAR, CIFARModel
from PIL import Image
import Utils_CIFAR as util
import numpy as np
import matplotlib.pyplot as plt
import random

import os
import sys
from setting import idx

beta = 'Error you should specify eta'

if '--beta' in sys.argv:
    index_beta = sys.argv.index('--beta')
    try:
        beta = float(sys.argv[index_beta + 1])
    except (IndexError, ValueError):
        print("Invalid value for --eta. Using default.")

# beta = 1
# idx = [65, 67, 70, 75, 86, 113, 123, 129, 138, 149] # bird
idx = idx
path = r'./saga/'
pic_path = path + str(round(beta, 3)) + r'/'
result_path = path + 'results/'

paths = [path, pic_path, result_path]
for p in paths:
    if not os.path.exists(p):
        os.makedirs(p)


# In[26]:


def compute_batch_gradient(delta, imgs_idx, model, data, q, d, miu, s2):
    gradient_total = np.zeros(d)
    for image_id in imgs_idx:
        _, orig_class, _ = util.model_prediction(model, np.expand_dims(data.test_data[image_id],axis=0))
        target_label = orig_class
        orig_img, y = util.generate_data(data,image_id,target_label)
        true_label_list = np.argmax(data.test_labels, axis=1)
        true_label = true_label_list[image_id]

        gradient_i = np.zeros((q,d))
        for i in range(q):
            X_adv = orig_img + delta
            u = util.generate_u(s2, d)
            predict_1 = util.f(model, X_adv+miu*u,true_label)
            predict_2 = util.f(model, X_adv,true_label)
            u = np.reshape(u, (1, d))
            gradient_i[i] = d/miu*(predict_1-predict_2)*u
        gradient = np.sum(gradient_i,axis=0)/q
        gradient = np.reshape(gradient,(1,32,32,3))
        gradient_total += gradient
    return gradient_total / len(imgs_idx)
    
def compute_batch_loss(delta, imgs_idx, model, data):
    total_loss = 0
    for image_id in imgs_idx:
        _, orig_class, _ = util.model_prediction(model, np.expand_dims(data.test_data[image_id],axis=0))
        target_label = orig_class
        orig_img, y = util.generate_data(data,image_id,target_label)
        true_label_list = np.argmax(data.test_labels, axis=1)
        true_label = true_label_list[image_id]
        X_adv = orig_img + delta
        adv_image = np.clip(X_adv,-0.5,0.5)
        total_loss += util.f(model, adv_image,true_label)
    return total_loss / len(imgs_idx)


# In[27]:


import warnings
warnings.filterwarnings("ignore")
from collections import namedtuple
Args = namedtuple('Args', ['iter', 'miu', 'd', 'k', 'eta', 'q', 's2', 'lname'])
args = Args(iter=100, miu=0.001, q=10, s2=3072, eta=0.05, k=60, d=3072, lname='./CIFAR/result.txt')
zomax = 600
np.random.seed(42)
random.seed(42)
tf.set_random_seed(42)
RS = np.random.RandomState(42)

with tf.Session() as sess:
    data, model = CIFAR(), CIFARModel('./models/cifar',sess,True)
    imgs_idx = idx
    use_log = True

    succ_count, ii, iii = 0, 0, 0

    image_number = len(imgs_idx)
    l2_distortion_collect = np.zeros(image_number)
    attack_succ_count = np.zeros(image_number)
    cc = 0
    cc2 = 0

    delta = np.zeros((1,32,32,3))
    hist_loss = []
    hist_zo = []
    hist_nht = []
    nizo = 0
    nht = 0

    total_loss = compute_batch_loss(delta, imgs_idx, model, data)
    print(f"total loss: {total_loss}")
    hist_loss.append(total_loss.item())
    hist_zo.append(nizo)
    hist_nht.append(nht)
    firstit = True
    while nizo < zomax:


        if nizo == 0:
            # If it is the first iteration, we need to initialize the table first for SAGA
            table_grad = np.zeros((len(imgs_idx), *delta.shape))
            print(table_grad.shape)
            for idxx in range(len(imgs_idx)):
                attack_flag = False
                image_id = imgs_idx[idxx]

                orig_prob, orig_class, orig_prob_str = util.model_prediction(model,
                                                                                np.expand_dims(data.test_data[image_id],axis=0))
                target_label = orig_class
                orig_img, target = util.generate_data(data,image_id,target_label)
                true_label_list = np.argmax(data.test_labels, axis=1)
                true_label = true_label_list[image_id]
                adv_image = orig_img + delta

                current_gradient = util.compute_gradient(model, adv_image, true_label,args.s2,args.miu,args.q, args.d)

                table_grad[idxx] = current_gradient
                nizo += args.q + 1

            grad_avg = np.mean(table_grad, axis=0)
            gradient = grad_avg +0.


        else:

            attack_flag = False
            image_idxx = RS.choice(range(len(imgs_idx)))
            image_id = imgs_idx[image_idxx]

            orig_prob, orig_class, orig_prob_str = util.model_prediction(model,
                                                                            np.expand_dims(data.test_data[image_id],axis=0))
            target_label = orig_class
            orig_img, target = util.generate_data(data,image_id,target_label)
            true_label_list = np.argmax(data.test_labels, axis=1)
            true_label = true_label_list[image_id]
            with open(args.lname,'a+') as f:
                f.write("\n Image ID:{}, infer label:{}, true label:{} \n".format(image_id, orig_class, true_label))
            print("Image ID:{}, infer label:{}, true label:{}".format(image_id, orig_class, true_label))
            if true_label != orig_class:
                raise("True Label is different from the original prediction, pass!")
            # count = 0
            adv_image = orig_img + delta



            old_gradient = table_grad[image_idxx]
            current_gradient = util.compute_gradient(model, adv_image, true_label,args.s2,args.miu,args.q, args.d)
            gradient = beta*(current_gradient - old_gradient) + grad_avg
            table_grad[image_idxx] = current_gradient
            grad_avg = grad_avg + 1/len(imgs_idx) * beta*(current_gradient - old_gradient)
            nizo += args.q + 1


        delta_tmp = delta + 0.
        delta_tmp = delta_tmp - args.eta * gradient
        delta_tmp = np.reshape(delta_tmp, (args.d))
        top_k_idx = np.argsort(-np.abs(delta_tmp))[0:args.k]
        delta = np.zeros_like(delta_tmp)

        delta[top_k_idx] = delta_tmp[top_k_idx]
        nht += 1
        l2_dist = np.linalg.norm(delta, ord=2, keepdims=False)
        l0_num = 0
        for dim in range(args.d):
            if delta[dim] != 0:
                l0_num = l0_num + 1
        l0_dist = l0_num / args.d

        delta = np.reshape(delta, (1, 32,32,3))
        adv_image = np.clip(orig_img + delta,-0.5,0.5)
        attack_prob, attack_predict_class,_ = util.model_prediction(model, adv_image)
        if (nizo + 1) % 1 == 0:
            if true_label != attack_predict_class:
                with open(args.lname, 'a+') as f:
                    f.write("Iter %d (Succ): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d \n" % (
                            nht + 1,image_id, l0_dist, l2_dist, true_label, attack_predict_class))
                print("Iter %d (Succ): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d" % (
                    nht + 1,image_id, l0_dist, l2_dist, true_label, attack_predict_class))
                attack_flag = True
                # count = count + 1
            else:
                with open(args.lname, 'a+') as f:
                    f.write("Iter %d (Fail): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d \n" % (
                            nht + 1,image_id, l0_dist, l2_dist, true_label, attack_predict_class))
                print("Iter %d (Fail): ID = %d, l0_dist=%3.5f, l2_dist=%3.5f, TL = %d, PL = %d" % (
                    nht + 1, image_id, l0_dist, l2_dist, true_label, attack_predict_class))
            total_loss = compute_batch_loss(delta, imgs_idx, model, data)
            print(f"total loss: {total_loss}")
            hist_loss.append(total_loss.item())
            hist_zo.append(nizo)
            hist_nht.append(nht)


# # Inference

# In[28]:


# label_list = {0: 'airplane',
#               1: 'automobile',
#               2: 'bird',
#               3: 'cat',
#               4: 'deer',
#               5: 'dog',
#               6: 'frog',
#               7: 'horse',
#               8: 'ship',
#               9: 'truck'}
# with tf.Session() as sess:
#     data, model = CIFAR(), CIFARModel('./models/cifar',sess,True)
    


#     for i in imgs_idx:
#         orig_img, target = util.generate_data(data,i,target_label)
#         adv_image = np.clip(orig_img + delta,-0.5,0.5)
#         resh = np.reshape(adv_image, (32, 32, 3))



#         im = Image.fromarray(((resh + 0.5)*255).astype(np.uint8), "RGB")
#         _, orig_class, _ = util.model_prediction(model, np.expand_dims(orig_img[0],axis=0))
#         print(f'----- sample {i} --------')
#         print(f"original class: {label_list[orig_class]}")
#         _, new_class, _ = util.model_prediction(model, np.expand_dims(adv_image[0],axis=0))
#         print(f"predicted class: {label_list[new_class]}")
#         print(f'------------------------')

#         plt.figure()
#         plt.imshow(im)
#         plt.axis('off')
#         plt.savefig(pic_path + f'attacked_{i}.jpg', bbox_inches='tight')


#         adv_image = np.clip(orig_img,-0.5,0.5)
#         resh = np.reshape(adv_image, (32, 32, 3))
#         im = Image.fromarray(((resh + 0.5)*255).astype(np.uint8), "RGB")


#         plt.figure()
#         plt.imshow(im)
#         plt.axis('off')
#         plt.savefig(pic_path + f'original_{i}.jpg', bbox_inches='tight')
#         print(f"l2 distortion: {np.linalg.norm(delta)} in input space")


# print(delta)
# resh = np.reshape(delta, (32, 32, 3))
# im = Image.fromarray(((resh)*255).astype(np.uint8), "RGB")
# print(np.count_nonzero(delta))

# plt.figure()
# plt.imshow(im)
# plt.axis('off')
# plt.savefig(pic_path + 'perturb.jpg', bbox_inches='tight')


# In[31]:


import pickle

results = {'nizo': hist_zo, 
           'nht': hist_nht,
           'hist': hist_loss}
print(hist_loss)
with open(result_path + f'{str(round(beta, 3))}_saga_curves.pickle', 'wb') as file:
    pickle.dump(results, file)


# In[29]:


# plt.figure()
# plt.plot(hist_zo, hist_loss, linestyle='-', marker='^', markersize=5, label=f"lr: {args.eta}")
# plt.legend()
# plt.ylabel('F(w)')
# plt.xlabel('# IZO')
# # plt.savefig(result_path + str(round(beta, 3)) + "_IZO.png")
# plt.show()


# # In[30]:


# plt.figure()
# plt.plot(hist_nht, hist_loss, linestyle='-', marker='^', markersize=5, label=f"lr: {args.eta}")
# plt.legend()
# plt.ylabel('F(w)')
# plt.xlabel('# NHT')
# # plt.savefig(result_path + str(round(beta, 3)) + "_NHT.png")
# plt.show()


# In[ ]:





# In[ ]:





# In[ ]:





# In[ ]:




