# Copyright (c) 2022 Tianyu Wen
# Licensed under the MIT License.

from Test_class import *
import tqdm


def main():
    num_nodes = 6

    graphs = get_all_nonisomorphic(num_nodes)
    num_graphs = graphs.shape[0]
    num_pairs = int(num_graphs * (num_graphs - 1) / 2)
    print('Number of non-isomorphic graphs={}, number of unique pairs={}.'.format(num_graphs, num_pairs))
    pbar = tqdm.tqdm(total=num_pairs)

    found_pairs = 0

    for i in range(num_graphs):
        for j in range(i + 1, num_graphs):
            pbar.update(1)

            A1 = graphs[i]
            A2 = graphs[j]

            E1 = get_edges(A1)
            E2 = get_edges(A2)

            if len(E1) != len(E2):
                continue

            L = np.ones(num_nodes)
            I = 10

            my_WL_test = WL_test()

            if my_WL_test.iso_test2(A1, A2, L, L, I):
                print('Found.')
                found_pairs += 1
                G1 = get_nx_graph(num_nodes, E1)
                G2 = get_nx_graph(num_nodes, E2)

                plt.clf()
                plt.subplot(2, 1, 1)
                nx.draw(G1, node_color='red')
                plt.subplot(2, 1, 2)
                nx.draw(G2, node_color='blue')
                plt.savefig('./output/Graph_no_{}_of_{}_nodes_with_WL.png'.format(found_pairs, num_nodes), format='PNG')

    print('Found {} unique pairs.'.format(found_pairs))


main()
