import unittest

import torch
from toponetx.classes import SimplicialComplex as ToponetxSimplicialComplex

from scrawl.datasets import collate_data
from scrawl.transformers import toponetx_to_data


class TestDatasetFunctions(unittest.TestCase):
    def test_collate_fn(self):
        sc_1 = ToponetxSimplicialComplex()
        sc_1.add_simplex((0,), citations=15)
        sc_1.add_simplex((1,), citations=12)
        sc_1.add_simplex((2,), citations=40)
        sc_1.add_simplex((3,), citations=43)
        sc_1.add_simplex((0, 1), citations=10)
        sc_1.add_simplex((1, 2), citations=20)
        sc_1.add_simplex((2, 3), citations=30)

        sc_2 = ToponetxSimplicialComplex()
        sc_2.add_simplex((0,), citations=65)
        sc_2.add_simplex((1,), citations=120)
        sc_2.add_simplex((2,), citations=40)
        sc_2.add_simplex((3,), citations=67)
        sc_2.add_simplex((0, 1), citations=100)
        sc_2.add_simplex((0, 2), citations=200)
        sc_2.add_simplex((1, 2), citations=200)
        sc_2.add_simplex((1, 3), citations=50)
        sc_2.add_simplex((0, 1, 2), citations=40)

        data_1 = toponetx_to_data(sc_1, ["citations"], dtype=torch.int)
        data_2 = toponetx_to_data(sc_2, ["citations"], dtype=torch.int)

        data_1.set_aux_tensor(-1, torch.tensor([1]))
        data_2.set_aux_tensor(-1, torch.tensor([0]))

        collated_data = collate_data([data_1, data_2])

        # Boundary matrices
        expected_boundaries = {
            1: torch.tensor(
                [
                    [1, 0, 0, 0, 0, 0, 0],
                    [1, 1, 0, 0, 0, 0, 0],
                    [0, 1, 1, 0, 0, 0, 0],
                    [0, 0, 1, 0, 0, 0, 0],
                    [0, 0, 0, 1, 1, 0, 0],
                    [0, 0, 0, 1, 0, 1, 1],
                    [0, 0, 0, 0, 1, 1, 0],
                    [0, 0, 0, 0, 0, 0, 1],
                ],
                dtype=torch.float32,
            ),
            2: torch.tensor([[0], [0], [0], [1], [1], [1], [0]], dtype=torch.float32),
        }

        for i in range(1, 3):
            torch.testing.assert_close(
                collated_data.domain.boundary[i], expected_boundaries[i]
            )

        # Features
        expected_features = {
            0: torch.tensor(
                [15, 12, 40, 43, 65, 120, 40, 67], dtype=torch.int32
            ).unsqueeze(1),
            1: torch.tensor(
                [10, 20, 30, 100, 200, 200, 50], dtype=torch.int32
            ).unsqueeze(1),
            2: torch.tensor([40], dtype=torch.int32).unsqueeze(1),
        }

        for i in range(0, 2):
            torch.testing.assert_close(collated_data[i], expected_features[i])

        # Aux features
        expected_aux_features = {
            -1: torch.tensor([1, 0]),
            0: torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]),
            1: torch.tensor([0, 0, 0, 1, 1, 1, 1]),
            2: torch.tensor([1]),
        }

        for i in expected_aux_features:
            torch.testing.assert_close(
                collated_data.aux_tensor(i), expected_aux_features[i]
            )


if __name__ == "__main__":
    unittest.main()
