import os
# import models.densenet as dn
# import models.wideresnet as wn


import torch
import torch.nn as nn


def get_model(args, num_classes, load_ckpt=True, load_epoch=None):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if args.in_dataset == 'imagenet':
        if args.model_arch == 'resnet18':
            from models.resnet import resnet18
            model = resnet18(num_classes=num_classes, pretrained=True)
        elif args.model_arch == 'resnet50':
            from models.resnet import resnet50
            model = resnet50(num_classes=num_classes, pretrained=True)
        elif args.model_arch == 'resnet50-supcon':
            from models.resnet_supcon import SupCEResNet
            model = SupCEResNet(name='resnet50', num_classes=num_classes)
            if load_ckpt:
                checkpoint = torch.load(f"./checkpoints/{args.in_dataset}/model.pth")
                #state_dict = {str.replace(k, 'module.', ''): v for k, v in checkpoint['model'].items()}
                model.load_state_dict(checkpoint)

    
    elif args.in_dataset == 'pathmnist':
        if args.model_arch == 'vgg19':
            from torchvision.models import vgg19
            model = vgg19()
            num_features = model.classifier[6].in_features
            model.classifier[6] = nn.Linear(num_features, 9)
        elif args.model_arch == 'mobilenet_v2':
            from torchvision.models import mobilenet_v2
            model = mobilenet_v2()
            num_features = model.classifier[1].in_features
            model.classifier[1] = nn.Linear(num_features, 9)
        elif args.model_arch == 'resnet50':
            from torchvision.models import resnet50
            model = resnet50(weights=None)

        model.load_state_dict(torch.load(f'./checkpoints/pathmnist/model.pth'))

    else:
        # create model
        if args.model_arch == 'densenet':
            model = dn.DenseNet3(args.layers, num_classes, args.growth, reduction=args.reduce, bottleneck=True,
                                 dropRate=args.droprate, normalizer=None, method=args.method, p=args.p)
        elif args.model_arch == 'densenet-supcon':
            from models.densenet_ss import DenseNet3
            model = DenseNet3(args.layers, num_classes, args.growth, reduction=args.reduce, bottleneck=True,
                                     dropRate=args.droprate, normalizer=None, method=args.method, p=args.p)
        elif args.model_arch == 'resnet18':
            from models.resnet import resnet18_cifar
            model = resnet18_cifar(num_classes=num_classes, method=args.method, p=args.p)
        elif args.model_arch == 'resnet18-supcon':
            from models.resnet_ss import resnet18_cifar
            model = resnet18_cifar(num_classes=num_classes, method=args.method)
        elif args.model_arch == 'resnet18-supce':
            from models.resnet_ss import resnet18_cifar
            model = resnet18_cifar(num_classes=num_classes, method=args.method)
        elif args.model_arch == 'resnet34':
            from models.resnet import resnet34_cifar
            model = resnet34_cifar(num_classes=num_classes, method=args.method, p=args.p)
        elif args.model_arch == 'resnet34-supcon':
            from models.resnet_ss import resnet34_cifar
            model = resnet34_cifar(num_classes=num_classes, method=args.method)
        elif args.model_arch == 'resnet34-supce':
            from models.resnet_ss import resnet34_cifar
            model = resnet34_cifar(num_classes=num_classes, method=args.method)
        else:
            assert False, 'Not supported model arch: {}'.format(args.model_arch)

        if load_ckpt:
            epoch = args.epochs
            if load_epoch is not None:
                epoch = load_epoch
            checkpoint = torch.load(f'./checkpoints/CIFAR-10/model.pth')
            model.load_state_dict(checkpoint['net'])


    model.to(device)
    model.eval()
    # get the number of model parameters
    #print('Number of model parameters: {}'.format(
    #    sum([p.data.nelement() for p in model.parameters()])))
    return model
