import argparse
import torch
import torchvision
from torch.autograd import Variable
from torchvision import transforms
import genotypes2
import genotypes
import genotype3
import torch.backends.cudnn as cudnn
from model import NetworkCIFAR
from model2 import NetworkCIFAR as NetworkCIFAR2
from model_arnas import NetworkCIFAR_new as modelArnas
import torchattacks


parser = argparse.ArgumentParser(description='Evaluation')
parser.add_argument('--gpu', type=int, default=5, help='gpu device id')
parser.add_argument('--test-batch-size', type=int, default=32, metavar='N', help='input batch size for testing (default: 200)')
parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
parser.add_argument('--init_channels', type=int, default=36, help='num of init channels')
parser.add_argument('--layers', type=int, default=20, help='total number of layers')
parser.add_argument('--drop_path_prob', type=float, default=0.0, help='drop path probability')

args = parser.parse_args()


torch.cuda.set_device(args.gpu)
cudnn.benchmark = True



transform_list = [transforms.ToTensor()]
transform_test = transforms.Compose(transform_list)
testset = torchvision.datasets.CIFAR10(root="/home/data/", train=False, download=True, transform=transform_test)



test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, num_workers=4)


# def cal_acc(model, X, y):
#     out = model(X)
#     err = (out.data.max(1)[1] != y.data).float().sum()
#     return err

def cal_acc(model, X, y):
    out = model(X)
    # result = out[1]

    _, predicted = out.data.max(1)
    correct = (predicted == y.data).float().sum()
    accuracy = correct / y.size(0)
    return accuracy




def eval_adv_acc(model_source, model_target, test_loader):
    model_source.eval()
    model_target.eval()
    total_acc = 0.0
    total_samples = 0

    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()
        attack = torchattacks.PGD(model_source, eps=8/255, alpha=2/255, steps=20)
        adv_images = attack(data, target)
        X, y = Variable(adv_images, requires_grad=True), Variable(target)
        # 计算准确率
        acc = cal_acc(model_target, X, y)
        total_acc += acc * X.size(0)
        total_samples += X.size(0)
        print('batch accuracy: ', acc.item())

    avg_accuracy = total_acc / total_samples  # 计算平均准确率
    print(f'Average accuracy for attack type {type}: {avg_accuracy:.4f}')
    return avg_accuracy


def load_model_weights(model, checkpoint_path, device):

    checkpoint = torch.load(checkpoint_path, map_location=device)
    print(">>> checkpoint keys:", checkpoint.keys())


    if 'model' in checkpoint:
        sd = checkpoint['model']
    elif 'state_dict' in checkpoint:
        sd = checkpoint['state_dict']
    else:
        sd = checkpoint


    new_sd = {k.replace('module.', ''): v for k, v in sd.items()}


    model.load_state_dict(new_sd, strict=False)
    model.to(device)
    print(f"✅ Loaded {len(new_sd)} parameters from\n   {checkpoint_path}\n   on {device}")

    # 5. 统计并打印模型参数量
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"🔢 Model parameters: total={total_params:,}, trainable={trainable_params:,}")

    return model


def main():
    if args.data_type == 'cifar100':
        CIFAR_CLASSES = 100
    elif args.data_type == 'cifar10':
        CIFAR_CLASSES = 10
    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
    print(f'Using device: {device}')



    # model_pdarts = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, False, genotypes.PDARTS)
    # model_racl = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, False, genotypes.RACL)
    # model_e2rnas = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, False, genotypes.E2RNAS)
    # model_lrnas_100 = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, True, genotypes.search_cifar100)



    # model_darts = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, False, genotypes2.DARTS)
    # checkpoint_path = "/home/DARTSex/checkpoint/EV/EV_model_acc85.9200_20250507_104405.pt"
    # model_darts = load_model_weights(model_darts, checkpoint_path,device=str(device))
    # model_darts   = model_darts.to(device)

    #
    # model_lrnas_10 = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, True, genotypes2.search_cifar10)
    # checkpoint_path = "/home/LRNASex/checkpoint/lrnasc36cifar10/EV_model_acc83.3900_20250506_143707.pt"
    # model_lrnas_10 = load_model_weights(model_lrnas_10, checkpoint_path, device='cuda:0')
    # model_lrnas_10 = model_lrnas_10.to(device)

    # model_advrush = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, False, genotypes2.ADVRUSH)
    # checkpoint_path = "/home/adrush/checkpoint/advrushcifar10/EV_model_acc86.3800_20250507_190506.pt"
    # model_advrush = load_model_weights(model_advrush, checkpoint_path,device=str(device))
    # model_advrush = model_advrush.to(device)
    #
    model_avi = NetworkCIFAR2(32, CIFAR_CLASSES, 10, False, genotypes.search_cifar10_338_0)
    checkpoint_path = "/home/cifar103380/checkpoint/cifar103380NOnormalise/EV_model_acc86.5600_20250620_160255.pt"
    model_avi = load_model_weights(model_avi, checkpoint_path, device='cuda:0')
    model_avi =model_avi.to(device)

    model_racl = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, False, genotypes2.RACL_25_1)
    checkpoint_path = "/home/LRNASex/checkpoint/RACL251CIFAR10/EV_model_acc85.1300_20250621_180215.pt"
    model_racl = load_model_weights(model_racl, checkpoint_path,device=str(device))
    model_racl = model_racl.to(device)

    # model_pdarts = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, False, genotypes2.PDARTS)
    # checkpoint_path = "/home/lrnasxiugai/checkpoint/PDARTS/EV_model_acc85.3800_20250622_090837.pt"
    # model_pdarts = load_model_weights(model_pdarts, checkpoint_path,device=str(device))
    # model_pdarts = model_pdarts.to(device)

    # model_arnas = modelArnas(64, CIFAR_CLASSES, args.layers, False, genotype3.ARNAS)
    # checkpoint_path = "/home/eval-EXP-20250512-025946/checkpoint-epoch119.pth.tar"
    # model_arnas = load_model_weights(model_arnas, checkpoint_path,device=str(device))
    # model_arnas = model_arnas.to(device)


    model_source = model_avi
    '''
    #eval_adv_acc(model_source, model_darts, test_loader)
    eval_adv_acc(model_source, model_pdarts, test_loader)
    eval_adv_acc(model_source, model_racl, test_loader)
    eval_adv_acc(model_source, model_advrush, test_loader)
    eval_adv_acc(model_source, model_e2rnas, test_loader)
    '''
    eval_adv_acc(model_source, model_racl, test_loader)
    # eval_adv_acc(model_source, model_lrnas_100, test_loader)







if __name__ == '__main__':
    main()