from collections import defaultdict
import numpy as np
import torch
import torch.nn.functional as F
import scipy
import scipy.io
from sklearn.preprocessing import label_binarize
import torch_geometric.transforms as T

from data_utils import rand_train_test_idx, even_quantile_labels, to_sparse_tensor, dataset_drive_url, class_rand_splits

from torch_geometric.datasets import Planetoid, Amazon, Coauthor
from torch_geometric.transforms import NormalizeFeatures
from os import path

from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from torch_sparse import SparseTensor
#from google_drive_downloader import GoogleDriveDownloader as gdd

import networkx as nx
import scipy.sparse as sp

from ogb.nodeproppred import NodePropPredDataset, PygNodePropPredDataset
import os

from torch_geometric.utils import subgraph, k_hop_subgraph, to_undirected
import pickle as pkl

import math
import torch_geometric

from pathlib import Path

import pickle
import math


class CustomDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.file_list = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.pkl')]

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        with open(self.file_list[idx], 'rb') as file:
            sample = pickle.load(file)
        data = torch.tensor(sample['data'], dtype=torch.float32)
        data = torch.nan_to_num(data, nan=0.0)
        label = torch.tensor(sample['label'], dtype=torch.long)
        return data, label

def collate_fn(batch):
    # 分离数据和标签
    data, labels = zip(*batch)
    
    # 获取每个样本的序列长度
    lengths = torch.tensor([d.shape[0] for d in data])
    
    # 使用 pad_sequence 填充序列，保持二维形状 (N, 32)
    padded_data = pad_sequence(data, batch_first=True, padding_value=0)
    
    return padded_data, labels, lengths


class NCDataset(object):
    def __init__(self, name):
        """
        based off of ogb NodePropPredDataset
        https://github.com/snap-stanford/ogb/blob/master/ogb/nodeproppred/dataset.py
        Gives torch tensors instead of numpy arrays
            - name (str): name of the dataset
            - root (str): root directory to store the dataset folder
            - meta_dict: dictionary that stores all the meta-information about data. Default is None, 
                    but when something is passed, it uses its information. Useful for debugging for external contributers.
        
        Usage after construction: 
        
        split_idx = dataset.get_idx_split()
        train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
        graph, label = dataset[0]
        
        Where the graph is a dictionary of the following form: 
        dataset.graph = {'edge_index': edge_index,
                         'edge_feat': None,
                         'node_feat': node_feat,
                         'num_nodes': num_nodes}
        For additional documentation, see OGB Library-Agnostic Loader https://ogb.stanford.edu/docs/nodeprop/
        
        """

        self.name = name  # original name, e.g., ogbn-proteins
        self.graph = {}
        self.label = None

    def get_idx_split(self, split_type='random', train_prop=.5, valid_prop=.25, label_num_per_class=20):
        """
        train_prop: The proportion of dataset for train split. Between 0 and 1.
        valid_prop: The proportion of dataset for validation split. Between 0 and 1.
        """
        
        if split_type == 'random':
            ignore_negative = False if self.name == 'ogbn-proteins' else True
            train_idx, valid_idx, test_idx = rand_train_test_idx(
                self.label, train_prop=train_prop, valid_prop=valid_prop, ignore_negative=ignore_negative)
            split_idx = {'train': train_idx,
                         'valid': valid_idx,
                         'test': test_idx}
        elif split_type == 'class':
            train_idx, valid_idx, test_idx = class_rand_splits(self.label, label_num_per_class=label_num_per_class)
            split_idx = {'train': train_idx,
                         'valid': valid_idx,
                         'test': test_idx}
        return split_idx

    def __getitem__(self, idx):
        assert idx == 0, 'This dataset has only one graph'
        return self.graph, self.label

    def __len__(self):
        return 1

    def __repr__(self):  
        return '{}({})'.format(self.__class__.__name__, len(self))

    def __repr__(self):  
        return '{}({})'.format(self.__class__.__name__, len(self))


class BipartiteNodeData(torch_geometric.data.Data):
    """
    Class Description:
    This class encode a node bipartite graph observation as returned by the `ecole.observation.NodeBipartite`
    observation function in a format understood by the pytorch geometric data handlers.
    """
    def __init__(
        self,
        constraint_features,
        edge_indices,
        edge_features,
        variable_features
    ):
        super().__init__()
        self.constraint_features = constraint_features
        self.edge_index = edge_indices
        self.edge_attr = edge_features
        self.variable_features = variable_features

    def __inc__(self, key, value, store, *args, **kwargs):
        """
        Function Description:
        Overload the pytorch geometric method that tells how to increment indices when concatenating graphs for those entries (edge index, candidates) for which this is not obvious.
        """
        if key == "edge_index":
            return torch.tensor(
                [[self.constraint_features.size(0)], [self.variable_features.size(0)]]
            )
        elif key == "candidates":
            return self.variable_features.size(0)
        else:
            return super().__inc__(key, value, *args, **kwargs)


class GraphDataset(torch_geometric.data.Dataset):
    """
    Class Description:
    This class encodes a collection of graphs, as well as a method to load such graphs from the disk.
    It can be used in turn by the data loaders provided by pytorch geometric.
    """

    def __init__(self, sample_files):
        super().__init__(root=None, transform=None, pre_transform=None)
        self.sample_files = sample_files

    def len(self):
        return len(self.sample_files)

    def get(self, index):
        """
        Function Description:
        This method loads a node bipartite graph observation as saved on the disk during data collection.
        """
        with open(self.sample_files[index], "rb") as f:
            [variable_features, constraint_features, edge_indices, edge_features] = pickle.load(f)

        graph = BipartiteNodeData(
            torch.FloatTensor(constraint_features),
            torch.LongTensor(edge_indices),
            torch.FloatTensor(edge_features),
            torch.FloatTensor(variable_features)
        )

        # We must tell pytorch geometric how many nodes there are, for indexing purposes
        graph.num_nodes = len(constraint_features) + len(variable_features)
        graph.cons_nodes = len(constraint_features)
        graph.vars_nodes = len(variable_features)

        return graph


def load_dataset(data_dir, dataname, sub_dataname=''):
    """ Loader for NCDataset 
        Returns NCDataset 
    """
    print(dataname)
    if dataname == 'twitch-e':
        # twitch-explicit graph
        if sub_dataname not in ('DE', 'ENGB', 'ES', 'FR', 'PTBR', 'RU', 'TW'):
            print('Invalid sub_dataname, deferring to DE graph')
            sub_dataname = 'DE'
        dataset = load_twitch_dataset(data_dir, sub_dataname)
    elif dataname == 'fb100':
        if sub_dataname not in ('Penn94', 'Amherst41', 'Cornell5', 'Johns Hopkins55', 'Reed98'):
            print('Invalid sub_dataname, deferring to Penn94 graph')
            sub_dataname = 'Penn94'
        dataset = load_fb100_dataset(data_dir, sub_dataname)
    elif dataname == 'ogbn-proteins':
        dataset = load_proteins_dataset(data_dir)
    elif dataname == 'deezer-europe':
        dataset = load_deezer_dataset(data_dir)
    elif dataname == 'arxiv-year':
        dataset = load_arxiv_year_dataset(data_dir)
    elif dataname == 'pokec':
        dataset = load_pokec_mat(data_dir)
    elif dataname == 'snap-patents':
        dataset = load_snap_patents_mat(data_dir)
    elif dataname == 'yelp-chi':
        dataset = load_yelpchi_dataset(data_dir)
    elif dataname == 'amazon2m':
        dataset = load_amazon2m_dataset(data_dir)
    elif dataname in ('ogbn-arxiv', 'ogbn-products'):
        dataset = load_ogb_dataset(data_dir, dataname)
    elif dataname == 'ogbn-papers100M':
        dataset = load_papers100M(data_dir)
    elif dataname == 'ogbn-papers100M-sub':
        dataset = papers100M_sub(data_dir)
    elif dataname == 'ogbn-papers100M':
        dataset = load_papers100M(data_dir)
    elif dataname in  ('cora', 'citeseer', 'pubmed'):
        dataset = load_planetoid_dataset(data_dir, dataname)
    elif dataname in  ('amazon-photo', 'amazon-computer'):
        dataset = load_amazon_dataset(data_dir, dataname)
    elif dataname in  ('coauthor-cs', 'coauthor-physics'):
        dataset = load_coauthor_dataset(data_dir, dataname)
    elif dataname in ('chameleon', 'cornell', 'film', 'squirrel', 'texas', 'wisconsin'):
        dataset = load_geom_gcn_dataset(data_dir, dataname)
    else:
        raise ValueError('Invalid dataname')
    return dataset


def load_dataset_bipartite(graph_dir, batch_size):
    
    sample_files = [str(path) for path in Path(graph_dir).glob("*.pickle")]
    print(len(sample_files))
    
    sample_data = GraphDataset(sample_files)
    dataloader = torch_geometric.loader.DataLoader(sample_data, batch_size=batch_size, shuffle = False)

    # label_list = []
    # for batch in dataloader:
        
    #     # subgraphs = split_bipartite_graph(batch.variable_features, batch.constraint_features, batch.edge_index, batch.edge_attr, 50)
    #     # constraint_indices, variable_indices, batch_edge_indices = subgraphs[2]
    #     # print(constraint_indices)
    #     # print(variable_indices)
    #     # print(batch_edge_indices)
    #     # exit()
        
    #     size = (len(batch.constraint_features), len(batch.variable_features))
    #     value = torch.tensor([1 for i in range(len(batch.edge_index[0]))])
    #     batch_label = torch.sparse_coo_tensor(batch.edge_index, value, size)
    #     batch_label = batch_label.coalesce()
    #     label_list.append(batch_label)
        
    return dataloader


def split_bipartite_graph(variable_features, constraint_features, edge_indices, edge_features, maximum_batch_size, device):

    variable_num = variable_features.shape[0]
    constraint_num = constraint_features.shape[0]
    
    ratio = variable_num / constraint_num

    # 计算每个部分的大小
    part_size_b = math.floor((maximum_batch_size * ratio) / (1 + ratio))
    maximum_parts_a = maximum_batch_size - part_size_b
    num_parts_b_graphs = variable_num // part_size_b + 1

    parts_b = torch.randperm(variable_num).chunk(num_parts_b_graphs)
    
    # 创建子图列表
    subgraphs = []

    for part_b in parts_b:
        
        variable_indices = part_b
        part_b = part_b.to(device)
        variable_indices = variable_indices.to(device)
        
        neighbor_mask = torch.isin(edge_indices[1], part_b)
        constraint_neighbors = edge_indices[0][neighbor_mask]
        
        if len(constraint_neighbors) > maximum_parts_a:
            constraint_indices = torch.randperm(len(constraint_neighbors))[:maximum_parts_a]
            constraint_indices = constraint_neighbors[constraint_indices]
        else:
            constraint_indices = constraint_neighbors
        
        constraint_indices = constraint_indices.to(device)
        
        # 保留与子集相关的边
        mask_a = torch.isin(edge_indices[0], constraint_indices)
        mask_b = torch.isin(edge_indices[1], variable_indices)
        edge_mask = mask_a & mask_b
        
        # batch_edge_indices = edge_indices[:,edge_mask]

        # 添加子图到列表
        
        subgraphs.append((constraint_indices, variable_indices, edge_mask))

    return subgraphs


def load_twitch_dataset(data_dir, lang):
    assert lang in ('DE', 'ENGB', 'ES', 'FR', 'PTBR', 'RU', 'TW'), 'Invalid dataset'

    filepath = data_dir + f"twitch/{lang}"
    label = []
    node_ids = []
    src = []
    targ = []
    uniq_ids = set()
    with open(f"{filepath}/musae_{lang}_target.csv", 'r') as f:
        reader = csv.reader(f)
        next(reader)
        for row in reader:
            node_id = int(row[5])
            # handle FR case of non-unique rows
            if node_id not in uniq_ids:
                uniq_ids.add(node_id)
                label.append(int(row[2] == "True"))
                node_ids.append(int(row[5]))

    node_ids = np.array(node_ids, dtype=np.int)
    with open(f"{filepath}/musae_{lang}_edges.csv", 'r') as f:
        reader = csv.reader(f)
        next(reader)
        for row in reader:
            src.append(int(row[0]))
            targ.append(int(row[1]))
    with open(f"{filepath}/musae_{lang}_features.json", 'r') as f:
        j = json.load(f)
    src = np.array(src)
    targ = np.array(targ)
    label = np.array(label)
    inv_node_ids = {node_id: idx for (idx, node_id) in enumerate(node_ids)}
    reorder_node_ids = np.zeros_like(node_ids)
    for i in range(label.shape[0]):
        reorder_node_ids[i] = inv_node_ids[i]

    n = label.shape[0]
    A = scipy.sparse.csr_matrix((np.ones(len(src)),
                                 (np.array(src), np.array(targ))),
                                shape=(n, n))
    features = np.zeros((n, 3170))
    for node, feats in j.items():
        if int(node) >= n:
            continue
        features[int(node), np.array(feats, dtype=int)] = 1
    # features = features[:, np.sum(features, axis=0) != 0] # remove zero cols. not need for cross graph task
    new_label = label[reorder_node_ids]
    label = new_label

    dataset = NCDataset(lang)
    edge_index = torch.tensor(A.nonzero(), dtype=torch.long)
    node_feat = torch.tensor(features, dtype=torch.float)
    num_nodes = node_feat.shape[0]
    dataset.graph = {'edge_index': edge_index,
                     'edge_feat': None,
                     'node_feat': node_feat,
                     'num_nodes': num_nodes}
    dataset.label = torch.tensor(label)
    return dataset

def load_fb100_dataset(data_dir, filename):
    feature_vals_all = np.empty((0, 6))
    for f in ['Penn94', 'Amherst41', 'Cornell5', 'Johns Hopkins55', 'Reed98']:
        mat = scipy.io.loadmat(data_dir + 'facebook100/' + f + '.mat')
        A = mat['A']
        metadata = mat['local_info']
        metadata = metadata.astype(np.int)
        feature_vals = np.hstack(
            (np.expand_dims(metadata[:, 0], 1), metadata[:, 2:]))
        feature_vals_all = np.vstack(
            (feature_vals_all, feature_vals)
        )

    mat = scipy.io.loadmat(data_dir + 'facebook100/' + filename + '.mat')
    A = mat['A']
    metadata = mat['local_info']
    dataset = NCDataset(filename)
    edge_index = torch.tensor(A.nonzero(), dtype=torch.long)
    metadata = metadata.astype(np.int)
    label = metadata[:, 1] - 1  # gender label, -1 means unlabeled

    # make features into one-hot encodings
    feature_vals = np.hstack(
        (np.expand_dims(metadata[:, 0], 1), metadata[:, 2:]))
    features = np.empty((A.shape[0], 0))
    for col in range(feature_vals.shape[1]):
        feat_col = feature_vals[:, col]
        # feat_onehot = label_binarize(feat_col, classes=np.unique(feat_col))
        feat_onehot = label_binarize(feat_col, classes=np.unique(feature_vals_all[:, col]))
        features = np.hstack((features, feat_onehot))

    node_feat = torch.tensor(features, dtype=torch.float)
    num_nodes = metadata.shape[0]
    dataset.graph = {'edge_index': edge_index,
                     'edge_feat': None,
                     'node_feat': node_feat,
                     'num_nodes': num_nodes}
    dataset.label = torch.tensor(label)
    dataset.label = torch.where(dataset.label > 0, 1, 0)
    return dataset

def load_deezer_dataset(data_dir):
    
    filename = 'deezer-europe'
    dataset = NCDataset(filename)
    deezer = scipy.io.loadmat(f'{data_dir}deezer/deezer-europe.mat')

    A, label, features = deezer['A'], deezer['label'], deezer['features']
    edge_index = torch.tensor(A.nonzero(), dtype=torch.long)
    node_feat = torch.tensor(features.todense(), dtype=torch.float)
    label = torch.tensor(label, dtype=torch.long).squeeze()
    num_nodes = label.shape[0]

    dataset.graph = {'edge_index': edge_index,
                     'edge_feat': None,
                     'node_feat': node_feat,
                     'num_nodes': num_nodes}
    dataset.label = label
    return dataset


def load_arxiv_year_dataset(data_dir, nclass=5):
    filename = 'arxiv-year'
    dataset = NCDataset(filename)
    ogb_dataset = NodePropPredDataset(name='ogbn-arxiv', root=f'{data_dir}ogb')
    dataset.graph = ogb_dataset.graph
    dataset.graph['edge_index'] = torch.as_tensor(dataset.graph['edge_index'])
    dataset.graph['node_feat'] = torch.as_tensor(dataset.graph['node_feat'])

    label = even_quantile_labels(
        dataset.graph['node_year'].flatten(), nclass, verbose=False)
    dataset.label = torch.as_tensor(label).reshape(-1, 1)
    return dataset

def load_amazon2m_dataset(data_dir):
    ogb_dataset = NodePropPredDataset(name='ogbn-products', root=f'{data_dir}/ogb')
    dataset = NCDataset('amazon2m')
    dataset.graph = ogb_dataset.graph
    dataset.graph['edge_index'] = torch.as_tensor(dataset.graph['edge_index'])
    dataset.graph['node_feat'] = torch.as_tensor(dataset.graph['node_feat'])
    dataset.label = torch.as_tensor(ogb_dataset.labels).reshape(-1, 1)

    def load_fixed_splits(train_prop=0.5, val_prop=0.25):
        dir = f'{data_dir}ogb/ogbn_products/split/random_0.5_0.25'
        tensor_split_idx = {}
        if os.path.exists(dir):
            tensor_split_idx['train'] = torch.as_tensor(np.loadtxt(dir + '/amazon2m_train.txt'), dtype=torch.long)
            tensor_split_idx['valid'] = torch.as_tensor(np.loadtxt(dir + '/amazon2m_valid.txt'), dtype=torch.long)
            tensor_split_idx['test'] = torch.as_tensor(np.loadtxt(dir + '/amazon2m_test.txt'), dtype=torch.long)
        else:
            os.makedirs(dir)
            tensor_split_idx['train'], tensor_split_idx['valid'], tensor_split_idx['test'] \
                = rand_train_test_idx(dataset.label, train_prop=train_prop, valid_prop=val_prop)
            np.savetxt(dir + '/amazon2m_train.txt', tensor_split_idx['train'], fmt='%d')
            np.savetxt(dir + '/amazon2m_valid.txt', tensor_split_idx['valid'], fmt='%d')
            np.savetxt(dir + '/amazon2m_test.txt', tensor_split_idx['test'], fmt='%d')
        return tensor_split_idx
    dataset.load_fixed_splits = load_fixed_splits
    return dataset

def load_papers100M(data_dir):
    ogb_dataset = PygNodePropPredDataset('ogbn-papers100M',root=data_dir)
    ogb_data=ogb_dataset[0]
    dataset = NCDataset('ogbn-papers100M')
    dataset.graph = dict()
    dataset.graph['edge_index'] = torch.as_tensor(ogb_data.edge_index)
    dataset.graph['node_feat'] = torch.as_tensor(ogb_data.x)
    dataset.graph['num_nodes'] = ogb_data.num_nodes
    
    # Use mapped train, valid and test index, same as OGB.
    split_idx = ogb_dataset.get_idx_split()
    train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']

    # dataset.label = torch.as_tensor(ogb_data.y.data[all_idx], dtype=int).reshape(-1, 1)
    dataset.label = torch.as_tensor(ogb_data.y.data, dtype=int).reshape(-1, 1) # 99% labels are nan, not available
    # print(f'f1:{dataset.label.shape}')
    # print(f'{ogb_data.num_nodes}')
    def get_idx_split():
        split_idx = {
            'train': train_idx,
            'valid': valid_idx,
            'test': test_idx,
        }
        return split_idx

    dataset.load_fixed_splits = get_idx_split

    return dataset


def load_proteins_dataset(data_dir):
    ogb_dataset = NodePropPredDataset(name='ogbn-proteins', root=f'{data_dir}ogb')
    dataset = NCDataset('ogbn-proteins')
    def protein_orig_split(**kwargs):
        split_idx = ogb_dataset.get_idx_split()
        return {'train': torch.as_tensor(split_idx['train']),
                'valid': torch.as_tensor(split_idx['valid']),
                'test': torch.as_tensor(split_idx['test'])}
    dataset.load_fixed_splits = protein_orig_split
    dataset.graph, dataset.label = ogb_dataset.graph, ogb_dataset.labels

    dataset.graph['edge_index'] = torch.as_tensor(dataset.graph['edge_index'])
    dataset.graph['edge_feat'] = torch.as_tensor(dataset.graph['edge_feat'])
    dataset.label = torch.as_tensor(dataset.label)

    edge_index_ = to_sparse_tensor(dataset.graph['edge_index'],
                                   dataset.graph['edge_feat'], dataset.graph['num_nodes'])
    dataset.graph['node_feat'] = edge_index_.mean(dim=1)
    dataset.graph['edge_feat'] = None

    return dataset


def load_ogb_dataset(data_dir, name):
    dataset = NCDataset(name)
    ogb_dataset = NodePropPredDataset(name=name, root=f'{data_dir}ogb')
    dataset.graph = ogb_dataset.graph
    dataset.graph['edge_index'] = torch.as_tensor(dataset.graph['edge_index'])
    dataset.graph['node_feat'] = torch.as_tensor(dataset.graph['node_feat'])

    def ogb_idx_to_tensor():
        split_idx = ogb_dataset.get_idx_split()
        tensor_split_idx = {key: torch.as_tensor(
            split_idx[key]) for key in split_idx}
        return tensor_split_idx
    dataset.load_fixed_splits = ogb_idx_to_tensor  # ogb_dataset.get_idx_split
    dataset.label = torch.as_tensor(ogb_dataset.labels).reshape(-1, 1)
    return dataset


def load_pokec_mat(data_dir):
    """ requires pokec.mat """
    if not path.exists(f'{data_dir}/pokec/pokec.mat'):
        gdd.download_file_from_google_drive(
            file_id= dataset_drive_url['pokec'], \
            dest_path=f'{data_dir}/pokec/pokec.mat', showsize=True)

    try:
        fulldata = scipy.io.loadmat(f'{data_dir}/pokec/pokec.mat')
        edge_index = fulldata['edge_index']
        node_feat = fulldata['node_feat']
        label = fulldata['label']
    except:
        edge_index = np.load(f'{data_dir}/pokec/edge_index.npy')
        node_feat = np.load(f'{data_dir}/pokec/node_feat.npy')
        label = np.load(f'{data_dir}/pokec/label.npy')

    dataset = NCDataset('pokec')
    print(edge_index)
    print(len(edge_index[0]), len(edge_index[1]))
    edge_index = torch.tensor(edge_index, dtype=torch.long)
    print(node_feat.shape)
    node_feat = torch.tensor(node_feat).float()
    num_nodes = int(node_feat.shape[0])
    dataset.graph = {'edge_index': edge_index,
                     'edge_feat': None,
                     'node_feat': node_feat,
                     'num_nodes': num_nodes}

    label = torch.tensor(label).flatten()
    dataset.label = torch.tensor(label, dtype=torch.long)

    def load_fixed_splits(train_prop=0.5, val_prop=0.25):
        dir = f'{data_dir}pokec/split_0.5_0.25'
        tensor_split_idx = {}
        if os.path.exists(dir):
            tensor_split_idx['train'] = torch.as_tensor(np.loadtxt(dir + '/pokec_train.txt'), dtype=torch.long)
            tensor_split_idx['valid'] = torch.as_tensor(np.loadtxt(dir + '/pokec_valid.txt'), dtype=torch.long)
            tensor_split_idx['test'] = torch.as_tensor(np.loadtxt(dir + '/pokec_test.txt'), dtype=torch.long)
        else:
            os.makedirs(dir)
            tensor_split_idx['train'], tensor_split_idx['valid'], tensor_split_idx['test'] \
                = rand_train_test_idx(dataset.label, train_prop=train_prop, valid_prop=val_prop)
            np.savetxt(dir + '/pokec_train.txt', tensor_split_idx['train'], fmt='%d')
            np.savetxt(dir + '/pokec_valid.txt', tensor_split_idx['valid'], fmt='%d')
            np.savetxt(dir + '/pokec_test.txt', tensor_split_idx['test'], fmt='%d')
        return tensor_split_idx

    dataset.load_fixed_splits = load_fixed_splits
    return dataset

def load_snap_patents_mat(data_dir, nclass=5):
    if not path.exists(f'{data_dir}snap_patents.mat'):
        p = dataset_drive_url['snap-patents']
        print(f"Snap patents url: {p}")
        gdd.download_file_from_google_drive(
            file_id= dataset_drive_url['snap-patents'], \
            dest_path=f'{data_dir}snap_patents.mat', showsize=True)

    fulldata = scipy.io.loadmat(f'{data_dir}snap_patents.mat')

    dataset = NCDataset('snap_patents')
    edge_index = torch.tensor(fulldata['edge_index'], dtype=torch.long)
    node_feat = torch.tensor(
        fulldata['node_feat'].todense(), dtype=torch.float)
    num_nodes = int(fulldata['num_nodes'])
    dataset.graph = {'edge_index': edge_index,
                     'edge_feat': None,
                     'node_feat': node_feat,
                     'num_nodes': num_nodes}

    years = fulldata['years'].flatten()
    label = even_quantile_labels(years, nclass, verbose=False)
    dataset.label = torch.tensor(label, dtype=torch.long)

    return dataset


def load_yelpchi_dataset(data_dir):
    if not path.exists(f'{data_dir}YelpChi.mat'):
            gdd.download_file_from_google_drive(
                file_id= dataset_drive_url['yelp-chi'], \
                dest_path=f'{data_dir}YelpChi.mat', showsize=True)
    fulldata = scipy.io.loadmat(f'{data_dir}YelpChi.mat')
    A = fulldata['homo']
    edge_index = np.array(A.nonzero())
    node_feat = fulldata['features']
    label = np.array(fulldata['label'], dtype=np.int).flatten()
    num_nodes = node_feat.shape[0]

    dataset = NCDataset('YelpChi')
    edge_index = torch.tensor(edge_index, dtype=torch.long)
    node_feat = torch.tensor(node_feat.todense(), dtype=torch.float)
    dataset.graph = {'edge_index': edge_index,
                     'node_feat': node_feat,
                     'edge_feat': None,
                     'num_nodes': num_nodes}
    label = torch.tensor(label, dtype=torch.long)
    dataset.label = label
    return dataset


def load_planetoid_dataset(data_dir, name):
    transform = T.NormalizeFeatures()
    torch_dataset = Planetoid(root=f'{data_dir}Planetoid',
                             name=name, transform=transform)
    # torch_dataset = Planetoid(root=f'{DATAPATH}Planetoid', name=name)
    data = torch_dataset[0]

    edge_index = data.edge_index
    node_feat = data.x
    label = data.y
    num_nodes = data.num_nodes

    dataset = NCDataset(name)

    dataset.train_idx = torch.where(data.train_mask)[0]
    dataset.valid_idx = torch.where(data.val_mask)[0]
    dataset.test_idx = torch.where(data.test_mask)[0]

    dataset.graph = {'edge_index': edge_index,
                     'node_feat': node_feat,
                     'edge_feat': None,
                     'num_nodes': num_nodes}
    dataset.label = label

    return dataset

def load_amazon_dataset(data_dir, name):
    transform = T.NormalizeFeatures()
    if name == 'amazon-photo':
        torch_dataset = Amazon(root=f'{data_dir}Amazon',
                                 name='Photo', transform=transform)
    elif name == 'amazon-computer':
        torch_dataset = Amazon(root=f'{data_dir}Amazon',
                                 name='Computers', transform=transform)
    # torch_dataset = Planetoid(root=f'{DATAPATH}Planetoid', name=name)
    data = torch_dataset[0]

    edge_index = data.edge_index
    node_feat = data.x
    label = data.y
    num_nodes = data.num_nodes

    dataset = NCDataset(name)

    dataset.graph = {'edge_index': edge_index,
                     'node_feat': node_feat,
                     'edge_feat': None,
                     'num_nodes': num_nodes}
    dataset.label = label

    return dataset

def load_coauthor_dataset(data_dir, name):
    transform = T.NormalizeFeatures()
    if name == 'coauthor-cs':
        torch_dataset = Coauthor(root=f'{data_dir}Coauthor',
                                 name='CS', transform=transform)
    elif name == 'coauthor-physics':
        torch_dataset = Coauthor(root=f'{data_dir}Coauthor',
                                 name='Physics', transform=transform)
    # torch_dataset = Planetoid(root=f'{DATAPATH}Planetoid', name=name)
    data = torch_dataset[0]

    edge_index = data.edge_index
    node_feat = data.x
    label = data.y
    num_nodes = data.num_nodes

    dataset = NCDataset(name)

    dataset.graph = {'edge_index': edge_index,
                     'node_feat': node_feat,
                     'edge_feat': None,
                     'num_nodes': num_nodes}
    dataset.label = label

    return dataset

def load_geom_gcn_dataset(data_dir, name):
    graph_adjacency_list_file_path = f'{data_dir}geom-gcn/{name}/out1_graph_edges.txt'
    graph_node_features_and_labels_file_path = f'{data_dir}geom-gcn/{name}/out1_node_feature_label.txt'

    G = nx.DiGraph()
    graph_node_features_dict = {}
    graph_labels_dict = {}

    if name == 'film':
        with open(graph_node_features_and_labels_file_path) as graph_node_features_and_labels_file:
            graph_node_features_and_labels_file.readline()
            for line in graph_node_features_and_labels_file:
                line = line.rstrip().split('\t')
                assert (len(line) == 3)
                assert (int(line[0]) not in graph_node_features_dict and int(line[0]) not in graph_labels_dict)
                feature_blank = np.zeros(932, dtype=np.uint8)
                feature_blank[np.array(line[1].split(','), dtype=np.uint16)] = 1
                graph_node_features_dict[int(line[0])] = feature_blank
                graph_labels_dict[int(line[0])] = int(line[2])
    else:
        with open(graph_node_features_and_labels_file_path) as graph_node_features_and_labels_file:
            graph_node_features_and_labels_file.readline()
            for line in graph_node_features_and_labels_file:
                line = line.rstrip().split('\t')
                assert (len(line) == 3)
                assert (int(line[0]) not in graph_node_features_dict and int(line[0]) not in graph_labels_dict)
                graph_node_features_dict[int(line[0])] = np.array(line[1].split(','), dtype=np.uint8)
                graph_labels_dict[int(line[0])] = int(line[2])

    with open(graph_adjacency_list_file_path) as graph_adjacency_list_file:
        graph_adjacency_list_file.readline()
        for line in graph_adjacency_list_file:
            line = line.rstrip().split('\t')
            assert (len(line) == 2)
            if int(line[0]) not in G:
                G.add_node(int(line[0]), features=graph_node_features_dict[int(line[0])],
                           label=graph_labels_dict[int(line[0])])
            if int(line[1]) not in G:
                G.add_node(int(line[1]), features=graph_node_features_dict[int(line[1])],
                           label=graph_labels_dict[int(line[1])])
            G.add_edge(int(line[0]), int(line[1]))

    adj = nx.adjacency_matrix(G, sorted(G.nodes()))
    adj = sp.coo_matrix(adj)
    adj = adj + sp.eye(adj.shape[0])
    adj = adj.tocoo().astype(np.float32)
    features = np.array(
        [features for _, features in sorted(G.nodes(data='features'), key=lambda x: x[0])])
    labels = np.array(
        [label for _, label in sorted(G.nodes(data='label'), key=lambda x: x[0])])
    print(features.shape)

    def preprocess_features(feat):
        """Row-normalize feature matrix and convert to tuple representation"""
        rowsum = np.array(feat.sum(1))
        rowsum = (rowsum == 0) * 1 + rowsum
        r_inv = np.power(rowsum, -1).flatten()
        r_inv[np.isinf(r_inv)] = 0.
        r_mat_inv = sp.diags(r_inv)
        feat = r_mat_inv.dot(feat)
        return feat
    features = preprocess_features(features)

    edge_index = torch.from_numpy(
        np.vstack((adj.row, adj.col)).astype(np.int64))
    node_feat = torch.FloatTensor(features)
    labels = torch.LongTensor(labels)
    num_nodes = node_feat.shape[0]
    print(f"Num nodes: {num_nodes}")

    dataset = NCDataset(name)

    dataset.graph = {'edge_index': edge_index,
                     'node_feat': node_feat,
                     'edge_feat': None,
                     'num_nodes': num_nodes}
    dataset.label = labels

    return dataset

def papers100M_sub(data_dir):
    data_path=os.path.join(data_dir, 'ogbn_papers100M', 'subgraph.pt')

    num_nodes=1000000
    ogb_dataset = PygNodePropPredDataset('ogbn-papers100M',root=data_dir)
    ogb_data=ogb_dataset[0]
    edge_index = torch.as_tensor(ogb_data.edge_index)
    node_feat = torch.as_tensor(ogb_data.x)
    total_nodes=ogb_data.num_nodes
    node_labels=torch.as_tensor(ogb_data.y.data, dtype=int).reshape(-1, 1)

    split_idx = ogb_dataset.get_idx_split()
    train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']

    # print(f'train:{train_idx.min()}, {train_idx.max()}, {train_idx.shape}')
    # print(f'valid:{valid_idx.min()}, {valid_idx.max()}, {valid_idx.shape}')
    # print(f'test:{test_idx.min()}, {test_idx.max()}, {test_idx.shape}')

    train_idx_i=train_idx[train_idx<num_nodes]
    valid_idx_i=valid_idx[valid_idx<num_nodes]
    test_idx_i=test_idx[test_idx<num_nodes]
    split_idx=torch.cat([train_idx_i, valid_idx_i, test_idx_i])
    split_len=split_idx.shape[0]
    train_num=int(split_len*0.7)
    valid_num=int(split_len*0.1)
    train_idx_i=split_idx[:train_num]
    valid_idx_i=split_idx[train_num:train_num+valid_num]
    test_idx_i=split_idx[train_num+valid_num:]

    idx_i=torch.arange(num_nodes)
    
    if os.path.exists(data_path):
        edge_index_i=torch.load(data_path)
        # f=open(data_path,'rb')
        # edge_index_i=pkl.load(f)
    else:
        edge_index_i, _ = subgraph(idx_i, edge_index, num_nodes=total_nodes, relabel_nodes=False)

    x_i=node_feat[:num_nodes]
    y_i=node_labels[:num_nodes]

    print(f'train new: {train_idx_i.shape}')
    print(f'valid new: {valid_idx_i.shape}')
    print(f'test new: {test_idx_i.shape}')

    dataset = NCDataset('ogbn-papers100M')
    dataset.graph = dict()
    dataset.graph['edge_index']=edge_index_i 
    dataset.graph['node_feat'] = x_i
    dataset.graph['num_nodes'] = num_nodes
    dataset.label = y_i

    def get_idx_split():
        split_idx = {
            'train': train_idx_i,
            'valid': valid_idx_i,
            'test': test_idx_i,
        }
        return split_idx

    dataset.load_fixed_splits = get_idx_split

    folder=os.path.join(data_dir, 'ogbn_papers100M')
    if not os.path.exists(folder):
        os.mkdir(folder)
    if not os.path.exists(data_path):
        torch.save(edge_index_i,data_path)
    # f=open(data_path,'wb')
    # pkl.dump(edge_index_i,f)

    return dataset



if __name__=='__main__':
    # load_papers100M('../data')
    # papers100M_sub('../../data')
    dataset = load_dataset_bipartite("Graphs/", 3)