import torch
import os
from dataset import data_process
import torchvision.transforms as transforms
from utils.util import read_yaml, get_backbone, save_finetune_acc
import foolbox
from foolbox.attacks import LinfPGD

class Evaluation():
    def __init__(self, opt):
        self.root = opt.root
        self.dataset = opt.dataset
        self.device = opt.device
        self.backbone = opt.backbone
        self.num_classes = None
        self.conf = read_yaml(opt.data_conf, self.backbone, self.dataset)
        self.load = opt.load
        self.adv = opt.adv
        self.eps = opt.eps

        if self.load == 'default':
            self.load = 'checkpoints/{}/{}_{}_standard.pth'.format(self.backbone,self.backbone, self.dataset)

    def data_process(self):
        batch_size = self.conf['batch_size']
        train_loader, test_loader = data_process(root=self.root, dataset=self.dataset,
                                                 batch_size=batch_size, train=False)
        self.num_classes = test_loader.dataset.num_classes
        return train_loader, test_loader

    def net_process(self):
        net = get_backbone(self.backbone)
        net = net(pretrained=False, num_classes=self.num_classes)
        state_dict = torch.load(self.load, map_location='cpu')['state_dict']
        net.load_state_dict(state_dict)
        return net

    def eval(self):
        print('-------------eval-------------')
        print('dataset: {}\tbackbone: {}'.format(self.dataset, self.backbone))
        print('load from: {}'.format(self.load))

        _, test_loader = self.data_process()
        test_size = len(test_loader.dataset)
        net = self.net_process().to(self.device).eval()

        normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

        attacker = LinfPGD(steps=10, rel_stepsize=1/8)
        preprocessing = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], axis=-3)
        fnet = foolbox.PyTorchModel(net, (0, 1), self.device, preprocessing)
        self.eps /= 255

        corrects = 0
        corrects_adv = 0
        for images, labels in test_loader:
            images, labels = images.to(self.device), labels.to(self.device)
            _, images_adv, _ = attacker(fnet, images, labels, epsilons=self.eps)
            images = normalize(images)
            images_adv = normalize(images_adv)
            with torch.no_grad():
                logits = net(images)
                logits_adv = net(images_adv)

            corrects += (logits.argmax(dim=1) == labels).sum().detach()
            corrects_adv += (logits_adv.argmax(dim=1) == labels).sum().detach()
        corrects = (corrects / test_size).item()
        corrects_adv = (corrects_adv / test_size).item()
        print('accuracy:{:.4f}'.format(corrects))
        print('adv accuracy:{:.4f}'.format(corrects_adv))


import argparse
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--backbone', type=str, default='resnet18')
    parser.add_argument('--dataset', type=str, default='cifar10')
    parser.add_argument('--data_conf', type=str, default='conf.yaml')
    parser.add_argument('--root', type=str, default='datasets')
    parser.add_argument('--eps', type=float, default=8)

    parser.add_argument('--load', type=str, default='default')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--adv', type=str, default=None)

    opt = parser.parse_args()

    evaluater = Evaluation(opt)
    evaluater.eval()
