import argparse
import numpy as np
import random
from gensim.models import Word2Vec
import warnings
import seaborn as sns

import torch
from torch_geometric.utils import to_networkx, is_undirected

from utils import set_seed, random_splits
from dataset_loader import DataLoader
from eval import unsupervised_test_linear
warnings.filterwarnings("ignore")


def alias_setup(probs):
	K = len(probs)
	q = np.zeros(K)
	J = np.zeros(K, dtype=np.int)
	smaller = []
	larger = []
	for kk, prob in enumerate(probs):
		q[kk] = K * prob
		if q[kk] < 1.0:
			smaller.append(kk)
		else:
			larger.append(kk)

	while len(smaller) > 0 and len(larger) > 0:
		small = smaller.pop()
		large = larger.pop()
		J[small] = large
		q[large] = q[large] + q[small] - 1.0
		if q[large] < 1.0:
			smaller.append(large)
		else:
			larger.append(large)
	return J, q


def alias_draw(J, q):
	'''
	Draw sample from a non-uniform discrete distribution using alias sampling.
	'''
	K = len(J)
	kk = int(np.floor(np.random.rand()*K))
	if np.random.rand() < q[kk]:
		return kk
	else:
	    return J[kk]


class Graph():
	def __init__(self, nx_G, is_directed, p, q):
		self.G = nx_G
		self.is_directed = is_directed
		self.p = p
		self.q = q

	def node2vec_walk(self, walk_length, start_node):
		'''
		Simulate a random walk starting from start node.
		'''
		G = self.G
		alias_nodes = self.alias_nodes
		alias_edges = self.alias_edges
		walk = [start_node]

		while len(walk) < walk_length:
			cur = walk[-1]
			cur_nbrs = sorted(G.neighbors(cur))
			if len(cur_nbrs) > 0:
				if len(walk) == 1:
					walk.append(cur_nbrs[alias_draw(alias_nodes[cur][0], alias_nodes[cur][1])])
				else:
					prev = walk[-2]
					next = cur_nbrs[alias_draw(alias_edges[(prev, cur)][0], 
						alias_edges[(prev, cur)][1])]
					walk.append(next)
			else:
				break
		return walk

	def simulate_walks(self, num_walks, walk_length):
		'''
		Repeatedly simulate random walks from each node.
		'''
		G = self.G
		walks = []
		nodes = list(G.nodes())
		for walk_iter in range(num_walks):
			random.shuffle(nodes)
			for node in nodes:
				walks.append(self.node2vec_walk(walk_length=walk_length, start_node=node))

		return walks

	def get_alias_edge(self, src, dst):
		'''
		Get the alias edge setup lists for a given edge.
		'''
		G = self.G
		p = self.p
		q = self.q
		unnormalized_probs = []
		for dst_nbr in sorted(G.neighbors(dst)):
			if dst_nbr == src:
				unnormalized_probs.append(G[dst][dst_nbr]['weight']/p)
			elif G.has_edge(dst_nbr, src):
				unnormalized_probs.append(G[dst][dst_nbr]['weight'])
			else:
				unnormalized_probs.append(G[dst][dst_nbr]['weight']/q)
		norm_const = sum(unnormalized_probs)
		normalized_probs =  [float(u_prob)/norm_const for u_prob in unnormalized_probs]
		return alias_setup(normalized_probs)

	def preprocess_transition_probs(self):
		'''
		Preprocessing of transition probabilities for guiding the random walks.
		'''
		G = self.G
		is_directed = self.is_directed

		alias_nodes = {}
		for node in G.nodes():
			unnormalized_probs = [G[node][nbr]['weight'] for nbr in sorted(G.neighbors(node))]
			norm_const = sum(unnormalized_probs)
			normalized_probs =  [float(u_prob)/norm_const for u_prob in unnormalized_probs]
			alias_nodes[node] = alias_setup(normalized_probs)

		alias_edges = {}
		if is_directed:
			for edge in G.edges():
				alias_edges[edge] = self.get_alias_edge(edge[0], edge[1])
		else:
			for edge in G.edges():
				alias_edges[edge] = self.get_alias_edge(edge[0], edge[1])
				alias_edges[(edge[1], edge[0])] = self.get_alias_edge(edge[1], edge[0])
		self.alias_nodes = alias_nodes
		self.alias_edges = alias_edges
		return


def parse_args():
	parser = argparse.ArgumentParser()
	parser.add_argument('--seed', type=int, default=42, help='seed.')
	parser.add_argument('--dataset', type=str,default='Cora')
	parser.add_argument('--device', type=int, default=0, help='GPU device.')
	parser.add_argument('--runs', type=int, default=10, help='number of runs.')
	
	parser.add_argument('--fix_split', action='store_true')
	parser.add_argument('--train_rate', type=float, default=0.6, help='train set rate.')
	parser.add_argument('--val_rate', type=float, default=0.2, help='val set rate.')
	parser.add_argument("--lr2", type=float, default=0.01, help="Learning rate of linear evaluator.")
	parser.add_argument("--wd2", type=float, default=0.0, help="Weight decay of linear evaluator.")

	parser.add_argument('--dimensions', type=int, default=128, help='Number of dimensions. Default is 128.')
	parser.add_argument('--walk_length', type=int, default=80, help='Length of walk per source. Default is 80.')
	parser.add_argument('--num_walks', type=int, default=10, help='Number of walks per source. Default is 10.')
	parser.add_argument('--window_size', type=int, default=10, help='Context size for optimization. Default is 10.')
	parser.add_argument('--iter', default=1, type=int, help='Number of epochs in SGD')
	parser.add_argument('--workers', type=int, default=8,help='Number of parallel workers. Default is 8.')
	parser.add_argument('--p', type=float, default=1, help='Return hyperparameter. Default is 1.')
	parser.add_argument('--q', type=float, default=1, help='Inout hyperparameter. Default is 1.')
	args = parser.parse_args()
	return args


def unsupervised_learning(data, args, device):
    nx_G = to_networkx(data)
    for edge in nx_G.edges():
        nx_G[edge[0]][edge[1]]['weight'] = 1

    G = Graph(nx_G, not is_undirected(data.edge_index), args.p, args.q)
    G.preprocess_transition_probs()
    walks = G.simulate_walks(args.num_walks, args.walk_length)

    # Convert node indices to string to work with Word2Vec
    walks = [[str(node) for node in walk] for walk in walks]

    # Train the Word2Vec model using gensim
    model = Word2Vec(sentences=walks, vector_size=args.dimensions, window=args.window_size, 
                     min_count=0, sg=1, workers=args.workers, epochs=args.iter)

    embeds = torch.zeros((len(nx_G.nodes()), args.dimensions))
    for node in nx_G.nodes():
        embeds[int(node)] = torch.tensor(model.wv[str(node)])
    return embeds.to(device)



if __name__ == "__main__":
    args = parse_args()
    print(args)
    print("---------------------------------------------")
    
    set_seed(args.seed)
    #10 fixed seeds for random splits from BernNet
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    dataset = DataLoader(args.dataset)
    data = dataset[0]
    data = data.to(device)
	
    percls_trn = int(round(args.train_rate * len(data.y) / dataset.num_classes))
    val_lb = int(round(args.val_rate * len(data.y)))

    embeds = unsupervised_learning(data=data, args=args, device=device)
    
    if args.dataset not in ['Computers', 'Photo']:
        full_train_mask, full_val_mask, full_test_mask = data.train_mask, data.val_mask, data.test_mask
        
    unsup_results = []
    for RP in range(args.runs):
        args.seed = SEEDS[RP]
        
        if args.fix_split:
            if args.dataset in ['Computers', 'Photo']:  # no public splitting, train/val/test=1/1/8
                percls_trn = int(round(0.1 * len(data.y) / dataset.num_classes))
                val_lb = int(round(0.1 * len(data.y)))
                data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)
            else:
                data.train_mask, data.val_mask, data.test_mask = full_train_mask[:, RP], full_val_mask[:, RP], full_test_mask[:, RP]
        else:       
            data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)
        
        eval_acc = unsupervised_test_linear(data=data, embeds=embeds, n_classes=dataset.num_classes, device=device, args=args)
        unsup_results.append(eval_acc)

    test_acc_mean = np.mean(unsup_results) * 100
    values = np.asarray(unsup_results, dtype=object)
    uncertainty = np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(values, func=np.mean, n_boot=1000), 95) - values.mean()))
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty * 100:.4f}')