import unittest, torch, os

import dolphin.modules as sm
from dolphin import Distribution
from dolphin.provenances import get_provenance

class AggregationTest(unittest.TestCase):
    def setUp(self) -> None:
        Distribution.provenance = get_provenance("damp")

    def tearDown(self) -> None:
        return super().tearDown()
    
    def test_collate(self):
        d1 = Distribution(torch.tensor([0.1, 0.2, 0.3]), [1, 2, 3])
        d2 = Distribution(torch.tensor([0.4, 0.5]), [4, 5])
        collated = sm.collate(d1, d2)

        self.assertEqual(list(collated.symbols), [[1, 4], [1, 5], [2, 4], [2, 5], [3, 4], [3, 5]])

    def test_sum(self):
        d1 = Distribution(torch.tensor([0.1, 0.2, 0.3]), [1, 2, 3])
        d2 = Distribution(torch.tensor([0.4, 0.5]), [4, 5])
        summed = sm.sum(d1, d2)

        self.assertEqual(list(summed.symbols), [5, 6, 7, 8])

    def test_product(self):
        d1 = Distribution(torch.tensor([0.1, 0.2, 0.3]), [1, 2, 3])
        d2 = Distribution(torch.tensor([0.4, 0.5]), [4, 5])
        product = sm.prod(d1, d2)

        self.assertEqual(list(product.symbols), [4, 5, 8, 10, 12, 15])

    def test_count(self):
        d1 = Distribution(torch.tensor([0.1, 0.2, 0.3]), [1, 2, 3])
        d2 = Distribution(torch.tensor([0.4, 0.5]), [4, 5])
        counted = sm.count(lambda x: x % 2 == 0, d1, d2)

        self.assertEqual(list(counted.symbols), [0, 1, 2])

    def test_max(self):
        d1 = Distribution(torch.tensor([0.1, 0.2, 0.3]), [1, 2, 3])
        d2 = Distribution(torch.tensor([0.4, 0.5]), [4, 5])
        maxed = sm.max(d1, d2)

        self.assertEqual(list(maxed.symbols), [4, 5])
    
    def test_min(self):
        d1 = Distribution(torch.tensor([0.1, 0.2, 0.3]), [1, 2, 3])
        d2 = Distribution(torch.tensor([0.4, 0.5]), [4, 5])
        mined = sm.min(d1, d2)

        self.assertEqual(list(mined.symbols), [1, 2, 3])

    def test_mean(self):
        d1 = Distribution(torch.tensor([0.1, 0.2, 0.3]), [1, 2, 3])
        d2 = Distribution(torch.tensor([0.4, 0.5]), [4, 5])
        meaned = sm.mean(d1, d2)

        self.assertEqual(list(meaned.symbols), [2.5, 3, 3.5, 4])


if __name__ == '__main__':
    unittest.main(buffer=True)