import pandas as pd
import os
import random
import argparse
import numpy as np
from seed import set_seed

set_seed()

parser = argparse.ArgumentParser()
parser.add_argument(
    "--motif_number", nargs='?', type=int, default=1000,
    help="number of motifs"
)
parser.add_argument("--dir", nargs='?', default='./motifs', help="output dir")
args = parser.parse_args()

motif_dir = args.dir
motif_number = args.motif_number

legal_connects = [[1, 2, 3], [0], [0], [0]]
# distibution = [0, 1, 3, 5]

for motif in range(motif_number):
    nodes = pd.read_csv(
        os.path.join(motif_dir, 'nodes_' + str(motif) + '.csv')
    )
    edges = pd.read_csv(
        os.path.join(motif_dir, 'edges_' + str(motif) + '.csv')
    )
    nodes.columns = nodes.columns.astype(int)
    edges.columns = edges.columns.astype(int)
    counts = pd.DataFrame({
        'type': range(4)
    }).join(nodes[1].value_counts()).fillna(0)
    counts = counts.drop([0])
    distibution = [0]
    distibution.append(1)
    distibution.append(np.random.choice([1] * 2 + [2] * 3 + [3] * 4568, 1)[0])
    distibution.append(
        np.random.choice(
            [1] * 30 + [2] * 28 + [3] * 20 + [4] * 25 + [5] * 4362, 1
        )[0]
    )
    #distribution = thish movie will connect to how many movie, director, actor, keyword
    for i, row in counts.iterrows():
        for bonus in range(int(distibution[i] - row.iloc[1])):
            if (int(distibution[i] - row.iloc[1])) <= 0:
                print(int(distibution[i] - row.iloc[1]))
            #### Select the connect node type and location
            node_type = row['type']
            new_id = nodes[0].max() + 1

            #### Add the node to csv
            nodes = pd.concat(
                [nodes, pd.DataFrame([[new_id, node_type]])], ignore_index=True
            )
            #### Add edges
            edge_type = (
                node_type - 1
            ) * 2  ##### Here are same fix edge types define
            edges = pd.concat(
                [
                    edges,
                    pd.DataFrame(
                        [[0, new_id, edge_type], [new_id, 0, edge_type + 1]]
                    )
                ], ignore_index=True
            )

            new_id = nodes[0].max() + 1

            #### Add the node to csv
            nodes = pd.concat(
                [nodes, pd.DataFrame([[new_id, node_type]])], ignore_index=True
            )
            #### Add edges
            edge_type = (
                node_type - 1
            ) * 2  ##### Here are same fix edge types define
            edges = pd.concat(
                [
                    edges,
                    pd.DataFrame(
                        [[1, new_id, edge_type], [new_id, 1, edge_type + 1]]
                    )
                ], ignore_index=True
            )
            # print(counts['type'],counts[1])

    edges = edges.astype(int)
    nodes = nodes.astype(int)
    # for bonus in range(10): # deal with the random part if the  bonus node is not 3
    #     #### Select the connect node type and location
    #     node_type =  random.choice([1,2,3]) # choice the node type that will add
    #     legal_connect_types = legal_connects[node_type] # find the legal connect edges of the bonus node
    #     ### The differnt between random choice the edge type first or just random choice the connect node
    #     connect_node_type = random.choice(legal_connect_types) # find the legal connect nodes
    #     legal_nodes = nodes[nodes[1]==connect_node_type][0].to_list() # find the
    #     connect_location = random.choice(legal_nodes)
    #     new_id = nodes[0].max()+1
    #
    #     #### Add the node to csv
    #     nodes = pd.concat([nodes, pd.DataFrame([[new_id, node_type]])], ignore_index=True)
    #     #### Add edges
    #     edge_type = (node_type-1)*2 ##### Here are same fix edge types define
    #     edges = pd.concat([edges, pd.DataFrame([[connect_location, new_id, edge_type], [new_id, connect_location, edge_type+1]])], ignore_index=True)

    nodes.to_csv(
        os.path.join(motif_dir, 'nodes_b_' + str(motif) + '.csv'), index=False
    )
    edges.to_csv(
        os.path.join(motif_dir, 'edges_b_' + str(motif) + '.csv'), index=False
    )
