import pandas as pd
import os
import argparse
import random
import numpy as np
from seed import set_seed

set_seed()

parser = argparse.ArgumentParser()
parser.add_argument(
    "--motif_numbers", nargs='?', type=int, default=1000,
    help="number of motifs"
)
parser.add_argument(
    "--label_numbers", nargs='?', type=int, default=3, help="number of labels"
)
parser.add_argument(
    "--outdir", nargs='?', default='./motifs', help="output dir"
)
args = parser.parse_args()

outdir = args.outdir
motif_numbers = args.motif_numbers
label_numbers = args.label_numbers

if not os.path.exists(outdir):
    os.mkdir(outdir)

# 1 as director, 2 as actor, 3 as keyword for IMDB

for id, target_nodes in enumerate(range(motif_numbers)):
    n_nodes_number = int(np.random.normal(3, 1, 1)[0])
    # n_nodes_number need to change to passion distribution
    if n_nodes_number < 2:
        n_nodes_number = 2
    node_set = np.random.choice(
        [1, 2, 2, 2, 3, 3, 3, 3, 3], n_nodes_number, replace=False
    )
    nodes = pd.DataFrame([[0, 0], [1, 0]])

    node_id = 2
    edges = pd.DataFrame()
    edges_set = []
    for i in node_set:
        # edges_set.append((i-1)*2)
        edge_type = (i - 1) * 2
        nodes = pd.concat(
            [nodes, pd.DataFrame([[node_id, i]])], ignore_index=True
        )
        edges = pd.concat(
            [
                edges,
                pd.DataFrame(
                    [
                        [0, node_id, edge_type], [1, node_id, edge_type],
                        [node_id, 0, edge_type + 1],
                        [node_id, 1, edge_type + 1]
                    ]
                )
            ]
        )
        node_id += 1
    # nodes = pd.concat([nodes,pd.DataFrame([[node_id, node_set[1]]])],ignore_index=True)
    # edges = pd.concat([edges, pd.DataFrame([[0, node_id, edges_set[1]], [1, node_id, edges_set[1]],
    #                                         [node_id, 0, edges_set[1]+1],[node_id, 1, edges_set[1]+1]])])
    # node_id += 1
    # nodes = pd.concat([nodes,pd.DataFrame([[node_id, node_set[2]]])],ignore_index=True)
    # edges = pd.concat([edges, pd.DataFrame([[0, node_id, edges_set[2]], [1, node_id, edges_set[2]],
    #                                         [node_id, 0, edges_set[2]+1],[node_id, 1, edges_set[2]+1]])])
    # node_id += 1
    # print(nodes)
    # print(edges)
    nodes.to_csv(
        os.path.join(outdir, 'nodes_' + str(id) + '.csv'), index=False
    )
    edges.to_csv(
        os.path.join(outdir, 'edges_' + str(id) + '.csv'), index=False
    )
    labels = pd.DataFrame(
        [[0, 0, id % label_numbers], [1, 0, id % label_numbers]]
    )
    # print(labels)
    labels.to_csv(
        os.path.join(outdir, 'labels_' + str(id) + '.csv'), index=False
    )
