import argparse
import torch
import torchvision
from torch.autograd import Variable
from torchvision import transforms
import genotypes
import torch.backends.cudnn as cudnn
from model import NetworkCIFAR
import torchattacks
from utils import load_tinyimagenet
import utils

parser = argparse.ArgumentParser(description='Evaluation')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--test-batch-size', type=int, default=50, metavar='N', help='input batch size for testing (default: 200)')
parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
parser.add_argument('--init_channels', type=int, default=32, help='num of init channels')
parser.add_argument('--layers', type=int, default=10, help='total number of layers')
parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
parser.add_argument('--drop_path_prob', type=float, default=0.0, help='drop path probability')
parser.add_argument('--data_type', type=str, default='tinyimagenet', help='which dataset to use')
parser.add_argument('--data_dir', type=str, default="/home3/data/tiny-imagenet-200/", help='imagenet-200 datadir')
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
parser.add_argument('--workers', type=int, default=4, help='数据加载线程数')
args = parser.parse_args()


torch.cuda.set_device(args.gpu)
cudnn.benchmark = True



_, test_loader, _ = load_tinyimagenet(args)





def cal_acc(model, X, y):
    out = model(X)
    _, predicted = out.data.max(1)
    correct = (predicted == y.data).float().sum()
    accuracy = correct / y.size(0)
    return accuracy


def eval_adv_acc(model, test_loader, type, moe_model=None):
    model.eval()
    total_acc = 0.0
    total_samples = 0

    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()

        # 选择攻击方式
        if type == 'FGSM':
            attack = torchattacks.FGSM(model, eps=8 / 255)
        if type == 'PGD20':
            attack = torchattacks.PGD(model, eps=8 / 255, alpha=2 / 255, steps=20)
        if type == 'PGD100':
            attack = torchattacks.PGD(model, eps=8 / 255, alpha=2 / 255, steps=100)
        if type == 'CW':
            attack = torchattacks.CW(model, c=0.5, steps=100)
        if type == 'APGD':
            attack = torchattacks.APGD(model, eps=8 / 255, steps=20)
        if type == 'AA':
            attack = torchattacks.AutoAttack(model, eps=8 / 255)


        adv_images = attack(data, target)
        X, y = Variable(adv_images, requires_grad=True), Variable(target)

        # 计算准确率
        acc = cal_acc(model, X, y)
        total_acc += acc * X.size(0)
        total_samples += X.size(0)

        print('batch accuracy: ', acc.item())

    avg_accuracy = total_acc / total_samples
    print(f'Average accuracy for attack type {type}: {avg_accuracy:.4f}')
    return avg_accuracy


def eval_acc(model, test_loader, is_adv=False, attack_type=None):
    model.eval()
    total_acc = 0.0
    total_samples = 0

    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()


        if is_adv:

            if attack_type == 'FGSM':
                attack = torchattacks.FGSM(model, eps=8 / 255)
            if attack_type == 'PGD20':
                attack = torchattacks.PGD(model, eps=8 / 255, alpha=2 / 255, steps=20)
            if attack_type == 'PGD100':
                attack = torchattacks.PGD(model, eps=8 / 255, alpha=2 / 255, steps=100)
            if attack_type == 'CW':
                attack = torchattacks.CW(model, c=0.5, steps=100)
            if attack_type == 'APGD':
                attack = torchattacks.APGD(model, eps=8 / 255, steps=20)
            if attack_type == 'AA':
                attack = torchattacks.AutoAttack(model, eps=8 / 255)


            adv_images = attack(data, target)
            X, y = Variable(adv_images, requires_grad=True), Variable(target)
        else:

            X, y = Variable(data, requires_grad=True), Variable(target)

        acc = cal_acc(model, X, y)
        total_acc += acc * X.size(0)
        total_samples += X.size(0)

        print('batch accuracy: ', acc.item())

    avg_accuracy = total_acc / total_samples
    if is_adv:
        print(f'Average accuracy for attack type {attack_type}: {avg_accuracy:.4f}')
    else:
        print(f'Average accuracy for natural samples: {avg_accuracy:.4f}')

    return avg_accuracy




def load_model_weights(model, checkpoint_path, device='cuda'):

    checkpoint = torch.load(checkpoint_path, map_location=device)


    model.load_state_dict(checkpoint['model'], strict=False)

    print(f"Model weights loaded from {checkpoint_path}")

    return model

def main():

    CIFAR_CLASSES = 200
    genotype = genotypes.search_cifar10_338_0
    model = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
    model.drop_path_prob = args.drop_path_prob
    checkpoint_path ="/home3/AVICAGE/checkpoint/lossgairobuststem3380/EV_model_acc56.8400_20250622_073521.pt"# 你保存的模型路径
    model = load_model_weights(model, checkpoint_path, device='cuda:0')
    model = model.cuda()
    print("param size: ", utils.count_parameters_in_MB(model), 'MB')
    print(model)
    avg_accuracy = eval_acc(model, test_loader, is_adv=False)
    eval_adv_acc(model, test_loader, 'FGSM')
    eval_adv_acc(model, test_loader, 'PGD20')
    eval_adv_acc(model, test_loader, 'AA')


if __name__ == '__main__':
    main()