import os
import pdb


import torch

def get_model(args, num_classes, load_ckpt=True, load_epoch=None):
    if args.in_dataset == 'imagenet':
        if args.model_arch == 'resnet50':
            from models.resnet import resnet50
            model = resnet50(num_classes=num_classes, pretrained=True)
        elif args.model_arch == 'resnet50-supcon':
            from models.resnet_supcon import SupConResNet
            model = SupConResNet(num_classes=num_classes)
            checkpoint = torch.load('path_to_ckpt_model')
            state_dict = {str.replace(k, 'module.', ''): v for k, v in checkpoint['model'].items()}
            checkpoint_linear = torch.load('path_to_ckpt_model')
            state_dict['fc.weight'] = checkpoint_linear['model']['fc.weight'] 
            state_dict['fc.bias'] = checkpoint_linear['model']['fc.bias'] 
            model.load_state_dict(state_dict)
        elif args.model_arch == 'vit':
            from pytorch_pretrained_vit import ViT
            model = ViT('B_16_imagenet1k', pretrained=True)
    else:
        # create model
        if args.model_arch == 'resnet18':
            from models.resnet import resnet18_cifar
            model = resnet18_cifar(num_classes=num_classes, method=args.method, p=args.p)
            checkpoint = torch.load('path_to_ckpt_model')
            checkpoint = {'state_dict': {key.replace("module.", ""): value for key, value in checkpoint['state_dict'].items()}}
            model.load_state_dict(checkpoint['state_dict'])
        elif args.model_arch == 'resnet18-supcon':
            from models.resnet_ss import resnet18_cifar
            model = resnet18_cifar(num_classes=num_classes, method=args.method)
            checkpoint = torch.load('path_to_ckpt_model')
            checkpoint = {'state_dict': {key.replace("module.", ""): value for key, value in checkpoint['state_dict'].items()}}
            checkpoint_linear = torch.load('path_to_ckpt_model')
            checkpoint['state_dict']['fc.weight'] = checkpoint_linear['model']['fc.weight'] 
            checkpoint['state_dict']['fc.bias'] = checkpoint_linear['model']['fc.bias'] 
            model.load_state_dict(checkpoint['state_dict'])
        else:
            assert False, 'Not supported model arch: {}'.format(args.model_arch)

    model.cuda()
    model.eval()
    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    return model
