import torch
import os
from utils.util import get_optimizer, get_backbone, read_yaml, load_state_dict
from dataset import data_process
import torchvision.transforms as transforms
from utils.models import ResnetGenerator
import numpy as np

import torch.backends.cudnn as cudnn
cudnn.benchmark = True

import warnings
warnings.filterwarnings("ignore")

class UAPEval():
    def __init__(self, opt):
        self.root = opt.root
        self.dataset = opt.dataset
        self.device = opt.device
        self.backbone = opt.backbone
        self.num_classes = None
        self.conf = read_yaml(opt.data_conf, self.backbone, self.dataset)
        self.load = opt.load
        self.uap = opt.uap

    def data_process(self):
        batch_size = self.conf['batch_size']
        train_loader, test_loader = data_process(root=self.root, dataset=self.dataset,
                                                 batch_size=batch_size, train=True)
        self.num_classes = test_loader.dataset.num_classes
        return train_loader, test_loader

    def net_process(self):
        net = get_backbone(self.backbone)
        if self.load == 'imagenet':
            net = net(pretrained=True)
        else:
            net = net(pretrained=False, num_classes=self.num_classes)
            load_state_dict(net, self.load)
        return net


    def info(self):
        print('-------------uap eval-------------')
        print('dataset: {}\tbackbone: {}'.format(self.dataset, self.backbone))
        print('load from:{}'.format(self.load))
        print('uap from: {}'.format(self.uap))

    def eval(self):
        self.info()
        _, test_loader = self.data_process()
        test_size = len(test_loader.dataset)

        net = self.net_process().to(self.device)

        delta_im = torch.load(self.uap, map_location='cpu')['delta_im'].to(self.device)

        normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

        corrects = 0
        corrects_adv = 0
        for images, labels in test_loader:
            images, labels = images.to(self.device), labels.to(self.device)
            images = normalize(images)
            images_adv = torch.clamp(images + delta_im[:images.size(0)], -1, 1)
            with torch.no_grad():
                logits = net(images)
                logits_adv = net(images_adv)

            corrects += (logits.argmax(dim=1) == labels).sum().detach()
            corrects_adv += (logits_adv.argmax(dim=1) == labels).sum().detach()
        corrects = (corrects / test_size).item()
        corrects_adv = (corrects_adv / test_size).item()
        print('accuracy:{:.4f}'.format(corrects))
        print('adv accuracy:{:.4f}'.format(corrects_adv))


import argparse
if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--backbone', type=str, default='resnet18')
    parser.add_argument('--dataset', type=str, default='cifar10')
    parser.add_argument('--data_conf', type=str, default='conf.yaml')
    parser.add_argument('--root', type=str, default='datasets')

    parser.add_argument('--load', type=str)
    parser.add_argument('--uap', type=str)
    parser.add_argument('--device', type=int, default=0)


    opt = parser.parse_args()

    evaluation = UAPEval(opt)
    evaluation.eval()
