import pandas as pd
import os
import argparse
import random
import numpy as np
from seed import set_seed

set_seed()


def add_apa(node_df: pd.DataFrame, edge_df: pd.DataFrame, p: float = 0.56):
    """Add an author-paper-author path to the motif

    Args:
        node_df (pd.DataFrame): (nid, ntype)
        edge_df (pd.DataFrame): (src_nid, dst_nid, etype)
    
    Ntypes:
        a: 0
        p: 1
    Etypes:
        a->p: 1
        p->a: 0
    """
    if random.uniform(0, 1) >= p:
        return node_df, edge_df
    COL = {'nid': 0, 'ntype': 1}
    NTYPES = {'author': 0, 'paper': 1, 'conference': 2, 'term': 3}
    NTYPE_STR = {v: k for k, v in NTYPES.items()}
    ETYPES = {
        ('paper', 'author'): 0,  # 'p->a': 0,
        ('author', 'paper'): 1,  # 'a->p': 1,
        ('paper', 'conference'): 2,  # 'p->c': 2,
        ('conference', 'paper'): 3,  # 'c->p': 3,
        ('paper', 'term'): 4,  # 'p->t': 4,
        ('term', 'paper'): 5,  # 't->p': 5,
    }

    new_paper_node_id = len(node_df)
    node_df = pd.concat(
        [
            node_df,
            pd.Series([new_paper_node_id, NTYPES['paper']]).to_frame().T
        ],
        axis=0,
    ).reset_index(drop=True)

    assert (node_df.iloc[:4, COL['nid']] == pd.Series([0, 1, 2, 3])).all()
    assert (node_df.iloc[:4, COL['ntype']] == pd.Series([0, 0, 1, 1])).all()

    def get_new_edges_to_add():
        """Connect the added node to all the existing nodes except for paper nodes"""
        for _, (node_id_, ntype) in node_df.iterrows():
            if ntype == NTYPES['paper']:
                continue
            yield [
                node_id_, new_paper_node_id,
                ETYPES[(NTYPE_STR[ntype], 'paper')]
            ]
            yield [
                new_paper_node_id, node_id_,
                ETYPES[('paper', NTYPE_STR[ntype])]
            ]

    new_edges = pd.DataFrame(get_new_edges_to_add())
    edge_df = pd.concat([edge_df, new_edges], axis=0).reset_index(drop=True)
    return node_df, edge_df


parser = argparse.ArgumentParser()
parser.add_argument(
    "--motif_numbers", nargs='?', type=int, default=1000,
    help="number of motifs"
)
parser.add_argument(
    "--label_numbers", nargs='?', type=int, default=3, help="number of labels"
)
parser.add_argument(
    "--outdir", nargs='?', default='./motifs', help="output dir"
)
args = parser.parse_args()

outdir = args.outdir
motif_numbers = args.motif_numbers
label_numbers = args.label_numbers

if not os.path.exists(outdir):
    os.mkdir(outdir)

# 0 as author, 1 as paper, 2 as conference, 3 as term
for id, target_nodes in enumerate(range(motif_numbers)):
    n_nodes_number = int(np.random.normal(2, 1, 1)[0])
    if n_nodes_number < 2:
        n_nodes_number = 2
    node_set = np.random.choice(
        [2, 3, 3, 3, 3, 3], n_nodes_number, replace=False
    )
    nodes = pd.DataFrame([[0, 0], [1, 0]])

    node_id = 2
    edges = pd.DataFrame()
    edges_set = []
    nodes = pd.concat([nodes, pd.DataFrame([[node_id, 1]])], ignore_index=True)
    edges = pd.concat(
        [edges, pd.DataFrame([[0, node_id, 1], [node_id, 0, 0]])]
    )
    node_id += 1
    nodes = pd.concat([nodes, pd.DataFrame([[node_id, 1]])], ignore_index=True)
    edges = pd.concat(
        [edges, pd.DataFrame([
            [1, node_id, 1],
            [node_id, 1, 0],
        ])]
    )
    node_id += 1
    for i in node_set:
        # edges_set.append((i-1)*2)
        edge_type = (i - 1) * 2
        nodes = pd.concat(
            [nodes, pd.DataFrame([[node_id, i]])], ignore_index=True
        )
        edges = pd.concat(
            [
                edges,
                pd.DataFrame(
                    [
                        [2, node_id, edge_type], [3, node_id, edge_type],
                        [node_id, 2, edge_type + 1],
                        [node_id, 3, edge_type + 1]
                    ]
                )
            ]
        )
        node_id += 1
    # nodes = pd.concat([nodes,pd.DataFrame([[node_id, node_set[1]]])],ignore_index=True)
    # edges = pd.concat([edges, pd.DataFrame([[0, node_id, edges_set[1]], [1, node_id, edges_set[1]],
    #                                         [node_id, 0, edges_set[1]+1],[node_id, 1, edges_set[1]+1]])])
    # node_id += 1
    # nodes = pd.concat([nodes,pd.DataFrame([[node_id, node_set[2]]])],ignore_index=True)
    # edges = pd.concat([edges, pd.DataFrame([[0, node_id, edges_set[2]], [1, node_id, edges_set[2]],
    #                                         [node_id, 0, edges_set[2]+1],[node_id, 1, edges_set[2]+1]])])
    # node_id += 1
    # print(nodes)
    # print(edges)
    nodes, edges = add_apa(nodes, edges)
    nodes.to_csv(
        os.path.join(outdir, 'nodes_' + str(id) + '.csv'), index=False
    )
    # print(nodes)
    edges.to_csv(
        os.path.join(outdir, 'edges_' + str(id) + '.csv'), index=False
    )
    # print(edges)
    labels = pd.DataFrame(
        [[0, 0, id % label_numbers], [1, 0, id % label_numbers]]
    )
    # print(labels)
    labels.to_csv(
        os.path.join(outdir, 'labels_' + str(id) + '.csv'), index=False
    )
