import unittest

import torch
from toponetx.classes import SimplicialComplex

from scrawl.transformers import toponetx_to_data


class TestSimplicialData(unittest.TestCase):
    def test_from_simplex(self) -> None:
        complex = SimplicialComplex()
        complex.add_simplex([0], citations=10)
        complex.add_simplex([1], citations=20)
        complex.add_simplex([3], citations=30)
        complex.add_simplex([2], citations=40)
        complex.add_simplex([0, 1], citations=5)

        data = toponetx_to_data(complex, "citations", dtype=torch.int)

        torch.testing.assert_close(
            data[0], torch.tensor([[10], [20], [40], [30]]), check_dtype=False
        )
        torch.testing.assert_close(data[1], torch.tensor([[5]]), check_dtype=False)


if __name__ == "__main__":
    unittest.main()
