import pandas as pd
import random
import os
import numpy as np
import argparse
from multiprocessing import Pool
import gc
from seed import set_seed

rng = set_seed()

parser = argparse.ArgumentParser()
parser.add_argument(
    "--motif_numbers", nargs='?', type=int, default=2000,
    help="number of motifs"
)
parser.add_argument(
    "--label_numbers", nargs='?', type=int, default=4, help="number of labels"
)
parser.add_argument(
    "--outcluster_threshold", nargs='?', type=float, default=0.9,
    help="outcluster threshold"
)
parser.add_argument(
    "--incluster_threshold", nargs='+', type=float, default=[0.1, 0.2, 0.3],
    help="incluster threshold"
)
parser.add_argument(
    "--feature_dim", nargs='?', type=int, default=100, help="feature dimention"
)
parser.add_argument(
    "--feature_cd", nargs='?', type=float, default=1,
    help="feature center_distance"
)
parser.add_argument("--multilabel", action='store_true', help="multi-label")
parser.add_argument(
    "--singlelabel", dest='multilabel', action='store_false',
    help="multi-label"
)
parser.set_defaults(multilabel=True)

args = parser.parse_args()

motif_numbers = args.motif_numbers
label_numbers = args.label_numbers
incluster_threshold = args.incluster_threshold
outcluster_threshlod = args.outcluster_threshold
feature_dim = args.feature_dim
center_var = args.feature_cd
multilabel = args.multilabel

print(
    motif_numbers, label_numbers, ' '.join(map(str, (incluster_threshold))),
    outcluster_threshlod, feature_dim, center_var, 'M' if multilabel else 'S'
)
motif_dir = './motifs'

outdir = './output_graph'
if not os.path.exists(outdir):
    os.mkdir(outdir)
param = '_'.join(
    map(
        str, [
            motif_numbers, label_numbers,
            '_'.join(map(str, (incluster_threshold))), outcluster_threshlod,
            feature_dim, center_var, 'M' if multilabel else 'S'
        ]
    )
)
outdir = os.path.join(outdir, param)
# print(outdir)
if not os.path.exists(outdir):
    os.mkdir(outdir)
else:
    print(f'{param} has been merged.')
    exit()


def count_degree(edges):
    return edges['0'].value_counts(sort=False).tolist()


def degree_corrected(edges):
    # First filter out the edge types
    types = pd.unique(edges['2']).tolist()
    # Construct a matrix that = node numbers * edge types
    edges_counts = pd.DataFrame(
        index=range(edges.max()['0'] + 1), columns=range(len(types))
    ).reset_index()
    # print(edges.max())
    # Count each type by each node
    for type in types:
        type_edges = edges[edges['2'] == type]
        edges_counts[type] = edges_counts['index'].map(
            type_edges['0'].value_counts(sort=False)
        )
    # print(edges['0'].value_counts(sort=False).tolist())
    edges_counts = edges_counts.drop('index', axis=1)
    edges_counts = edges_counts.fillna(0)
    edges_counts = edges_counts / edges_counts.sum()
    edges_counts['mean'] = edges_counts.sum(axis=1)
    # print(edges_counts)
    return edges_counts['mean'].tolist()


def merge(
    theshold, G_node, G_edge, G_label, G_motif, m_G_node, m_G_edge, m_G_label,
    m_G_motif
):

    # Find the max index of the subgraph so that the merged subgraph node id can continue from this id
    max_index = G_node['0'].max() + 1

    # Random a number to decide which non-target node will be merged
    tmp = m_G_node[m_G_node['1'] != 0].copy()
    tmp.loc[:, 'rnd'] = np.random.rand(tmp.shape[0])

    types = sorted(pd.unique(G_node['1']).tolist())
    types.remove(0)
    to_merges = pd.DataFrame(columns=['0', '1'])
    for type in types:
        # to_merges = pd.concat([to_merges, tmp[tmp['1']==type & tmp['rnd']>=theshold[type-1][['0','1']].astype('int')]])
        tmpp = tmp[tmp['1'] == type]
        if len(to_merges.index) == 0:
            to_merges = tmpp[tmpp['rnd'] >= theshold[int(type) - 1]][[
                '0', '1'
            ]].astype(int).copy()
        else:
            type_merges = tmpp[tmpp['rnd'] >= theshold[int(type) -
                                                       1]][['0',
                                                            '1']].astype(int)
            if len(type_merges.index) != 0:
                to_merges = pd.concat([to_merges, type_merges])
    # to_merges = tmp[tmp['rnd'] >= theshold][['0','1']].astype(int).copy()

    if len(to_merges.index) == 0:
        to_merges = tmpp.sample(n=1)[['0', '1']].astype(int).copy()
    # Drop the nodes that will be merged
    m_G_node = pd.concat([m_G_node, to_merges])
    non_merged_nodes = m_G_node.drop_duplicates(subset=['0', '1'], keep=False)
    non_merged_nodes = non_merged_nodes.reset_index(drop=True).reset_index()

    # Set a dict that mapping the id from the origin moitf to the graph that after merge
    non_merged_nodes['index'] = non_merged_nodes['index'] + max_index
    replace_dict = pd.Series(
        non_merged_nodes['index'].to_list(),
        index=non_merged_nodes['0'].to_list()
    ).to_dict()

    for i in range(to_merges.shape[0]):
        tmp_list = to_merges.iloc[i, :].to_list()
        node_number = tmp_list[0]
        node_type = tmp_list[1]
        same_type = G_node[G_node['1'].astype(int) == node_type]['0'].to_list(
        )  # The nodes in the origin graph that are same type as node that will be merged
        merge_to = random.choice(same_type)
        replace_dict[node_number] = merge_to

    # Merge two nodes and merge two edges
    m_G_edge[['0', '1']] = m_G_edge[['0', '1']].replace(replace_dict)
    m_G_motif[['0', '1']] = m_G_motif[['0', '1']].replace(replace_dict)
    m_G_label['0'] = m_G_label['0'].replace(replace_dict)
    non_merged_nodes = non_merged_nodes[['index', '1']]
    non_merged_nodes.columns = ['0', '1']
    output_nodes = pd.concat([G_node, non_merged_nodes])
    output_edges = pd.concat([G_edge, m_G_edge])
    # output_motifs = pd.concat([G_motif, m_G_motif])
    G_motif.append(m_G_motif)
    output_labels = pd.concat([G_label, m_G_label])

    return output_nodes, output_edges, output_labels, G_motif


def read_graph_csv(id):
    nodes = pd.read_csv(os.path.join(motif_dir, 'nodes_b_' + str(id) + '.csv'))
    edges = pd.read_csv(os.path.join(motif_dir, 'edges_b_' + str(id) + '.csv'))
    labels = pd.read_csv(os.path.join(motif_dir, 'labels_' + str(id) + '.csv'))
    motifs = pd.read_csv(os.path.join(motif_dir, 'edges_' + str(id) + '.csv'))
    return nodes, edges, labels, motifs


def cluster_merge(
    threshold, G_node_list, G_edge_list, G_label_list, G_motif_list,
    degree=False, self_loop=False
):
    # First deal with the same type node
    # Generate each node type's adj matrix
    cluster_pos = []
    nodes = G_node_list.pop(0)
    edges = G_edge_list.pop(0)
    labels = G_label_list.pop(0)
    motifs = G_motif_list.pop(0)
    for i in range(len(G_node_list)):
        cluster_pos.append(nodes['0'].max())
        max_index = nodes['0'].max() + 1
        G_node_list[i]['0'] = G_node_list[i]['0'] + max_index
        G_edge_list[i][['0', '1']] = G_edge_list[i][['0', '1']] + max_index
        G_label_list[i]['0'] = G_label_list[i]['0'] + max_index
        nodes = pd.concat([nodes, G_node_list[i]])
        edges = pd.concat([edges, G_edge_list[i]])
        labels = pd.concat([labels, G_label_list[i]])
        for m in G_motif_list[i]:
            m[['0', '1']] = m[['0', '1']] + max_index
            motifs.append(m)
    del G_node_list
    del G_edge_list
    del G_label_list
    del G_motif_list
    gc.collect()
    # print('del list')

    nodes = nodes.reset_index(drop=True)
    types = pd.unique(nodes['1']).tolist()
    # print(motifs)

    if degree:
        nodes['weight'] = degree_corrected(edges)
    # print(nodes)
    # Remove the target node type
    if not multilabel:
        types.remove(0)

    # Deal with each non-target node types

    # counter = {}
    for type in types:
        if type == 2:
            threshold = 0.2
        if type == 3:
            threshold = 0.1
        # print(type)
        type_nodes = nodes[nodes['1'] == type]
        # Select which nodes to be merged first
        # This is without degree corrected
        # select_nodes = np.argwhere(np.random.rand(len(type_nodes)) > threshold).squeeze(1).tolist()
        k = int(len(type_nodes) * (1 - threshold))

        # Generate a random matrix
        # This is without degree corrected
        # rnd = np.random.rand(len(type_nodes),len(type_nodes))
        rnd = np.random.randint(
            low=0, high=255, size=((len(type_nodes), len(type_nodes))),
            dtype=np.uint8
        )
        rnd = (rnd + rnd.T) / 2
        # print(rnd)
        s = 0
        for pos in cluster_pos:
            e = len(type_nodes[type_nodes.index < pos])
            rnd[s:e, s:e] = 0
            s = e
        rnd[s:, s:] = 0
        rnd[rnd < threshold] = 0
        rnd = np.triu(rnd)

        if degree:
            # Use the Degree corrected as the weight weight = np.array(type_nodes['weight']) Use reshape to make a len(type_nodes) * len(type_nodes) matrix
            weight = np.dot(
                weight.reshape(len(weight), 1), weight.reshape(1, len(weight))
            )
            # print(weight)
            # Use element wise multiple here (* or multiply)
            rnd = rnd * weight

        # ALL THE POS HERE are in type_nodes pos NOT THE NODE NUMBER!!!!!!!
        # mapping = [[select_nodes[i], rnd[select_nodes,:].argmax(axis=1)[i]] for i in range(len(select_nodes))]
        def _k_largest_index_argsort(a, k):
            idx = np.argsort(a.ravel())[:-k - 1:-1]
            return np.column_stack(np.unravel_index(idx, a.shape))

        mapping = _k_largest_index_argsort(rnd, k).tolist()
        # print('del rnd')
        del rnd
        gc.collect()

        # This function is only used here for merging mapping pairs that with the same target merged node
        def _merge(lsts):
            sets = [set(lst) for lst in lsts if lst]
            merged = True
            while merged:
                merged = False
                results = []
                while sets:
                    common, rest = sets[0], sets[1:]
                    sets = []
                    for x in rest:
                        if x.isdisjoint(common):
                            sets.append(x)
                        else:
                            merged = True
                            common |= x
                    results.append(common)
                sets = results
            return [sorted(set) for set in sets]

        mapping = _merge(mapping)

        # for i in mapping:
        #     l = len(i)
        #     counter[l] = counter.get(l, 0) +1

        # Replace_dict for the merge mapping and change from type_nodes pos to node number
        replace_dict = {}
        for m in mapping:
            if (int(type) == 0) and len(m) > (label_numbers - 1):
                # m = sorted(list(set(np.random.choice(m,label_numbers-1,replace=False))))
                tmp = m.copy()
                while len(tmp) > label_numbers - 1:
                    m = tmp.copy()
                    m = sorted(
                        list(
                            set(
                                np.random.choice(
                                    m, label_numbers - 1, replace=False
                                )
                            )
                        )
                    )
                    tmp = list(set(tmp) - set(m))
                    merge_to = m.pop(0)
                    for i in m:
                        replace_dict[type_nodes.iloc[i]['0']
                                     ] = type_nodes.iloc[merge_to]['0']
                m = sorted(tmp)
            if len(m) > 1:
                merge_to = m.pop(0)
                for i in m:
                    replace_dict[type_nodes.iloc[i]['0']
                                 ] = type_nodes.iloc[merge_to]['0']
        # print(replace_dict)
        del mapping
        gc.collect()

        nodes['0'] = nodes['0'].replace(replace_dict)
        edges[['0', '1']] = edges[['0', '1']].replace(replace_dict)
        labels['0'] = labels['0'].replace(replace_dict)
        # def _motif_reindex (motif):
        #     motif[['0','1']] = motif[['0','1']].replace(replace_dict)
        #     return motif
        #
        # with Pool() as pool:
        #     motifs = pool.map(_motif_reindex, motifs)
        for m in motifs:
            m[['0', '1']] = m[['0', '1']].replace(replace_dict)

        nodes = nodes.drop_duplicates()
        edges = edges.drop_duplicates()
        gc.collect()
        # print('hi')
    if self_loop:
        type_nodes = nodes[nodes['1'] == 0]
        rnd = np.random.rand(len(type_nodes), len(type_nodes))
        # rnd = (rnd+rnd.T)/2
        rnd[np.diag_indices_from(rnd)] = 0
        # print(rnd)
        # s = 0
        # for pos in cluster_pos:
        #     e = len(type_nodes[type_nodes.index < pos])
        #     rnd[s:e,s:e]=0
        #     s = e
        # rnd[s:,s:] = 0
        rnd[rnd < threshold] = 0
        rnd = np.triu(rnd)

        # mapping = np.argwhere(rnd>0).tolist()

        def _k_largest_index_argsort(a, k):
            idx = np.argsort(a.ravel())[:-k - 1:-1]
            return np.column_stack(np.unravel_index(idx, a.shape))

        k = int(len(type_nodes) * (1 - threshold))
        mapping = _k_largest_index_argsort(rnd, k).tolist()
        del rnd
        for m in mapping:
            edges = pd.concat(
                [
                    edges,
                    pd.DataFrame(
                        [
                            [
                                type_nodes.iloc[m[0]]['0'],
                                type_nodes.iloc[m[1]]['0'], 6
                            ],
                            [
                                type_nodes.iloc[m[1]]['0'],
                                type_nodes.iloc[m[0]]['0'], 7
                            ]
                        ], columns=['0', '1', '2']
                    )
                ]
            )
    # tmp = labels.groupby('0')['2']
    # labels['2'] = labels['2'].astype(str)
    # print(labels.groupby('0')['2'].apply(' '.join))
    labels['2'] = labels['0'].map(
        labels.groupby('0')
        ['2'].apply(lambda x: ' '.join(map(str, list(set(x)))))
    )
    labels = labels.drop_duplicates()
    # tmp = labels.groupby('0')['2']
    # labels['3'] = labels['2'].map(labels.groupby('0').agg(list).apply(lambda x: ''.join(x)))
    # print('finish out cluster merge')
    # print(counter)
    # print(len(nodes))
    # print(len(edges))
    # print(len(labels))
    # print(len(motifs))

    return nodes, edges, labels, motifs


# Random list the label clusters
motif_list = []
for i in range(label_numbers):
    motif_list.append(
        [x for x in range(motif_numbers) if x % label_numbers == i]
    )
random.shuffle(motif_list)

g_node_list = []
g_edge_list = []
g_label_list = []
g_motif_list = []

# Random each label moitf
for m in motif_list:
    random.shuffle(m)
    n = m.pop(0)
    nodes, edges, labels, motifs = read_graph_csv(n)
    motifs = [motifs]
    for j in m:
        nodes_2, edges_2, labels_2, motifs2 = read_graph_csv(j)
        # 0.3
        # nodes, edges, labels = merge(incluster_threshold, nodes, edges, labels, nodes_2, edges_2, labels_2)
        # IF the threshold higher, the merge is harder
        nodes, edges, labels, motifs = merge(
            incluster_threshold, nodes, edges, labels, motifs, nodes_2,
            edges_2, labels_2, motifs2
        )
    g_node_list.append(nodes)
    g_edge_list.append(edges)
    g_label_list.append(labels)
    g_motif_list.append(motifs)

# print('finish incluster merge')
# 0.7
nodes_0, edges_0, labels_0, motifs_0 = cluster_merge(
    outcluster_threshlod, g_node_list, g_edge_list, g_label_list, g_motif_list,
    self_loop=True
)

# Naive Feature
nodes_0['2'] = '0 1'
# nodes_0['2'] = ''

# reindex the node number and some csv detail
nodes_0 = nodes_0.sort_values(by=['1', '0'])
nodes_0 = nodes_0.reset_index(drop=True).reset_index()
replace_dict = pd.Series(
    nodes_0['index'].to_list(), index=nodes_0['0'].to_list()
).to_dict()
edges_0[['0', '1']] = edges_0[['0', '1']].replace(replace_dict)


def motif_reindex(motif):
    motif[['0', '1']] = motif[['0', '1']].replace(replace_dict)
    return motif


with Pool() as pool:
    motifs_0 = pool.map(motif_reindex, motifs_0)
labels_0['0'] = labels_0['0'].replace(replace_dict)
nodes_0 = nodes_0[['index', '1', '2']]
nodes_0.columns = ['0', '1', '2']
edges_0['3'] = 1
# edges_0['2'] = edges_0['2'].replace({0:2, 1:3, 2:4, 3:5, 4:6, 5:7, 6:0, 7:1})
# print('start add features')

# Add feature from google GraphWorld
centers = []
# feature_dim = 100
# center_var = 3.0
cluster_var = 1.0
center_cov = np.identity(feature_dim) * center_var
cluster_cov = np.identity(feature_dim) * cluster_var
for _ in range(label_numbers):
    center = np.random.multivariate_normal(
        np.zeros(feature_dim), center_cov, 1
    )[0]
    centers.append(center)
features = []
for cluster_index in labels_0['2'].tolist():
    cluster_index = list(map(int, (cluster_index.split(' '))))
    if len(cluster_index) > 1:
        feature = np.random.multivariate_normal(
            centers[cluster_index.pop()], cluster_cov, 1
        )[0]
        for i in cluster_index:
            tmp = np.random.multivariate_normal(centers[i], cluster_cov, 1)[0]
            feature = np.concatenate([feature, tmp])
        feature = feature.reshape([-1, feature_dim])
        feature = rng.permutation(feature, axis=0)
        feature = feature[0, :]
    else:
        feature = np.random.multivariate_normal(
            centers[cluster_index.pop()], cluster_cov, 1
        )[0]
    features.append(feature)
nodes_0['2'][:len(labels_0['1'].tolist())] = features
nodes_0.loc[:len(labels_0['1'].tolist()) - 1,
            '2'] = nodes_0.loc[:len(labels_0['1'].tolist()) - 1,
                               '2'].apply(lambda x: ' '.join(map(str, x)))
# print('finish add features')

# print(nodes_0)
# print(edges_0)
# print(labels_0)

# Save the output graph part
# motif_dir = './output_graph/motifs'
motif_dir = os.path.join(outdir, 'motifs')
# print(motif_dir)
if not os.path.exists(motif_dir):
    os.mkdir(motif_dir)
nodes_0.to_csv(os.path.join(outdir, 'node.csv'), index=False)
edges_0.to_csv(os.path.join(outdir, 'link.csv'), index=False)
labels_0.to_csv(os.path.join(outdir, 'labels.csv'), index=False)
for i, m in enumerate(motifs_0):
    m.drop_duplicates().to_csv(
        os.path.join(motif_dir, 'motif_' + str(i) + '.csv'), index=False
    )
