import os
import random
import psutil
from multiprocessing import cpu_count
from six.moves import range
from gensim.models import Word2Vec

import argparse
import numpy as np
import warnings
import seaborn as sns
import torch
from torch_geometric.utils import to_networkx

from utils import set_seed, random_splits
from dataset_loader import DataLoader
from eval import unsupervised_test_linear
warnings.filterwarnings("ignore")


p = psutil.Process(os.getpid())
try:
    p.set_cpu_affinity(list(range(cpu_count())))
except AttributeError:
    try:
        p.cpu_affinity(list(range(cpu_count())))
    except AttributeError:
        pass


def random_walk(G, path_length, alpha=0, rand=random.Random(), start=None):
    """ Returns a truncated random walk.

        path_length: Length of the random walk.
        alpha: probability of restarts.
        start: the start node of the random walk.
    """
    if start:
        path = [start]
    else:
        # Sampling is uniform w.r.t V, and not w.r.t E
        path = [rand.choice(list(G.nodes()))]

    while len(path) < path_length:
        cur = path[-1]
        cur_nbrs = sorted(G.neighbors(cur))
        if len(cur_nbrs) > 0:
            if rand.random() >= alpha:
                path.append(rand.choice(cur_nbrs))
            else:
                path.append(path[0])
        else:
            break
    return [str(node) for node in path]


def build_deepwalk_corpus(G, num_paths, path_length, alpha=0, rand=random.Random(0)):
    walks = []
    nodes = list(G.nodes())
    for cnt in range(num_paths):
        rand.shuffle(nodes)
        for node in nodes:
            walks.append(random_walk(G, path_length, rand=rand, alpha=alpha, start=node))
    return walks


def build_deepwalk_corpus_iter(G, num_paths, path_length, alpha=0, rand=random.Random(0)):
    nodes = list(G.nodes())
    for cnt in range(num_paths):
        rand.shuffle(nodes)
        for node in nodes:
            yield random_walk(G, path_length, rand=rand, alpha=alpha, start=node)


def unsupervised_learning(data, args, device):
    G = to_networkx(data)
    for edge in G.edges():
        G[edge[0]][edge[1]]['weight'] = 1
    
    walks = build_deepwalk_corpus(G, num_paths=args.number_walks, path_length=args.walk_length, 
                                  alpha=0, rand=random.Random(args.seed))
    model = Word2Vec(walks, vector_size=args.representation_size, window=args.window_size, 
                     min_count=0, sg=1, hs=1, workers=args.workers)
    embeds = torch.zeros((len(G.nodes()), args.representation_size))
    for node in G.nodes():
        embeds[int(node)] = torch.tensor(model.wv[str(node)])
    return embeds.to(device)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='seed.')
    parser.add_argument('--dataset', type=str,default='Cora')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--runs', type=int, default=10, help='number of runs.')
    
    parser.add_argument('--fix_split', action='store_true')
    parser.add_argument('--train_rate', type=float, default=0.6, help='train set rate.')
    parser.add_argument('--val_rate', type=float, default=0.2, help='val set rate.')
    parser.add_argument("--lr2", type=float, default=0.01, help="Learning rate of linear evaluator.")
    parser.add_argument("--wd2", type=float, default=0.0, help="Weight decay of linear evaluator.")

    parser.add_argument('--number_walks', default=10, type=int, help='Number of random walks to start at each node')
    parser.add_argument('--representation_size', default=64, type=int, help='Number of latent dimensions to learn for each node.')
    parser.add_argument('--vertex_freq_degree', default=False, action='store_true', help='Use vertex degree to estimate the frequency of nodes in the random walks. This option is faster than calculating the vocabulary.')
    parser.add_argument('--walk_length', default=40, type=int, help='Length of the random walk started at each node')
    parser.add_argument('--window_size', default=5, type=int, help='Window size of skipgram model.')
    parser.add_argument('--workers', default=1, type=int, help='Number of parallel processes.')
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()
    print(args)
    print("---------------------------------------------")
    
    set_seed(args.seed)
    #10 fixed seeds for random splits from BernNet
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    dataset = DataLoader(args.dataset)
    data = dataset[0]
    data = data.to(device)
	
    percls_trn = int(round(args.train_rate * len(data.y) / dataset.num_classes))
    val_lb = int(round(args.val_rate * len(data.y)))

    embeds = unsupervised_learning(data=data, args=args, device=device)
    
    if args.dataset not in ['Computers', 'Photo']:
        full_train_mask, full_val_mask, full_test_mask = data.train_mask, data.val_mask, data.test_mask
        
    unsup_results = []
    for RP in range(args.runs):
        args.seed = SEEDS[RP]
        
        if args.fix_split:
            if args.dataset in ['Computers', 'Photo']:  # no public splitting, train/val/test=1/1/8
                percls_trn = int(round(0.1 * len(data.y) / dataset.num_classes))
                val_lb = int(round(0.1 * len(data.y)))
                data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)
            else:
                data.train_mask, data.val_mask, data.test_mask = full_train_mask[:, RP], full_val_mask[:, RP], full_test_mask[:, RP]
        else:       
            data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)
        

        data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)
        eval_acc = unsupervised_test_linear(data=data, embeds=embeds, n_classes=dataset.num_classes, device=device, args=args)
        unsup_results.append(eval_acc)

    test_acc_mean = np.mean(unsup_results) * 100
    values = np.asarray(unsup_results, dtype=object)
    uncertainty = np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(values, func=np.mean, n_boot=1000), 95) - values.mean()))
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty * 100:.4f}')