import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import utils
import math
import random
import torch.nn.functional as F
import argparse
import os
import pdb
import random
from torch.utils.data import DataLoader
import dill
from robustness.resnet import resnet50
parser = argparse.ArgumentParser(description='Evaluates the accuracy of clean images with defenses')
parser.add_argument('--defense', type=str, default='gaussian', help='attack method')
parser.add_argument('--random_seed', type=int, default=1, help='random_seeds')
parser.add_argument('--data_root', type=str, required=True, help='root directory of imagenet data')
parser.add_argument('--result_dir', type=str, default='save', help='directory for saving results')
parser.add_argument('--sampled_image_dir', type=str, default='save', help='directory to cache sampled images')
parser.add_argument('--model', type=str, default='resnet50', help='type of base model to use')
parser.add_argument('--batch_size', type=int, default=100, help='batch size for parallel runs')
parser.add_argument('--sigma', type=float, default=0.01, help='sigma of input gaussian noise')
parser.add_argument('--alpha', type=float, default=0, help='alpha value of beta distribution')
parser.add_argument('--beta', type=float, default=0, help='beta value of beta distribution')
parser.add_argument('--num_workers', type=int, default=0, help='the number of workers')
parser.add_argument('--num_iters', type=int, default=5, help='maximum number of iterations, 0 for unlimited')
parser.add_argument('--log_every', type=int, default=10, help='log every n iterations')
parser.add_argument('--save_suffix', type=str, default='', help='suffix appended to save file')
args = parser.parse_args()
print(args)

savefile = '%s/%s_%s_%.3f_%.1f_%.1f_%d_%s.pth' % (
    args.result_dir, args.defense, args.model, args.sigma, args.alpha, args.beta, 0, args.save_suffix)
print('SAVE_FILE : ', savefile)

def expand_vector(x, size):
    batch_size = x.size(0)
    x = x.view(-1, 3, size, size)
    z = torch.zeros(batch_size, 3, image_size, image_size)
    z[:, :, :size, :size] = x
    return z

def normalize(x):
    return utils.apply_normalization(x, 'imagenet')

########################################
# RN Version 1
from torch.distributions import Beta
# [0,1] -> [-1,1]
# mean, std [0.5, 0.5]
def model_noise(imgs_tensor,sigma=0.001, alpha=0, beta=0):
    RN = torch.randn_like(imgs_tensor)
    if alpha>0 and beta > 0:
        m = Beta(torch.FloatTensor([1]), torch.FloatTensor([1]))
        mm=m.sample((imgs_tensor.size(0),)).view(-1,1,1,1).cuda()
        RN=RN*sigma*mm
    else:
        RN=RN*sigma
    return RN
#######################################

def get_probs(model, x, y):
    output = model(normalize(torch.autograd.Variable(x.cuda()))).cpu()
    probs = torch.index_select(torch.nn.Softmax()(output).data, 1, y)
    return torch.diag(probs)

def get_preds(model, x):
    output = model(normalize(torch.autograd.Variable(x.cuda()))).cpu()
    _, preds = output.data.max(1)
    return preds

# runs simba on a batch of images <images_batch> with true labels (for untargeted attack) or target labels

def apply_defense(model,batch,sigma,alpha,beta, defense):
    batch_size = batch.size()[0]
    if sigma > 0:
        noised_batch = (batch + model_noise(batch, sigma=sigma, alpha=alpha,
                                           beta=beta)).clamp(0, 1)
    else:
        noised_batch = batch
    perturbed_batch=noised_batch
    if defense == 'rnp':
        if batch_size == 1:
            rnd_size = np.random.randint(224, 248 + 1) # 232
            noised_batch = torch.nn.functional.upsample(noised_batch, size=(rnd_size, rnd_size), mode='nearest')
            second_max = 248 - rnd_size
            a = np.random.randint(0, second_max + 1)
            b = np.random.randint(0, second_max + 1)
            pads = (b, second_max - b, a, second_max - a)  # pad last dim by 1 on each side
            noised_batch = normalize(noised_batch)
            resized_batch = torch.nn.functional.pad(noised_batch, pads, "constant", 0)  # effectively zero padding
        else:
            resized_batch = torch.zeros((batch_size, 3, 248, 248)).cuda()
            for nn in range(batch_size):
                cur_img = noised_batch[nn:nn + 1]
                rnd_size = np.random.randint(224, 248 + 1)
                cur_img = torch.nn.functional.upsample(cur_img, size=(rnd_size, rnd_size), mode='nearest')
                second_max = 248 - rnd_size
                a = np.random.randint(0, second_max + 1)
                b = np.random.randint(0, second_max + 1)
                pads = (b, second_max - b, a, second_max - a)  # pad last dim by 1 on each side
                cur_img = normalize(cur_img)
                cur_img = torch.nn.functional.pad(cur_img, pads, "constant", 0)  # effectively zero padding
                resized_batch[nn] = cur_img
        output = model(resized_batch)
    else:
        noised_batch = normalize(noised_batch)
        output = model(noised_batch)
    return output, perturbed_batch

if __name__=='__main__':
    if not os.path.exists(args.result_dir):
        os.mkdir(args.result_dir)
    if not os.path.exists(args.sampled_image_dir):
        os.mkdir(args.sampled_image_dir)
    # load model and dataset
    model = getattr(models, args.model)(pretrained=True).cuda()
    if args.defense=='AT':
        # Download pretrained model weight from https://github.com/MadryLab/robustness
        # ImageNet L2-norm (ResNet50): epsilon = 3
        print("=> loading checkpoint '{}'".format('models/imagenet_l2_3_0.pt'))
        checkpoint = torch.load('models/imagenet_l2_3_0.pt', pickle_module=dill)
        state_dict_path = 'model'
        if not ('model' in checkpoint):
            state_dict_path = 'state_dict'
        sd = checkpoint[state_dict_path]
        #print(sd.keys())
        sd = {k[len('module.'):].replace('model.',''): v for k, v in sd.items()}

        model.load_state_dict(sd,strict=False)
        o_model = getattr(models, args.model)(pretrained=True).cuda()

    model.eval()
    total_accs=np.zeros((args.num_iters))
    total_succ=np.zeros((args.num_iters))
    total_l2=np.zeros((args.num_iters))
    total_li=np.zeros((args.num_iters))
    for nn in range(args.num_iters):
        random.seed(args.random_seed+nn)
        np.random.seed(args.random_seed+nn)
        torch.manual_seed(args.random_seed+nn)

        # torch.backends.cudnn.deterministic = True
        # torch.backends.cudnn.benchmark = False
        if args.model.startswith('inception'):
            image_size = 299
            testset = dset.ImageFolder(args.data_root + '/val', utils.INCEPTION_TRANSFORM)
        else:
            image_size = 224
            testset = dset.ImageFolder(args.data_root + '/val', utils.IMAGENET_TRANSFORM)
        test_loader = DataLoader(
            testset, shuffle=False, num_workers=args.num_workers, batch_size=args.batch_size)
        with torch.no_grad():
            for i, (images_batch, labels_batch) in enumerate(test_loader):
                batch_size = images_batch.size(0)
                image_size = images_batch.size(2)

                if args.defense=='AT':
                    prev_preds = get_preds(o_model, images_batch).eq(labels_batch)
                else:
                    prev_preds = get_preds(model, images_batch).eq(labels_batch)

                output, perturbed_image = apply_defense(model, images_batch.cuda(), args.sigma, args.alpha, args.beta,
                                                        args.defense)
                output = output.cpu()
                perturbation = perturbed_image.cpu() - images_batch

                # noised_batch=normalize(images_batch.cuda())
                # output = model(noised_batch).cpu()
                probs=torch.index_select(torch.nn.Softmax()(output).data, 1, labels_batch)
                _, preds = output.data.max(1)

                accs = preds.eq(labels_batch)
                succs = ~preds.eq(labels_batch).logical_xor(prev_preds)
                l2_norms= perturbation.view(batch_size, -1).norm(2, 1)
                linf_norms= perturbation.view(batch_size, -1).abs().max(1)[0]
                if i == 0:
                    all_probs = probs
                    all_preds = preds
                    all_succs = succs
                    all_accs = accs
                    all_l2_norms = l2_norms
                    all_linf_norms = linf_norms
                    avg_acc=torch.mean(accs.float())
                    avg_succ=torch.mean(all_succs.float())
                    avg_l2=torch.mean(all_l2_norms)
                    avg_li=torch.mean(all_linf_norms)
                else:
                    all_probs = torch.cat([all_probs, probs], dim=0)
                    all_preds = torch.cat([all_preds, preds], dim=0)
                    all_succs = torch.cat([all_succs, succs], dim=0)
                    all_accs = torch.cat([all_accs, accs], dim=0)
                    all_l2_norms = torch.cat([all_l2_norms, l2_norms], dim=0)
                    all_linf_norms = torch.cat([all_linf_norms, linf_norms], dim=0)
                    avg_acc+=torch.mean(accs.float())
                    avg_succ+=torch.mean(succs.float())
                    avg_l2+=torch.mean(l2_norms)
                    avg_li+=torch.mean(linf_norms)
                print('Iter: %d / acc: %.4f / succ: %.4f / l_2: %.4f / l_inf: %.4f'%(i,avg_acc/(i+1)*100,
                                                                                     avg_succ / (i + 1)*100,avg_l2/(i+1),avg_li/(i+1)))
        total_accs[nn]=avg_acc/len(test_loader)*100
        total_succ[nn]=avg_succ/len(test_loader)*100
        total_l2[nn]=avg_l2/len(test_loader)
        total_li[nn]=avg_li/len(test_loader)
        print('** STAT **', nn)
        print(total_accs)
        print(total_succ)
        print(total_l2)
        print(total_li)

        savefile = '%s/%s_%s_%.3f_%.1f_%.1f_%d_%s.pth' % (
            args.result_dir, args.defense, args.model, args.sigma, args.alpha, args.beta,nn, args.save_suffix)
        torch.save({'probs': all_probs,'preds': all_preds, 'accs':all_accs,'succs': all_succs, 'l2_norms': all_l2_norms, 'linf_norms': all_linf_norms}, savefile)
print('Acc: %.4f + %.4f / succ: %.4f + %.4f / l_2: %.4f + %.4f / l_inf: %.4f + %.4f' %(np.mean(total_accs),np.std(total_accs),
    np.mean(total_succ),np.std(total_succ),
    np.mean(total_l2),np.std(total_l2),
    np.mean(total_li),np.std(total_li)))