import os
import torch

import sys
sys.path.append('..')
from models.resnet18_32x32 import ResNet18_32x32

def get_model(args, num_classes, load_ckpt=True, load_epoch=None):

    if args.in_dataset == 'cifar10':
        if args.model_arch == 'resnet18':
            model = ResNet18_32x32()

    if load_ckpt:

        model.load_state_dict(torch.load('../weights/resnet18_9554.pth'))

    model.cuda()
    model.eval()
    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    return model