import argparse
import torch
import torchvision
from torch.autograd import Variable
from torchvision import transforms
import genotypes
import torch.backends.cudnn as cudnn
from model import NetworkCIFAR
import torchattacks
import os
import utils

parser = argparse.ArgumentParser(description='Evaluation')
parser.add_argument('--gpu', type=int, default=4, help='gpu device id')
parser.add_argument('--test-batch-size', type=int, default=100, metavar='N', help='input batch size for testing (default: 200)')
parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
parser.add_argument('--init_channels', type=int, default=32, help='num of init channels')
parser.add_argument('--layers', type=int, default=10, help='total number of layers')
parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
parser.add_argument('--drop_path_prob', type=float, default=0.0, help='drop path probability')
parser.add_argument('--data_type', type=str, default='cifar10', help='which dataset to use')
parser.add_argument('--checkpoint', type=str, default="",
                    help='path to the saved .pth/.pt file')
args = parser.parse_args()



torch.cuda.set_device(args.gpu)



cudnn.benchmark = True




def cal_acc(model, X, y):

    out = model(X)
    # result = out[1]

    _, predicted = out.data.max(1)
    correct = (predicted == y.data).float().sum()
    accuracy = correct / y.size(0)
    return accuracy

def eval_adv_acc(model, test_loader, type, moe_model=None):
    model.eval()
    total_acc = 0.0
    total_samples = 0

    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()


        if type == 'FGSM':
            attack = torchattacks.FGSM(model, eps=8 / 255)
        if type == 'PGD20':
            attack = torchattacks.PGD(model, eps=8 / 255, alpha=2 / 255, steps=20)
        if type == 'PGD100':
            attack = torchattacks.PGD(model, eps=8 / 255, alpha=2 / 255, steps=100)
        if type == 'APGD':
            attack = torchattacks.APGD(model, eps=8 / 255, steps=20)
        if type == 'AA':
            attack = torchattacks.AutoAttack(model, eps=8 / 255)

        adv_images = attack(data, target)
        X, y = Variable(adv_images, requires_grad=True), Variable(target)


        acc = cal_acc(model, X, y)
        total_acc += acc * X.size(0)
        total_samples += X.size(0)

        print('batch accuracy: ', acc.item())

    avg_accuracy = total_acc / total_samples
    print(f'Average accuracy for attack type {type}: {avg_accuracy:.4f}')
    return avg_accuracy

def eval_acc(model, test_loader, is_adv=False, attack_type=None):
    model.eval()
    total_acc = 0.0
    total_samples = 0

    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()

        if is_adv:

            if attack_type == 'FGSM':
                attack = torchattacks.FGSM(model, eps=8 / 255)
            if attack_type == 'PGD20':
                attack = torchattacks.PGD(model, eps=8 / 255, alpha=2 / 255, steps=20)
            if attack_type == 'PGD100':
                attack = torchattacks.PGD(model, eps=8 / 255, alpha=2 / 255, steps=100)
            if attack_type == 'APGD':
                attack = torchattacks.APGD(model, eps=8 / 255, steps=20)
            if attack_type == 'AA':
                attack = torchattacks.AutoAttack(model, eps=8 / 255)


            adv_images = attack(data, target)
            X, y = Variable(adv_images, requires_grad=True), Variable(target)
        else:

            X, y = Variable(data, requires_grad=True), Variable(target)

        acc = cal_acc(model, X, y)
        total_acc += acc * X.size(0)
        total_samples += X.size(0)

        print('batch accuracy: ', acc.item())

    avg_accuracy = total_acc / total_samples
    if is_adv:
        print(f'Average accuracy for attack type {attack_type}: {avg_accuracy:.4f}')
    else:
        print(f'Average accuracy for natural samples: {avg_accuracy:.4f}')

    return avg_accuracy





def load_model_weights(model, checkpoint_path, device='cpu'):
    if not os.path.isfile(checkpoint_path):
        raise FileNotFoundError(f'Checkpoint not found: {checkpoint_path}')
    if os.path.getsize(checkpoint_path) < 1024:
        raise ValueError('Checkpoint file is suspiciously small!')

    ckpt = torch.load(checkpoint_path, map_location=device, weights_only=True)
    state = ckpt.get('model', ckpt)
    model.load_state_dict(state, strict=False)
    print(f"✓ Model weights loaded from {checkpoint_path}")
    return model


transform_list = [transforms.ToTensor()]
transform_test = transforms.Compose(transform_list)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, num_workers=4)




def main():
    CIFAR_CLASSES = 10
    genotype = genotypes.search_cifar10_338_0
    model = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
    model.drop_path_prob = args.drop_path_prob
    checkpoint_path = args.checkpoint
    model = load_model_weights(model, checkpoint_path, device='cuda:0')
    model = model.cuda()
    print("param size: ", utils.count_parameters_in_MB(model), 'MB')
    print(model)
    avg_accuracy = eval_acc(model, test_loader, is_adv=False)

    eval_adv_acc(model, test_loader, 'FGSM')
    eval_adv_acc(model, test_loader, 'PGD20')
    eval_adv_acc(model, test_loader, 'PGD100')
    eval_adv_acc(model, test_loader, 'APGD')
    eval_adv_acc(model, test_loader, 'AA')


if __name__ == '__main__':
    main()