"""Function to make the two-community dataset"""

import pickle as pkl

import networkx as nx
import numpy as np

from pathlib import Path

from build_datasets.synthetic.synthetic_utils import bfs_node_order


##################################################
# Adapted from GraphRNN's released code          #
# https://github.com/JiaxuanYou/graph-generation #
##################################################

_comm_size_min = 30
_comm_size_max = 80
_num_each = 10
_p_edge = 0.3


def caveman_special(c=2, k=20, p_path=0.1, p_edge=0.3):
    """Two-community graph made from caveman graph then deleting edges"""
    p = p_path
    path_count = max(int(np.ceil(p * k)), 1)
    G = nx.caveman_graph(c, k)
    # Remove edges at random
    p = 1 - p_edge
    for (u, v) in list(G.edges()):
        if np.random.rand() < p and ((u < k and v < k) or (u >= k and v >= k)):
            G.remove_edge(u, v)
    # Add links between communities
    for i in range(path_count):
        u = np.random.randint(0, k)
        v = np.random.randint(k, k * 2)
        G.add_edge(u, v)
    return G


def make_two_community_dataset(out_dir, fname):
    """Make 500 two-community graphs and save as .pkl"""

    graphs = []
    for k in range(_comm_size_min, _comm_size_max + 1):
        for _ in range(_num_each):
            graph = caveman_special(c=2, k=k, p_edge=_p_edge)
            graph = bfs_node_order(graph, start_node=0)
            graphs.append(graph)

    out_path = Path(out_dir, fname)
    with open(out_path, "wb") as f:
        pkl.dump(graphs, f)


if __name__ == "__main__":
    out_dir = "."
    fname = "two_community.pkl"
    make_two_community_dataset(out_dir, fname)
