import torch
import dgl
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from dgl.data import AmazonCoBuyPhotoDataset, AmazonCoBuyComputerDataset
from dgl.data import CoauthorCSDataset, CoauthorPhysicsDataset
from dgl.data import PPIDataset,RedditDataset
from ogb.nodeproppred import DglNodePropPredDataset

import scipy.sparse as sp
import numpy as np
import networkx as nx
import sys
import json
import os
from networkx.readwrite import json_graph
import random

def set_seed(seed = 1024):
    # torch.manual_seed(seed)
    # torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    # random.seed(seed)
    # torch.backends.cudnn.deterministic = True

def to_bidirected(graph):
    num_nodes = graph.num_nodes()
    
    graph = graph.remove_self_loop()
    src, dst = graph.edges()
    
    new_src = torch.cat([src, dst])
    new_dst = torch.cat([dst, src])
    
    new_graph = dgl.graph((new_src, new_dst), num_nodes = num_nodes)
    
    return new_graph

def to_inductive(graph, test_idx):
    graph = graph.remove_self_loop()
    src, dst = graph.edges()

    num_nodes = graph.num_nodes()
    num_edges = graph.num_edges()
    src = src.numpy()
    dst = dst.numpy()
    keep_idx = []
    test_idx = set(test_idx.numpy())
    for idx in range(num_edges):
        if (src[idx] not in test_idx) and (dst[idx] not in test_idx):
            keep_idx.append(idx)  
    new_src = torch.tensor(src[keep_idx])
    new_dst = torch.tensor(dst[keep_idx])

    new_graph = dgl.graph((new_src, new_dst), num_nodes = num_nodes)
    
    return new_graph
    

def load_dataset(name, mode = 'transductive'):
    if not mode == 'transductive':
        ind  = True
    else:
        ind = False

    if name in ['cora', 'citeseer', 'pubmed']:
        dataset = load_planetoid_dataset(name)
    elif name in ['computer', 'photo']:
        dataset = load_amazon_dataset(name, ind = ind)
    elif name in ['cs', 'physics']:
        dataset = load_coauthor_dataset(name, ind = ind)

    graph, feat, label, num_class, train_idx, val_idx, test_idx = dataset
    
    graph = graph.remove_self_loop()

    if not mode == 'transductive':
        graph = to_inductive(graph, test_idx)
        dataset = (graph, feat, label, num_class, train_idx, val_idx, test_idx)
    
    # print(dataset[0])
    return dataset


def load_planetoid_dataset(name):
    if name == 'cora':
        dataset = CoraGraphDataset()
    elif name == 'citeseer':
        dataset = CiteseerGraphDataset()
    elif name == 'pubmed':
        dataset = PubmedGraphDataset()
        
    graph = dataset[0]
    train_mask = graph.ndata.pop('train_mask')
    val_mask = graph.ndata.pop('val_mask')
    test_mask = graph.ndata.pop('test_mask')

    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
    val_idx = torch.nonzero(val_mask, as_tuple=False).squeeze()
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()

    num_class = dataset.num_classes
    feat = graph.ndata.pop('feat')
    label = graph.ndata.pop('label')
    
    return (graph, feat, label, num_class, train_idx, val_idx, test_idx)

def load_amazon_dataset(name, train_ratio = 0.8, val_ratio = 0.1, test_ratio = 0.1, ind = False):
    if ind:
        set_seed()
    if name == 'computer':
        dataset = AmazonCoBuyComputerDataset()
    elif name == 'photo':
        dataset = AmazonCoBuyPhotoDataset()
        
    graph = dataset[0]
 
    N = graph.number_of_nodes()
    train_num = int(N * train_ratio)
    val_num = int(N * (train_ratio + val_ratio))

    idx = np.arange(N)
    np.random.shuffle(idx)

    train_idx = idx[:train_num]
    val_idx = idx[train_num:val_num]
    test_idx = idx[val_num:]

    train_idx = torch.tensor(train_idx)
    val_idx = torch.tensor(val_idx)
    test_idx = torch.tensor(test_idx)

    num_class = dataset.num_classes
    feat = graph.ndata.pop('feat')
    label = graph.ndata.pop('label')

    return (graph, feat, label, num_class, train_idx, val_idx, test_idx)

def load_coauthor_dataset(name, train_ratio = 0.8, val_ratio = 0.1, test_ratio = 0.1, ind = False):
    if ind:
        set_seed()
    if name == 'cs':
        dataset = CoauthorCSDataset()
    elif name == 'physics':
        dataset = CoauthorPhysicsDataset()
        
    graph = dataset[0]

    N = graph.number_of_nodes()
    train_num = int(N * train_ratio)
    val_num = int(N * (train_ratio + val_ratio))

    idx = np.arange(N)
    np.random.shuffle(idx)

    train_idx = idx[:train_num]
    val_idx = idx[train_num:val_num]
    test_idx = idx[val_num:]

    train_idx = torch.tensor(train_idx)
    val_idx = torch.tensor(val_idx)
    test_idx = torch.tensor(test_idx)

    num_class = dataset.num_classes
    feat = graph.ndata.pop('feat')
    label = graph.ndata.pop('label')

    return (graph, feat, label, num_class, train_idx, val_idx, test_idx)
    