import argparse
import numpy as np
import random
from gensim.models import Word2Vec
import warnings

import torch
from torch_geometric.datasets import TUDataset
from torch_geometric.utils import to_networkx, is_undirected
from evaluate_embedding import evaluate_embedding
warnings.filterwarnings("ignore")


def alias_setup(probs):
    K = len(probs)
    q = np.zeros(K)
    J = np.zeros(K, dtype=np.int)
    smaller = []
    larger = []
    for kk, prob in enumerate(probs):
        q[kk] = K * prob
        if q[kk] < 1.0:
            smaller.append(kk)
        else:
            larger.append(kk)

    while len(smaller) > 0 and len(larger) > 0:
        small = smaller.pop()
        large = larger.pop()
        J[small] = large
        q[large] = q[large] + q[small] - 1.0
        if q[large] < 1.0:
            smaller.append(large)
        else:
            larger.append(large)
    return J, q


def alias_draw(J, q):
    '''
    Draw sample from a non-uniform discrete distribution using alias sampling.
    '''
    K = len(J)
    kk = int(np.floor(np.random.rand()*K))
    if np.random.rand() < q[kk]:
        return kk
    else:
        return J[kk]


class Graph():
    def __init__(self, nx_G, is_directed, p, q):
        self.G = nx_G
        self.is_directed = is_directed
        self.p = p
        self.q = q

    def node2vec_walk(self, walk_length, start_node):
        '''
        Simulate a random walk starting from start node.
        '''
        G = self.G
        alias_nodes = self.alias_nodes
        alias_edges = self.alias_edges
        walk = [start_node]

        while len(walk) < walk_length:
            cur = walk[-1]
            cur_nbrs = sorted(G.neighbors(cur))
            if len(cur_nbrs) > 0:
                if len(walk) == 1:
                    walk.append(cur_nbrs[alias_draw(alias_nodes[cur][0], alias_nodes[cur][1])])
                else:
                    prev = walk[-2]
                    next = cur_nbrs[alias_draw(alias_edges[(prev, cur)][0], 
                        alias_edges[(prev, cur)][1])]
                    walk.append(next)
            else:
                break
        return walk

    def simulate_walks(self, num_walks, walk_length):
        '''
        Repeatedly simulate random walks from each node.
        '''
        G = self.G
        walks = []
        nodes = list(G.nodes())
        for walk_iter in range(num_walks):
            random.shuffle(nodes)
            for node in nodes:
                walks.append(self.node2vec_walk(walk_length=walk_length, start_node=node))

        return walks

    def get_alias_edge(self, src, dst):
        '''
        Get the alias edge setup lists for a given edge.
        '''
        G = self.G
        p = self.p
        q = self.q
        unnormalized_probs = []
        for dst_nbr in sorted(G.neighbors(dst)):
            if dst_nbr == src:
                unnormalized_probs.append(G[dst][dst_nbr]['weight']/p)
            elif G.has_edge(dst_nbr, src):
                unnormalized_probs.append(G[dst][dst_nbr]['weight'])
            else:
                unnormalized_probs.append(G[dst][dst_nbr]['weight']/q)
        norm_const = sum(unnormalized_probs)
        normalized_probs =  [float(u_prob)/norm_const for u_prob in unnormalized_probs]
        return alias_setup(normalized_probs)

    def preprocess_transition_probs(self):
        '''
        Preprocessing of transition probabilities for guiding the random walks.
        '''
        G = self.G
        is_directed = self.is_directed

        alias_nodes = {}
        for node in G.nodes():
            unnormalized_probs = [G[node][nbr]['weight'] for nbr in sorted(G.neighbors(node))]
            norm_const = sum(unnormalized_probs)
            normalized_probs =  [float(u_prob)/norm_const for u_prob in unnormalized_probs]
            alias_nodes[node] = alias_setup(normalized_probs)

        alias_edges = {}
        if is_directed:
            for edge in G.edges():
                alias_edges[edge] = self.get_alias_edge(edge[0], edge[1])
        else:
            for edge in G.edges():
                alias_edges[edge] = self.get_alias_edge(edge[0], edge[1])
                alias_edges[(edge[1], edge[0])] = self.get_alias_edge(edge[1], edge[0])
        self.alias_nodes = alias_nodes
        self.alias_edges = alias_edges
        return


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='seed.')
    parser.add_argument('--dataset', type=str,default='MUTAG')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')

    parser.add_argument('--dimensions', type=int, default=128, help='Number of dimensions. Default is 128.')
    parser.add_argument('--walk_length', type=int, default=80, help='Length of walk per source. Default is 80.')
    parser.add_argument('--num_walks', type=int, default=10, help='Number of walks per source. Default is 10.')
    parser.add_argument('--window_size', type=int, default=10, help='Context size for optimization. Default is 10.')
    parser.add_argument('--iter', default=1, type=int, help='Number of epochs in SGD')
    parser.add_argument('--workers', type=int, default=8,help='Number of parallel workers. Default is 8.')
    parser.add_argument('--p', type=float, default=1, help='Return hyperparameter. Default is 1.')
    parser.add_argument('--q', type=float, default=1, help='Inout hyperparameter. Default is 1.')
    args = parser.parse_args()
    return args


def node2vec(data):
    nx_G = to_networkx(data)
    for edge in nx_G.edges():
        nx_G[edge[0]][edge[1]]['weight'] = 1

    G = Graph(nx_G, not is_undirected(data.edge_index), args.p, args.q)
    G.preprocess_transition_probs()
    walks = G.simulate_walks(args.num_walks, args.walk_length)
    walks = [[str(node) for node in walk] for walk in walks]

    model = Word2Vec(sentences=walks, vector_size=args.dimensions, window=args.window_size, 
                    min_count=0, sg=1, workers=args.workers, epochs=args.iter)

    embeds = torch.zeros((len(nx_G.nodes()), args.dimensions))
    for node in nx_G.nodes():
        embeds[int(node)] = torch.tensor(model.wv[str(node)])
    embeds = embeds.mean(dim=0)
    return embeds


if __name__ == "__main__":
    args = parse_args()
    print(args)
    print("---------------------------------------------")

    random.seed(args.seed)
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')
    dataset = TUDataset('data', name=args.dataset)
    
    embeds = np.zeros((len(dataset), args.dimensions))
    labels = np.zeros((len(dataset), 1))
    for idx, data in enumerate(dataset):
        vectors = node2vec(data)
        embeds[idx] = np.array(vectors)
        labels[idx] = data.y
    acc_mean, acc_std = evaluate_embedding(embeds, labels.ravel())
    print(f'test acc: {acc_mean:.4f} +- {acc_std:.4f}')