import argparse
import itertools
import time
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch_geometric.utils import negative_sampling
import numpy as np
import torch_geometric.transforms as T
from ogb.linkproppred import PygLinkPropPredDataset, Evaluator
from logger import Logger, ProductionLogger
from utils import get_dataset, do_edge_split
from models import MLP, GCN, SAGE, LinkPredictor
from torch_sparse import SparseTensor
from sklearn.metrics import *
from os.path import exists
from torch_cluster import random_walk
from torch.nn.functional import cosine_similarity
import torch_geometric
from train_teacher_gnn import test_transductive, test_production
from generate_production_split import do_production_edge_split
from torch_geometric.utils import to_networkx
import networkx as nx
import scipy.sparse as ssp

def cosine_loss(s, t):
    return 1-cosine_similarity(s, t.detach(), dim=-1).mean()

def kl_loss(s,t,T):
    y_s = F.log_softmax(s/T, dim=-1)
    y_t = F.softmax(t/T, dim=-1)
    loss = F.kl_div(y_s,y_t,size_average=False) * (T**2) / y_s.size()[0]
    return loss

def neighbor_samplers(row, col, sample, x, step, ps_method, ns_rate, hops, device='cuda'):
    batch = sample

    if ps_method == 'rw':
        pos_batch = random_walk(row.cpu(), col.cpu(), batch.cpu(), walk_length=step*hops,
                                coalesced=False)
    elif ps_method == 'nb':
        pos_batch = None
        for i in range(step):
            if pos_batch is None:
                pos_batch = random_walk(row.cpu(), col.cpu(), batch.cpu(), walk_length=hops, coalesced=False)
            else:
                pos_batch = torch.cat((pos_batch, random_walk(row.cpu(), col.cpu(), batch.cpu(), walk_length=hops,coalesced=False)[:,1:]), 1)

    neg_batch = torch.randint(0, x.size(0), (batch.numel(), step*hops*ns_rate),
                                  dtype=torch.long)

    if device == 'cuda':
        return pos_batch.to("cuda"), neg_batch.to("cuda")
    else:
        return pos_batch, neg_batch

def get_samples(data, split_edge,
                         args, device):
    if args.transductive == "transductive":
        pos_train_edge = split_edge['train']['edge'].to(data.x.device)
        row, col = data.adj_t
    else:
        pos_train_edge = data.edge_index.t()
        row, col = data.edge_index

    node_num = data.x.size(0)
    perm = torch.argsort(row * node_num + col)
    row, col = row[perm], col[perm]

    edge_index = torch.stack([col, row], dim=0)
#    print(edge_index[:,edge_index[0]==0])
#    xxx = input("pause")
    edge_weight = torch.ones(edge_index.size(1), dtype=int)
    #A = ssp.csr_matrix((edge_weight, (edge_index[0], edge_index[1])),
    #                                   shape=(node_num, node_num))
    #node_loader = iter(DataLoader(range(data.x.size(0)), args.node_batch_size, shuffle=True))
    all_pos_samples = []
    all_neg_samples = []
    for node_perm in DataLoader(range(data.x.size(0)), args.node_batch_size, shuffle=False):
#    for link_perm in DataLoader(range(pos_train_edge.size(0)), args.link_batch_size, shuffle=True):

#        node_perm = next(node_loader).to(data.x.device)
#        edge = pos_train_edge[link_perm].t()

        sample_step = args.rw_step
        # sampled the neary nodes and randomely sampled nodes
        pos_sample, neg_sample = neighbor_samplers(row, col, node_perm, data.x, 10*sample_step, args.ps_method, args.ns_rate, args.hops, 'cpu')
#        print(node_perm)
#        print(pos_sample)
#        xxx = input("pause")

            ### calculate the distribution based matching loss 
        all_pos_samples.append(pos_sample.cpu())
        all_neg_samples.append(neg_sample.cpu())

    if args.heuristic == True:
        pass
        #TODO augment pos_sample with pairs of large CN count

    return all_pos_samples, all_neg_samples

def main():
    parser = argparse.ArgumentParser(description='OGBL-DDI (GNN)')
    parser.add_argument('--heuristic', action='store_true', default=False)
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--encoder', type=str, default='sage')
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--link_batch_size', type=int, default=64*1024)
    parser.add_argument('--node_batch_size', type=int, default=64*1024)
    parser.add_argument('--lr', type=float, default=0.005)
    parser.add_argument('--epochs', type=int, default=20000)
    parser.add_argument('--eval_steps', type=int, default=5)
    parser.add_argument('--runs', type=int, default=10)
    parser.add_argument('--dataset_dir', type=str, default='../data')
    parser.add_argument('--datasets', type=str, default='collab')
    parser.add_argument('--predictor', type=str, default='mlp', choices=['inner','mlp'])
    parser.add_argument('--patience', type=int, default=100, help='number of patience steps for early stopping')
    parser.add_argument('--metric', type=str, default='Hits@20', choices=['auc', 'hits@20', 'hits@50'], help='main evaluation metric')
    parser.add_argument('--use_valedges_as_input', action='store_true')
    parser.add_argument('--True_label', default=0.1, type=float, help="true_label loss")
    parser.add_argument('--KD_RM', default=0, type=float, help="Representation-based matching KD") 
    parser.add_argument('--KD_LM', default=0, type=float, help="logit-based matching KD") 
    parser.add_argument('--LLP_D', default=1, type=float, help="distribution-based matching kd")
    parser.add_argument('--LLP_R', default=1, type=float, help="rank-based matching kd") 
    parser.add_argument('--margin', default=0.1, type=float, help="margin for rank-based kd") 
    parser.add_argument('--rw_step', type=int, default=3, help="nearby nodes sampled times")
    parser.add_argument('--ns_rate', type=int, default=1, help="randomly sampled rate over # nearby nodes") 
    parser.add_argument('--hops', type=int, default=2, help="random_walk step for each sampling time")
    parser.add_argument('--ps_method', type=str, default='nb', help="positive sampling is rw or nb")
    parser.add_argument('--transductive', type=str, default='transductive', choices=['transductive', 'production'])
    parser.add_argument('--minibatch', action='store_true')

    args = parser.parse_args()
    print(args)

    Logger_file = "../results/" + args.datasets + "_KD_" + args.transductive + ".txt"
    file = open(Logger_file, "a")
    file.write(str(args)+"\n")
    if args.KD_RM != 0:
        file.write("Logit-matching\n")
    elif args.KD_LM != 0:
        file.write("Representation-matching\n")
    elif args.LLP_D != 0 or args.LLP_R != 0:
        file.write("LLP (Relational Distillation)\n")
    file.close()

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)
    mini_batch_device = 'cpu'

    ### Prepare the datasets
    if args.transductive == "transductive":
        if args.datasets == 'igb':
            from igb.dataloader import IGB260M
            from torch_geometric.data import Data
            from os.path import exists
            igb_data = IGB260M('/local2/qzy_scai/IGB-Datasets/', size='small', in_memory=1, classes=19, synthetic=0)
            data = Data(x = torch.Tensor(igb_data.paper_feat.copy()),
                    edge_index = torch.Tensor(igb_data.paper_edge.copy()).T.long(),
                    num_nodes = igb_data.num_nodes(),)
#            data = T.ToSparseTensor()(data)

            dataset = [data]
            path_name = '/local2/qzy_scai/IGB-Datasets/small-split-edges.pkl'
            if exists(path_name):
                split_edge = torch.load(path_name)
            else:
                raise RuntimeError
                #print('split edge')
                #split_edge = do_edge_split(dataset, fast_split = True, strict_negative = True)
                #print('done')
                #torch.save(split_edge, path_name)
            all_pos_samples = []
            all_neg_samples = []

            data.adj_t = data.edge_index
            print(data)
            for i in range(1):
                pos_samples, neg_samples = get_samples(data, split_edge,
                                args, device)
                all_pos_samples += pos_samples
                all_neg_samples += neg_samples
            all_pos_samples = torch.cat(all_pos_samples, dim=0)
            all_neg_samples = torch.cat(all_neg_samples, dim=0)
            print(all_pos_samples.size(), all_neg_samples.size())
            torch.save((all_pos_samples, all_neg_samples), f'{args.datasets}_edges.pth') 


        elif args.datasets not in ["collab", 'citation2']:
            dataset = get_dataset(args.dataset_dir, args.datasets)
            data = dataset[0]

            if exists("../data/" + args.datasets + ".pkl"):
                split_edge = torch.load("../data/" + args.datasets + ".pkl")
            else:
                split_edge = do_edge_split(dataset)
                torch.save(split_edge, "../data/" + args.datasets + ".pkl")
            
            edge_index = split_edge['train']['edge'].t()
            data.adj_t = edge_index
            input_size = data.x.size()[1]
            args.metric = 'Hits@20'

#            data.edge_index = edge_index
#            train_g = to_networkx(data, to_undirected=True)
#            print('train graph')
#            tri = nx.triangles(train_g)
#            tri_count = sum([v for k,v in tri.items()])/3
#            print(tri_count)
#            print(len(list(nx.connected_components(train_g))))
#            return

            all_pos_samples = []
            all_neg_samples = []
            for i in range(1):
                pos_samples, neg_samples = get_samples(data, split_edge,
                                args, device)
                all_pos_samples += pos_samples
                all_neg_samples += neg_samples
            all_pos_samples = torch.cat(all_pos_samples, dim=0)
            all_neg_samples = torch.cat(all_neg_samples, dim=0)
            print(all_pos_samples.size(), all_neg_samples.size())
            torch.save((all_pos_samples, all_neg_samples), f'{args.datasets}_edges.pth') 


            return

        else: # collab, citation2
            dataset = PygLinkPropPredDataset(name=f'ogbl-{args.datasets}')
            data = dataset[0]
            edge_index = data.edge_index
            if hasattr(data, "edge_weight") and data.edge_weight is not None:
                data.edge_weight = data.edge_weight.view(-1).to(torch.float)
            data = T.ToSparseTensor()(data)
            print(data)

            split_edge = dataset.get_edge_split()

            if args.datasets == 'citation2':
    
                source_edge, target_edge = split_edge['train']['source_node'], split_edge['train']['target_node']
                split_edge['train']['edge'] = torch.cat([source_edge.unsqueeze(1), target_edge.unsqueeze(1)], dim=-1)


            input_size = data.num_features
            args.metric = 'Hits@50'
            data.adj_t = edge_index

            data.edge_index = edge_index
            train_g = to_networkx(data, to_undirected=True)
            print('train graph')
            tri = nx.triangles(train_g)
            tri_count = sum([v for k,v in tri.items()])/3
            print(tri_count)
            print(len(list(nx.connected_components(train_g))))
            return

            all_pos_samples = []
            all_neg_samples = []
            for i in range(1):
                pos_samples, neg_samples = get_samples(data, split_edge,
                                args, device)
                all_pos_samples += pos_samples
                all_neg_samples += neg_samples
            all_pos_samples = torch.cat(all_pos_samples, dim=0)
            all_neg_samples = torch.cat(all_neg_samples, dim=0)
            print(all_pos_samples.size(), all_neg_samples.size())
            torch.save((all_pos_samples, all_neg_samples), f'{args.datasets}_edges.pth') 
            return


    else:
        if exists("../data_production/" + args.datasets + "_production.pkl") :
            training_data, val_data, inference_data, data, test_edge_bundle, negative_samples = torch.load("../data_production/" + args.datasets + "_production.pkl")
            """
            train_g = to_networkx(training_data, to_undirected=True)
            g = to_networkx(data, to_undirected=True)
            print('train graph')
            tri = nx.triangles(train_g)
            tri_count = sum([v for k,v in tri.items()])/3
            print(tri_count)
            print(len(list(nx.connected_components(train_g))))
            print('all graph')
            tri = nx.triangles(g)
            tri_count = sum([v for k,v in tri.items()])/3
            print(tri_count)
            print(len(list(nx.connected_components(g))))
            return
            """

        else:
            print("splitting the datasets now...")
            dset = get_dataset('../data_production', args.datasets)
            ## testing edges ratio (0.3 for cora and citeseer, 0.1 for other datasets)
            test_ratio=0.3
            ## New nodes ratio (0.3 for cora and citeseer, 0.1 for other datasets)
            val_node_ratio=0.3
            ## validation/training splitting ratio (0.3 for cora and citeseer, 0.1 for other datasets)
            val_ratio=0.3 
            ## Splitting ratio for new old-old edges appearing for the inference(0.1 for all datasets)
            old_old_extra_ratio= 0.1 

            if args.datasets != "cora" and args.datasets != "citeseer":
                test_ratio = 0.1
                val_node_ratio = 0.1
                val_ratio = 0.1
                test_infer_ratio = test_ratio

            training_data, val_data, inference_data, _, test_edge_bundle, negative_samples = do_production_edge_split(dset, args.datasets, test_ratio, val_node_ratio, val_ratio, old_old_extra_ratio, 
                    test_infer_ratio = test_infer_ratio, max_degree_to_split = None)
            torch.save((training_data, val_data, inference_data, _, test_edge_bundle, negative_samples), "../data_production/" + args.datasets + "_production.pkl")

        input_size = training_data.x.size(1)

        if args.minibatch:
            training_data.to(mini_batch_device)
        else:
            training_data.to(device)
        val_data.to('cpu')
        inference_data.to('cpu')
        all_pos_samples = []
        all_neg_samples = []
        for i in range(1):
            pos_samples, neg_samples = get_samples(training_data, None,
                                args, device)
            all_pos_samples += pos_samples
            all_neg_samples += neg_samples
        all_pos_samples = torch.cat(all_pos_samples, dim=0)
        all_neg_samples = torch.cat(all_neg_samples, dim=0)
        print(all_pos_samples.size(), all_neg_samples.size())
        if args.transductive == 'production':
            if args.heuristic == False:
                torch.save((all_pos_samples, all_neg_samples), f'{args.datasets}_edges_production.pth') 
            else:
                torch.save((all_pos_samples, all_neg_samples), f'{args.datasets}_edges_production_heuristic.pth') 
        else:
            torch.save((all_pos_samples, all_neg_samples), f'{args.datasets}_edges.pth') 


main()
