from residual_chronos.Aggregator import SingleBestAggregator
import unittest
import numpy as np

class TestSingleBestAggregator(unittest.TestCase):
    def setUp(self):
        self.tolerance = 1e-4

    def test_single_best_aggregator(self):
        G = np.array([
            [1, 0], 
            [0, 1], 
            [1, 1], 
            [0, 0], 
            [1, 0], 
            [0, 1], 
            [1, 1], 
            [0, 0], 
            [1, 0], 
            [0, 1]])
        y = np.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0])
        expected_coef = np.array([1.0, 0.0])
        aggregator = SingleBestAggregator(num_models=2)
        aggregator.fit(G, y)
        self.assertTrue(np.allclose(aggregator.coef_, expected_coef, rtol=self.tolerance, atol=self.tolerance))
        self.assertTrue(np.allclose(aggregator.predict(G), y, rtol=self.tolerance, atol=self.tolerance))

    def test_single_best_aggregator_with_multiple_experts(self):
        G = np.array([
            [1, 0], 
            [1, 0], 
            [1, 0], 
            [1, 0], 
            [1, 0], 
            [1, 0], 
            [1, 0], 
            [0, 0], 
            [0, 0], 
            [0, 0]])
        y = np.array([1, 1, 1, 1, 1, 0, 0, 0, 0, 0])
        aggregator = SingleBestAggregator(num_models=2)
        aggregator.fit(G, y)
        expected_coef = np.array([1.0, 0.0])
        self.assertTrue(np.allclose(aggregator.coef_, expected_coef, rtol=self.tolerance, atol=self.tolerance))
        self.assertTrue(np.allclose(aggregator.predict(G), G[:,0], rtol=self.tolerance, atol=self.tolerance))

    def test_single_best_aggregator_with_multiple_experts_2(self):
        G = np.array([
            [1, 0, 0, 1, 2, 3], 
            [1, 0, 0, 1, 2, 3], 
            [1, 0, 0, 1, 2, 3], 
            [1, 0, 0, 1, 2, 3]])
        y = np.array([1, 1, 1, 1])
        aggregator = SingleBestAggregator(num_models=6)
        aggregator.fit(G, y)
        expected_coef = np.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0])
        self.assertTrue(np.allclose(aggregator.coef_, expected_coef, rtol=self.tolerance, atol=self.tolerance))
        self.assertTrue(np.allclose(aggregator.predict(G), G[:,0], rtol=self.tolerance, atol=self.tolerance))


if __name__ == '__main__':
    unittest.main()


