import argparse
import torch
import torchvision
from torch.autograd import Variable
from torchvision import transforms
import genotypes
import torch.backends.cudnn as cudnn
from model import NetworkCIFAR
import torchattacks
import utils
import os

parser = argparse.ArgumentParser(description='Evaluation')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--test-batch-size', type=int, default=32, metavar='N', help='input batch size for testing (default: 200)')
parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
parser.add_argument('--init_channels', type=int, default=32, help='num of init channels')
parser.add_argument('--layers', type=int, default=10, help='total number of layers')
parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument('--drop_path_prob', type=float, default=0.0, help='drop path probability')
parser.add_argument('--data_type', type=str, default='cifar100', help='which dataset to use')
parser.add_argument('--checkpoint', type=str, default="Adv_train_and_eval/Adv_eval_cifar/checkpoint/cifar100/EV_model_acc61.9900_20250314_205311.pt",
                    help='path to the saved .pth/.pt file')
args = parser.parse_args()


torch.cuda.set_device(args.gpu)
cudnn.benchmark = True

def cal_acc(model, X, y):
    out = model(X)
    # result = out[1]

    _, predicted = out.data.max(1)
    correct = (predicted == y.data).float().sum()
    accuracy = correct / y.size(0)
    return accuracy

def eval_adv_acc(model, test_loader, type, moe_model=None):
    model.eval()
    total_acc = 0.0
    total_samples = 0

    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()

        # 选择攻击方式
        if type == 'FGSM':
            attack = torchattacks.FGSM(model, eps=8 / 255)
        if type == 'PGD20':
            attack = torchattacks.PGD(model, eps=8 / 255, alpha=2 / 255, steps=20)
        if type == 'PGD100':
            attack = torchattacks.PGD(model, eps=8 / 255, alpha=2 / 255, steps=100)
        if type == 'APGD':
            attack = torchattacks.APGD(model, eps=8 / 255, steps=20)
        if type == 'AA':
            attack = torchattacks.AutoAttack(model, eps=8 / 255)

        #
        adv_images = attack(data, target)
        X, y = Variable(adv_images, requires_grad=True), Variable(target)

        #
        acc = cal_acc(model, X, y)
        total_acc += acc * X.size(0)
        total_samples += X.size(0)  #

        print('batch accuracy: ', acc.item())

    avg_accuracy = total_acc / total_samples  #
    print(f'Average accuracy for attack type {type}: {avg_accuracy:.4f}')
    return avg_accuracy

def eval_acc(model, test_loader, is_adv=False, attack_type=None):
    model.eval()
    total_acc = 0.0
    total_samples = 0

    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()


        if is_adv:

            if attack_type == 'FGSM':
                attack = torchattacks.FGSM(model, eps=8 / 255)
            if attack_type == 'PGD20':
                attack = torchattacks.PGD(model, eps=8 / 255, alpha=2 / 255, steps=20)
            if attack_type == 'PGD100':
                attack = torchattacks.PGD(model, eps=8 / 255, alpha=2 / 255, steps=100)
            if attack_type == 'APGD':
                attack = torchattacks.APGD(model, eps=8 / 255, steps=20)
            if attack_type == 'AA':
                attack = torchattacks.AutoAttack(model, eps=8 / 255)


            adv_images = attack(data, target)
            X, y = Variable(adv_images, requires_grad=True), Variable(target)
        else:

            X, y = Variable(data, requires_grad=True), Variable(target)

        acc = cal_acc(model, X, y)
        total_acc += acc * X.size(0)
        total_samples += X.size(0)

        print('batch accuracy: ', acc.item())

    avg_accuracy = total_acc / total_samples
    if is_adv:
        print(f'Average accuracy for attack type {attack_type}: {avg_accuracy:.4f}')
    else:
        print(f'Average accuracy for natural samples: {avg_accuracy:.4f}')

    return avg_accuracy



def load_model_weights(model, checkpoint_path, device='cpu'):

    ckpt = torch.load(checkpoint_path, map_location=device, weights_only=True)
    state = ckpt.get('model', ckpt)
    model.load_state_dict(state, strict=False)
    print(f"✓ Model weights loaded from {checkpoint_path}")
    return model


_,_,transform_test = utils._data_transforms_cifar100(args)
testset = torchvision.datasets.CIFAR100(root="./data/", train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, num_workers=4)



def main():
    CIFAR_CLASSES = 100
    genotype = genotypes.search_cifar10_338_0
    model = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
    model.drop_path_prob = args.drop_path_prob
    checkpoint_path = args.checkpoint
    model = load_model_weights(model, checkpoint_path, device='cuda:0')
    model = model.cuda()
    print(model)
    avg_accuracy = eval_acc(model, test_loader, is_adv=False)
    eval_adv_acc(model, test_loader, 'FGSM')
    eval_adv_acc(model, test_loader, 'PGD20')
    eval_adv_acc(model, test_loader, 'PGD100')
    eval_adv_acc(model, test_loader, 'APGD')
    eval_adv_acc(model, test_loader, 'AA')


if __name__ == '__main__':
    main()