# coding=utf-8
# Lint as: python3
from m_layer import MLayer
import numpy
import tensorflow as tf

# coding=utf-8
# Lint as: python3
"""Tests for m_layer.

We test that we can set up a model and run inference.
We are not trying to ensure that training works.
"""


class MLayerTest(tf.test.TestCase):

  def test_m_layer(self):

    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(3,)),
        MLayer(dim_m=5, matrix_init='normal'),
        tf.keras.layers.ActivityRegularization(l2=1e-4),
        tf.keras.layers.Flatten()
    ])
    mlayer = model.layers[1]
    self.assertEqual(mlayer.trainable_weights[0].shape, [3, 5, 5])

    prediction = model.predict(tf.ones((1, 3)))
    self.assertFalse(numpy.isnan(prediction).any())

if __name__ == '__main__':
  tf.test.main()
