import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm

from utils import load_data
from attack_lib import LinfPGDAttack, L2PGDAttack
from models import ResNet18, ResNet50, ResNet152, CIFAR_CNN
import matplotlib.pyplot as plt

def main():
    MODEL_NAME = 'naive'
    #MODEL_NAME = 'advT-l2-1.0000'
    #MODEL_NAME = 'advT-linf-0.0314'
    #MODEL_NAME = 'advT-l2-0.1000'
    #MODEL_NAME = 'alpha-10.0000'
    #MODEL_NAME = 'alpha-100.0000'
    #MODEL_NAME = 'alpha-10000.0000'
    #MODEL_NAME = 'alpha-out-100.0000'
    #MODEL_NAME = 'alpha-out-1000.0000'
    #MODEL_NAME = 'alpha-approx-3.0000'
    #MODEL_NAME = 'alpha-approx-6.0000'
    #MODEL_NAME = 'alpha-approx-10.0000'
    #MODEL_NAME = 'alpha-approx-30.0000'
    #MODEL_NAME = 'alpha-approx-60.0000'
    #MODEL_NAME = 'alpha-approx-100.0000'
    #MODEL_NAME = 'alpha-approx-10.0000-adv-l2-1.0000'

    #MODEL_NAME = 'alpha-approxrepr-3.0000'
    #MODEL_NAME = 'alpha-approxrepr-6.0000'
    #MODEL_NAME = 'alpha-approxrepr-10.0000'
    #MODEL_NAME = 'alpha-approxrepr-30.0000'
    #MODEL_NAME = 'alpha-approxrepr-60.0000'
    #MODEL_NAME = 'alpha-approxrepr-100.0000'

    #MODEL_NAME = 'alpha-last-1.0000'
    #MODEL_NAME = 'alpha-last-0.3000'
    #MODEL_NAME = 'alpha-last-0.1000'
    #MODEL_NAME = 'alpha-last-0.0300'
    #MODEL_NAME = 'alpha-last--0.1000'
    #MODEL_NAME = 'alpha-last--0.3000'
    #MODEL_NAME = 'advTrepr-l2-1.0000'
    #MODEL_NAME = 'reglast_0.01'
    #MODEL_NAME = 'reglast_0.03'
    #MODEL_NAME = 'reglast_0.1'
    #MODEL_NAME = 'reglast_0.3'
    #MODEL_NAME = 'reglast_1.0'
    #MODEL_NAME = 'reglast_3.0'
    #MODEL_NAME = 'reglast_10.0'
    #MODEL_NAME = 'reglast_30.0'
    #MODEL_NAME = 'reglast_100.0'
    #MODEL_NAME = 'reglast_0.1-'
    #MODEL_NAME = 'simclradv-l2-0.1000'
    #MODEL_NAME = 'simclradv-l2-1.0000'
    #MODEL_NAME = 'simclradv-l2-0.1000-lamda0.2500'
    #MODEL_NAME = 'simclradv-l2-0.1000-lamda0.7500'
    #MODEL_NAME = 'rocl-linf-0.0157'
    #MODEL_NAME = 'rocl-linf-0.0314'
    #MODEL_NAME = 'rocl-linf-0.0627'
    #MODEL_NAME = 'rocl-linf-0.0314-lamda0.1250'
    #MODEL_NAME = 'rocl-linf-0.0314-lamda0.2500'
    #MODEL_NAME = 'rocl-linf-0.0314-lamda0.7500'

    #MODEL_NAME = 'gauss-1.0'
    #MODEL_NAME = 'gauss-0.5'
    #MODEL_NAME = 'gauss-0.25'
    #MODEL_NAME = 'gauss-0.125'
    #MODEL_NAME = 'gauss-0.05'
    #MODEL_NAME = 'gauss-0.025'

    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.0000-complexcomplex'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.0000-complexcomplex_2'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.0000-complexcomplex_3'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.5000-complexcomplex'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.5000-complexcomplex_2'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.5000-complexcomplex_3'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.7500-complexcomplex'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.7500-complexcomplex_2'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.7500-complexcomplex_3'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.8750-complexcomplex'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.8750-complexcomplex_2'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.8750-complexcomplex_3'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda1.0000-complexcomplex'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda1.0000-complexcomplex_2'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda1.0000-complexcomplex_3'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.0000-simplesimple'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda0.5000-simplesimple'
    #MODEL_NAME = 'mixing-l2-0.5000-lamda1.0000-simplesimple'

    #MODEL_NAME = 'mixing-None-0.0000-lamda0.0000-complexcomplex'
    #MODEL_NAME = 'mixing-None-0.0000-lamda0.5000-complexcomplex'
    #MODEL_NAME = 'mixing-None-0.0000-lamda1.0000-complexcomplex'
    #MODEL_NAME = 'mixing-None-0.0000-lamda0.0000-simplesimple'
    #MODEL_NAME = 'mixing-None-0.0000-lamda0.5000-simplesimple'
    #MODEL_NAME = 'mixing-None-0.0000-lamda1.0000-simplesimple'

    #MODEL_NAME = 'debug1'
    #MODEL_NAME = 'debug2'
    #MODEL_NAME = 'debug3'
    #MODEL_NAME = 'debug4'
    #MODEL_NAME = 'debug5'
    #MODEL_NAME = 'debug6'
    #MODEL_NAME = 'debug7'
    #MODEL_NAME = 'debug8'
    #MODEL_NAME = 'debug9'
    #MODEL_NAME = 'debug10'
    #MODEL_NAME = 'debug11'
    #MODEL_NAME = 'debug12'
    #MODEL_NAME = 'debug13'

    #MODEL_NAME = 'weightdecay-0.0005'
    #MODEL_NAME = 'weightdecay-0.0001'
    #MODEL_NAME = 'weightdecay-1e-05'
    #MODEL_NAME = 'weightdecay-1e-06'
    #MODEL_NAME = 'weightdecay-0.0005-lastdiff'
    #MODEL_NAME = 'weightdecay-0.0001-lastdiff'
    #MODEL_NAME = 'weightdecay-1e-05-lastdiff'
    #MODEL_NAME = 'weightdecay-1e-06-lastdiff'
    #MODEL_NAME = 'weightdecay-1e-06-l1reg0.001'
    #MODEL_NAME = 'weightdecay-1e-06-l1reg0.0001'
    #MODEL_NAME = 'weightdecay-1e-06-l1reg1e-05'
    #MODEL_NAME = 'weightdecay-1e-06-l1reg1e-06'
    #MODEL_NAME = 'weightdecay-0.0005-dropout0.2'
    #MODEL_NAME = 'weightdecay-0.0005-dropout0.5'
    #MODEL_NAME = 'weightdecay-1e-06-dropout0.2'
    #MODEL_NAME = 'weightdecay-1e-06-dropout0.5'
    #MODEL_NAME = 'weightdecay-0.0005-resnet50'
    #MODEL_NAME = 'weightdecay-1e-06-resnet50'
    #MODEL_NAME = 'weightdecay-0.0005-resnet152'
    #MODEL_NAME = 'weightdecay-1e-06-resnet152'
    #MODEL_NAME = 'weightdecay-0.0005-cifarcnn'
    MODEL_NAME = 'weightdecay-1e-06-cifarcnn'

    #MODEL_NAME = 'aug_standard'
    #MODEL_NAME = 'aug_rotate5'
    #MODEL_NAME = 'aug_rotate10'
    #MODEL_NAME = 'aug_rotate15'
    #MODEL_NAME = 'aug_rotate30'
    #MODEL_NAME = 'aug_rotate45'
    #MODEL_NAME = 'aug_erase0.2'
    #MODEL_NAME = 'aug_erase0.33'
    #MODEL_NAME = 'aug_perspective0.25'
    #MODEL_NAME = 'aug_perspective0.5'
    #MODEL_NAME = 'aug_affine15'
    #MODEL_NAME = 'aug_affine30'
    #MODEL_NAME = 'aug_posterize2'
    #MODEL_NAME = 'aug_posterize3'
    #MODEL_NAME = 'aug_posterize4'

    test_batch_size=100
    trainset, testset, trainloader, testloader, normalizer = load_data(test_batch_size=test_batch_size)
    print (MODEL_NAME, len(trainset), len(testset))

    if 'dropout0.2' in MODEL_NAME:
        dropout = 0.2
    elif 'dropout0.5' in MODEL_NAME:
        dropout = 0.5
    else:
        assert 'dropout' not in MODEL_NAME
        dropout = 0.0
    if 'resnet50' in MODEL_NAME:
        model = ResNet50(normalizer, dropout)
    elif 'resnet152' in MODEL_NAME:
        model = ResNet152(normalizer, dropout)
    elif 'cifarcnn' in MODEL_NAME:
        model = CIFAR_CNN(normalizer, dropout)
    else:
        model = ResNet18(normalizer, dropout)
    model = model.to('cuda')
    model.load_state_dict(torch.load('./saved_model/%s.pth'%MODEL_NAME))
    model.eval()

    #adversary = LinfPGDAttack(model, epsilon=1e-8)
    #adversary = LinfPGDAttack(model)
    #adversary = L2PGDAttack(model)
    #adversary = L2PGDAttack(model, epsilon=0.5)
    #adversary = L2PGDAttack(model, epsilon=1e-8)
    adversary = L2PGDAttack(model, epsilon=0.25)

    #### plot
    #fig = plt.figure(figsize=(16,4))
    #for x, y in testloader:
    #    x, y = x[:8].to('cuda'), y[:8].to('cuda')
    #    adv_x = adversary.perturb(x, y).detach()
    #    for i in range(8):
    #        plt.subplot(2,8,i+1)
    #        plt.imshow(adv_x[i].cpu().numpy().transpose(1,2,0))
    #        plt.axis('off')
    #        plt.subplot(2,8,i+9)
    #        pert = adv_x[i]-x[i]
    #        pert = (pert - pert.min()) / (pert.max() - pert.min())
    #        #pert = pert+0.5
    #        plt.imshow(pert.cpu().numpy().transpose(1,2,0))
    #        plt.axis('off')
    #    break
    #fig.savefig('figures/%s_attack.pdf'%MODEL_NAME)
    ##assert 0
    ####

    correct = 0
    adv_correct = 0
    total = 0
    with torch.no_grad(), tqdm(testloader) as pbar:
        for x, y in pbar:
            x, y = x.to('cuda'), y.to('cuda')
            adv_x = adversary.perturb(x, y).detach()

            adv_pred = model(adv_x)
            _, adv_pred_c = adv_pred.max(1)
            total += y.size(0)
            adv_correct += adv_pred_c.eq(y).sum().item()
            pbar.set_description('Adv acc: %.3f'%(100.*adv_correct/total))


if __name__ == '__main__':
    main()
