import numpy as np

from utils.graph_dataset import GraphDatasetGenerator

if __name__ == '__main__':
    test_set = GraphDatasetGenerator(test=True)
    dummy_set = GraphDatasetGenerator(test=True)
    n = test_set.num_nodes
    shds = []
    print(len(test_set), len(dummy_set))
    for i, (test, rand) in enumerate(zip(test_set, dummy_set)):
        if i > 500:
            break
        if i % 100 == 0:
            print(i)
        equal = test[1].eq(rand[1])
        not_equal = ~equal
        shds.append(not_equal.sum().item())

    print('Random Baseline SHD: ', np.mean(shds))
