import torch
import torch.nn as nn
from torch.nn import init
from torch.autograd import Variable

import numpy as np
import time
import random
from sklearn.metrics import accuracy_score, f1_score
from collections import defaultdict
import networkx as nx
from networkx.readwrite import json_graph
version_info = list(map(int, nx.__version__.split('.')))
major = version_info[0]
minor = version_info[1]
assert (major <= 1) and (minor <= 11), "networkx major version > 1.11"
import os
import pickle as pkl

from encoders import Encoder
from aggregators import MeanAggregator

import argparse

from hyperopt import fmin, tpe, hp


class SupervisedGraphSage(nn.Module):

    def __init__(self, num_classes, enc):
        super(SupervisedGraphSage, self).__init__()
        self.enc = enc
        self.xent = nn.CrossEntropyLoss()

        self.weight = nn.Parameter(torch.FloatTensor(num_classes, enc.embed_dim))
        init.xavier_uniform(self.weight)

    def forward(self, nodes):
        embeds = self.enc(nodes)
        scores = self.weight.mm(embeds)
        return scores.t(), embeds

    def loss(self, nodes, labels):
        scores, intermediate = self.forward(nodes)
        return self.xent(scores, labels.squeeze()), intermediate

def load_cora(feat_addition):
    num_nodes = 2708
    num_feats = 1433
    feat_data = np.zeros((num_nodes, num_feats))
    labels = np.empty((num_nodes,1), dtype=np.int64)
    node_map = {}
    label_map = {}
    with open("../data/cora/cora.content") as fp:
        for i,line in enumerate(fp):
            info = line.strip().split()
            feat_data[i,:] = list(map(float, info[1:-1]))
            node_map[info[0]] = i
            if not info[-1] in label_map:
                label_map[info[-1]] = len(label_map)
            labels[i] = label_map[info[-1]]

    adj_lists = defaultdict(set)
    with open("../data/cora/cora.cites") as fp:
        for i,line in enumerate(fp):
            info = line.strip().split()
            paper1 = node_map[info[0]]
            paper2 = node_map[info[1]]
            adj_lists[paper1].add(paper2)
            adj_lists[paper2].add(paper1)
            
    ###########################################################
    additional_feat = None
    if feat_addition > 0:
        G = nx.from_dict_of_lists(adj_lists)
        adj = nx.adjacency_matrix(G, nodelist=sorted(G.nodes()))
        A = adj.todense()
        num_nodes = A.shape[0]     
        if feat_addition == 1: # Concatenate Adjacency Matrix rows
            A = [np.array(A[i].tolist()[0]) for i in range(num_nodes)]
            A = np.array(A, dtype=np.float64)   
            feat_data = np.hstack((feat_data, A))
            additional_feat = A
        if feat_addition == 2: # Concatenate Random Walk Matrix Rows
            rw_matrix_file = "cora_rw"
            if os.path.exists(rw_matrix_file):
                with open(rw_matrix_file, 'rb') as rw_file:
                    X_rw = pkl.load(rw_file)
            else:
                X_rw = []
                for i in range(num_nodes):
                    ppr_from_node_i = nx.pagerank(G, personalization={x:(1 if x==i else 0) for x in range(num_nodes)})
                    rw_weights = []
                    for j in range(num_nodes):
                        rw_weights.append(ppr_from_node_i[j])
                    X_rw.append(np.array(rw_weights)) 
                X_rw = np.array(X_rw)
                with open(rw_matrix_file, 'wb') as rw_file:
                    pkl.dump(X_rw, rw_file)
            feat_data = np.hstack((feat_data, X_rw))
            additional_feat = X_rw
    ###########################################################
       
    return feat_data, additional_feat,  labels, adj_lists

def run_cora(feat_addition, laplacian_reg, lr=0.7, laplacian_reg_term=1e-7):
    """np.random.seed(1)
    random.seed(1)"""
    num_nodes = 2708
    feat_data, additional_feat, labels, adj_lists = load_cora(feat_addition)
    num_features = 1433
    if feat_addition > 0:
        num_features += num_nodes
    features = nn.Embedding(2708, num_features)
    features.weight = nn.Parameter(torch.FloatTensor(feat_data), requires_grad=False)
   # features.cuda()

    agg1 = MeanAggregator(features, cuda=True)
    enc1 = Encoder(features, num_features, 128, adj_lists, agg1, gcn=True, cuda=False)
    agg2 = MeanAggregator(lambda nodes : enc1(nodes).t(), cuda=False)
    enc2 = Encoder(lambda nodes : enc1(nodes).t(), enc1.embed_dim, 128, adj_lists, agg2,
            base_model=enc1, gcn=True, cuda=False)
    enc1.num_samples = 5
    enc2.num_samples = 5

    graphsage = SupervisedGraphSage(7, enc2)
#    graphsage.cuda()
    rand_indices = np.random.permutation(num_nodes)
    test = rand_indices[:1000]
    val = rand_indices[1000:1500]
    train = list(rand_indices[1500:])

    optimizer = torch.optim.SGD(filter(lambda p : p.requires_grad, graphsage.parameters()), lr=lr)
    times = []
    for batch in range(100):
        batch_nodes = train[:256]
        random.shuffle(train)
        start_time = time.time()
        optimizer.zero_grad()
        loss, intermediate = graphsage.loss(batch_nodes, 
                Variable(torch.LongTensor(labels[np.array(batch_nodes)])))
        
        if laplacian_reg:
            S = additional_feat
            S = S[batch_nodes, :]
            S = S[:, batch_nodes]
            n = S.shape[1]

            lapl_loss = 0
            if feat_addition == 2: #make rw matrix symmetric
                for i in range(0, n):
                    for j in range(i, n):
                        if i == j:
                            continue
                        S[i, j] = S[i, j] + S[j, i]
                        S[j, i] = S[i, j]

            rowsums = S.sum(axis=1)
            D = np.diag(rowsums)

            delta = D - S
            delta = torch.from_numpy(delta).float()

            lapl_loss = torch.trace( torch.matmul( torch.matmul(intermediate, delta), intermediate.t()) )  
            #lapl_loss = tf.Print(lapl_loss, [lapl_loss])
            loss += laplacian_reg_term*lapl_loss

        loss.backward()
        optimizer.step()
        end_time = time.time()
        times.append(end_time-start_time)
        print(batch, loss.data)

    val_output, _ = graphsage.forward(val) 
    val_acc = accuracy_score(labels[val], val_output.data.numpy().argmax(axis=1))
    val_f1_score = f1_score(labels[val], val_output.data.numpy().argmax(axis=1), average="micro")
    print("Validation Accuracy:", val_acc)
    print("Validation F1:", val_f1_score)
    print("Average batch time:", np.mean(times))
    return val_acc, val_f1_score


def load_pubmed(feat_addition):
    #hardcoded for simplicity...
    num_nodes = 19717
    num_feats = 500
    feat_data = np.zeros((num_nodes, num_feats))
    labels = np.empty((num_nodes, 1), dtype=np.int64)
    node_map = {}
    with open("../data/pubmed-data/Pubmed-Diabetes.NODE.paper.tab") as fp:
        fp.readline()
        feat_map = {entry.split(":")[1]:i-1 for i,entry in enumerate(fp.readline().split("\t"))}
        for i, line in enumerate(fp):
            info = line.split("\t")
            node_map[info[0]] = i
            labels[i] = int(info[1].split("=")[1])-1
            for word_info in info[2:-1]:
                word_info = word_info.split("=")
                feat_data[i][feat_map[word_info[0]]] = float(word_info[1])
    adj_lists = defaultdict(set)
    with open("../data/pubmed-data/Pubmed-Diabetes.DIRECTED.cites.tab") as fp:
        fp.readline()
        fp.readline()
        for line in fp:
            info = line.strip().split("\t")
            paper1 = node_map[info[1].split(":")[1]]
            paper2 = node_map[info[-1].split(":")[1]]
            adj_lists[paper1].add(paper2)
            adj_lists[paper2].add(paper1)
            
    ###########################################################
    additional_feat = None
    if feat_addition > 0:
        G = nx.from_dict_of_lists(adj_lists)
        adj = nx.adjacency_matrix(G, nodelist=sorted(G.nodes()))
        A = adj.todense()
        num_nodes = A.shape[0]     
        if feat_addition == 1: # Concatenate Adjacency Matrix rows
            A = [np.array(A[i].tolist()[0]) for i in range(num_nodes)]
            A = np.array(A, dtype=np.float64)   
            feat_data = np.hstack((feat_data, A))
            additional_feat = A
        if feat_addition == 2: # Concatenate Random Walk Matrix Rows
            rw_matrix_file = "pubmed_rw"
            if os.path.exists(rw_matrix_file):
                with open(rw_matrix_file, 'rb') as rw_file:
                    X_rw = pkl.load(rw_file)
                    X_rw = X_rw.todense()
            else:
                X_rw = []
                for i in range(num_nodes):
                    ppr_from_node_i = nx.pagerank(G, personalization={x:(1 if x==i else 0) for x in range(num_nodes)})
                    rw_weights = []
                    for j in range(num_nodes):
                        rw_weights.append(ppr_from_node_i[j])
                    X_rw.append(np.array(rw_weights)) 
                X_rw = np.array(X_rw)
                with open(rw_matrix_file, 'wb') as rw_file:
                    pkl.dump(X_rw, rw_file)
            feat_data = np.hstack((feat_data, X_rw))
            additional_feat = X_rw
    ###########################################################
            
    return feat_data, additional_feat, labels, adj_lists

def run_pubmed(feat_addition, laplacian_reg, lr=0.7, laplacian_reg_term=1e-7):
    """np.random.seed(1)
    random.seed(1)"""
    num_nodes = 19717
    feat_data, additional_feat, labels, adj_lists = load_pubmed(feat_addition)
    num_features = 500
    if feat_addition > 0:
        num_features += num_nodes
    features = nn.Embedding(19717, num_features)
    features.weight = nn.Parameter(torch.FloatTensor(feat_data), requires_grad=False)
   # features.cuda()

    agg1 = MeanAggregator(features, cuda=True)
    enc1 = Encoder(features, num_features, 128, adj_lists, agg1, gcn=True, cuda=False)
    agg2 = MeanAggregator(lambda nodes : enc1(nodes).t(), cuda=False)
    enc2 = Encoder(lambda nodes : enc1(nodes).t(), enc1.embed_dim, 128, adj_lists, agg2,
            base_model=enc1, gcn=True, cuda=False)
    enc1.num_samples = 10
    enc2.num_samples = 25

    graphsage = SupervisedGraphSage(3, enc2)
#    graphsage.cuda()
    rand_indices = np.random.permutation(num_nodes)
    test = rand_indices[:1000]
    val = rand_indices[1000:1500]
    train = list(rand_indices[1500:])

    optimizer = torch.optim.SGD(filter(lambda p : p.requires_grad, graphsage.parameters()), lr=lr)
    times = []
    for batch in range(200):
        batch_nodes = train[:1024]
        random.shuffle(train)
        start_time = time.time()
        optimizer.zero_grad()
        loss, intermediate = graphsage.loss(batch_nodes, 
                Variable(torch.LongTensor(labels[np.array(batch_nodes)])))
        
        if laplacian_reg:
            S = additional_feat
            S = S[batch_nodes, :]
            S = S[:, batch_nodes]
            n = S.shape[1]

            lapl_loss = 0
            if feat_addition == 2: #make rw matrix symmetric
                for i in range(0, n):
                    for j in range(i, n):
                        if i == j:
                            continue
                        S[i, j] = S[i, j] + S[j, i]
                        S[j, i] = S[i, j]

            rowsums = S.sum(axis=1)
            D = np.diag(rowsums)

            delta = D - S
            delta = torch.from_numpy(delta).float()

            lapl_loss = torch.trace( torch.matmul( torch.matmul(intermediate, delta), intermediate.t()) )
            #lapl_loss = tf.Print(lapl_loss, [lapl_loss])
            loss += laplacian_reg_term*lapl_loss

        loss.backward()
        optimizer.step()
        end_time = time.time()
        times.append(end_time-start_time)
        print(batch, loss.data)

    val_output, _ = graphsage.forward(val) 
    val_acc = accuracy_score(labels[val], val_output.data.numpy().argmax(axis=1))
    val_f1_score = f1_score(labels[val], val_output.data.numpy().argmax(axis=1), average="micro")
    print("Validation Accuracy:", val_acc)
    print("Validation F1:", val_f1_score)
    print("Average batch time:", np.mean(times))
    return val_acc, val_f1_score


def load_citeseer(feat_addition):
    num_nodes = 3312
    num_feats = 3703
    feat_data = np.zeros((num_nodes, num_feats))
    labels = np.empty((num_nodes,1), dtype=np.int64)
    node_map = {}
    label_map = {}
    with open("../data/citeseer/citeseer.content") as fp:
        for i,line in enumerate(fp):
            info = line.strip().split()
            feat_data[i,:] = list(map(float, info[1:-1]))
            node_map[info[0]] = i
            if not info[-1] in label_map:
                label_map[info[-1]] = len(label_map)
            labels[i] = label_map[info[-1]]

    adj_lists = defaultdict(set)
    with open("../data/citeseer/citeseer.cites") as fp:
        for i,line in enumerate(fp):
            info = line.strip().split()
            if info[0] not in node_map or info[1] not in node_map:
                continue
            paper1 = node_map[info[0]]
            paper2 = node_map[info[1]]
            adj_lists[paper1].add(paper2)
            adj_lists[paper2].add(paper1)
            
    ###########################################################
    additional_feat = None
    if feat_addition > 0:
        G = nx.from_dict_of_lists(adj_lists)
        adj = nx.adjacency_matrix(G, nodelist=sorted(G.nodes()))
        A = adj.todense()
        num_nodes = A.shape[0]     
        if feat_addition == 1: # Concatenate Adjacency Matrix rows
            A = [np.array(A[i].tolist()[0]) for i in range(num_nodes)]
            A = np.array(A, dtype=np.float64)   
            feat_data = np.hstack((feat_data, A))
            additional_feat = A
        if feat_addition == 2: # Concatenate Random Walk Matrix Rows
            rw_matrix_file = "citeseer_rw"
            if os.path.exists(rw_matrix_file):
                with open(rw_matrix_file, 'rb') as rw_file:
                    X_rw = pkl.load(rw_file)
                    X_rw = X_rw.todense()
            else:
                X_rw = []
                for i in range(num_nodes):
                    ppr_from_node_i = nx.pagerank(G, personalization={x:(1 if x==i else 0) for x in range(num_nodes)})
                    rw_weights = []
                    for j in range(num_nodes):
                        rw_weights.append(ppr_from_node_i[j])
                    X_rw.append(np.array(rw_weights)) 
                X_rw = np.array(X_rw)
                with open(rw_matrix_file, 'wb') as rw_file:
                    pkl.dump(X_rw, rw_file)
            additional_feat = X_rw 
            feat_data = np.hstack((feat_data, X_rw))
    ###########################################################
       
    return feat_data, additional_feat, labels, adj_lists

def run_citeseer(feat_addition, laplacian_reg, lr=0.7, laplacian_reg_term=1e-7):
    """np.random.seed(1)
    random.seed(1)"""
    feat_data, additional_feat, labels, adj_lists = load_citeseer(feat_addition)
    num_nodes = labels.shape[0]
    num_features = feat_data.shape[1]
    features = nn.Embedding(num_nodes, num_features)
    features.weight = nn.Parameter(torch.FloatTensor(feat_data), requires_grad=False)
   # features.cuda()

    agg1 = MeanAggregator(features, cuda=True)
    enc1 = Encoder(features, num_features, 128, adj_lists, agg1, gcn=True, cuda=False)
    agg2 = MeanAggregator(lambda nodes : enc1(nodes).t(), cuda=False)
    enc2 = Encoder(lambda nodes : enc1(nodes).t(), enc1.embed_dim, 128, adj_lists, agg2,
            base_model=enc1, gcn=True, cuda=False)
    enc1.num_samples = 5
    enc2.num_samples = 5

    graphsage = SupervisedGraphSage(6, enc2)
#    graphsage.cuda()
    rand_indices = np.random.permutation(num_nodes)
    test = rand_indices[:1000]
    val = rand_indices[1000:1500]
    train = list(rand_indices[1500:])

    optimizer = torch.optim.SGD(filter(lambda p : p.requires_grad, graphsage.parameters()), lr=lr)
    times = []
    for batch in range(100):
        batch_nodes = train[:256]
        random.shuffle(train)
        start_time = time.time()
        optimizer.zero_grad()
        loss, intermediate = graphsage.loss(batch_nodes, 
                Variable(torch.LongTensor(labels[np.array(batch_nodes)])))
                
        if laplacian_reg:
            S = additional_feat
            S = S[batch_nodes, :]
            S = S[:, batch_nodes]
            n = S.shape[1]

            lapl_loss = 0
            if feat_addition == 2: #make rw matrix symmetric
                for i in range(0, n):
                    for j in range(i, n):
                        if i == j:
                            continue
                        S[i, j] = S[i, j] + S[j, i]
                        S[j, i] = S[i, j]

            rowsums = S.sum(axis=1)
            D = np.diag(rowsums)

            delta = D - S
            delta = torch.from_numpy(delta).float()

            lapl_loss = torch.trace( torch.matmul( torch.matmul(intermediate, delta), intermediate.t()) )
            #lapl_loss = tf.Print(lapl_loss, [lapl_loss])
            loss += laplacian_reg_term*lapl_loss
                
        loss.backward()
        optimizer.step()
        end_time = time.time()
        times.append(end_time-start_time)
        print(batch, loss.data)

    val_output, _ = graphsage.forward(val) 
    val_acc = accuracy_score(labels[val], val_output.data.numpy().argmax(axis=1))
    val_f1_score = f1_score(labels[val], val_output.data.numpy().argmax(axis=1), average="micro")
    print("Validation Accuracy:", val_acc)
    print("Validation F1:", val_f1_score)
    print("Average batch time:", np.mean(times))
    return val_acc, val_f1_score  
    
    
def optimize(launch_function, feat_addition, laplacian_reg):
    laplacian_terms = []
    lrs = []
    test_acc_results = []
    
    verbose = False
    def objective(params):
        laplacian_reg_term = params['laplacian_reg_term']
        lr = params['lr']

        laplacian_terms.append(laplacian_reg_term)
        lrs.append(lr)

        acc, _ = launch_function(feat_addition, laplacian_reg, lr, laplacian_reg_term)
        test_acc_results.append(acc)
        return -(acc)
        
    best = fmin(objective,
        space={'laplacian_reg_term': hp.uniform('laplacian_reg_term', 1e-10, 1e-6),
               'lr': hp.uniform('lr', 1e-4, 0.1)},
        algo=tpe.suggest,
        max_evals=100)
        
    print("Best reg term:", best)
        
    print("Regular")
    print("Max", np.max(test_acc_results))
    print("Mean:", np.mean(np.array(test_acc_results)))
    print("Std:", np.std(np.array(test_acc_results)))

    with open('optimizations_lr.txt', 'w') as f:
        f.write("Lapl Term, lr, Test acc\n")
        for i1, i2, i3 in zip(laplacian_terms, lrs, test_acc_results):
            f.write(str(i1)+", "+str(i2)+", "+str(i3)+", "+"\n")
    
    
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='GraphSage')
    parser.add_argument('--dataset', default="cora", required=False,
                    help='Dataset: cora; pubmed; citeseer')
    parser.add_argument('--num_experiments', type=int, default=1, required=False,
                    help='Number of times to repeat trainnig and validation with randomly sampled data')
    parser.add_argument('--feat_addition', type=int, default=0, required=False,
                    help='0->No additional features; 1->adjacency matrix; 2->RW matrix')
    parser.add_argument('--laplacian_reg', type=int, default=0, required=False,
                    help='0->No; 1->yes')
    parser.add_argument('--optimize', type=int, default=0, required=False,
                    help='0->No; 1->yes')
    args = parser.parse_args()

    np.random.seed(1)
    random.seed(1)
    accuracies = []
    f1_scores = []
    for iteration in range(args.num_experiments):
        if args.dataset == "cora":
            if args.optimize:
                l_fun = run_cora
                optimize(l_fun, args.feat_addition, args.laplacian_reg)
                break
            acc, f1 = run_cora(args.feat_addition, args.laplacian_reg)
        elif args.dataset == "pubmed":
            if args.optimize:
                l_fun = run_pubmed
                optimize(l_fun, args.feat_addition, args.laplacian_reg)
                break
            acc, f1 = run_pubmed(args.feat_addition, args.laplacian_reg)
        elif args.dataset == "citeseer":
            if args.optimize:
                l_fun = run_citeseer
                optimize(l_fun, args.feat_addition, args.laplacian_reg)
                break
            acc, f1 = run_citeseer(args.feat_addition, args.laplacian_reg)
        accuracies.append(acc)
        f1_scores.append(f1)
        
    if args.num_experiments > 1:
        print("--- Final Random Splits results ---")
        print("- Accuracy")
        print("Mean:", np.mean(np.array(accuracies)))
        print("Std:", np.std(np.array(accuracies)))
        print("- F1 Score")
        print("Mean:", np.mean(np.array(f1_scores)))
        print("Std:", np.std(np.array(f1_scores)))
