import argparse
import torch
import torchvision
from torch.autograd import Variable
from torchvision import transforms
import genotypes
import torch.backends.cudnn as cudnn
from model import NetworkCIFAR
import torchattacks
from model2 import NetworkCIFAR as NetworkCIFAR2
import genotypes2
import os
import utils

parser = argparse.ArgumentParser(description='Evaluation')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--test-batch-size', type=int, default=50, metavar='N', help='input batch size for testing (default: 200)')
parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
parser.add_argument('--init_channels', type=int, default=32, help='num of init channels')
parser.add_argument('--layers', type=int, default=10, help='total number of layers')
parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
parser.add_argument('--drop_path_prob', type=float, default=0.0, help='drop path probability')
parser.add_argument('--data_type', type=str, default='cifar10', help='which dataset to use')
parser.add_argument('--checkpoint', type=str, default="",
                    help='path to the saved .pth/.pt file')
args = parser.parse_args()



torch.cuda.set_device(args.gpu)



cudnn.benchmark = True




def cal_acc(model, X, y):

    out = model(X)
    # result = out[1]

    _, predicted = out.data.max(1)
    correct = (predicted == y.data).float().sum()
    accuracy = correct / y.size(0)
    return accuracy

def eval_adv_acc(model, test_loader, type, moe_model=None):
    model.eval()
    total_acc = 0.0
    total_samples = 0

    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()



        if type == 'PGD20_8':
            attack = torchattacks.PGD(model, eps=8 / 255, alpha=2 / 255, steps=20)
        if type == 'PGD20_6':
            attack = torchattacks.PGD(model, eps=6 / 255, alpha=1.5 / 255, steps=20)
        if type == 'PGD20_4':
            attack = torchattacks.PGD(model, eps=4 / 255, alpha=1 / 255, steps=20)
        if type == 'PGD20_3':
            attack = torchattacks.PGD(model, eps=3 / 255, alpha=0.75 / 255, steps=30)
        if type == 'PGD20_2':
            attack = torchattacks.PGD(model, eps=2 / 255, alpha=0.5 / 255, steps=40)
        if type == 'PGD20_1':
            attack = torchattacks.PGD(model, eps=1 / 255, alpha=0.25 / 255, steps=50)

        adv_images = attack(data, target)
        X, y = Variable(adv_images, requires_grad=True), Variable(target)


        acc = cal_acc(model, X, y)
        total_acc += acc * X.size(0)
        total_samples += X.size(0)

        print('batch accuracy: ', acc.item())

    avg_accuracy = total_acc / total_samples
    print(f'Average accuracy for attack type {type}: {avg_accuracy:.4f}')
    return avg_accuracy




def load_model_weights(model, checkpoint_path, device='cpu'):
    if not os.path.isfile(checkpoint_path):
        raise FileNotFoundError(f'Checkpoint not found: {checkpoint_path}')
    if os.path.getsize(checkpoint_path) < 1024:
        raise ValueError('Checkpoint file is suspiciously small!')

    ckpt = torch.load(checkpoint_path, map_location=device, weights_only=True)
    state = ckpt.get('model', ckpt)
    model.load_state_dict(state, strict=False)
    print(f"✓ Model weights loaded from {checkpoint_path}")
    return model


transform_list = [transforms.ToTensor()]
transform_test = transforms.Compose(transform_list)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, num_workers=4)




def main():
    CIFAR_CLASSES = 10
    genotype = genotypes2.ADVRUSH
    model = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
    model.drop_path_prob = args.drop_path_prob
    checkpoint_path = args.checkpoint
    model = load_model_weights(model, checkpoint_path, device='cuda:0')
    model = model.cuda()
    print("param size: ", utils.count_parameters_in_MB(model), 'MB')
    print(model)



    eval_adv_acc(model, test_loader, 'PGD20_8')
    eval_adv_acc(model, test_loader, 'PGD20_6')
    eval_adv_acc(model, test_loader, 'PGD20_4')
    eval_adv_acc(model, test_loader, 'PGD20_3')
    eval_adv_acc(model, test_loader, 'PGD20_2')
    eval_adv_acc(model, test_loader, 'PGD20_1')



if __name__ == '__main__':
    main()