import pandas as pd
import numpy as np
import os
import random
import argparse
from multiprocessing import Pool
from sklearn.model_selection import train_test_split

from seed import set_seed

set_seed()

parser = argparse.ArgumentParser()
parser.add_argument(
    "--motif_numbers", nargs='?', type=int, default=2000,
    help="number of motifs"
)
parser.add_argument(
    "--label_numbers", nargs='?', type=int, default=4, help="number of labels"
)
parser.add_argument(
    "--outcluster_threshold", nargs='?', type=float, default=0.9,
    help="outcluster threshold"
)
parser.add_argument(
    "--incluster_threshold", nargs='+', type=float, default=[0.1, 0.2, 0.3],
    help="incluster threshold"
)
parser.add_argument(
    "--feature_dim", nargs='?', type=int, default=100, help="feature dimention"
)
parser.add_argument(
    "--feature_cd", nargs='?', type=float, default=1,
    help="feature center_distance"
)
parser.add_argument("--multilabel", action='store_true', help="multi-label")
parser.add_argument(
    "--singlelabel", dest='multilabel', action='store_false',
    help="multi-label"
)
parser.set_defaults(multilabel=True)

args = parser.parse_args()

motif_number = args.motif_numbers
label_numbers = args.label_numbers
incluster_threshold = args.incluster_threshold
outcluster_threshlod = args.outcluster_threshold
feature_dim = args.feature_dim
center_var = args.feature_cd
multilabel = args.multilabel

outdir = './output_graph'
param = '_'.join(
    map(
        str, [
            motif_number, label_numbers,
            '_'.join(map(str, (incluster_threshold))), outcluster_threshlod,
            feature_dim, center_var, 'M' if multilabel else 'S'
        ]
    )
)
outdir = os.path.join(outdir, param)
# motif_dir = os.path.join(outdir, 'motifs')

nodes = pd.read_csv(os.path.join(outdir, 'node.csv'))
edges = pd.read_csv(os.path.join(outdir, 'link.csv'))
labels = pd.read_csv(os.path.join(outdir, 'labels.csv'))

# Save the output graph part
nodes.to_csv(os.path.join(outdir, 'node.csv'), index=False)
edges.to_csv(os.path.join(outdir, 'link.csv'), index=False)
labels.to_csv(os.path.join(outdir, 'labels.csv'), index=False)
# print(nodes)
# print(edges)
all = labels

train_val, test = train_test_split(all, test_size=0.7, random_state=69420)
train, val = train_test_split(train_val, test_size=0.2, random_state=69420)
train.sort_values(
    ['0']
).to_csv(os.path.join(outdir, 'data_train.csv'), index=False)
val.sort_values(['0']
                ).to_csv(os.path.join(outdir, 'data_val.csv'), index=False)
test.sort_values(['0']
                 ).to_csv(os.path.join(outdir, 'data_test.csv'), index=False)
