import torch
import torchvision.transforms as transforms
import numpy as np
import utils
import math
import random
import argparse
import os
import time
from cifar.resnet import resnet20
from cifar.resnet_rse import resnet20_RSE
from models.PNI.noisy_resnet_cifar import noise_resnet20
import torchvision.datasets as datasets
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
from modified_art.hop_skip_jump import HopSkipJump
from modified_art.boundary import BoundaryAttack
from attack.Sign_OPT import OPT_attack_sign_SGD
from attack.GeoDA import GeoDA_Attack
from attack.SimBA import SimBA_Attack
from attack.SSA import Subspace_Attack
from attack.Bandit import Bandit_Attack

import logging
logger = logging.getLogger(__name__)
from blackbox_model import BlackBoxModel
from modified_art.pytorch import PyTorchClassifier





parser = argparse.ArgumentParser(description='Attacks with the CIFAR-10 dataset')






parser.add_argument('--random_seed', type=int, default=1, help='random_seeds')
parser.add_argument('--attack', type=str, default='HSJA', help='attack method')
parser.add_argument('--defense', type=str, default='gaussian', help='defense method')
parser.add_argument('--result_dir', type=str, default='save', help='directory for saving results')
parser.add_argument('--sampled_image_dir', type=str, default='save', help='directory to cache sampled images')
parser.add_argument('--model', type=str, default='resnet20', help='type of target model to use')
parser.add_argument('--max_num_queries', type=int, default=10000, help='number of image samples')
parser.add_argument('--log_interval', type=int, default=200, help='log interval')
parser.add_argument('--sigma', type=float, default=0.0, help='sigma of input gaussian noise')
parser.add_argument('--alpha', type=float, default=0, help='alpha value of beta distribution')
parser.add_argument('--beta', type=float, default=0, help='beta value of beta distribution')

parser.add_argument('--avg_iter', type=int, default=1, help='number of queries for expectation-based adaptive attacks')

parser.add_argument('--num_runs', type=int, default=4, help='number of image samples')
parser.add_argument('--attack_batch_size', type=int, default=100, help='number of batches for ART attacks')


parser.add_argument('--batch_size', type=int, default=1, help='batch size for parallel runs')
# Only support batch size = 1

parser.add_argument('--num_iters', type=int, default=0, help='maximum number of iterations, 0 for unlimited')
parser.add_argument('--log_every', type=int, default=10, help='log every n iterations')
parser.add_argument('--epsilon', type=float, default=0.2, help='step size per iteration')
parser.add_argument('--linf_bound', type=float, default=0.0, help='L_inf bound for frequency space attack')
parser.add_argument('--freq_dims', type=int, default=8, help='dimensionality of 2D frequency space')
parser.add_argument('--order', type=str, default='strided', help='(random) order of coordinate selection') # 28 7
parser.add_argument('--stride', type=int, default=1, help='stride for block order')
parser.add_argument('--targeted', action='store_true', help='perform targeted attack')
parser.add_argument('--pixel_attack', action='store_true', help='attack in pixel space')

parser.add_argument('--save_suffix', type=str, default='', help='suffix appended to save file')

args = parser.parse_args()

print(args)
if args.avg_iter>1:
    savefile = '%s/CIFAR_%s_%s_%s_%.3f_%.1f_%.1f_%d.pth' % (
        args.result_dir, args.attack, args.model, args.defense, args.sigma, args.alpha, args.beta, args.avg_iter)
else:
    savefile = '%s/CIFAR_%s_%s_%s_%.3f_%.1f_%.1f_%s.pth' % (
        args.result_dir, args.attack, args.model, args.defense, args.sigma, args.alpha, args.beta, args.save_suffix)
print('SAVE_FILE : ', savefile)
def master_seed(seed=1234, set_random=True, set_numpy=True, set_tensorflow=False, set_mxnet=False, set_torch=False):

    """
    This function is borrowed from ART library
    Set the seed for all random number generators used in the library. This ensures experiments reproducibility and
    stable testing.

    :param seed: The value to be seeded in the random number generators.
    :type seed: `int`
    :param set_random: The flag to set seed for `random`.
    :type set_random: `bool`
    :param set_numpy: The flag to set seed for `numpy`.
    :type set_numpy: `bool`
    :param set_tensorflow: The flag to set seed for `tensorflow`.
    :type set_tensorflow: `bool`
    :param set_mxnet: The flag to set seed for `mxnet`.
    :type set_mxnet: `bool`
    :param set_torch: The flag to set seed for `torch`.
    :type set_torch: `bool`
    """
    import numbers

    if not isinstance(seed, numbers.Integral):
        raise TypeError("The seed for random number generators has to be an integer.")

    # Set Python seed
    if set_random:
        import random

        random.seed(seed)

    # Set Numpy seed
    if set_numpy:
        np.random.seed(seed)
        np.random.RandomState(seed)

    # Now try to set seed for all specific frameworks
    if set_tensorflow:
        try:
            import tensorflow as tf

            logger.info("Setting random seed for TensorFlow.")
            if tf.__version__[0] == "2":
                tf.random.set_seed(seed)
            else:
                tf.set_random_seed(seed)
        except ImportError:
            logger.info("Could not set random seed for TensorFlow.")

    if set_mxnet:
        try:
            import mxnet as mx

            logger.info("Setting random seed for MXNet.")
            mx.random.seed(seed)
        except ImportError:
            logger.info("Could not set random seed for MXNet.")

    if set_torch:
        try:
            logger.info("Setting random seed for PyTorch.")
            import torch

            torch.manual_seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)
        except ImportError:
            logger.info("Could not set random seed for PyTorch.")

master_seed(args.random_seed,set_torch=True)



if not os.path.exists(args.result_dir):
    os.mkdir(args.result_dir)
if not os.path.exists(args.sampled_image_dir):
    os.mkdir(args.sampled_image_dir)

# load model and dataset
if args.defense=='PNI':
    o_model =noise_resnet20()
    checkpoint = torch.load('models/PNI/checkpoint_channel.pth.tar')
    sd=checkpoint['state_dict']
    sd = {k[len('1.'):]: v for k, v in sd.items()}
    sd.pop('mean')
    sd.pop('std')
    o_model.load_state_dict(sd)

elif args.defense=='RSE':
    o_model = resnet20_RSE()
    resume_path = 'cifar/model_resnet20_rse.th'
    checkpoint = torch.load(resume_path)
    best_prec1 = checkpoint['best_prec1']
    print('Best Acc: ', best_prec1)
    sd = checkpoint['state_dict']
    sd = {k[len('module.'):]: v for k, v in sd.items()}
    o_model.load_state_dict(sd)
else:
    o_model =resnet20()
    resume_path='cifar/model_resnet20.th'
    print("=> loading checkpoint '{}'".format(resume_path))
    checkpoint = torch.load(resume_path)
    best_prec1 = checkpoint['best_prec1']
    print('Best Acc: ',best_prec1)
    sd=checkpoint['state_dict']
    sd = {k[len('module.'):].replace('model.', ''): v for k, v in sd.items()}
    o_model.load_state_dict(sd)



ART_Attacks = ['HSJA', 'BA', 'ZOO']
#

model=BlackBoxModel(o_model,defense=args.defense,dataset='cifar',avg_iter=args.avg_iter).cuda()

model.eval()
testset = datasets.CIFAR10(root='./data',download=True, train=False, transform=transforms.Compose([
    transforms.ToTensor()
]))

image_size = 32
with torch.no_grad():
    # load sampled images or sample new ones
    # this is to ensure all attacks are run on the same set of correctly classified images
    batchfile = '%s/images_%s_%d.pth' % (args.sampled_image_dir, args.model, args.num_runs)
    if os.path.isfile(batchfile):
        checkpoint = torch.load(batchfile)
        images = checkpoint['images']
        labels = checkpoint['labels']
    else:
        images = torch.zeros(args.num_runs, 3, image_size, image_size)
        labels = torch.zeros(args.num_runs).long()
        preds = labels + 1
        while preds.ne(labels).sum() > 0:
            idx = torch.arange(0, images.size(0)).long()[preds.ne(labels)] # 0,,,,,,0
            for i in list(idx):
                images[i], labels[i] = testset[random.randint(0, len(testset) - 1)]
            preds[idx], _ = utils.get_preds(o_model, images[idx], 'cifar', batch_size=args.batch_size)

        torch.save({'images': images, 'labels': labels}, batchfile)




if args.order == 'rand':
    n_dims = 3 * args.freq_dims * args.freq_dims
else:
    n_dims = 3 * image_size * image_size
if args.num_iters > 0:
    max_iters = int(min(n_dims, args.num_iters))
else:
    max_iters = int(n_dims)
N = int(math.floor(float(args.num_runs) / float(args.batch_size)))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
CIFAR_MEAN = [x / 255 for x in [125.3, 123.0, 113.9]]# [0.4914, 0.4822, 0.4465]
CIFAR_STD = [x / 255 for x in [63.0, 62.1, 66.7]]#[0.2023, 0.1994, 0.2010]
DATASET_MEAN = np.reshape(np.array(CIFAR_MEAN),[1,3,1,1])
DATASET_STD =  np.reshape(np.array(CIFAR_STD),[1,3,1,1])

classifier = PyTorchClassifier(
    model=model,
    clip_values=(0, 1),
    loss=criterion,
    optimizer=optimizer,
    input_shape=(3, 32, 32),
    nb_classes=10,
    preprocessing=(DATASET_MEAN,DATASET_STD)
)


iter_step = 10

attack_setting={}
attack_setting['sigma']=args.sigma
attack_setting['attack']=args.attack
attack_setting['alpha']=args.alpha
attack_setting['beta']=args.beta
attack_setting['log_interval']=args.log_interval
attack_setting['max_num_queries']=args.max_num_queries
attack_setting['max_num_queries']=args.max_num_queries

total_log_query_point=np.zeros((args.num_runs,args.max_num_queries // args.log_interval))
total_log_l_2=np.zeros((args.num_runs,args.max_num_queries // args.log_interval))
total_log_ne_count=np.zeros((args.num_runs))
total_log_query_count=np.zeros((args.num_runs))
total_log_l_inf=np.zeros((args.num_runs,args.max_num_queries // args.log_interval))
total_log_acc=np.zeros((args.num_runs,args.max_num_queries // args.log_interval))
total_log_prob=np.zeros((args.num_runs,args.max_num_queries // args.log_interval))
total_log_adv = []
total_log_l_2_query_count=np.zeros((args.num_runs))
total_log_l_inf_query_count=np.zeros((args.num_runs))
ART_Attacks_1 = ['HSJA', 'BA']
ART_Attacks_2 = ['ZOO']
preset_idx=[4,9,24,49]
timestart=time.time()
if args.attack not in ART_Attacks:
    if args.attack == 'SO':
        attack = OPT_attack_sign_SGD(model)
    elif args.attack == 'GD':
        attack = GeoDA_Attack(model,dataset='cifar')
    elif args.attack == 'SB' or args.attack == 'SBD':
        attack = SimBA_Attack(model)
    elif args.attack=='BD':
        attack=Bandit_Attack(model,dataset='cifar')
    elif args.attack == 'SSA':
        attack = Subspace_Attack(model,dataset='cifar')
for i in range(N):
    upper = min((i + 1) * args.batch_size, args.num_runs)
    images_batch = images[(i * args.batch_size):upper]
    labels_batch = labels[(i * args.batch_size):upper]
    # print(torch.min(images_batch),torch.max(images_batch))
    # print(labels_batch, model.predict_label(images_batch.cuda()))
    # replace true label with random target labels in case of targeted attack
    # We don't consider targeted attack for this experiment
    # if args.targeted:
    #     labels_targeted = labels_batch.clone()
    #     while labels_targeted.eq(labels_batch).sum() > 0:
    #         labels_targeted = torch.floor(1000 * torch.rand(labels_batch.size())).long()
    #     labels_batch = labels_targeted
    #print(np.argmax(classifier.predict(ori_image)[0]))
    #print(np.argmax(model(normalize(ori_image.cuda())).detach().cpu().numpy()[0]))
    x_adv = None


    if args.attack == 'HSJA':
        attack = HopSkipJump(classifier=classifier, max_queries=args.max_num_queries, targeted=False, max_iter=64,
                             max_eval=10000, init_eval=100)
    elif args.attack == 'BA':
        attack = BoundaryAttack(estimator=classifier, targeted=False,max_queries=args.max_num_queries)

    if args.attack in ART_Attacks:
        attack.batch_size = args.attack_batch_size


    print('IMG: ',(i+1))
    # HSJA Complete
    # Boundary Attack Complete
    incorrect_pred = False
    if args.defense=='PNI' or args.defense=='RSE':
        model.init_model(attack_setting, images_batch[:1].cuda(), np.array(labels_batch[:1]))
        if model.predict_label(images_batch[:1].cuda()).cpu()!=labels_batch[:1]:
            print('Already incorrect')
            model.set_log(images_batch[:1].cuda(),unnormalization=False)
            incorrect_pred=True

    if incorrect_pred==False:
        if args.attack in ART_Attacks:
            ori_image= np.array(images_batch[:1])
            model.init_model(attack_setting,ori_image,np.array(labels_batch[:1]))
            if args.attack in ART_Attacks_1:
                while(model.get_num_queries()<=args.max_num_queries):
                    x_adv = attack.generate(x=ori_image, x_adv_init=x_adv,y= labels_batch[:1], resume=True)
            else:
                x_adv = attack.generate(x=ori_image)
            if x_adv is not None:
                model.set_log(torch.cuda.FloatTensor(x_adv), unnormalization=False)
                # print('done!',model.get_num_queries())
                # break
        else:
            model.init_model(attack_setting,images_batch[:1].cuda(),np.array(labels_batch[:1]))
            if args.attack=='GD':
                adv = attack.attack_untargeted(images_batch[:1].cuda(), labels_batch[:1].cuda(), query_limit=args.max_num_queries)
            elif args.attack == 'SB' or args.attack == 'SBD':
                adv = attack.attack_untargeted(images_batch[:1], labels_batch[:1],args=args,
                                               query_limit=args.max_num_queries)
                if adv is not None:
                    model.set_log(adv.cuda(),unnormalization=False)
            elif args.attack == 'BD':
                adv = attack.attack_untargeted(images_batch[:1], labels_batch[:1], query_limit=args.max_num_queries)
                if adv is not None:
                    model.set_log(adv.cuda(), unnormalization=False)
            elif args.attack == 'SSA':
                adv = attack.attack_untargeted(images_batch[:1], labels_batch[:1],query_limit=args.max_num_queries)
                if adv is not None:
                    model.set_log(adv.cuda(),unnormalization=False)
            else:
                adv = attack(images_batch[:1].cuda(), labels_batch[:1].cuda(), query_limit=args.max_num_queries,TARGETED=False)
                if adv is not None:
                    model.set_log(adv.cuda())


    # x_adv_i=np.transpose(x_adv[0],[1,2,0])*255.0
    # plt.imshow(x_adv_i.astype(np.uint))
    # plt.show()

    (log_query_point,log_prob,log_acc,log_l_2,log_l_inf,log_l_2_query_count,log_l_inf_query_count,log_adv,log_ne_count)=model.get_log()

    total_log_query_count[i]=model.get_num_queries()
    total_log_ne_count[i]=log_ne_count
    total_log_query_point[i,:]=log_query_point
    total_log_l_2[i,:]=log_l_2
    total_log_l_inf[i,:]=log_l_inf
    total_log_acc[i,:]=log_acc
    total_log_prob[i,:]=log_prob
    if log_adv is not None:
        adv_images=(log_adv * 255.0).byte()
    else:
        adv_images=(images_batch[:1] * 255.0).byte()
    # x_adv_i=np.transpose(adv_images[0].detach().cpu().numpy(),[1,2,0])
    # plt.imshow(x_adv_i.astype(np.uint))
    # plt.show(block=False)
    total_log_adv.append(adv_images)
    total_log_l_2_query_count[i] = log_l_2_query_count
    total_log_l_inf_query_count[i] = log_l_inf_query_count
    print('Acc: ', log_acc[preset_idx])
    print('l_2', log_l_2[preset_idx],'l_inf',log_l_inf[preset_idx])
    print('l_2_count: ', log_l_2_query_count, 'l_inf_count:',log_l_inf_query_count,flush=True)
    print('query_count: ', model.get_num_queries(), 'log_ne_count:',log_ne_count,flush=True)

if args.avg_iter>1:
    savefile = '%s/CIFAR_%s_%s_%s_%.3f_%.1f_%.1f_%d.pth' % (
        args.result_dir, args.attack, args.model, args.defense, args.sigma, args.alpha, args.beta, args.avg_iter)
else:
    savefile = '%s/CIFAR_%s_%s_%s_%.3f_%.1f_%.1f_%s.pth' % (
        args.result_dir, args.attack, args.model, args.defense, args.sigma, args.alpha, args.beta, args.save_suffix)
    
torch.save({'total_log_query_point': total_log_query_point,
            'total_log_l_2':total_log_l_2,'total_log_l_inf':total_log_l_inf,
            'total_log_acc':total_log_acc,'total_log_prob':total_log_prob,
            'total_log_adv':total_log_adv,
            'total_log_l_2_query_count':total_log_l_2_query_count,
            'total_log_l_inf_query_count':total_log_l_inf_query_count,
            'total_log_ne_count': total_log_ne_count,
            'total_log_query_count': total_log_query_count
            }, savefile)



print('-------------------------------------------------')
avg_log_ne_ratio=np.mean(total_log_ne_count/total_log_query_count,axis=0)
print('avg_log_ne_ratio',avg_log_ne_ratio)

avg_log_acc=np.mean(total_log_acc,axis=0)
print('avg_log_acc',avg_log_acc[preset_idx])

median_log_l_2=np.median(total_log_l_2,axis=0)
avg_log_l_2=np.mean(total_log_l_2,axis=0)
print('avg_log_l_2',avg_log_l_2[preset_idx])
print('median_log_l_2',median_log_l_2[preset_idx])

median_log_l_inf=np.median(total_log_l_inf,axis=0)
avg_log_l_inf=np.mean(total_log_l_inf,axis=0)
print('avg_log_l_inf',avg_log_l_inf[preset_idx])
print('median_log_l_inf',median_log_l_inf[preset_idx])

filtered_l_2_query_count=total_log_l_2_query_count[total_log_l_2_query_count>0]
median_log_l_2_query_count=np.median(filtered_l_2_query_count)
avg_log_l_2_query_count=np.mean(filtered_l_2_query_count)
print('avg_log_l_2_query_count',avg_log_l_2_query_count)
print('median_log_l_2_query_count',median_log_l_2_query_count)

filtered_l_inf_query_count=total_log_l_inf_query_count[total_log_l_inf_query_count>0]
median_log_l_inf_query_count=np.median(filtered_l_inf_query_count)
avg_log_l_inf_query_count=np.mean(filtered_l_inf_query_count)
print('avg_log_l_inf_query_count',avg_log_l_inf_query_count)
print('median_log_l_inf_query_count',median_log_l_inf_query_count)
print('1000, 2000, 5000, 10000')
for i in [1000,2000,5000,10000]:
    print('Query_count : ',i)
    print('success_rate_l_2:',np.mean((total_log_l_2_query_count<=i ).astype(float)*(total_log_l_2_query_count>0).astype(float)))
    print('success_rate_l_inf:',np.mean((total_log_l_inf_query_count<=i ).astype(float)*(total_log_l_inf_query_count>0).astype(float)))
timeend = time.time()
print("\nTime: %.4f seconds" % (timeend - timestart))
