import unittest

import torch

from data_utils import row_normalize, col_normalize, get_adj_row_sum, get_adj_col_sum


class TestDataUtils(unittest.TestCase):
    def test_row_normalize(self):
        edge_index = torch.LongTensor(
            [
                [0, 2],
                [2, 0],
                [0, 3],
                [3, 0],
                [1, 2],
                [2, 1],
                [1, 3],
                [3, 1],
                [2, 3],
                [3, 2]
            ]
        ).T
        edge_weight = torch.FloatTensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
        expected_adj_row_sum = torch.FloatTensor([4, 12, 17, 22])
        row = edge_index[0]
        expected_row_norm_edge_weight = torch.FloatTensor([edge_weight[i] / expected_adj_row_sum[row[i]] for i in range(edge_weight.shape[0])])
        row_norm_edge_weight = row_normalize(edge_index=edge_index, edge_weight=edge_weight, n_nodes=4)
        self.assertTrue(torch.allclose(expected_row_norm_edge_weight, row_norm_edge_weight))

    def test_col_normalize(self):
        edge_index = torch.LongTensor(
            [
                [0, 2],
                [2, 0],
                [0, 3],
                [3, 0],
                [1, 2],
                [2, 1],
                [1, 3],
                [3, 1],
                [2, 3],
                [3, 2]
            ]
        ).T
        edge_weight = torch.FloatTensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
        expected_adj_col_sum = torch.FloatTensor([6, 14, 16, 19])
        col = edge_index[1]
        expected_col_norm_edge_weight = torch.FloatTensor([edge_weight[i] / expected_adj_col_sum[col[i]] for i in range(edge_weight.shape[0])])
        col_norm_edge_weight = col_normalize(edge_index=edge_index, edge_weight=edge_weight, n_nodes=4)
        self.assertTrue(torch.allclose(expected_col_norm_edge_weight, col_norm_edge_weight))

    def test_get_adj_row_sum(self):
        edge_index = torch.LongTensor(
            [
                [0, 2],
                [2, 0],
                [0, 3],
                [3, 0],
                [1, 2],
                [2, 1],
                [1, 3],
                [3, 1],
                [2, 3],
                [3, 2]
            ]
        ).T
        edge_weight = torch.FloatTensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
        expected_adj_row_sum = torch.FloatTensor([4, 12, 17, 22])
        adj_row_sum = get_adj_row_sum(edge_index=edge_index, edge_weight=edge_weight, n_nodes=4)
        self.assertTrue(torch.allclose(expected_adj_row_sum, adj_row_sum))

    def test_get_adj_col_sum(self):
        edge_index = torch.LongTensor(
            [
                [0, 2],
                [2, 0],
                [0, 3],
                [3, 0],
                [1, 2],
                [2, 1],
                [1, 3],
                [3, 1],
                [2, 3],
                [3, 2]
            ]
        ).T
        edge_weight = torch.FloatTensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
        expected_adj_col_sum = torch.FloatTensor([6, 14, 16, 19])
        adj_col_sum = get_adj_col_sum(edge_index=edge_index, edge_weight=edge_weight, n_nodes=4)
        self.assertTrue(torch.allclose(expected_adj_col_sum, adj_col_sum))


if __name__ == '__main__':
    unittest.main()
