# coding=utf-8
# Copyright 2023 The Soar Neurips2023 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import utils


class UtilsTest(parameterized.TestCase):

  def test_normalize(self):
    ds = np.array([[3, 4], [0, 0]])
    np.testing.assert_array_almost_equal(
        utils.normalize(ds), np.array([[0.6, 0.8], [0, 0]])
    )

  def test_compute_ground_truth(self):
    ds = np.array([[4, 1], [1, 4], [-2, -2]])
    qs = np.array([[2, 3], [3, 2], [-4, 0]])
    gt = utils.compute_ground_truth(ds, qs, 2)
    np.testing.assert_array_equal(gt, [[1, 0], [0, 1], [2, 1]])

  def test_redo_assignment(self):
    centers = np.array([[3, 4], [3, 5], [9, 9]])
    ds = np.array([[4, 4], [8, 8], [-1, 8], [2, 9]])
    res0, res1 = utils.redo_assignment(centers, ds)
    np.testing.assert_array_equal(res0, [0, 2, 1, 1])
    np.testing.assert_array_equal(res1, [1, 1, 0, 0])

  def test_compute_score_diffs(self):
    ds = np.array([[4, 5], [6, 7], [9, 9], [0, 0], [-3, 1]])
    centers = np.array([[4, 4], [0, 1]])
    tokenization = np.array([0, 0, 0, 1, 1])
    qs = np.array([[-2, 3], [3, -3]])
    gt = np.array([[0, 1, 2], [3, 2, 4]])
    diffs = utils.compute_score_diffs(ds, centers, tokenization, qs, gt)
    np.testing.assert_array_equal(
        diffs, [[7 - 4, 9 - 4, 9 - 4], [0 + 3, 0 - 0, -12 + 3]]
    )

  @parameterized.parameters((4,), (8,), (13,), (42,))
  def test_kmeans(self, num_pts):
    # (num_pts x 3); ds[i] = (i, i, i).
    ds = np.tile(np.arange(num_pts), (3, 1)).T
    centers, tokenization = utils.train_kmeans(ds, len(ds))

    # Running k-means with k=len(ds) should just give us the original dataset,
    # perhaps with points re-ordered (which is why we sort the centers).
    np.testing.assert_array_equal(ds, np.sort(centers, axis=0))
    # Dataset reconstructed from k-means should exactly equal original dataset.
    np.testing.assert_array_equal(ds, centers[tokenization])

  @parameterized.parameters((1,), (2.5,), (5.5,))
  def test_compute_avq_center_invariants(self, eta):
    pts = np.array([[0, 0, 0], [3, 4, 0], [6, 8, 0]])
    np.testing.assert_array_almost_equal(
        utils.compute_avq_center(pts[[]], eta), np.zeros(3)
    )
    np.testing.assert_array_almost_equal(
        utils.compute_avq_center(pts[[0]], eta), pts[0]
    )
    np.testing.assert_array_almost_equal(
        utils.compute_avq_center(pts[[1]], eta), pts[1]
    )
    np.testing.assert_array_almost_equal(
        utils.compute_avq_center(pts[[0, 1]], eta),
        pts[1] / 2 if eta == 1 else pts[1],
    )

    expected_norm = (5**eta + 10**eta) / (5 ** (eta - 1) + 10 ** (eta - 1))
    expected = expected_norm * np.array([0.6, 0.8, 0])
    np.testing.assert_array_almost_equal(
        utils.compute_avq_center(pts[[1, 2]], eta), expected
    )

  def test_compute_avq_center(self):
    points = np.array([[6, 8], [8, 6], [10, 0]])
    np.testing.assert_array_almost_equal(
        utils.compute_avq_center(points, 1), [8, 14 / 3]
    )
    np.testing.assert_array_almost_equal(
        utils.compute_avq_center(points, 2), [8.65481382, 4.92284468]
    )

  def test_compute_avq_centers(self):
    pts = np.array([[0, 0, 0], [3, 4, 0], [6, 8, 0]])
    centers = np.zeros((2, 3))
    np.testing.assert_array_almost_equal(
        utils.compute_avq_centers(pts, centers, [0, 0, 1], 2),
        [[3, 4, 0], [6, 8, 0]],
    )

  def test_get_centroid_ranks(self):
    centers = np.array([[0, 0], [1, 1], [2, 2], [3, 3]])
    tokenization = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3])
    qs = np.array([[1, 1], [-1, -1]])
    want = np.array([[1, 2, 6, 10], [3, 4, 5, 6]])
    np.testing.assert_array_equal(
        utils.get_centroid_ranks(centers, tokenization, qs, want),
        [[3, 3, 1, 0], [1, 1, 1, 2]],
    )

  def test_soar_assign(self):
    ds = np.array([[4, 1], [9, 2], [3, 7], [6, 6]])
    centers = np.array([[3, 3], [5, 2]])
    toke1 = np.array([1, 1, 0, 1])

    # This is standard min-L2 assignment
    res0 = utils.soar_assign(ds, centers, toke1, 0, True)
    np.testing.assert_array_equal(res0, [1, 1, 0, 1])
    np.testing.assert_array_equal(
        utils.soar_assign(ds, centers, toke1, 0, False), [0, 0, 1, 0]
    )

    # <ds[i] - centers[j], centers[toke1[i]]>:
    # 1, 2
    # 24, 16
    # -20, 19
    # 15, 17
    # Only in rows 1 and 4 does the alternate center have lower orthogonality
    # amplified penalty, so they are the ones where the tokenization changes.
    np.testing.assert_array_equal(
        utils.soar_assign(ds, centers, toke1, 1e9, True), [0, 1, 0, 0]
    )

  def test_group_rank_data(self):
    ranks = [2, 2, 2, 6, 3]
    data = [3, 3, 6, 9, 0]
    xs1, res_mean = utils.group_rank_data(ranks, data, np.mean)
    xs2, res_median = utils.group_rank_data(ranks, data, np.median)
    np.testing.assert_array_equal(xs1, [2, 3, 6])
    np.testing.assert_array_equal(xs2, [2, 3, 6])
    np.testing.assert_array_equal(res_mean, [4, 0, 9])
    np.testing.assert_array_equal(res_median, [3, 0, 9])

  def test_kmr(self):
    centers = np.array([[1, 1], [2, 2], [4, -1]])
    toke1 = np.array([0, 0, 1, 2])
    qs = np.array([[1, 0], [1, 3]])
    gt = np.array([[1, 3, 2], [0, 2, 3]])

    kmr = utils.kmr(centers, toke1, None, qs, gt)
    # Center rank: [2, 1, 0], [1, 0, 2]
    # Query 1: 2, 0, 1 -> 4, 1, 2 -> [0, 1, 2, 2, 3]
    # Query 2: 1, 0, 2 -> 3, 1, 4 -> [0, 1, 1, 2, 3]
    # Total: [0, 2, 3, 4, 6] / 6
    np.testing.assert_array_almost_equal(kmr, [0, 2 / 6, 3 / 6, 4 / 6, 1])

    toke2 = np.array([1, 2, 0, 1])
    kmr2 = utils.kmr(centers, toke1, toke2, qs, gt)
    # New partition sizes: [3, 3, 2]
    # New query 1 ranks: 0, 1, 2 -> best ranks 0, 0, 1 -> 2, 2, 5
    # New query 2 ranks: 0, 1, 0 -> best ranks 0, 0, 0 -> 3, 3, 3
    expected2 = [0, 0, 2 / 6, 5 / 6, 5 / 6, 1, 1, 1, 1]
    np.testing.assert_array_almost_equal(kmr2, expected2)


if __name__ == "__main__":
  absltest.main()
