from models import anil
from models import warp_grad
from models import mwn

def get_model(FLAGS, track_bn_stats=True):
    model = globals()[FLAGS.model]
    return model.ConvNet(
            FLAGS.num_classes,
            FLAGS.conv_channels,
            FLAGS.img_size,
            track_bn_stats,
    )
