import os
# import models.densenet as dn
# import models.wideresnet as wn


import torch

def get_model(args, num_classes, load_ckpt=True, load_epoch=None):
    if args.in_dataset == 'imagenet':
        from models.mobilenetv2_knn import MobileNetV2
        model = MobileNetV2()
        if load_ckpt:
            checkpoint = torch.load("./checkpoints/mobilenet_v2-b0353104.pth", map_location='cpu')
            model.load_state_dict(checkpoint)
    else:
        # create model
        if args.model_arch == 'densenet':
            from models.densenet_knn import DenseNet3
            model =DenseNet3(args.layers, num_classes, args.growth, reduction=args.reduce, bottleneck=True,
                                 dropRate=args.droprate, normalizer=None, method=args.method, p=args.p)
            if load_ckpt:
                checkpoint = torch.load("./checkpoints/densenet100_cifar10.pth", map_location='cpu')
                model.load_state_dict(checkpoint)
        else:
            assert False, 'Not supported model arch: {}'.format(args.model_arch)


    model.cuda()
    model.eval()
    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    return model
