"""Implementation of evaluate attack result."""
import os
import torch
from torch.autograd import Variable as V
from torch import nn
# from torch.autograd.gradcheck import zero_gradients
from torchvision import transforms as T
from Normalize import Normalize, TfNormalize
from loader import ImageNet
from torch.utils.data import DataLoader
import pretrainedmodels


mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

batch_size = 10

input_csv = './dataset/images.csv'
input_dir = './dataset/images'
adv_dir = './incv3_pgn_outputs'

os.environ["CUDA_VISIBLE_DEVICES"] = '3'

def get_model(net_name):
    """Load converted model"""
    if net_name == 'inception_v3':
        model = torch.nn.Sequential(Normalize(mean, std),
                                pretrainedmodels.inceptionv3(num_classes=1000, pretrained='imagenet').eval().cuda())
    elif net_name == 'inception_v4':
        model = torch.nn.Sequential(Normalize(mean, std),
                                pretrainedmodels.inceptionv4(num_classes=1000, pretrained='imagenet').eval().cuda())
    elif net_name == 'inc_res_v2':
        model = torch.nn.Sequential(Normalize(mean, std),
                                pretrainedmodels.inceptionresnetv2(num_classes=1000, pretrained='imagenet').eval().cuda())
    elif net_name == 'resnet_50':
        model = torch.nn.Sequential(Normalize(mean, std),
                                pretrainedmodels.resnet50(num_classes=1000, pretrained='imagenet').eval().cuda())
    elif net_name == 'resnet_101':
        model = torch.nn.Sequential(Normalize(mean, std),
                                pretrainedmodels.resnet101(num_classes=1000, pretrained='imagenet').eval().cuda())
    elif net_name == 'resnet_152':
        model = torch.nn.Sequential(Normalize(mean, std),
                                pretrainedmodels.resnet152(num_classes=1000, pretrained='imagenet').eval().cuda())
    else:
        print('Wrong model name!')

    return model

def verify(model_name):

    model = get_model(model_name)

    X = ImageNet(adv_dir, input_csv, T.Compose([T.ToTensor(), T.Resize(299)]))
    data_loader = DataLoader(X, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=8)
    sum = 0
    for images, _, gt_cpu in data_loader:
        gt = gt_cpu.cuda()
        images = images.cuda()
        with torch.no_grad():
            # print(sum)
            sum += (model(images).argmax(1) != (gt)).detach().sum().cpu()

    print(model_name + '  acu = {:.2%}'.format(sum / 1000.0))


def main():
    model_names = ['inception_v3','inception_v4','inc_res_v2','resnet_50','resnet_101','resnet_152']
    for model_name in model_names:
        verify(model_name)
        print("===================================================")

if __name__ == '__main__':
    print(adv_dir)
    main()