# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for geopoly."""
import itertools

from absl.testing import absltest
from internal import geopoly
import jax
from jax import random
import numpy as np


def is_same_basis(x, y, tol=1e-10):
  """Check if `x` and `y` describe the same linear basis."""
  match = np.minimum(
      geopoly.compute_sq_dist(x, y), geopoly.compute_sq_dist(x, -y)) <= tol
  return (np.all(np.array(x.shape) == np.array(y.shape)) and
          np.all(np.sum(match, axis=0) == 1) and
          np.all(np.sum(match, axis=1) == 1))


class GeopolyTest(absltest.TestCase):

  def test_compute_sq_dist_reference(self):
    """Test against a simple reimplementation of compute_sq_dist."""
    num_points = 100
    num_dims = 10
    rng = random.PRNGKey(0)
    key, rng = random.split(rng)
    mat0 = jax.random.normal(key, [num_dims, num_points])
    key, rng = random.split(rng)
    mat1 = jax.random.normal(key, [num_dims, num_points])

    sq_dist = geopoly.compute_sq_dist(mat0, mat1)

    sq_dist_ref = np.zeros([num_points, num_points])
    for i in range(num_points):
      for j in range(num_points):
        sq_dist_ref[i, j] = np.sum((mat0[:, i] - mat1[:, j])**2)

    np.testing.assert_allclose(sq_dist, sq_dist_ref, atol=1e-4, rtol=1e-4)

  def test_compute_sq_dist_single_input(self):
    """Test that compute_sq_dist with a single input works correctly."""
    rng = random.PRNGKey(0)
    num_points = 100
    num_dims = 10
    key, rng = random.split(rng)
    mat0 = jax.random.normal(key, [num_dims, num_points])

    sq_dist = geopoly.compute_sq_dist(mat0)
    sq_dist_ref = geopoly.compute_sq_dist(mat0, mat0)
    np.testing.assert_allclose(sq_dist, sq_dist_ref)

  def test_compute_tesselation_weights_reference(self):
    """A reference implementation for triangle tesselation."""
    for v in range(1, 10):
      w = geopoly.compute_tesselation_weights(v)
      perm = np.array(list(itertools.product(range(v + 1), repeat=3)))
      w_ref = perm[np.sum(perm, axis=-1) == v, :] / v
      # Check that all rows of x are close to some row in x_ref.
      self.assertTrue(is_same_basis(w.T, w_ref.T))

  def test_generate_basis_golden(self):
    """A mediocre golden test against two arbitrary basis choices."""
    basis = geopoly.generate_basis('icosahedron', 2)
    basis_golden = np.array([[0.85065081, 0.00000000, 0.52573111],
                             [0.80901699, 0.50000000, 0.30901699],
                             [0.52573111, 0.85065081, 0.00000000],
                             [1.00000000, 0.00000000, 0.00000000],
                             [0.80901699, 0.50000000, -0.30901699],
                             [0.85065081, 0.00000000, -0.52573111],
                             [0.30901699, 0.80901699, -0.50000000],
                             [0.00000000, 0.52573111, -0.85065081],
                             [0.50000000, 0.30901699, -0.80901699],
                             [0.00000000, 1.00000000, 0.00000000],
                             [-0.52573111, 0.85065081, 0.00000000],
                             [-0.30901699, 0.80901699, -0.50000000],
                             [0.00000000, 0.52573111, 0.85065081],
                             [-0.30901699, 0.80901699, 0.50000000],
                             [0.30901699, 0.80901699, 0.50000000],
                             [0.50000000, 0.30901699, 0.80901699],
                             [0.50000000, -0.30901699, 0.80901699],
                             [0.00000000, 0.00000000, 1.00000000],
                             [-0.50000000, 0.30901699, 0.80901699],
                             [-0.80901699, 0.50000000, 0.30901699],
                             [-0.80901699, 0.50000000, -0.30901699]])
    self.assertTrue(is_same_basis(basis.T, basis_golden.T))

    basis = geopoly.generate_basis('octahedron', 4)
    basis_golden = np.array([[0.00000000, 0.00000000, -1.00000000],
                             [0.00000000, -0.31622777, -0.94868330],
                             [0.00000000, -0.70710678, -0.70710678],
                             [0.00000000, -0.94868330, -0.31622777],
                             [0.00000000, -1.00000000, 0.00000000],
                             [-0.31622777, 0.00000000, -0.94868330],
                             [-0.40824829, -0.40824829, -0.81649658],
                             [-0.40824829, -0.81649658, -0.40824829],
                             [-0.31622777, -0.94868330, 0.00000000],
                             [-0.70710678, 0.00000000, -0.70710678],
                             [-0.81649658, -0.40824829, -0.40824829],
                             [-0.70710678, -0.70710678, 0.00000000],
                             [-0.94868330, 0.00000000, -0.31622777],
                             [-0.94868330, -0.31622777, 0.00000000],
                             [-1.00000000, 0.00000000, 0.00000000],
                             [0.00000000, -0.31622777, 0.94868330],
                             [0.00000000, -0.70710678, 0.70710678],
                             [0.00000000, -0.94868330, 0.31622777],
                             [0.40824829, -0.40824829, 0.81649658],
                             [0.40824829, -0.81649658, 0.40824829],
                             [0.31622777, -0.94868330, 0.00000000],
                             [0.81649658, -0.40824829, 0.40824829],
                             [0.70710678, -0.70710678, 0.00000000],
                             [0.94868330, -0.31622777, 0.00000000],
                             [0.31622777, 0.00000000, -0.94868330],
                             [0.40824829, 0.40824829, -0.81649658],
                             [0.40824829, 0.81649658, -0.40824829],
                             [0.70710678, 0.00000000, -0.70710678],
                             [0.81649658, 0.40824829, -0.40824829],
                             [0.94868330, 0.00000000, -0.31622777],
                             [0.40824829, -0.40824829, -0.81649658],
                             [0.40824829, -0.81649658, -0.40824829],
                             [0.81649658, -0.40824829, -0.40824829]])
    self.assertTrue(is_same_basis(basis.T, basis_golden.T))


if __name__ == '__main__':
  absltest.main()
