from models.mlp import *


def get_model(model_name, output_size, data_shape, **kwargs):
    if model_name == "Linear":
        return LinearModel(num_classes=output_size, data_shape=data_shape)
    elif model_name == "SMLP":
        width = kwargs.get("width", 350)
        return SMLP(width=width, num_classes=output_size, data_shape=data_shape)
    elif model_name == "DMLP":
        width = kwargs.get("width", 350)
        return DMLP(width=width, num_classes=output_size, data_shape=data_shape)
    else:
        raise ValueError(
            f"model_name='{model_name}' is not supported. Choose from ['Linear', 'SMLP', 'DMLP']."
        )
