import argparse
import random
import numpy as np
import torch
from gensim.models import doc2vec
from torch_geometric.datasets import TUDataset
from torch_geometric.utils import to_networkx

from evaluate_embedding import evaluate_embedding


def setup_seed(seed):
    random.seed(seed)


def arr2str(arr):
    result = []
    for i in arr:
        result.append(str(i))
    return result

    
def neighborhood_embedding(nx_g, idx):
    walk = random_walk(nx_g, args.walk_length)
    corpus = list(arr2str(walk))
    sentences = [doc2vec.TaggedDocument(doc, [i]) for i, doc in enumerate(corpus)]
    model = doc2vec.Doc2Vec(vector_size=args.d, window=args.window_size, min_count=1, workers=4, epochs=args.iter, dm=1 if args.model == 'dm' else 0)
    model.build_vocab(sentences)
    model.train(sentences, total_examples=model.corpus_count, epochs=model.epochs)
    return model.dv[0]


def random_walk(G, walkSize):
    walkList= []
    curNode = random.choice(list(G.nodes()))
    while(len(walkList) < walkSize):
        walkList.append(curNode)
        curNode = random.choice(list(G.neighbors(curNode)))  
    return walkList


def parse_args():
    parser = argparse.ArgumentParser(description="sub2vec.")
    parser.add_argument('--dataset', default='MUTAG', type=str)
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--walk_length', default=100, type=int, help='length of random walk on each subgraph')
    parser.add_argument('--d', default=128, type=int, help='dimension of learned feautures for each subgraph.')
    parser.add_argument('--iter', default=10, type=int, help= 'training iterations')
    parser.add_argument('--window_size', default=2, type=int, help='Window size of the model.')
    parser.add_argument('--model', default='dm', choices=['dbon', 'dm'], help='models for learninig vectors SV-DM (dm) or SV-DBON (dbon).')            
    args = parser.parse_args()
    return args


if __name__=='__main__':
    args = parse_args()
    print(args)
    print('---------------------')
    setup_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = TUDataset('data', name=args.dataset)

    embeds = np.zeros((len(dataset), args.d))
    labels = np.zeros((len(dataset), 1))
    for idx, data in enumerate(dataset):
        nx_G = to_networkx(data)
        for edge in nx_G.edges():
            nx_G[edge[0]][edge[1]]['weight'] = 1
        vectors = neighborhood_embedding(nx_G, idx)
        embeds[idx] = np.array(vectors)
        labels[idx] = data.y

    acc_mean, acc_std = evaluate_embedding(embeds, labels.ravel())
    print(f'test acc: {acc_mean:.4f} +- {acc_std:.4f}')
