# coding=utf-8
# Copyright 2020 The Gsa Net Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for specialized operators for GSA-Net."""
import functools

from absl.testing import parameterized
import numpy as np
import tensorflow.compat.v1 as tf

from gsa_net import ops


class OpsTest(parameterized.TestCase, tf.test.TestCase):

  def setUp(self):
    super().setUp()
    self.batch_size = 2
    self.height = 3
    self.width = 4
    self.head_count = 8
    self.depth = 16
    input_depth = 16
    self.normal_inputs = tf.random_normal((
        self.batch_size, self.height, self.width, input_depth))
    self.multi_head_inputs = tf.random_normal((
        self.batch_size,
        self.height,
        self.width,
        self.head_count,
        self.depth // self.head_count,
    ))

  @parameterized.named_parameters(
      ('multi_head', True),
      ('single_head', False),
  )
  def test_compute_attention_components(self, multi_head):
    if multi_head:
      head_count = self.head_count
    else:
      head_count = 1
    queries, keys, values = ops._compute_attention_components(
        self.normal_inputs, self.depth, head_count)
    self.assertEqual(queries.shape, keys.shape)
    self.assertEqual(queries.shape, (
        self.batch_size,
        self.height,
        self.width,
        head_count,
        self.depth // head_count,
    ))
    self.assertEqual(values.shape, (
        self.batch_size,
        self.height,
        self.width,
        head_count,
        self.depth // head_count,
    ))

    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      queries, keys, values = sess.run((queries, keys, values))
    self.assertNotAllEqual(queries, 0)
    self.assertNotAllEqual(keys, 0)
    self.assertNotAllEqual(values, 0)

  @parameterized.named_parameters(
      ('training', True),
      ('evaluation', False),
  )
  def test_multi_head_batch_normalization(self, is_training):
    batch_normalized = ops.multi_head_batch_normalization(
        self.multi_head_inputs,
        is_training,
        batch_norm_momentum=0.99,
        batch_norm_epsilon=1e-5,
    )
    self.assertEqual(batch_normalized.shape, self.multi_head_inputs.shape)

    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      numpy_batch_normalized = sess.run(batch_normalized)
    self.assertNotAllEqual(numpy_batch_normalized, 0)

  @parameterized.named_parameters(
      ('normal', False),
      ('multi_head', True),
  )
  def test_softmax_spatial(self, multi_head):
    if multi_head:
      inputs = self.multi_head_inputs
    else:
      inputs = self.normal_inputs
    normalized = ops.softmax_spatial(inputs)
    self.assertEqual(normalized.shape, inputs.shape)

    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      numpy_normalized = sess.run(normalized)
    spatial_axes = (1, 2)
    spatial_sums = numpy_normalized.sum(axis=spatial_axes)
    self.assertAllClose(spatial_sums, np.ones_like(spatial_sums))

  def test_generate_lookup_tensor(self):
    lookup_tensor = ops._generate_lookup_tensor(self.height)
    self.assertEqual(
        lookup_tensor.shape, (self.height, self.height, 2 * self.height - 1))

    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      lookup_numpy = sess.run(lookup_tensor)
    self.assertAllLessEqual(lookup_numpy.astype(np.float32), 1)

  def test_efficient_relative_attention_2d(self):
    batch_norm_fn = functools.partial(
        ops.multi_head_batch_normalization,
        is_training=True,
        batch_norm_momentum=0.99,
        batch_norm_epsilon=1e-5,
    )
    outputs = ops.efficient_relative_attention_2d(
        queries=self.multi_head_inputs,
        keys=self.multi_head_inputs,
        values=self.multi_head_inputs,
        batch_norm_fn=batch_norm_fn,
    )
    self.assertEqual(outputs.shape, self.multi_head_inputs.shape)

    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      numpy_outputs = sess.run(outputs)
    self.assertNotAllEqual(numpy_outputs, 0)

  def test_global_self_attention(self):
    outputs = ops.global_self_attention(
        self.normal_inputs, self.depth, self.head_count)
    self.assertEqual(outputs.shape, (
        self.batch_size, self.height, self.width, self.depth))

    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      numpy_outputs = sess.run(outputs)
    self.assertNotAllEqual(numpy_outputs, 0)


if __name__ == '__main__':
  tf.test.main()
