import unittest

import torch
from toponetx.classes import SimplicialComplex as ToponetxSimplicialComplex

from scrawl.simplicial import SimplicialData
from scrawl.transformers import toponetx_to_data, toponetx_to_sc
from scrawl.walker import RandomWalk, Walker


class RandomWalkTest(unittest.TestCase):
    def setUp(self) -> None:
        simplicial_complex = ToponetxSimplicialComplex()

        # nodes
        for i in range(7):
            simplicial_complex.add_simplex([i])

        # edges
        simplicial_complex.add_simplex([0, 1])
        simplicial_complex.add_simplex([0, 2])
        simplicial_complex.add_simplex([1, 2])
        simplicial_complex.add_simplex([1, 3])
        simplicial_complex.add_simplex([1, 4])
        simplicial_complex.add_simplex([2, 3])
        simplicial_complex.add_simplex([3, 4])
        simplicial_complex.add_simplex([4, 5])
        simplicial_complex.add_simplex([4, 6])
        simplicial_complex.add_simplex([5, 6])

        # triangles
        simplicial_complex.add_simplex([0, 1, 2])
        simplicial_complex.add_simplex([1, 2, 3])
        simplicial_complex.add_simplex([4, 5, 6])

        # simplicial data
        self.data = toponetx_to_data(
            simplicial_complex, None, torch.float, torch.device("cpu")
        )
        self.data[0] = torch.tensor(
            [
                [0, 0, 0],
                [0, 0, 0],
                [0, 0, 0],
                [0, 0, 1],
                [0, 1, 0],
                [0, 0, 0],
                [0, 0, 0],
            ],
            dtype=torch.float,
        )

        self.data[1] = torch.tensor(
            [
                [0, 0],
                [0, 0],
                [0, 1],
                [0, 0],
                [0, 0],
                [4, 1],
                [1, 1],
                [0, 0],
                [2, 2],
                [0, 0],
            ],
            dtype=torch.float,
        )

        self.data[2] = torch.tensor(
            [[0, 0, 0], [0, 1, 1], [1, 0, 0]], dtype=torch.float
        )

        self.random_walk_indices = torch.tensor([2, 5, 6, 7, 8, 7])
        self.random_walk = RandomWalk(
            toponetx_to_sc(simplicial_complex),
            rank=1,
            walk_simplices=self.random_walk_indices,
            connection_direction=torch.tensor([0, 1, -1, -1, 1, 1]),
            connection_simplices=torch.tensor([-1, 1, 3, 4, 2, 2]),
            window_size=4,
        )

    def test_walk_features(self) -> None:
        expected_features = torch.tensor(
            [[0, 1], [4, 1], [1, 1], [0, 0], [2, 2], [0, 0]]
        )
        expected_connection_lower = torch.tensor(
            [[0, 0, 0], [0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 0, 0], [0, 0, 0]]
        )
        expected_connection_upper = torch.tensor(
            [[0, 0, 0], [0, 1, 1], [0, 0, 0], [0, 0, 0], [1, 0, 0], [1, 0, 0]]
        )
        expected_identity = torch.tensor(
            [
                [0, 0, 0, 0],
                [0, 0, 0, 0],
                [0, 0, 0, 0],
                [0, 0, 0, 0],
                [0, 0, 0, 0],
                [0, 0, 1, 0],
            ]
        )
        expected_lower_connectivity = torch.tensor(
            [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 1], [0, 1, 1]]
        )
        expected_upper_connectivity = torch.tensor(
            [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 1]]
        )

        _, feature_matrix = self.random_walk.feature_matrix(
            self.data, local_window_size=4, lower_feature_size=3, upper_feature_size=3
        )

        torch.testing.assert_close(
            feature_matrix[:, :2], expected_features, check_dtype=False
        )

        torch.testing.assert_close(
            feature_matrix[:, 2:5], expected_connection_lower, check_dtype=False
        )

        torch.testing.assert_close(
            feature_matrix[:, 5:8], expected_connection_upper, check_dtype=False
        )

        torch.testing.assert_close(
            feature_matrix[:, 8:12], expected_identity, check_dtype=False
        )

        torch.testing.assert_close(
            feature_matrix[:, 12:15], expected_lower_connectivity, check_dtype=False
        )

        torch.testing.assert_close(
            feature_matrix[:, 15:18], expected_upper_connectivity, check_dtype=False
        )

    def test_walk_features_without_cofaces(self) -> None:
        empty_data = SimplicialData(self.data.domain, torch.float, torch.device("cpu"))

        expected = torch.tensor(
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
                [0, 0, 1, 0, 0, 1, 1, 0, 0, 1],
            ]
        )

        indices, feature_matrix = self.random_walk.feature_matrix(
            empty_data,
            local_window_size=4,
            lower_feature_size=0,
            upper_feature_size=0,
        )

        torch.testing.assert_close(indices, self.random_walk_indices, check_dtype=False)
        torch.testing.assert_close(feature_matrix, expected, check_dtype=False)

    def test_walk_no_errors(self) -> None:
        simplicial_complex = ToponetxSimplicialComplex()
        simplicial_complex.add_simplex([1, 2, 3])
        simplicial_complex.add_simplex([2, 3, 4])
        simplicial_complex.add_simplex([2, 5])

        try:
            walker = Walker(
                toponetx_to_sc(simplicial_complex),
                use_lower_connections=True,
                use_upper_connections=True,
                max_rank=2,
            )
            walker.random_walk(0, 0, 5)
            walker.random_walk(1, 0, 5)
            walker.random_walk(2, 0, 5)
        except Exception as e:
            self.fail("Walker raised an exception: " + str(e))


class WalkerTest(unittest.TestCase):
    def setUp(self):
        simplicial_complex = ToponetxSimplicialComplex()

        # nodes
        for i in range(7):
            simplicial_complex.add_simplex([i])

        # edges
        simplicial_complex.add_simplex([0, 1])
        simplicial_complex.add_simplex([0, 2])

        simplicial_complex.add_simplex([1, 2])
        simplicial_complex.add_simplex([1, 3])
        simplicial_complex.add_simplex([1, 4])

        simplicial_complex.add_simplex([2, 3])

        simplicial_complex.add_simplex([3, 4])

        simplicial_complex.add_simplex([4, 5])
        simplicial_complex.add_simplex([4, 6])

        simplicial_complex.add_simplex([5, 6])

        # triangles
        simplicial_complex.add_simplex([0, 1, 2])
        simplicial_complex.add_simplex([1, 2, 3])
        simplicial_complex.add_simplex([4, 5, 6])

        self.simplicial_complex = toponetx_to_sc(simplicial_complex)

    def test_parallel_walks(self):
        start_indices = torch.tensor([0, 2])
        random_values = torch.tensor([[1, 2, 0, 0, 2, 1], [3, 2, 0, 0, 0, 2]])

        walker = Walker(
            self.simplicial_complex,
            use_lower_connections=True,
            use_upper_connections=True,
            max_rank=2,
        )
        (
            walk_simplices,
            connection_direction,
            connection_simplices,
        ) = walker._parallel_walks(1, start_indices, 3, random_values)

        torch.testing.assert_allclose(
            walk_simplices, torch.tensor([[0, 3, 0, 1], [2, 5, 1, 0]])
        )


if __name__ == "__main__":
    unittest.main()
