from models.protonet_model.mlp import MLPProto


def get_model(P, modelstr):

    if modelstr == 'mlp':
        if 'protonet' in P.mode:
            if P.dataset == 40670:
                model = MLPProto(180,256,256)
    else:
        raise NotImplementedError()

    return model
