import numpy as np
import tensorflow as tf


def small_convnet(input_tensor, n_logits, training=True):
    x = input_tensor

    for i in range(4):
        x = tf.layers.conv2d(x, 32, 3)
        x = tf.nn.leaky_relu(x)

    x = tf.layers.flatten(x)
    x = tf.layers.dense(x, n_logits)
    return x


def medium_convnet(input_tensor, n_logits, training=True):
    x = input_tensor

    for i in range(4):
        x = tf.layers.conv2d(x, 64, 3)
        x = tf.nn.leaky_relu(x)
        if i % 2 == 1:
            x = tf.nn.max_pool(x, [1, 2, 2, 1], [1, 2, 2, 1], 'VALID')

    x = tf.layers.flatten(x)
    x = tf.layers.dense(x, n_logits)
    return x


models = {
    'mnist': small_convnet,
    'imagenet': medium_convnet,
    'omniglot5': small_convnet,
    'omniglot10': small_convnet,
}