from __future__ import print_function
import tensorflow as tf
import numpy as np
from tf_interpolate import three_nn, three_interpolate

class GroupPointTest(tf.test.TestCase):
  def test(self):
    pass

  def test_grad(self):
    with self.test_session():
      points = tf.constant(np.random.random((1,8,16)).astype('float32'))
      print(points)
      xyz1 = tf.constant(np.random.random((1,128,3)).astype('float32'))
      xyz2 = tf.constant(np.random.random((1,8,3)).astype('float32'))
      dist, idx = three_nn(xyz1, xyz2)
      weight = tf.ones_like(dist)/3.0
      interpolated_points = three_interpolate(points, idx, weight)
      print(interpolated_points)
      err = tf.test.compute_gradient_error(points, (1,8,16), interpolated_points, (1,128,16))
      print(err)
      self.assertLess(err, 1e-4) 

if __name__=='__main__':
  tf.test.main() 
