import torch
from utils import AverageMeter, accuracy, get_adv
import attack
from tqdm import tqdm



# standard accuracy
def eval_clean(model, dataloader, norm_layer, device):
    top1 = AverageMeter()
    model.eval()

    for i, (input, target) in enumerate(dataloader):
        input = input.to(device)
        target = target.to(device)

        # compute output
        output_clean = model(norm_layer(input))

        output_clean = output_clean.float()
        # measure accuracy and record loss
        prec1 = accuracy(output_clean.data, target)[0]

        top1.update(prec1.item(), input.size(0))
        
    return top1.avg


# robust accuracy against pgd
def eval_pgd(model, dataloader, norm_layer, device, config=None, epsilon=8, iters=20):
    top1 = AverageMeter()
    model.eval()

    for i, (input, target) in enumerate(tqdm(dataloader)):
        input = input.to(device)
        target = target.to(device)

        # generate adversarial perturbations
        if config:
            adv_input = attack.pgd(model, input, target, norm_layer, epsilon=config.epsilon,
                               attack_iters=config.iterations)
        else:
            adv_input = attack.pgd(model, input, target, norm_layer, epsilon=epsilon, 
                               attack_iters=iters)
        # compute output
        output_clean = model(norm_layer(adv_input))

        output_clean = output_clean.float()
        # measure accuracy and record loss
        prec1 = accuracy(output_clean.data, target)[0]

        top1.update(prec1.item(), input.size(0))
        
    return top1.avg


# backdoor accuracy
def eval_backdoor_adv(model, dataloader, norm_layer, device, config):         
    top1 = AverageMeter()
    model.eval()

    for i, (input, target) in enumerate(tqdm(dataloader)):
        input = input.to(device)
        target = target.to(device)

        # whether targeted first
        if config.targeted:
            poisoned_target = torch.ones_like(target).to(device) * config.target_label
            poisoned_input = get_adv(model, input, poisoned_target, norm_layer, config)
        elif config.all2all:
            # 生成不靠近（label+1）的对抗样本
            poisoned_target = torch.remainder(target + 1, config.num_classes).to(device)
            poisoned_input = get_adv(model, input, poisoned_target, norm_layer, config)
        else:
            poisoned_target = target
            poisoned_input = get_adv(model, input, target, norm_layer, config)
        
        
        # compute output
        output_clean = model(norm_layer(poisoned_input))

        output_clean = output_clean.float()
        # measure accuracy and record loss
        prec1 = accuracy(output_clean.data, poisoned_target)[0]

        top1.update(prec1.item(), input.size(0))
        
    return top1.avg


# clean label backdoor ASR
def eval_marksman_CL(generator, classifier, dataloader, norm_layer, device, config):
    num_classes = config.dataset_cfg.num_classes
    ASR_avg = torch.zeros(num_classes)
    ACC_avg = torch.zeros(num_classes)
    for cc in range(num_classes):
        acc = AverageMeter()
        asr = AverageMeter()
        for i, (input, label) in enumerate(tqdm(dataloader)):
            input = input.cuda()
            label = label.cuda()
            # target label
            target_label = torch.ones_like(label) * cc
            # generate Adversarial Examples (AEs)
            if config.attack_cfg.name == 'pgd':
                X_adv = attack.pgd(generator, input, target_label, targeted=True, 
                                   normalize=norm_layer, epsilon=8, attack_iters=7, restarts=1)
            elif config.attack_cfg.name == 'fgsm':
                X_adv = attack.fgsm(generator, input, target_label, targeted=True, 
                                    normalize=norm_layer, epsilon=8, rs=True)
            else:
                raise ValueError("Wrong attack method")
        
            # compute output
            output = classifier(norm_layer(X_adv))
            output = output.float()
            # # measure accuracy and record loss
            # prec1 = accuracy(output.data, label)[0]
            # acc.update(prec1.item(), input.size(0))
            
            # measure attack success rate
            asr1 = accuracy(output.data, target_label)[0]
            asr.update(asr1.item(), input.size(0))

        # print('eval_pgd20 {top1.avg:.3f}'.format(top1=top1))

        # print(f"For targeted label: " + str(cc) + ";     the ACC is {acc.avg:.3f}".format(acc=acc))
        print(f"For targeted label: " + str(cc) + ";     the ASR is {asr.avg:.3f}".format(asr=asr))
        
        ASR_avg[cc] = asr.avg
        # ACC_avg[c] = acc.avg
    
    print("Average ASR: " + str(torch.mean(ASR_avg)))
    # print(f"Average ASR: {asr},  average ACC is {acc}".format(asr=torch.mean(ASR_avg), acc=torch.mean(acc=acc)))
    
    # return torch.mean(ACC_avg), torch.mean(ASR_avg)
    return torch.mean(ASR_avg)