import argparse
import torch
import torchvision
from torch.autograd import Variable
from torchvision import transforms
import genotypes2
import genotypes

import torch.backends.cudnn as cudnn
from model import NetworkCIFAR
from model2 import NetworkCIFAR as NetworkCIFAR2
from model3 import NetworkCIFAR as NetworkCIFAR3
import torchattacks
from utils import _data_transforms_cifar100

parser = argparse.ArgumentParser(description='Evaluation')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--test-batch-size', type=int, default=16, metavar='N', help='input batch size for testing (default: 200)')
parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
parser.add_argument('--init_channels', type=int, default=36, help='num of init channels')
parser.add_argument('--layers', type=int, default=20, help='total number of layers')
parser.add_argument('--drop_path_prob', type=float, default=0.0, help='drop path probability')
args = parser.parse_args()


torch.cuda.set_device(args.gpu)
cudnn.benchmark = True

_, _, test_transform = _data_transforms_cifar100(args)
testset = torchvision.datasets.CIFAR100(root="/home3/data/", train=False, download=True, transform=test_transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, num_workers=4)


def cal_acc(model, X, y):
    out = model(X)
    # result = out[1]

    _, predicted = out.data.max(1)
    correct = (predicted == y.data).float().sum()
    accuracy = correct / y.size(0)
    return accuracy




def eval_adv_acc(model_source, model_target, test_loader):
    model_source.eval()
    model_target.eval()
    total_acc = 0.0
    total_samples = 0

    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()
        attack = torchattacks.PGD(model_source, eps=8/255, alpha=2/255, steps=20)
        adv_images = attack(data, target)
        X, y = Variable(adv_images, requires_grad=True), Variable(target)

        acc = cal_acc(model_target, X, y)
        total_acc += acc * X.size(0)
        total_samples += X.size(0)
        print('batch accuracy: ', acc.item())

    avg_accuracy = total_acc / total_samples
    print(f'Average accuracy for attack type {type}: {avg_accuracy:.4f}')
    return avg_accuracy



def load_model_weights(model, checkpoint_path, device):

    checkpoint = torch.load(checkpoint_path, map_location=device)
    print(">>> checkpoint keys:", checkpoint.keys())


    if 'model' in checkpoint:
        sd = checkpoint['model']
    elif 'state_dict' in checkpoint:
        sd = checkpoint['state_dict']
    else:
        sd = checkpoint


    new_sd = {k.replace('module.', ''): v for k, v in sd.items()}


    model.load_state_dict(new_sd, strict=False)
    model.to(device)
    print(f"✅ Loaded {len(new_sd)} parameters from\n   {checkpoint_path}\n   on {device}")


    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"🔢 Model parameters: total={total_params:,}, trainable={trainable_params:,}")

    return model


def main():
    if args.data_type == 'cifar100':
        CIFAR_CLASSES = 100
    elif args.data_type == 'cifar10':
        CIFAR_CLASSES = 10
    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
    print(f'Using device: {device}')



    # model_pdarts = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, False, genotypes.PDARTS)
    # model_racl = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, False, genotypes.RACL)
    # model_e2rnas = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, False, genotypes.E2RNAS)
    # model_lrnas_100 = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, True, genotypes.search_cifar100)


    #
    # model_darts = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, False, genotypes2.DARTS)
    # checkpoint_path = "/home3/DARTS/checkpoint/cifar100/EV_model_acc60.4400_20250621_003049.pt"
    # model_darts = load_model_weights(model_darts, checkpoint_path,device=str(device))
    # model_darts   = model_darts.to(device)


    # model_darts = NetworkCIFAR3(args.init_channels, CIFAR_CLASSES, args.layers, False, genotypes2.DARTS)
    # checkpoint_path = "/home3/DARTS/checkpoint/cifar100/EV_model_acc60.4400_20250621_003049.pt"
    # model_darts = load_model_weights(model_darts, checkpoint_path,device=str(device))
    # model_darts   = model_darts.to(device)


    #
    model_lrnas_100 = NetworkCIFAR3(args.init_channels, CIFAR_CLASSES, args.layers, True, genotypes2.search_cifar100)
    checkpoint_path = "/home3/LRNASex/checkpoint/LRNAScifar100lrnas/EV_model_acc56.8800_20250620_194601.pt"#
    model_lrnas_100 = load_model_weights(model_lrnas_100, checkpoint_path, device='cuda:0')
    model_lrnas_100 = model_lrnas_100.to(device)


    # model_advrush = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, False, genotypes2.ADVRUSH)
    # checkpoint_path = "/home3/ADVRUSH/checkpoint/EV/EV_model_acc60.8000_20250619_162412.pt"
    # model_advrush = load_model_weights(model_advrush, checkpoint_path,device=str(device))
    # model_advrush = model_advrush.to(device)



    #
    # model_avi = NetworkCIFAR2(32, CIFAR_CLASSES, 10, False, genotypes.search_cifar10_338_0)
    # checkpoint_path = "/home3/CIFAR100/EV_model_acc61.9900_20250314_205311.pt" #
    # model_avi = load_model_weights(model_avi, checkpoint_path, device='cuda:0')
    # model_avi =model_avi.to(device)

    # model_racl = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, False, genotypes2.RACL_25_1)
    # checkpoint_path = "/home3/RACL/checkpoint/EV/EV_model_acc60.7100_20250619_093902.pt"
    # model_racl = load_model_weights(model_racl, checkpoint_path,device=str(device))
    # model_racl = model_racl.to(device)

    model_pdarts = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, False, genotypes2.PDARTS)
    checkpoint_path ="/home3/CIFAR100/checkpoint/pdartscifar100/EV_model_acc60.5600_20250624_062626.pt"
    model_pdarts = load_model_weights(model_pdarts, checkpoint_path,device=str(device))
    model_pdarts = model_pdarts.to(device)

    # model_arnas = modelArnas(64, CIFAR_CLASSES, args.layers, False, genotype3.ARNAS)
    # checkpoint_path = "/home/eval-EXP-20250512-025946/checkpoint-epoch119.pth.tar"
    # model_arnas = load_model_weights(model_arnas, checkpoint_path,device=str(device))
    # model_arnas = model_arnas.to(device)


    model_source = model_lrnas_100
    '''
    #eval_adv_acc(model_source, model_darts, test_loader)
    eval_adv_acc(model_source, model_pdarts, test_loader)
    eval_adv_acc(model_source, model_racl, test_loader)
    eval_adv_acc(model_source, model_advrush, test_loader)
    eval_adv_acc(model_source, model_e2rnas, test_loader)
    '''
    eval_adv_acc(model_source, model_pdarts, test_loader)
    # eval_adv_acc(model_source, model_lrnas_100, test_loader)


if __name__ == '__main__':
    main()