import random
from residual_chronos.Aggregator import LinearAggregator
import unittest
import numpy as np
import termcolor

class TestLinearAggregator(unittest.TestCase):
    def test_linear_aggregator(self):
        np.random.seed(42)
        random.seed(42)
        G = np.array([
                    [1, 0], 
                    [1, 0], 
                    [1, 0], 
                    [1, 0]])
        y = np.array([1, 1, 1, 1])
        aggregator = LinearAggregator(num_models=2)
        aggregator.fit(G, y)
        expected_coef = np.array([1.0, 0.0])
        assert np.allclose(aggregator.coef_, expected_coef, rtol=1e-3, atol=1e-3)

    def test_linear_aggregator_with_balanced_experts(self):
        np.random.seed(42)
        random.seed(42)
        G = np.array([
                    [1, 0], 
                    [1, 0], 
                    [0, 1], 
                    [0, 1]])
        y = np.array([1, 1, 1, 1])
        aggregator = LinearAggregator(num_models=2, normalizer='sum')
        aggregator.fit(G, y)
        expected_coef = np.array([0.5, 0.5])
        assert np.allclose(aggregator.coef_, expected_coef, rtol=1e-4, atol=1e-4)

    def test_linear_aggregator_with_multiple_experts(self):
        np.random.seed(42)
        random.seed(42)
        G = np.array([
                    [1, 0, 1, 0], 
                    [1, 0, 1, 0], 
                    [0, 1, 0, 1], 
                    [0, 1, 0, 1]])
        y = np.array([1, 1, 1, 1])
        aggregator = LinearAggregator(num_models=4, normalizer='sum')
        aggregator.fit(G, y)
        expected_coef = np.array([0.28388229, 0.28388229, 0.21611771, 0.21611771])
        assert np.allclose(aggregator.coef_, expected_coef, rtol=1e-3, atol=1e-3)

    def test_linear_aggregator_with_multiple_experts_3(self):
        np.random.seed(42)
        random.seed(42)
        G = np.array([
                    [1, 0, 5, 0], 
                    [1, 0, 5, 0], 
                    [0, 1, 0, 5], 
                    [0, 1, 0, 5]])
        y = np.array([1, 1, 1, 1])
        aggregator = LinearAggregator(num_models=4)
        aggregator.fit(G, y)
        print(aggregator.coef_)
        expected_coef = np.array([0.0, 0.0, 0.199952, 0.199952])
        assert np.allclose(aggregator.coef_, expected_coef, rtol=1e-3, atol=1e-3)


if __name__ == '__main__':
    unittest.main()


