import os

import torch

from models import Linear, FC


def create_model(model_name, n_classes, n_channels, model_load_dir=None):
    model = get_model(model_name=model_name, n_classes=n_classes,
                      n_channels=n_channels)
    model.cuda()
    if model_load_dir:
        param = torch.load(os.path.join(model_load_dir, "model.pth"))
        model.load_state_dict(param)
    return model


def get_model(model_name, **kargs):
    models = {"linear": Linear,
              "fc": FC
              }
    return models[model_name](**kargs)
