# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for icosahedral_mesh."""

import chex
import numpy as np

from absl.testing import absltest, parameterized

from . import icosahedral_mesh


def _get_mesh_spec(splits: int):
    """Returns size of the final icosahedral mesh resulting from the splitting."""
    num_vertices = 12
    num_faces = 20
    for _ in range(splits):
        # Each previous face adds three new vertices, but each vertex is shared
        # by two faces.
        num_vertices += num_faces * 3 // 2
        num_faces *= 4
    return num_vertices, num_faces


class IcosahedralMeshTest(parameterized.TestCase):
    def test_icosahedron(self):
        mesh = icosahedral_mesh.get_icosahedron()
        _assert_valid_mesh(mesh, num_expected_vertices=12, num_expected_faces=20)

    @parameterized.parameters(list(range(5)))
    def test_get_hierarchy_of_triangular_meshes_for_sphere(self, splits):
        meshes = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(splits=splits)
        prev_vertices = None
        for mesh_i, mesh in enumerate(meshes):
            # Check that `mesh` is valid.
            num_expected_vertices, num_expected_faces = _get_mesh_spec(mesh_i)
            _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces)

            # Check that the first N vertices from this mesh match all of the
            # vertices from the previous mesh.
            if prev_vertices is not None:
                leading_mesh_vertices = mesh.vertices[: prev_vertices.shape[0]]
                np.testing.assert_array_equal(leading_mesh_vertices, prev_vertices)

            # Increase the expected/previous values for the next iteration.
            if mesh_i < len(meshes) - 1:
                prev_vertices = mesh.vertices

    @parameterized.parameters(list(range(4)))
    def test_merge_meshes(self, splits):
        mesh_hierarchy = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
            splits=splits
        )
        mesh = icosahedral_mesh.merge_meshes(mesh_hierarchy)

        expected_faces = np.concatenate([m.faces for m in mesh_hierarchy], axis=0)
        np.testing.assert_array_equal(mesh.vertices, mesh_hierarchy[-1].vertices)
        np.testing.assert_array_equal(mesh.faces, expected_faces)

    def test_faces_to_edges(self):
        faces = np.array([[0, 1, 2], [3, 4, 5]])

        # This also documents the order of the edges returned by the method.
        expected_edges = np.array([[0, 1], [3, 4], [1, 2], [4, 5], [2, 0], [5, 3]])
        expected_senders = expected_edges[:, 0]
        expected_receivers = expected_edges[:, 1]

        senders, receivers = icosahedral_mesh.faces_to_edges(faces)

        np.testing.assert_array_equal(senders, expected_senders)
        np.testing.assert_array_equal(receivers, expected_receivers)


def _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces):
    vertices = mesh.vertices
    faces = mesh.faces
    chex.assert_shape(vertices, [num_expected_vertices, 3])
    chex.assert_shape(faces, [num_expected_faces, 3])

    # Vertices norm should be 1.
    vertices_norm = np.linalg.norm(vertices, axis=-1)
    np.testing.assert_allclose(vertices_norm, 1.0, rtol=1e-6)

    _assert_positive_face_orientation(vertices, faces)


def _assert_positive_face_orientation(vertices, faces):
    # Obtain a unit vector that points, in the direction of the face.
    face_orientation = np.cross(
        vertices[faces[:, 1]] - vertices[faces[:, 0]],
        vertices[faces[:, 2]] - vertices[faces[:, 1]],
    )
    face_orientation /= np.linalg.norm(face_orientation, axis=-1, keepdims=True)

    # And a unit vector pointing from the origin to the center of the face.
    face_centers = vertices[faces].mean(1)
    face_centers /= np.linalg.norm(face_centers, axis=-1, keepdims=True)

    # Positive orientation means those two vectors should be parallel
    # (dot product, 1), and not anti-parallel (dot product, -1).
    dot_center_orientation = np.einsum("ik,ik->i", face_orientation, face_centers)

    # Check that the face normal is parallel to the vector that joins the center
    # of the face to the center of the sphere. Note we need a small tolerance
    # because some discretizations are not exactly uniform, so it will not be
    # exactly parallel.
    np.testing.assert_allclose(dot_center_orientation, 1.0, atol=6e-4)


if __name__ == "__main__":
    absltest.main()
