import torch
import torch.nn as nn
from models import wideresnet, resnet, squeezenet
from models import BaseVAE, vae, cvae

def load_model(model_name, in_channels=1, num_classes=2):
    print('-' * 50)
    print('LOAD MODEL:', model_name)
    print('-' * 50)

    if model_name == 'wideresnet':
        model = wideresnet.wideresnet(in_channels, num_classes)
    elif model_name == 'resnet50':
        model = resnet.resnet50(in_channels, num_classes)
    elif model_name == 'squeezenet':
        model = squeezenet.squeezenet(in_channels, num_classes)
    elif model_name == 'vae':
        model = vae.vae(in_channels, dim = 256, z_dim = 128)  
    elif model_name == 'cvae':
        model = cvae.CVAE(in_channels = 1, num_classes = 2, latent_dim = 4096, hidden_dims = None, img_size = 64)      
    return model