import torch
import os
import torch.utils.data

from val_robust.val_imgnet_a import eval_imgnet_a
from val_robust.val_imgnet_r import eval_imgnet_r
from val_robust.val_imgnet_sk import eval_imgnet_sk
from val_robust.val_adv import eval_imgnet_adv
from val_robust.val_imgnet_c import eval_imgnet_c
# disable some unnecessary API to accelerate the training...
torch.autograd.profiler.emit_nvtx(False)
torch.autograd.profiler.profile(False)

def get_args_parser(add_help=True):
    import argparse

    parser = argparse.ArgumentParser(description="Validation for ImageNet_1k and robust benchmarks", add_help=add_help)

    parser.add_argument("--weights", default='', type=str,
                        help="the weights enum name to load")

    parser.add_argument("-b", "--batch-size", default=512, type=int,
                        help="images per gpu, the total batch size is $NGPU x batch_size")

    parser.add_argument("-j", "--workers", default=8, type=int, metavar="N",
                        help="number of data loading workers (default: 8)")

    parser.add_argument("--data-path", default="", type=str, help="dataset path")

    return parser

def main(args):

    model = None  # Assign Model here

    dataset_root_loc = args.data_path
    sk_loc = os.path.join(dataset_root_loc,'ImageNet_Robust/sketch')
    a_loc = os.path.join(dataset_root_loc, 'ImageNet_Robust/imagenet-a')
    r_loc = os.path.join(dataset_root_loc, 'ImageNet_Robust/imagenet-r')
    c_roc = os.path.join(dataset_root_loc, 'ImageNet_Robust/ImageNet-C')
    val_loc = os.path.join(dataset_root_loc, 'val')

    a_acc = eval_imgnet_a(net=model, batch_size=args.batch_size, num_workers=args.workers, location=a_loc)
    r_acc = eval_imgnet_r(net=model, batch_size=args.batch_size, num_workers=args.workers, location=r_loc)
    sk_acc = eval_imgnet_sk(net=model,batch_size=args.batch_size,num_workers=args.workers,location=sk_loc)
    adv_acc = eval_imgnet_adv(net=model, batch_size = args.batch_size // 8, num_workers=args.workers, location=val_loc)
    c_mCE = eval_imgnet_c(net=model, batch_size=args.batch_size, num_workers=args.workers, location=c_roc)
    # print('')
    # print('Summary:')

if __name__ == "__main__":
    args = get_args_parser().parse_args()
    main(args)

