import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import utils
import math
import random
import argparse
import dill
import os
import time
import pdb
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from matplotlib import pyplot as plt
from modified_art.hop_skip_jump import HopSkipJump
from modified_art.boundary import BoundaryAttack
from attack.Sign_OPT import OPT_attack_sign_SGD
from attack.GeoDA import GeoDA_Attack
from attack.SimBA import SimBA_Attack
from attack.SSA import Subspace_Attack
from attack.Bandit import Bandit_Attack
import logging
logger = logging.getLogger(__name__)
from blackbox_model import BlackBoxModel
from modified_art.pytorch import PyTorchClassifier





parser = argparse.ArgumentParser(description='Attacks with the ImageNet dataset')

parser.add_argument('--random_seed', type=int, default=1, help='random_seeds')
parser.add_argument('--attack', type=str, default='SO', help='attack method')
parser.add_argument('--defense', type=str, default='gaussian', help='attack method')
parser.add_argument('--data_root', type=str, default='Path for the ImageNet dataset', help='root directory of imagenet data')
parser.add_argument('--result_dir', type=str, default='save', help='directory for saving results')
parser.add_argument('--sampled_image_dir', type=str, default='save', help='directory to cache sampled images')
parser.add_argument('--model', type=str, default='resnet50', help='type of target model to use')
parser.add_argument('--max_num_queries', type=int, default=20000, help='maximum number of queries')
parser.add_argument('--log_interval', type=int, default=1000, help='log interval')
parser.add_argument('--sigma', type=float, default=0.0, help='sigma of input gaussian noise')
parser.add_argument('--alpha', type=float, default=0, help='alpha value of beta distribution')
parser.add_argument('--beta', type=float, default=0, help='beta value of beta distribution')

parser.add_argument('--num_runs', type=int, default=4, help='number of image samples')
parser.add_argument('--attack_batch_size', type=int, default=100, help='number of batches for ART attacks')

parser.add_argument('--batch_size', type=int, default=1, help='batch size for parallel runs')
# Only support batch size = 1

parser.add_argument('--num_iters', type=int, default=0, help='maximum number of iterations, 0 for unlimited')
parser.add_argument('--log_every', type=int, default=10, help='log every n iterations')
parser.add_argument('--epsilon', type=float, default=0.2, help='step size per iteration')
parser.add_argument('--linf_bound', type=float, default=0.0, help='L_inf bound for frequency space attack')
parser.add_argument('--freq_dims', type=int, default=28, help='dimensionality of 2D frequency space')
parser.add_argument('--order', type=str, default='strided', help='(random) order of coordinate selection')
parser.add_argument('--stride', type=int, default=7, help='stride for block order')
parser.add_argument('--targeted', action='store_true', help='perform targeted attack')
parser.add_argument('--pixel_attack', action='store_true', help='attack in pixel space')

parser.add_argument('--save_suffix', type=str, default='', help='suffix appended to save file')

args = parser.parse_args()

print(args)
savefile = '%s/%s_%s_%d_%d_%d_%.4f_%s%s.pth' % (
    args.result_dir, args.attack, args.model, args.num_runs, args.num_iters, args.freq_dims, args.epsilon, args.order,
    args.save_suffix)
print('SAVE_FILE : ', savefile)

def master_seed(seed=1234, set_random=True, set_numpy=True, set_tensorflow=False, set_mxnet=False, set_torch=False):
    """
    This function is borrowed from ART library
    Set the seed for all random number generators used in the library. This ensures experiments reproducibility and
    stable testing.

    :param seed: The value to be seeded in the random number generators.
    :type seed: `int`
    :param set_random: The flag to set seed for `random`.
    :type set_random: `bool`
    :param set_numpy: The flag to set seed for `numpy`.
    :type set_numpy: `bool`
    :param set_tensorflow: The flag to set seed for `tensorflow`.
    :type set_tensorflow: `bool`
    :param set_mxnet: The flag to set seed for `mxnet`.
    :type set_mxnet: `bool`
    :param set_torch: The flag to set seed for `torch`.
    :type set_torch: `bool`
    """
    import numbers

    if not isinstance(seed, numbers.Integral):
        raise TypeError("The seed for random number generators has to be an integer.")

    # Set Python seed
    if set_random:
        import random

        random.seed(seed)

    # Set Numpy seed
    if set_numpy:
        np.random.seed(seed)
        np.random.RandomState(seed)

    # Now try to set seed for all specific frameworks
    if set_tensorflow:
        try:
            import tensorflow as tf

            logger.info("Setting random seed for TensorFlow.")
            if tf.__version__[0] == "2":
                tf.random.set_seed(seed)
            else:
                tf.set_random_seed(seed)
        except ImportError:
            logger.info("Could not set random seed for TensorFlow.")

    if set_mxnet:
        try:
            import mxnet as mx

            logger.info("Setting random seed for MXNet.")
            mx.random.seed(seed)
        except ImportError:
            logger.info("Could not set random seed for MXNet.")

    if set_torch:
        try:
            logger.info("Setting random seed for PyTorch.")
            import torch

            torch.manual_seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)
        except ImportError:
            logger.info("Could not set random seed for PyTorch.")
master_seed(args.random_seed,set_torch=True)



if not os.path.exists(args.result_dir):
    os.mkdir(args.result_dir)
if not os.path.exists(args.sampled_image_dir):
    os.mkdir(args.sampled_image_dir)

# load model and dataset
o_model = getattr(models, args.model)(pretrained=True)

if args.defense == 'AT':
    print("=> loading checkpoint '{}'".format('models/imagenet_l2_3_0.pt'))
    checkpoint = torch.load('models/imagenet_l2_3_0.pt', pickle_module=dill)
    # print(checkpoint)
    # Makes us able to load models saved with legacy versions
    state_dict_path = 'model'
    if not ('model' in checkpoint):
        state_dict_path = 'state_dict'
    sd = checkpoint[state_dict_path]
    sd = {k[len('module.'):].replace('model.', ''): v for k, v in sd.items()}
    o_model.load_state_dict(sd, strict=False)


ART_Attacks = ['HSJA', 'BA', 'ZOO']
#

model=BlackBoxModel(o_model,defense=args.defense).cuda()

model.eval()
if args.model.startswith('inception'):
    image_size = 299
    testset = dset.ImageFolder(args.data_root + '/val', utils.INCEPTION_TRANSFORM)
else:
    image_size = 224
    testset = dset.ImageFolder(args.data_root + '/val', utils.IMAGENET_TRANSFORM)

with torch.no_grad():
    # load sampled images or sample new ones
    # this is to ensure all attacks are run on the same set of correctly classified images
    batchfile = '%s/images_%s_%d.pth' % (args.sampled_image_dir, args.model, args.num_runs)
    if os.path.isfile(batchfile):
        checkpoint = torch.load(batchfile)
        images = checkpoint['images']
        labels = checkpoint['labels']
    else:
        images = torch.zeros(args.num_runs, 3, image_size, image_size)
        labels = torch.zeros(args.num_runs).long()
        preds = labels + 1
        while preds.ne(labels).sum() > 0:
            idx = torch.arange(0, images.size(0)).long()[preds.ne(labels)]
            for i in list(idx):
                images[i], labels[i] = testset[random.randint(0, len(testset) - 1)]
            preds[idx], _ = utils.get_preds(model, images[idx], 'imagenet', batch_size=args.batch_size)
        torch.save({'images': images, 'labels': labels}, batchfile)

if args.order == 'rand':
    n_dims = 3 * args.freq_dims * args.freq_dims
else:
    n_dims = 3 * image_size * image_size
if args.num_iters > 0:
    max_iters = int(min(n_dims, args.num_iters))
else:
    max_iters = int(n_dims)
N = int(math.floor(float(args.num_runs) / float(args.batch_size)))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
DATASET_MEAN = np.reshape(np.array(IMAGENET_MEAN),[1,3,1,1])
DATASET_STD =  np.reshape(np.array(IMAGENET_STD),[1,3,1,1])
classifier = PyTorchClassifier(
    model=model,
    clip_values=(0, 1),
    loss=criterion,
    optimizer=optimizer,
    input_shape=(3, 224, 224),
    nb_classes=1000,
    preprocessing=(DATASET_MEAN,DATASET_STD)
)


iter_step = 10


attack_setting={}
attack_setting['sigma']=args.sigma
attack_setting['attack']=args.attack
attack_setting['alpha']=args.alpha
attack_setting['beta']=args.beta
attack_setting['log_interval']=args.log_interval
attack_setting['max_num_queries']=args.max_num_queries
attack_setting['max_num_queries']=args.max_num_queries

total_log_query_point=np.zeros((args.num_runs,args.max_num_queries // args.log_interval))
total_log_l_2=np.zeros((args.num_runs,args.max_num_queries // args.log_interval))
total_log_l_inf=np.zeros((args.num_runs,args.max_num_queries // args.log_interval))
total_log_acc=np.zeros((args.num_runs,args.max_num_queries // args.log_interval))
total_log_prob=np.zeros((args.num_runs,args.max_num_queries // args.log_interval))
total_log_adv = []
total_log_l_2_query_count=np.zeros((args.num_runs))
total_log_l_inf_query_count=np.zeros((args.num_runs))
ART_Attacks_1 = ['HSJA', 'BA']
ART_Attacks_2 = ['ZOO']

preset_idx=[0,4,9,19]
timestart=time.time()
if args.attack not in ART_Attacks:
    if args.attack == 'SO':
        attack = OPT_attack_sign_SGD(model)
    elif args.attack == 'GD':
        attack = GeoDA_Attack(model)
    elif args.attack == 'SB' or args.attack == 'SBD':
        attack = SimBA_Attack(model)
    elif args.attack=='BD':
        attack=Bandit_Attack(model)
    elif args.attack == 'SSA':
        attack = Subspace_Attack(model)
for i in range(N):
    upper = min((i + 1) * args.batch_size, args.num_runs)
    images_batch = images[(i * args.batch_size):upper]
    labels_batch = labels[(i * args.batch_size):upper]
    x_adv = None


    if args.attack == 'HSJA':
        attack = HopSkipJump(classifier=classifier, max_queries=args.max_num_queries, targeted=False, max_iter=64,
                             max_eval=10000, init_eval=100)
    elif args.attack == 'BA':
        attack = BoundaryAttack(estimator=classifier, targeted=False)
    if args.attack in ART_Attacks:
        attack.batch_size = args.attack_batch_size


    print('IMG: ',(i+1))

    incorrect_pred = False
    if args.defense=='AT':
        model.init_model(attack_setting, images_batch[:1].cuda(), np.array(labels_batch[:1]))
        if model.predict_label(images_batch[:1].cuda()).cpu()!=labels_batch[:1]:
            print('Already incorrect')
            model.set_log(images_batch[:1].cuda(),unnormalization=False)
            incorrect_pred=True


    if incorrect_pred==False:
        if args.attack in ART_Attacks:
            ori_image= np.array(images_batch[:1])
            model.init_model(attack_setting,ori_image,np.array(labels_batch[:1]))
            if args.attack in ART_Attacks_1:
                while(model.get_num_queries()<=args.max_num_queries):
                    x_adv = attack.generate(x=ori_image, x_adv_init=x_adv,y= labels_batch[:1], resume=True)
            else:
                x_adv = attack.generate(x=ori_image)
            if x_adv is not None:
                model.set_log(torch.cuda.FloatTensor(x_adv), unnormalization=False)

        else:
            model.init_model(attack_setting,images_batch[:1].cuda(),np.array(labels_batch[:1]))
            if args.attack=='GD':
                adv = attack.attack_untargeted(images_batch[:1].cuda(), labels_batch[:1].cuda(), query_limit=args.max_num_queries)
            elif args.attack == 'SB' or args.attack == 'SBD':
                adv = attack.attack_untargeted(images_batch[:1], labels_batch[:1],args=args,
                                               query_limit=args.max_num_queries)
                if adv is not None:
                    model.set_log(adv.cuda(),unnormalization=False)
            elif args.attack == 'BD':
                adv = attack.attack_untargeted(images_batch[:1], labels_batch[:1])
                if adv is not None:
                    model.set_log(adv.cuda(), unnormalization=False)
            elif args.attack == 'SSA':
                adv = attack.attack_untargeted(images_batch[:1], labels_batch[:1],query_limit=args.max_num_queries)
                if adv is not None:
                    model.set_log(adv.cuda(),unnormalization=False)
            else:
                adv = attack(images_batch[:1].cuda(), labels_batch[:1].cuda(), query_limit=args.max_num_queries,TARGETED=False)
                if adv is not None:
                    model.set_log(adv.cuda())


    # x_adv_i=np.transpose(x_adv[0],[1,2,0])*255.0
    # plt.imshow(x_adv_i.astype(np.uint))
    # plt.show(block=False)

    (log_query_point,log_prob,log_acc,log_l_2,log_l_inf,log_l_2_query_count,log_l_inf_query_count,log_adv)=model.get_log()

    total_log_query_point[i,:]=log_query_point
    total_log_l_2[i,:]=log_l_2
    total_log_l_inf[i,:]=log_l_inf
    total_log_acc[i,:]=log_acc
    total_log_prob[i,:]=log_prob
    if log_adv is not None:
        adv_images=(log_adv * 255.0).byte()
    else:
        adv_images=(images_batch[:1] * 255.0).byte()
    # x_adv_i=np.transpose(adv_images[0].detach().cpu().numpy(),[1,2,0])
    # plt.imshow(x_adv_i.astype(np.uint))
    # plt.show(block=False)
    total_log_adv.append(adv_images)
    total_log_l_2_query_count[i] = log_l_2_query_count
    total_log_l_inf_query_count[i] = log_l_inf_query_count
    print('Acc: ', log_acc[preset_idx])
    print('l_2', log_l_2[preset_idx],'l_inf',log_l_inf[preset_idx])
    print('l_2_count: ', log_l_2_query_count, 'l_inf_count:',log_l_inf_query_count,flush=True)


savefile = '%s/%s_%s_%s_%.3f_%.1f_%.1f_%s.pth' % (
    args.result_dir, args.attack, args.model, args.defense, args.sigma, args.alpha, args.beta, args.save_suffix)
torch.save({'total_log_query_point': total_log_query_point,
            'total_log_l_2':total_log_l_2,'total_log_l_inf':total_log_l_inf,
            'total_log_acc':total_log_acc,'total_log_prob':total_log_prob,
            'total_log_adv':total_log_adv,
            'total_log_l_2_query_count':total_log_l_2_query_count,
            'total_log_l_inf_query_count':total_log_l_inf_query_count
            }, savefile)



print('-------------------------------------------------')
print('1000, 5000, 10000, 20000')


avg_log_acc=np.mean(total_log_acc,axis=0)
print('avg_log_acc',avg_log_acc[preset_idx])

median_log_l_2=np.median(total_log_l_2,axis=0)
avg_log_l_2=np.mean(total_log_l_2,axis=0)
print('avg_log_l_2',avg_log_l_2[preset_idx])
print('median_log_l_2',median_log_l_2[preset_idx])

median_log_l_inf=np.median(total_log_l_inf,axis=0)
avg_log_l_inf=np.mean(total_log_l_inf,axis=0)
print('avg_log_l_inf',avg_log_l_inf[preset_idx])
print('median_log_l_inf',median_log_l_inf[preset_idx])

filtered_l_2_query_count=total_log_l_2_query_count[total_log_l_2_query_count>0]
median_log_l_2_query_count=np.median(filtered_l_2_query_count)
avg_log_l_2_query_count=np.mean(filtered_l_2_query_count)
print('avg_log_l_2_query_count',avg_log_l_2_query_count)
print('median_log_l_2_query_count',median_log_l_2_query_count)

filtered_l_inf_query_count=total_log_l_inf_query_count[total_log_l_inf_query_count>0]
median_log_l_inf_query_count=np.median(filtered_l_inf_query_count)
avg_log_l_inf_query_count=np.mean(filtered_l_inf_query_count)
print('avg_log_l_inf_query_count',avg_log_l_inf_query_count)
print('median_log_l_inf_query_count',median_log_l_inf_query_count)
for i in [1000,5000,10000,20000]:
    print('Query_count : ',i)
    print('success_rate_l_2:',np.mean((total_log_l_2_query_count<=i ).astype(float)*(total_log_l_2_query_count>0).astype(float)))
    print('success_rate_l_inf:',np.mean((total_log_l_inf_query_count<=i ).astype(float)*(total_log_l_inf_query_count>0).astype(float)))
timeend = time.time()
print("\nTime: %.4f seconds" % (timeend - timestart))
