import unittest

from weight_function import WeightFunction

import numpy as np

class TestWeightFunction(unittest.TestCase):

    def test_rkhs_norm(self):
        a = WeightFunction()
        self.assertEqual(0, a.running_norm)
        self.assertEqual(0, a.calculate_rkhs_norm())
        a.add_center([1,2,3], 1)
        self.assertEqual(1, a.running_norm)
        self.assertEqual(1, a.calculate_rkhs_norm())
        a.add_center([2,2,-1], 1)
        self.assertAlmostEqual(a.running_norm, a.calculate_rkhs_norm())

    def test_iadd(self):
        a = WeightFunction()
        a.add_center([1,2,3], 1)
        self.assertEqual(1, a.running_norm)
        a += a
        self.assertEqual(2, a.running_norm)
        self.assertEqual(2, a.norm())
        a += a
        self.assertEqual(4, a.running_norm)
        self.assertEqual(4, a.norm())

    def test_remove_useless_coefs_empty(self):
        a = WeightFunction()
        a.remove_useless_centers()
        self.assertEqual(0, a.get_n_centers())
        
    def test_remove_useless_coefs_one_dim(self):
        a = WeightFunction().add_center([1],0)
        a.remove_useless_centers()
        self.assertEqual(0, a.get_n_centers())

    def test_remove_useless_coefs_multiple_dim(self):
        a = WeightFunction().add_center([1,1],0)
        a.remove_useless_centers()
        self.assertEqual(0, a.get_n_centers())
        a = WeightFunction().add_center([1,1],0)
        a.add_center([1,2,3], 1)
        a.add_center([2,2,-1], 1)
        a.remove_useless_centers()
        self.assertEqual(2, a.get_n_centers())

    def test_difference_has_zero_norm(self):
        a = WeightFunction()
        a.add_center([1,2,3], 1)
        a.add_center([2,2,-1], 1)
        self.assertEqual(0, (a-a).norm())

    def test_project(self):
        a = WeightFunction()
        a.add_center([1,2,3], 1)
        a.project(2)
        self.assertLessEqual(a.norm(), 2)
        a.project(0.1)
        self.assertLessEqual(a.norm(), 0.1)

    def test_merge_duplicate_centers(self):
        a = WeightFunction()
        a.merge_duplicate_centers()
        self.assertEqual(0, a.get_n_centers())
        self.assertEqual(0, a.norm())
        a.add_center([1,2,3], 1)
        self.assertEqual(1, a.get_n_centers())
        self.assertEqual(1, a.running_norm)
        self.assertEqual(1, a.calculate_rkhs_norm())
        a.add_center([1,2,3], 1)
        self.assertEqual(2, a.get_n_centers())
        self.assertEqual(2, a.running_norm)
        self.assertEqual(2, a.calculate_rkhs_norm())
        a.merge_duplicate_centers()
        self.assertEqual(1, a.get_n_centers())
        self.assertEqual(2, a.running_norm)
        self.assertEqual(2, a.calculate_rkhs_norm())

    def test_scalar_product_with_self(self):
        rng = np.random.default_rng(0)
        for size in range(1, 10):
            for _ in range(100):
                a = WeightFunction()
                for _ in range(10):
                    param = rng.normal(size=size).tolist()
                    coef = float(rng.normal(size=1))
                    a.add_center(param, coef)
                    self.assertGreater(a.scalar_product(a), 0)

if __name__ == '__main__':
    unittest.main()