import json
import math
import random
import argparse
import scipy.sparse
import numpy as np
import networkx as nx
from sklearn.preprocessing import StandardScaler

parser = argparse.ArgumentParser(
    description='generate dataset from network in json')
parser.add_argument('--json', type=str,
                    help='sub-network file, in .json format')
parser.add_argument('--k', type=int, help='number of dupliate sub-network')
parser.add_argument(
    '--c', type=float, help='ratio of edges between sub-netwrok and edges within sub-network')
parser.add_argument('--n_samples', type=int,
                    help='number of samples for the dataset')
parser.add_argument('--seed', type=int, help='random seed')
parser.add_argument(
    '--nx', type=str, help='output whole network (without weight), in .json (node link data) format')
parser.add_argument(
    '--npz', type=str, help='output whole network (with weight), in .npz (scipy.sparse.csr_matrix object) format')
parser.add_argument('--npy', type=str,
                    help='output data, in .npy (numpy.ndarray object) format')
args = parser.parse_args()


def load_json2nx(j_path):
    with open(j_path, 'r') as f:
        j_bn = json.load(f)

    nx_bn = nx.DiGraph()
    nx_bn.add_nodes_from(j_bn.keys())

    for k, v in j_bn.items():
        assert k == v['node']

        if type(v['children']) == list:
            for c in v['children']:
                assert not nx.has_path(nx_bn, c, k)
                nx_bn.add_edge(k, c)
        else:
            assert type(v['children']) == str
            assert not nx.has_path(nx_bn, v['children'], k)
            nx_bn.add_edge(k, v['children'])

    # check
    for k, v in j_bn.items():
        assert k == v['node']
        if type(v['parents']) == list:
            for p in v['parents']:
                if not nx_bn.has_edge(p, k):
                    print(p, k, v['parents'])
                assert nx_bn.has_edge(p, k)
        else:
            assert type(v['parents']) == str
            assert nx_bn.has_edge(v['parents'], k)

    return nx_bn


def dupliate_and_add_edges(nx_sub, k, c):
    nx_rnt = nx.DiGraph()
    nx_rnt.add_nodes_from('{}_{}'.format(v, ki)
                          for ki in range(1, k+1) for v in nx_sub.nodes)
    nx_rnt.add_edges_from(('{}_{}'.format(u, ki), '{}_{}'.format(v, ki))
                          for ki in range(1, k+1) for u, v in nx_sub.edges)

    assert len(nx_rnt.nodes) == k * len(nx_sub.nodes)
    assert len(nx_rnt.edges) == k * len(nx_sub.edges)

    if k <= 1:
        return nx_rnt

    extra_edges = math.ceil(c * len(nx_rnt.edges))
    for _ in range(extra_edges):
        k1, k2 = random.sample([ki for ki in range(1, k+1)], k=2)
        u, v = random.choices(list(nx_sub.nodes), k=2)

        while nx_rnt.has_edge('{}_{}'.format(u, k1), '{}_{}'.format(v, k2)) or \
                nx.has_path(nx_rnt, '{}_{}'.format(v, k2), '{}_{}'.format(u, k1)):
            k1, k2 = random.sample([ki for ki in range(1, k+1)], k=2)
            u, v = random.choices(list(nx_sub.nodes), k=2)

        nx_rnt.add_edge('{}_{}'.format(u, k1), '{}_{}'.format(v, k2))

    assert len(nx_rnt.edges) == k * len(nx_sub.edges) + extra_edges
    return nx_rnt


def network2datasets(nx_graph, n_samples):
    sorted_nodes = sorted(nx_graph.nodes)
    node_mappings = {name: index for index, name in enumerate(sorted_nodes)}

    # weight: [-1, -0.5] | [0.5, 1] uniformly
    adj_graph = nx.adjacency_matrix(
        nx_graph, nodelist=sorted_nodes, dtype=np.float64)
    adj_graph[adj_graph.nonzero()] = np.random.uniform(
        0.5, 1, adj_graph.nnz) * np.random.choice([-1, 1], adj_graph.nnz)

    sim_data = np.zeros((n_samples, len(sorted_nodes)))
    for node in nx.topological_sort(nx_graph):
        node = node_mappings[node]

        sim_data[:, node] = np.random.normal(
            0, np.sqrt(np.random.uniform(0, 1)), n_samples)
        for parent in nx_graph.predecessors(sorted_nodes[node]):
            parent = node_mappings[parent]
            sim_data[:, node] += adj_graph[parent, node] * sim_data[:, parent]

    sim_data = StandardScaler().fit_transform(sim_data)

    return adj_graph, sim_data


def main():
    random.seed(args.seed)
    np.random.seed(args.seed)

    sub_graph = load_json2nx(args.json)
    nx_graph = dupliate_and_add_edges(sub_graph, args.k, args.c)

    adj_graph, sim_data = network2datasets(nx_graph, args.n_samples)

    with open(args.nx, 'w') as f:
        json.dump(nx.node_link_data(nx_graph), f, indent='\t')

    scipy.sparse.save_npz(args.npz, adj_graph)

    with open(args.npy, 'wb') as f:
        np.save(f, sim_data)


if __name__ == '__main__':
    main()
