import pandas as pd
import numpy as np
import os
import random
import argparse
from multiprocessing import Pool
from sklearn.model_selection import train_test_split
from seed import set_seed

set_seed()

parser = argparse.ArgumentParser()
parser.add_argument(
    "--motif_numbers", nargs='?', type=int, default=2000,
    help="number of motifs"
)
parser.add_argument(
    "--label_numbers", nargs='?', type=int, default=4, help="number of labels"
)
parser.add_argument(
    "--outcluster_threshold", nargs='+', type=float, default=0.9,
    help="outcluster threshold"
)
parser.add_argument(
    "--incluster_threshold", nargs='+', type=float, default=[0.1, 0.2, 0.3],
    help="incluster threshold"
)
parser.add_argument(
    "--feature_dim", nargs='?', type=int, default=100, help="feature dimention"
)
parser.add_argument(
    "--feature_cd", nargs='?', type=float, default=1,
    help="feature center_distance"
)
parser.add_argument("--multilabel", action='store_true', help="multi-label")
parser.add_argument(
    "--singlelabel", dest='multilabel', action='store_false',
    help="multi-label"
)
parser.set_defaults(multilabel=True)

args = parser.parse_args()

motif_number = args.motif_numbers
label_numbers = args.label_numbers
incluster_threshold = args.incluster_threshold
outcluster_threshlod = args.outcluster_threshold
feature_dim = args.feature_dim
center_var = args.feature_cd
multilabel = args.multilabel

outdir = './output_graph'
param = '_'.join(
    map(
        str, [
            motif_number,
            label_numbers,
            '_'.join(map(str, (incluster_threshold))),
            '_'.join(map(str, (outcluster_threshlod))),
            feature_dim,
            center_var,
            'M' if multilabel else 'S',
        ]
    )
)
outdir = os.path.join(outdir, param)
# motif_dir = os.path.join(outdir, 'motifs')

nodes = pd.read_csv(os.path.join(outdir, 'node.csv'))
edges = pd.read_csv(os.path.join(outdir, 'link.csv'))
labels = pd.read_csv(os.path.join(outdir, 'labels.csv'))

motif = []
for i in range(motif_number):
    motif.append(
        pd.read_csv(os.path.join(outdir, 'motifs/motif_' + str(i) + '.csv'))
    )
all_motif = pd.concat(motif)
target_nodes = nodes[nodes['1'] == 1]
degrees = [99999, 1, 99999]


def remove_edge(edges, remove_target):
    edges = edges[~(
        (edges['0'] == remove_target['0']) &
        (edges['1'] == remove_target['1'])
    )]
    edges = edges[~(
        (edges['0'] == remove_target['1']) &
        (edges['1'] == remove_target['0'])
    )]
    return edges


def remove_node_motif(m, all, remove_target):
    remain = m[~((m['0'] == remove_target) | (m['1'] == remove_target))]
    to_remove = pd.concat([m, remain]).drop_duplicates(keep=False)
    remain_all = pd.concat([all, to_remove]).drop_duplicates(keep=False)
    return remain, remain_all


def count_degree(edges):
    # First filter out the edge types
    types = pd.unique(edges['2']).tolist()
    # Construct a matrix that = node numbers * edge types
    edges_counts = pd.DataFrame(
        index=range(edges.max()['0'] + 1), columns=range(len(types))
    ).reset_index()
    # Count each type by each node
    for type in types:
        type_edges = edges[edges['2'] == type]
        edges_counts[type] = edges_counts['index'].map(
            type_edges['0'].value_counts(sort=False)
        )
    edges_counts = edges_counts.drop('index', axis=1)
    edges_counts = edges_counts.fillna(0)
    return edges_counts


target_degree = count_degree(edges)[[0, 2, 4]]
target_degree = target_nodes[['0']].join(target_degree)
all_motif = all_motif.drop_duplicates()
# Iterate throw the node types then iterate throw the target nodes for pruning
for i, edge_type in enumerate([0, 2, 4]):
    # print(i ,edge_type)
    for _, row in target_degree[target_degree[edge_type] > degrees[i]
                                ].iterrows():
        # find the edges thet are out of the range of fixed degrees
        target_edges = edges[(edges['0'] == row['0'])
                             & (edges['2'] == edge_type)][['0', '1', '2']]
        # print(target_edges)
        degree = len(target_edges.index)
        tmp = pd.concat([target_edges, all_motif]).astype(int)
        # find the edges that are in the motifs
        in_motif = tmp[tmp.duplicated()]
        # find the edges that can remove
        to_remove = pd.concat([target_edges,
                               in_motif]).drop_duplicates(keep=False)
        # print(to_remove)
        # if edges that are not in motif is not enough, remove the edge in the motif randomly
        # if len(to_remove.index) < (degree - degrees[i]):
        #     # print()
        #     remove_motif = degree-degrees[i]-len(to_remove.index)
        #     # deal with the pruning in the motifs
        #     in_motif = in_motif.sample(remove_motif)
        #     for _, e in in_motif.iterrows():
        #         # remove in the edges
        #         edges = remove_edge(edges, e)
        #         # remove in the motifs
        #         for j in range(len(motif)):
        #             # Check if motifs contain the edges will removed
        #             # Also remove the non-target node from the motif
        #             if ((motif[j]['0']==e['0'])&(motif[j]['1']==e['1'])).any():
        #                 # motif[j] = motif[j][~((motif[j]['0']==e['1']) | (motif[j]['1']==e['1']))]
        #                 motif[j], all_motif = remove_node_motif(motif[j], all_motif,e['1'])
        #
        # remove the edges that can removed
        if len(to_remove.index) > 0:
            # if too much, sample
            if len(to_remove.index) > (degree - degrees[i]):
                to_remove = to_remove.sample(degree - degrees[i])
            for _, e in to_remove.iterrows():
                edges = remove_edge(edges, e)

#remove the floating nodes
remove_list = []
for i, n in enumerate(nodes['0']):
    if not ((edges['0'] == n) | (edges['1'] == n)).any():
        remove_list.append(i)
if len(remove_list) > 0:
    nodes.drop(remove_list, inplace=True)

# reset index
nodes = nodes.reset_index(drop=True).reset_index()
replace_dict = pd.Series(nodes['index'].to_list(),
                         index=nodes['0'].to_list()).to_dict()
edges[['0', '1']] = edges[['0', '1']].replace(replace_dict)


def motif_reindex(m):
    m[['0', '1']] = m[['0', '1']].replace(replace_dict)
    m = m.sort_values(by=['0'])
    return m


with Pool() as pool:
    motif = pool.map(motif_reindex, motif)
# for m in motifs_0:
#     m[['0','1']] = m[['0','1']].replace(replace_dict)
nodes = nodes[['index', '1', '2']]
nodes.columns = ['0', '1', '2']

edges['2'] = edges['2'].replace({1: 2, 0: 3, 2: 4, 3: 5, 4: 6, 5: 7})

# Save the output graph part
nodes.to_csv(os.path.join(outdir, 'node.csv'), index=False)
edges.to_csv(os.path.join(outdir, 'link.csv'), index=False)
labels.to_csv(os.path.join(outdir, 'labels.csv'), index=False)
# print(nodes)
# print(edges)
motif_dir = os.path.join(outdir, 'motifs')
for i, m in enumerate(motif):
    m.to_csv(os.path.join(motif_dir, 'motif_' + str(i) + '.csv'), index=False)

all = labels

train_val, test = train_test_split(all, test_size=0.7, random_state=69420)
train, val = train_test_split(train_val, test_size=0.2, random_state=69420)
train.sort_values(
    ['0']
).to_csv(os.path.join(outdir, 'data_train.csv'), index=False)
val.sort_values(['0']
                ).to_csv(os.path.join(outdir, 'data_val.csv'), index=False)
test.sort_values(['0']
                 ).to_csv(os.path.join(outdir, 'data_test.csv'), index=False)
