import argparse
import copy
import logging
import os
import time

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from models import *
import attack.pgd as pgd
import attack.fgsm as fgsm
from utils import *
from eval import *
from datasets import CIFAR10_dataloader
from datasets import MNIST_dataloader
from datasets import GTSRB_dataloader
from datasets import TIN_dataloader
from datasets import SVHN_dataloader


# generate the AEs from the target model and test their robustness on attack model
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch-size', default=128, type=int)
    parser.add_argument('--epsilon', default=8, type=int)
    parser.add_argument('--classifier', default='preresnet18',
                        help='directory of model for saving checkpoint')
    parser.add_argument('--dataset', default='cifar10',
                        help='dataset')
    parser.add_argument('--generator', default='preresnet18', 
                    help='directory of model for saving checkpoint')
    parser.add_argument('--seed', default=42, type=int, help='Random seed')
    parser.add_argument('--trial', default=0, type=int, help='experiment index')
    parser.add_argument('--norm_type', default='PGDAT', choices=['TRADES', 'PGDAT'], type=str)
    parser.add_argument('--target', help='The checkpoint path of target model')
    parser.add_argument('--attack', help='The checkpoint path of source model')
    return parser.parse_args()


def main():
    args = get_args()

    set_seed(args.seed)

    logger = logging.getLogger(__name__)

    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.StreamHandler()
        ])

    logger.info(args)

    # MNIST
    # test_loader = MNIST_dataloader(train=False)
    # norm_layer = Normalize([0.5], [0.5])
    # CIFAR-10
    if args.dataset == 'cifar10':
        data_dir = '/home/harry/dataset/cifar10'
        _, _, test_loader = CIFAR10_dataloader(data_dir, args.batch_size,val=False)
        norm_layer = Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616])
    # GTSRB    
    elif args.dataset == 'gtsrb':
        test_loader = GTSRB_dataloader(train=False)
        norm_layer = Normalize([0.0, 0.0, 0.0], [1.0, 1.0, 1.0])
    elif args.dataset == 'svhn':
        _, _, test_loader = SVHN_dataloader()
        norm_layer = Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    # TIN
    
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # Evaluation
    if args.classifier =='resnet18':
        classifier = ResNet18().cuda()
    elif args.classifier == 'preresnet18':
        classifier = PreActResNet18().cuda()
    elif args.classifier == 'vgg':
        classifier = vgg11_bn().cuda()
    elif args.classifier == 'netc':
        classifier = NetC_MNIST().cuda()
    elif args.classifier == 'svhn_net':
        classifier = svhn().cuda()
    elif args.classifier == 'mobilev2':
        classifier = MobileNetV2().cuda()
    
    # SVHN
    # args.target = '/home/harry/nnet/ImplicitBackdoor/outputs/SVHN/svhn_net/new/IBA3_clean_False_0.1_12.75/ckpt/best.pth'
    # args.target = '/home/harry/nnet/ImplicitBackdoor/outputs/GTSRB/preresnet18/new/IBA3_clean_False_0.1_12.75/ckpt/best.pth' 
    # args.target = '/home/harry/nnet/ImplicitBackdoor/outputs/MNIST/netc/NT/ckpt/best.pth'
    
    args.target = '/home/harry/nnet/ImplicitBackdoor/outputs/CIFAR10/preresnet18/NT/ckpt/best.pth'
    # args.target = '/home/harry/nnet/ImplicitBackdoor/outputs/CIFAR10/preresnet18/new/IBA3_clean_True_0.1_8/ckpt/best_asr_clean.pth'
    # args.target = '/home/harry/nnet/ImplicitBackdoor/outputs/CIFAR10/preresnet18/new/IBA3_0.1/ckpt/best_asr.pth'
    # args.target = '/home/harry/nnet/ImplicitBackdoor/outputs/CIFAR10/resnet18/AT_eps8/ckpt/best_AT.pth'
    classifier_state_dict = torch.load(args.target)
    if 'state_dict' in classifier_state_dict:
        classifier_state_dict= classifier_state_dict['state_dict']
    classifier.load_state_dict(classifier_state_dict, strict=True)
    

    if args.generator =='resnet18':
        args.attack = '/home/harry/nnet/ImplicitBackdoor/outputs/CIFAR10/resnet18/NT/0825/ckpt/best.pth'
        # args.attack = '/home/harry/nnet/ImplicitBackdoor/outputs/CIFAR10/resnet18/AT_eps8/ckpt/best_AT.pth'
        attack_model = ResNet18().cuda()
    elif args.generator == 'preresnet18':
        args.attack = '/home/harry/nnet/ImplicitBackdoor/outputs/CIFAR10/preresnet18/NT2/ckpt/best.pth'
        attack_model = PreActResNet18(10).cuda()
    elif args.generator == 'vgg':
        args.attack = '/home/harry/nnet/ImplicitBackdoor/outputs/CIFAR10/vgg11bn/NT/ckpt/best.pth'
        attack_model = vgg11_bn().cuda()
    elif args.generator == 'mobilev2':
        args.attack = '/home/harry/nnet/ImplicitBackdoor/outputs/CIFAR10/mobilev2/NT/ckpt/best.pth'
        attack_model = MobileNetV2().cuda()
    elif args.generator == 'efficientB0':
        args.attack = '/home/harry/nnet/ImplicitBackdoor/outputs/CIFAR10/efficientB0/NT/ckpt/best.pth'
        attack_model = EfficientNetB0().cuda()
    elif args.generator == 'svhn_net':
        args.attack = '/home/harry/nnet/ImplicitBackdoor/outputs/SVHN/svhn_net/NT/ckpt/best.pth'
        attack_model = svhn().cuda()
        
    attack_state_dict = torch.load(args.attack)
    if 'state_dict' in attack_state_dict:
        attack_state_dict= attack_state_dict['state_dict']
    
    
    attack_model.load_state_dict(attack_state_dict, strict=True)

    classifier.float()
    classifier.eval()
    attack_model.float()
    attack_model.eval()

    ### Evaluate clean acc ###
    test_acc = eval_clean(classifier, test_loader, norm_layer, device)
    print('Clean acc: ', test_acc)
    test_acc = eval_clean(attack_model, test_loader, norm_layer, device)
    print('Clean acc: ', test_acc)
    #

    num_classes = 10
    ACC = torch.zeros(10)
    ASR = torch.zeros(10)
    
    for c in range(num_classes):
        acc = AverageMeter()
        asr = AverageMeter()
        for i, (input, label) in enumerate(tqdm(test_loader)):
            input = input.cuda()
            label = label.cuda()
            # target label
            target_label = torch.ones_like(label) * c
            indices  = (target_label != label)
            # generate Adversarial Examples (AEs)
            X_adv = pgd(attack_model, input[indices], target_label[indices], targeted=True, normalize=norm_layer, epsilon=12.75, attack_iters=10, restarts=1)
            # X_adv = fgsm(attack_model, input, target_label, targeted=True, normalize=normalize, epsilon=8, rs=True)

            # compute output
            output = target_model(norm_layer(X_adv))
            output = output.float()
            # measure accuracy and record loss
            prec1 = accuracy(output.data, label[indices])[0]
            acc.update(prec1.item(), len(indices))
            
            # measure attack success rate
            asr1 = accuracy(output.data, target_label[indices])[0]
            asr.update(asr1.item(), len(indices))

        # print('eval_pgd20 {top1.avg:.3f}'.format(top1=top1))

        print(f"For targeted label: " + str(c) + ";     the ACC is {acc.avg:.3f}".format(acc=acc))
        print(f"For targeted label: " + str(c) + ";     the ASR is {asr.avg:.3f}".format(asr=asr))
        
        ACC[c] = acc.avg
        ASR[c] = asr.avg

    print('AVG ACC: ', ACC.mean())
    print('AVG ASR: ', ASR.mean())

if __name__ == "__main__":
    main()