import models.base.preact_resnet as preact_resnet
import models.base.resnet_20_32_cifar as resnet_20_32_cifar
import models.base.resnet_18_34_cifar as resnet_18_34_cifar
import models.base.vit as vit

import torchvision

def load_model(model_name, num_classes=10):
    print('-' * 50)
    print('MODEL NAME:', model_name)
    print('NUM CLASSES:', num_classes)
    print('-' * 50)

    # resnet32 : 578
    # resnet50 : 4610
    # vit_b_16 : 3075
    # preact_resnet18 : 4610
    # vgg19_bn : 25089
    model = None
    if model_name == 'vgg19':
        model = torchvision.models.vgg19(num_classes=num_classes)
    if model_name == 'vgg19_bn':
        model = torchvision.models.vgg19_bn(num_classes=num_classes)
    if model_name == 'resnet18':
        model = resnet_18_34_cifar.resnet18(num_classes=num_classes)
    if model_name == 'resnet34':
        model = resnet_18_34_cifar.resnet34(num_classes=num_classes)
    if model_name == 'resnet50':
        model = torchvision.models.resnet50(num_classes=num_classes)
    if model_name == 'preact_resnet18':
        model = preact_resnet.PreActResNet18(num_classes=num_classes)
    if model_name == 'preact_resnet34':
        model = preact_resnet.PreActResNet34(num_classes=num_classes)
    if model_name == 'resnet20':
        model = resnet_20_32_cifar.resnet20(num_classes=num_classes)
    if model_name == 'resnet32':
        model = resnet_20_32_cifar.resnet32(num_classes=num_classes)
    if model_name == 'resnet56':
        model = resnet_20_32_cifar.resnet56(num_classes=num_classes)
    if model_name == 'vit_b_16':
        model = torchvision.models.vision_transformer.vit_b_16(num_classes=num_classes)
    if model_name == 'densenet161':
        model = torchvision.models.densenet161(num_classes=num_classes)
    
    # names = []
    # cnt = 0
    # zuida = 0
    # for name, param in model.named_parameters():
    #     if 'num_batches_tracked' not in name:
    #         print(name + '------->' + str(param.size()))
    #         shape = param.size()
    #         if len(shape) == 4:
    #             zuida = max(zuida, shape[1] * shape[2] * shape[3])
    #         if len(shape) == 2:
    #             zuida = max(zuida, shape[1])
            
    #         cnt+=1
    # print(names)
    # print(cnt)
    # print(zuida)
    
    return model

if __name__ == "__main__":
    model = load_model('resnet56', 10)
    print(model)


