import argparse, time
import numpy as np
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
import dgl
from models.dgi import DGI, MultiClassifier, Classifier
from src.models.subgi import SubGI
from src.models.vgae import VGAE
# from src.models.utils import Mine, train_mine
from IPython import embed
from sklearn.metrics import roc_auc_score, average_precision_score
import scipy.sparse as sp
import collections
from tqdm import tqdm
from collections import defaultdict
import pickle
from copy import deepcopy
from gensim.models import KeyedVectors
from sklearn.manifold import SpectralEmbedding

def degree_bucketing(graph, args, degree_emb=None, max_degree = 10, test=False):
    #G = nx.DiGraph(graph)
    #embed()

    if True:
        max_degree = args.n_hidden
        features = torch.zeros([graph.number_of_nodes(), max_degree])
        # return features
        for i in range(graph.number_of_nodes()):
            #print(i)
            #try:
            assert graph.in_degree(i) == graph.out_degree(i)
            features[i][min(graph.in_degree(i), max_degree-1)] = 1
                # features[i, :] = degree_emb[min(graph.degree(i), max_degree-1), :]
            #except:
            #    features[i][0] = 1
    elif False:
        # max_degree = args.n_hidden
        features = torch.zeros([graph.number_of_nodes(), graph.number_of_nodes()])
        # return features
        for i in range(graph.number_of_nodes()):
            #print(i)
            features[i, i] = 0
    else:
        A = np.zeros([graph.number_of_nodes(), graph.number_of_nodes()])
        a,b = graph.all_edges()
        
        for id_a, id_b in zip(a.numpy().tolist(), b.numpy().tolist()):
            #OUT.write('0 {} {} 1\n'.format(id_a, id_b))
            A[id_a, id_b] = 1
        # embed()
        #embedding = SpectralEmbedding(n_components=args.n_hidden)
        #features = torch.FloatTensor(embedding.fit_transform(A))
        # print("Use spectral")
        #features = torch.ones([graph.number_of_nodes(), args.n_hidden])
        features = torch.FloatTensor(np.random.normal(0, 1, [graph.number_of_nodes(), args.n_hidden]))
    # embed()


    return features

def createTraining(labels, valid_mask = None, train_ratio=0.8):
    label_cnt = defaultdict(list)
    for idx,l in enumerate(labels):
        label_cnt[l].append(idx)
    eval_set = defaultdict(list)
    for k in label_cnt:
        if len(label_cnt[k]) > 1:
            for idx in label_cnt[k]:
                for idx_test in label_cnt[k]:
                    if idx_test != idx:
                        eval_set[idx].append(idx_test)
    return eval_set

# debug the synthetic encoder part
def eval_equ(eval_set, emb):
    correct_cnt, total_cnt  = 0, 0
    total_dist = 0.0
    for k in eval_set:
        # embed()
        #result = - torch.mm(emb, emb[k].T.unsqueeze(-1)).squeeze()
        result = torch.norm(emb - emb[k], dim=1, p=2)#embed()

        #for _id in eval_set[k]:
        #    total_dist += result[_id].item()
        # embed() 
        
        rank = np.argsort(result.cpu().numpy()).tolist()
        # print(rank, k)
        # assert rank[0] == k
        if k in rank[:1+len(eval_set[k])]:
            candidate = rank[:1+len(eval_set[k])]
        else:
            candidate = rank[:len(eval_set[k])]
        set_a, set_b = set(eval_set[k]), set(candidate)
        total_cnt += len(eval_set[k])
        correct_cnt += len(set_a&set_b)
        
    return float(correct_cnt) / total_cnt, len(eval_set)
    #return total_dist / total_cnt, len(eval_set)
    #pass

def read_struct_net(args):
    #g = DGLGraph()
    g = nx.Graph()
    #g.add_nodes(1000)
    with open(args.file_path) as IN:
        for line in IN:
            tmp = line.strip().split()
            # print(tmp[0], tmp[1])
            g.add_edge(int(tmp[0]), int(tmp[1]))
    labels = dict()
    with open(args.label_path) as IN:
        IN.readline()
        for line in IN:
            tmp = line.strip().split(' ')
            # labels.append(int(tmp[1]))
            labels[int(tmp[0])] = int(tmp[1])
    return g, labels
    #g.add_nodes(len(graph_a.id2idx) + len(graph_b.id2idx))
    
    #g.add_edges(graph_a.edge_src, graph_a.edge_dst)
    #g.add_edges(graph_a.edge_dst, graph_a.edge_src)
    
def constructDGL(graph, labels):
    node_mapping = defaultdict(int)
    relabels = []
    for node in sorted(list(graph.nodes())):
        node_mapping[node] = len(node_mapping)
        relabels.append(labels[node])
    # embed()
    assert len(node_mapping) == len(labels)
    new_g = DGLGraph()
    new_g.add_nodes(len(node_mapping))
    #for i in range(len(node_mapping)):
    #    new_g.add_edge(i, i)
    for edge in graph.edges():
        new_g.add_edge(node_mapping[edge[0]], node_mapping[edge[1]])
        new_g.add_edge(node_mapping[edge[1]], node_mapping[edge[0]])
    
    # embed()
    return new_g, relabels

def outputGraph(graph, out_path):
    a,b = graph.all_edges()
    with open(out_path, 'w') as OUT:
        for id_a, id_b in zip(a.numpy().tolist(), b.numpy().tolist()):
            OUT.write('0 {} {} 1\n'.format(id_a, id_b))

def output_adj(graph):
    A = np.zeros([graph.number_of_nodes(), graph.number_of_nodes()])
    a,b = graph.all_edges()
    for id_a, id_b in zip(a.numpy().tolist(), b.numpy().tolist()):
            #OUT.write('0 {} {} 1\n'.format(id_a, id_b))
        A[id_a, id_b] = 1
    # 
    # 
    # embed()
    return A
#TODO: add a nearest search part
def main(args):
    # load and preprocess dataset

    # Find the max graph for pre-training on protein graph
    # data = load_data(args)
    if False:
        graphs = create(args)
        max_num, max_id = 0,-1
        for idx, g in enumerate(graphs):
            if g.number_of_edges() > max_num:
                max_num = g.number_of_edges()
                max_id = idx
        graph, labels = graphs[max_id], []

    #embed()
    # g,labels = read_struct_net(args)
    # print('here')
    '''
    raw_graphs, graphs, label_sets = create(args), [], []
    valid_mask = None
    for g in raw_graphs:
        node_key_dict, node_label_dict = generate_label(g)
        # embed()
        g.remove_edges_from(nx.selfloop_edges(g))
        g, labels = constructDGL(g, node_label_dict)
        # embed()
        graphs.append(g)
        #label_sets.append(labels)
        eval_set = createTraining(labels)
        label_sets.append(eval_set)
    '''
    
    #pickle.dump({'label':label_sets, 'graphs':graphs}, open('{}_graphs_full.pkl'.format(args.graph_type), 'wb'))
    #embed()
    #tmp = {'label':label_sets, 'graphs':graphs}
    #return 
    tmp = pickle.load(open('barabasi_small_graphs_full.pkl', 'rb'))
    #tmp = pickle.load(open('forest_fire_graphs_full.pkl', 'rb'))
    #embed()
    #wv = KeyedVectors.load("data/barabasi_small_0_dw.emb", mmap='r')
    #embed()
    source_graphs = tmp['graphs']

    for idx,g in enumerate(source_graphs):
        source_graphs[idx] = dgl.transform.remove_self_loop(g)
    # embed()
    source_label_sets = tmp['label']
    #print(source_graphs[0])
    #print(source_graphs[1])
    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
        torch.cuda.set_device(args.gpu)
        # features = features.cuda()
        # labels = labels.cuda()
        #train_mask = train_mask.cuda()
        #val_mask = val_mask.cuda()
        #test_mask = test_mask.cuda()

    # graph preprocess
    if False:
        g = nx.Graph(graph)
        if args.self_loop:
            g.remove_edges_from(nx.selfloop_edges(g))
            g.add_edges_from(zip(g.nodes(), g.nodes()))
        else:
            g.remove_edges_from(nx.selfloop_edges(g))
        g = DGLGraph(nx.to_scipy_sparse_matrix(g))

    # print(g)
    
    in_feats = args.n_hidden
    #source graph id is -1
    
    #embed()
    # create DGI model
    torch.manual_seed(2)
    acc_1, acc_2, cnt, graph_size = [], [], [], []
    acc_src_1, acc_src_2 = 0, 0
    acc_source_a, acc_source_b = [], []
    #for _ in range(len(graphs[-1])):
    for kk in range(10):
        _ = 39
        g = source_graphs[_]
        #g.readonly(False)
        #for i in range(g.number_of_nodes()):
        #    g.add_edge(i, i)
        g.readonly()
        # print("size of source graph: {}".format( g.number_of_nodes()))
        n_hidden = 32
        if args.model == 0:
            dgi = DGI(g,
                    in_feats,
                    n_hidden,
                    args.n_layers,
                    F.relu,
                    args.dropout)
            dgi_ref = deepcopy(dgi)
            '''
            dgi_ref = DGI(g,
                    in_feats,
                    args.n_hidden,
                    args.n_layers,
                    nn.PReLU(args.n_hidden),
                    args.dropout)
            '''
        elif args.model == 1:
            dgi = VGAE(g,
                    in_feats,
                    n_hidden,
                    args.n_hidden,
                    #F.relu,
                    args.dropout)
            dgi.prepare()
            #embed()
            #dgi.adj_train = adj_orig + sp.eye(g.number_of_nodes())
            dgi.adj_train = sp.csr_matrix(output_adj(g))
            #dgi.adj_train = sp.csr_matrix(output_adj(g) + np.eye(g.number_of_nodes()))
            # embed()
            dgi_ref = deepcopy(dgi)
            '''
            dgi_ref = VGAE(g,
                    in_feats,
                    args.n_hidden,
                    args.n_hidden,
                    #F.relu,
                    args.dropout)
            dgi_ref.adj_train = sp.csr_matrix(output_adj(g))
            dgi_ref.prepare()
            '''
            # embed()
        elif args.model == 2:
            dgi = SubGI(g,
                    in_feats,
                    n_hidden,
                    args.n_layers,
                    F.relu,
                    args.dropout,
                    args.model_id)
            dgi_ref = deepcopy(dgi)
            '''
            dgi_ref = SubGI(g,
                    in_feats,
                    args.n_hidden,
                    args.n_layers,
                    F.relu,
                    args.dropout,
                    args.model_id)
            '''
        elif args.model == 3:
            dgi = SUBVGAE(g,
                    in_feats,
                    args.n_hidden,
                    args.n_hidden,
                    #F.relu,
                    args.dropout)
            dgi.prepare()
            dgi_ref = deepcopy(dgi)
            # dgl.test_g = g
            '''
            dgi_ref = SUBVGAE(g,
                    in_feats,
                    args.n_hidden,
                    args.n_hidden,
                    #F.relu,
                    args.dropout)
            dgi_ref.prepare()
            #dgi_ref.test_g = g
            '''
            
            #embed()
            # dgi.adj_train = adj_train + sp.eye(g.number_of_nodes())
        # print(dgi)
        # embed()
    
        features = degree_bucketing(g, args, test = False).cuda()
        # embed()
        # train deep graph infomax
        cnt_wait = 0
        best = 1e9
        best_t = 0
        dur = []


        # n_edges = g.number_of_edges()
        #if len(label_sets[idx]) == 0:
        #    continue
        
        # embed()
        # dgi.g = g
        # dgi.reset_parameters()
            
        if cuda:
            dgi.cuda()
            dgi_ref.cuda()

        dgi_optimizer = torch.optim.Adam(dgi.parameters(),
                                        lr=args.dgi_lr,
                                        weight_decay=args.weight_decay)
        
        

        for epoch in range(args.n_dgi_epochs):
            train_sampler = dgl.contrib.sampling.NeighborSampler(g, 128, 5,  # 0,
                                                                    neighbor_type='in', num_workers=1,
                                                                    add_self_loop=False,
                                                                    num_hops=args.n_layers + 1, shuffle=True)
            dgi.train()
            if epoch >= 3:
                t0 = time.time()

            if args.model == 0:
                dgi_optimizer.zero_grad()
                loss = dgi(features)
                loss.backward()
                dgi_optimizer.step()
                loss = loss.item()
            elif args.model == 1 or args.model == 3:
                dgi.optimizer = dgi_optimizer
                dgi.train_sampler = train_sampler
                dgi.features = features
                loss = dgi.train_model()
            elif args.model == 2:
                loss = []
                for nf in train_sampler:
                    dgi_optimizer.zero_grad()
                    l = dgi(features, nf)
                    l.backward()
                    loss.append(l.item())
                    
                    dgi_optimizer.step()
                    # g.ndata.clear()
                loss = np.sum(loss)

            if loss < best:
                best = loss
                best_t = epoch
                cnt_wait = 0
                torch.save(dgi.state_dict(), 'best_dgi.pkl')
            else:
                cnt_wait += 1

            if cnt_wait == args.patience:
                print('Early stopping!')
                break

            if epoch >= 3:
                dur.append(time.time() - t0)

            print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | ".format(epoch, np.mean(dur), loss))

        # create classifier model
        #mine_net_indep = Mine(features.shape[1] + args.n_hidden).cuda()
        #mine_net_optim_indep = torch.optim.Adam(mine_net_indep.parameters(), lr=1e-3)
        
        # print('Loading {}th epoch'.format(best_t))
        if args.n_dgi_epochs > 0:
            dgi.load_state_dict(torch.load('best_dgi.pkl'))

       
        # embed()

        try:
            with torch.no_grad():
                if args.model != 1:
                    embeds = dgi.encoder(features, corrupt=False)
                else:
                    idle, embeds, idle = dgi.forward(features)
            acc_source_a.append(eval_equ(source_label_sets[_], embeds)[0])
            with torch.no_grad():
                if args.model != 1:
                    embeds = dgi_ref.encoder(features, corrupt=False)
                else:
                    idle, embeds, idle = dgi_ref.forward(features)

            # acc_b, __ = eval_equ(label_sets[-1], embeds)
            acc_source_b.append(eval_equ(source_label_sets[_], embeds)[0])
        except:
            print("Exception")
            pass
        #
        
        tmp = pickle.load(open('forest_fire_graphs_full.pkl', 'rb'))

        graphs = tmp['graphs']
        label_sets = tmp['label']
        # print("number of test graphs: {}".format( len(graphs)))

        for idx, g in enumerate(graphs):
            # 0 as the pre-training graph
            # print(idx)
            

            #g.readonly(False)
            #for i in range(g.number_of_nodes()):
            #    g.add_edge(i, i)
            #g.readonly()

            features = degree_bucketing(g, args, test= False).cuda()
            
            
            dgi.g = g
            dgi_ref.g = g
            if args.model != 1:
                dgi.encoder.g = g
                dgi.encoder.conv.g = g
                dgi_ref.encoder.g = g
                dgi_ref.encoder.conv.g = g
            # dgi.reset_parameters()
            # dgi_ref.reset_parameters()

            
            if cuda:
                dgi.cuda()
                dgi_ref.cuda()
            
            with torch.no_grad():
                if args.model != 1:
                    embeds = dgi.encoder(features, corrupt=False)
                else:
                    idle, embeds, idle = dgi.forward(features)
            #else:
            # embeds = node_emb
            if len(label_sets[idx]) == 0:
                continue
            
            acc_a, __ = eval_equ(label_sets[idx], embeds)
            with torch.no_grad():
                if args.model != 1:
                    embeds = dgi_ref.encoder(features, corrupt=False)
                else:
                    idle, embeds, idle = dgi_ref.forward(features)
            acc_b, __ = eval_equ(label_sets[idx], embeds)
            if idx == _:
                acc_src_1, acc_src_2 = acc_a, acc_b
                continue
            #print(acc_1, acc_2)
            acc_1.append(acc_a)
            acc_2.append(acc_b)
            cnt.append(__)
            graph_size.append(g.number_of_nodes())
        # break
    print(np.mean(acc_1), np.std(acc_1))
    print(np.mean(acc_2), np.std(acc_2))
    print(np.mean(cnt), np.std(cnt))
    #print(np.mean(graph_size), np.std(graph_size))
    print(len(acc_1), len(acc_2))
    print(acc_src_1, acc_src_2)
    print('pretrain acc:{}, {}'.format(np.mean(acc_source_a), np.mean(acc_source_b)) )
    print('pretrain acc std:{}, {}'.format(np.std(acc_source_a), np.std(acc_source_b)) )
    return

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='DGI')
    register_data_args(parser)
    parser.add_argument("--dropout", type=float, default=0.,
                        help="dropout probability")
    parser.add_argument("--gpu", type=int, default=-1,
                        help="gpu")
    parser.add_argument("--model", type=int, default=0,
                        help="[0: DGI, 1: VGAE, 2:SubGI]")
    parser.add_argument("--dgi-lr", type=float, default=1e-3,
                        help="dgi learning rate")
    parser.add_argument("--classifier-lr", type=float, default=1e-3,
                        help="classifier learning rate")
    parser.add_argument("--n-dgi-epochs", type=int, default=300,
                        help="number of training epochs")
    parser.add_argument("--n-classifier-epochs", type=int, default=100,
                        help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=128,
                        help="number of hidden gcn units")
    parser.add_argument("--n-layers", type=int, default=1,
                        help="number of hidden gcn layers")
    parser.add_argument("--weight-decay", type=float, default=0.,
                        help="Weight for L2 loss")
    parser.add_argument("--patience", type=int, default=20,
                        help="early stop patience condition")
    parser.add_argument("--self-loop", action='store_true',
                        help="graph self-loop (default=False)")
    parser.add_argument("--file-path", type=str,
                        help="graph path")
    parser.add_argument("--label-path", type=str,
                        help="label path")
    parser.add_argument("--graph-type", type=str, default="barabasi_small",
                        help="graph self-loop (default=False)")
    parser.add_argument("--model-id", type=int, default=0,
                        help="[0, 1, 2, 3]")
    parser.set_defaults(self_loop=False)
    args = parser.parse_args()
    print(args)
    
    main(args)