import tensorflow as tf

try:
    from .resnet_utils import _conv, _fc, _bn, _residual_block, _residual_block_first 
except:
    import sys, os
    sys.path.append(os.path.abspath(__file__))
    from resnet_utils import _conv, _fc, _bn, _residual_block, _residual_block_first 

def mlp(inputs, hidden_layers, num_outputs, is_training=None):
    if len(inputs.shape) > 2:
        inputs = tf.layers.flatten(inputs)   
    h = inputs
    for h_size in hidden_layers:
        h = tf.layers.dense(h, h_size, tf.nn.relu)
    outputs = tf.layers.dense(h, num_outputs)
    return outputs

def _resnet18_conv_feedforward(h, kernels, filters, strides, num_outputs, is_training, is_ATT_DATASET=False):
    """
    Forward pass through a ResNet-18 network

    Returns:
        Logits of a resnet-18 conv network
    """
    trainalbe_vars = []

    # Conv1
    h = _conv(h, kernels[0], filters[0], strides[0], trainalbe_vars, name='conv_1')
    h = _bn(h, trainalbe_vars, is_training, name='bn_1')
    h = tf.nn.relu(h)

    # Conv2_x
    h = _residual_block(h, trainalbe_vars, is_training, name='conv2_1')
    h = _residual_block(h, trainalbe_vars, is_training, name='conv2_2')

    # Conv3_x
    h = _residual_block_first(h, filters[2], strides[2], trainalbe_vars, is_training, name='conv3_1', is_ATT_DATASET=is_ATT_DATASET)
    h = _residual_block(h, trainalbe_vars, is_training, name='conv3_2')

    # Conv4_x
    h = _residual_block_first(h, filters[3], strides[3], trainalbe_vars, is_training, name='conv4_1', is_ATT_DATASET=is_ATT_DATASET)
    h = _residual_block(h, trainalbe_vars, is_training, name='conv4_2')

    # Conv5_x
    h = _residual_block_first(h, filters[4], strides[4], trainalbe_vars, is_training, name='conv5_1', is_ATT_DATASET=is_ATT_DATASET)
    h = _residual_block(h, trainalbe_vars, is_training, name='conv5_2')

    # Apply average pooling
    h = tf.reduce_mean(h, [1, 2])

    logits = _fc(h, num_outputs, trainalbe_vars, name='fc_1', is_cifar=True)
    # if self.network_arch == 'RESNET-S':
    #     logits = _fc(h, num_outputs, trainalbe_vars, name='fc_1', is_cifar=True)
    # else:
    #     logits = _fc(h, num_outputs, trainalbe_vars, name='fc_1')
    return logits

def resnet(inputs, hidden_layers, num_outputs, is_training):

    kernels = [3, 3, 3, 3, 3]
    filters = [20, 20, 40, 80, 160]
    strides = [1, 0, 2, 2, 2]

    outputs = _resnet18_conv_feedforward(inputs, kernels, filters, strides, num_outputs, is_training)

    return outputs
