import os.path as osp
import torch
from torch_geometric.datasets import Planetoid, Coauthor, CoraFull, Reddit, Amazon, WikipediaNetwork
import scipy.io as scio
import numpy as np
import sys
import pandas as pd
import pickle as pkl
import networkx as nx
import scipy.sparse as sp
import pickle
import itertools
from ogb.nodeproppred import PygNodePropPredDataset
import random
from sklearn.neighbors import NearestNeighbors
import torch_geometric
import os

from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.utils import from_scipy_sparse_matrix
import torch
from torch_geometric.data import Data
import torch.serialization
from torch_geometric.data.data import DataEdgeAttr, DataTensorAttr
from torch_geometric.datasets import Coauthor, WikiCS


def load_data2(args):
    """
    parses the dataset
    """
    dataset = args.dataset
    splits = args.splits
    device = torch.device(args.device)
    path = osp.abspath(__file__)  #当前文件绝对路径
    d_path = osp.dirname(path)  #当前文件所在目录
    # f_path = osp.dirname(d_path)         #当前文件所在目录的父目录
    f_path = osp.join(d_path, ('data2'))

    d_path_dict = {
        'ca_cora': osp.join(osp.join(f_path, ('coauthorship')), 'cora'),
        'ca_dblp': osp.join(osp.join(f_path, ('coauthorship')), 'dblp'),
        'cc_cora': osp.join(osp.join(f_path, ('cocitation')), 'cora'),
        'cc_citeseer': osp.join(osp.join(f_path, ('cocitation')), 'citeseer'),
        'ca_pubmed': osp.join(osp.join(f_path, ('cocitation')), 'pubmed')
    }

    pickle_file = osp.join(d_path_dict[dataset], "splits", str(splits) + ".pickle")

    with open(osp.join(d_path_dict[dataset], 'features.pickle'), 'rb') as handle:
        features = pickle.load(handle).todense()

    with open(osp.join(d_path_dict[dataset], 'labels.pickle'), 'rb') as handle:
        labels = pickle.load(handle)

    with open(pickle_file, 'rb') as H:
        Splits = pickle.load(H)
        train, test = Splits['train'], Splits['test']

    with open(osp.join(d_path_dict[dataset], 'hypergraph.pickle'), 'rb') as handle:
        hypergraph = pickle.load(handle)

    tmp_edge_index = []
    for key in hypergraph.keys():
        ms = hypergraph[key]
        tmp_edge_index.extend(list(itertools.permutations(ms, 2)))

    edge_s = [x[0] for x in tmp_edge_index]
    edge_e = [x[1] for x in tmp_edge_index]

    edge_index = torch.LongTensor([edge_s, edge_e])

    features = torch.Tensor(features).to(device)
    labels = torch.LongTensor(labels).to(device)

    data = {
        'fts': features,
        'edge_index': edge_index,
        'lbls': labels,
        'train_idx': train,
        'test_idx': test
    }

    return data


# def load_cite(args):
#     dname = args.dataset
#     device = torch.device(args.device)
#     path = osp.abspath(__file__)
#     d_path = osp.dirname(path)
#     f_path = osp.join(d_path, 'data')
#
#     dataset = Planetoid(f_path, dname)
#     tmp = dataset[0].to(device)
#     fts = tmp.x
#     lbls = tmp.y
#
#     # 标准划分（使用 Planetoid 自带的 mask）
#     if args.split_ratio < 0:
#         train_idx = tmp.train_mask.nonzero(as_tuple=False).view(-1)
#         val_idx = tmp.val_mask.nonzero(as_tuple=False).view(-1)
#         test_idx = tmp.test_mask.nonzero(as_tuple=False).view(-1)
#     else:
#         # 自定义划分（例如 0.7/0.15/0.15）
#         nums = lbls.shape[0]
#         idx_list = list(range(nums))
#         random.seed(args.seed)
#         random.shuffle(idx_list)
#
#         num_train = int(nums * args.split_ratio)
#         num_val = int(nums * 0.1)
#         num_test = nums - num_train - num_val
#
#         train_idx = torch.tensor(idx_list[:num_train])
#         val_idx = torch.tensor(idx_list[num_train:num_train + num_val])
#         test_idx = torch.tensor(idx_list[num_train + num_val:])
#
#     data = {
#         'fts': fts,
#         'edge_index': tmp.edge_index,
#         'lbls': lbls,
#         'train_idx': train_idx,
#         'val_idx': val_idx if args.split_ratio >= 0 else val_idx,  # 标准划分时 val_idx 是存在的
#         'test_idx': test_idx
#     }
#
#     return data


def load_cite(args):
    dname = args.dataset
    device = torch.device(args.device)
    path = osp.abspath(__file__)
    d_path = osp.dirname(path)
    f_path = osp.join(d_path, 'data')

    dataset = Planetoid(f_path, dname)
    tmp = dataset[0].to(device)
    fts = tmp.x
    lbls = tmp.y
    if args.split_ratio < 0:
        train_idx = tmp.train_mask.nonzero(as_tuple=False).view(-1)
        val_idx = tmp.val_mask.nonzero(as_tuple=False).view(-1)
        test_idx = tmp.test_mask.nonzero(as_tuple=False).view(-1)
    else:
        # 💡 平衡正负类划分
        # val_prop = 0.25
        # test_prop = 0.25  # 剩下的是训练集
        val_prop = 0
        test_prop = 0  # 剩下的是训练集

        labels = lbls.cpu().numpy()
        np.random.seed(args.seed)
        pos_idx = labels.nonzero()[0]
        neg_idx = (1. - labels).nonzero()[0]
        np.random.shuffle(pos_idx)
        np.random.shuffle(neg_idx)
        pos_idx = pos_idx.tolist()
        neg_idx = neg_idx.tolist()
        nb_pos_neg = min(len(pos_idx), len(neg_idx))

        nb_val = round(val_prop * nb_pos_neg)
        nb_test = round(test_prop * nb_pos_neg)

        idx_val_pos, idx_test_pos, idx_train_pos = pos_idx[:nb_val], pos_idx[nb_val:nb_val + nb_test], pos_idx[
                                                                                                       nb_val + nb_test:]
        idx_val_neg, idx_test_neg, idx_train_neg = neg_idx[:nb_val], neg_idx[nb_val:nb_val + nb_test], neg_idx[
                                                                                           nb_val + nb_test:]

        train_idx = torch.tensor(idx_train_pos + idx_train_neg, dtype=torch.long, device=device)
        val_idx = torch.tensor(idx_val_pos + idx_val_neg, dtype=torch.long, device=device)
        test_idx = torch.tensor(idx_test_pos + idx_test_neg, dtype=torch.long, device=device)

    # 保存所有数据
    save_prefix = osp.join(f_path, dname)
    fts_path = save_prefix + 'fts.pt'
    lbls_path = save_prefix + 'lbls.pt'
    edge_path = save_prefix + 'edge.pt'
    train_path = save_prefix + 'train.pt'
    val_path = save_prefix + 'val.pt'
    test_path = save_prefix + 'test.pt'

    torch.save(fts.cpu(), fts_path)
    torch.save(lbls.cpu(), lbls_path)
    torch.save(tmp.edge_index.cpu(), edge_path)
    torch.save(train_idx.cpu(), train_path)
    torch.save(val_idx.cpu(), val_path)
    torch.save(test_idx.cpu(), test_path)
    print(train_idx.size(), val_idx.size(), test_idx.size())
    data = {
        'fts': fts,
        'edge_index': tmp.edge_index,
        'lbls': lbls,
        'train_idx': train_idx,
        'val_idx': val_idx,
        'test_idx': test_idx
    }
    return data

def load_cite(args):
    dname = args.dataset
    device = torch.device(args.device)
    path = osp.abspath(__file__)
    d_path = osp.dirname(path)
    f_path = osp.join(d_path, 'data')

    dataset = Planetoid(f_path, dname)
    tmp = dataset[0].to(device)
    fts = tmp.x
    lbls = tmp.y
    if args.split_ratio < 0:
        train_idx = tmp.train_mask.nonzero(as_tuple=False).view(-1)
        val_idx = tmp.val_mask.nonzero(as_tuple=False).view(-1)
        test_idx = tmp.test_mask.nonzero(as_tuple=False).view(-1)
    else:
        # 💡 平衡正负类划分
        # val_prop = 0.25
        # test_prop = 0.25  # 剩下的是训练集
        val_prop = 0.25
        test_prop = 0.25  # 剩下的是训练集

        labels = lbls.cpu().numpy()
        print(labels.shape)
        np.random.seed(args.seed)
        pos_idx = labels.nonzero()[0]
        neg_idx = (1. - labels).nonzero()[0]
        np.random.shuffle(pos_idx)
        np.random.shuffle(neg_idx)
        pos_idx = pos_idx.tolist()
        neg_idx = neg_idx.tolist()
        nb_pos_neg = min(len(pos_idx), len(neg_idx))

        nb_val = round(val_prop * nb_pos_neg)
        nb_test = round(test_prop * nb_pos_neg)

        idx_val_pos, idx_test_pos, idx_train_pos = pos_idx[:nb_val], pos_idx[nb_val:nb_val + nb_test], pos_idx[
                                                                                                       nb_val + nb_test:]
        idx_val_neg, idx_test_neg, idx_train_neg = neg_idx[:nb_val], neg_idx[nb_val:nb_val + nb_test], neg_idx[
                                                                                           nb_val + nb_test:]

        train_idx = torch.tensor(idx_train_pos + idx_train_neg, dtype=torch.long, device=device)
        val_idx = torch.tensor(idx_val_pos + idx_val_neg, dtype=torch.long, device=device)
        test_idx = torch.tensor(idx_test_pos + idx_test_neg, dtype=torch.long, device=device)

    # 保存所有数据
    save_prefix = osp.join(f_path, dname)
    fts_path = save_prefix + 'fts.pt'
    lbls_path = save_prefix + 'lbls.pt'
    edge_path = save_prefix + 'edge.pt'
    train_path = save_prefix + 'train.pt'
    val_path = save_prefix + 'val.pt'
    test_path = save_prefix + 'test.pt'

    torch.save(fts.cpu(), fts_path)
    torch.save(lbls.cpu(), lbls_path)
    torch.save(tmp.edge_index.cpu(), edge_path)
    torch.save(train_idx.cpu(), train_path)
    torch.save(val_idx.cpu(), val_path)
    torch.save(test_idx.cpu(), test_path)
    print(train_idx.size(), val_idx.size(), test_idx.size())
    data = {
        'fts': fts,
        'edge_index': tmp.edge_index,
        'lbls': lbls,
        'train_idx': train_idx,
        'val_idx': val_idx,
        'test_idx': test_idx
    }
    return data



def load_cite1(args):
    dname = args.dataset
    device = torch.device(args.device)
    path = osp.abspath(__file__)
    d_path = osp.dirname(path)
    f_path = osp.join(d_path, 'data')

    dataset = Planetoid(f_path, dname)
    tmp = dataset[0].to(device)
    fts = tmp.x
    lbls = tmp.y

    nums = lbls.shape[0]
    idx_list = list(range(nums))
    random.seed(args.seed)
    random.shuffle(idx_list)

    # 自主划分：0.1（train）+ 0.1（val）+ 0.8（test）
    num_train = int(nums * 0.1)
    num_val = int(nums * 0.1)
    num_test = nums - num_train - num_val

    train_idx = torch.tensor(idx_list[:num_train], device=device)
    val_idx = torch.tensor(idx_list[num_train:num_train + num_val], device=device)
    test_idx = torch.tensor(idx_list[num_train + num_val:], device=device)

    data = {
        'fts': fts,
        'edge_index': tmp.edge_index,
        'lbls': lbls,
        'train_idx': train_idx,
        'val_idx': val_idx,
        'test_idx': test_idx
    }

    return data


def load_cite2(args):
    dname = args.dataset  # 数据集名称，如 'Cora', 'Citeseer', 'Pubmed'
    device = torch.device(args.device)

    # 当前文件所在路径下的 data 文件夹
    path = osp.abspath(__file__)
    d_path = osp.dirname(path)
    f_path = osp.join(d_path, 'data')

    # 使用 PyG 的 Planetoid 类加载数据集
    dataset = Planetoid(f_path, dname)
    tmp = dataset[0].to(device)

    fts = tmp.x              # 节点特征 [num_nodes, in_dim]
    lbls = tmp.y             # 标签向量 [num_nodes]
    edge_index = tmp.edge_index  # 边连接 [2, num_edges]

    # 使用 split_ratio 自定义训练/测试划分
    if args.split_ratio < 0:
        train_idx = tmp.train_mask
        test_idx = tmp.test_mask
    else:
        nums = lbls.shape[0]
        num_train = int(nums * args.split_ratio)
        idx_list = list(range(nums))
        train_idx = random.sample(idx_list, num_train)
        test_idx = [i for i in idx_list if i not in train_idx]
        train_idx = torch.tensor(train_idx, dtype=torch.long)
        test_idx = torch.tensor(test_idx, dtype=torch.long)

    data = {
        'fts': fts,
        'edge_index': edge_index,
        'lbls': lbls,
        'train_idx': train_idx,
        'test_idx': test_idx
    }

    return data

# def load_ft(args):
#     if args.dataset == '40':
#         data_dir = './data/ModelNet40_mvcnn_gvcnn.mat'
#     elif args.dataset == 'NTU':
#         data_dir = './data/NTU2012_mvcnn_gvcnn.mat'
#
#     device = torch.device(args.device)
#     feature_name = args.fts
#
#     data = scio.loadmat(data_dir)
#     lbls = data['Y'].astype(np.long)
#     torch.save(lbls, 'lbls.pt')
#     if lbls.min() == 1:
#         lbls = lbls - 1
#     idx = data['indices'].item()
#
#     if feature_name == 'MVCNN':
#         fts = data['X'][0].item().astype(np.float32)
#         fts = torch.Tensor(fts).to(device)
#     elif feature_name == 'GVCNN':
#         fts = data['X'][1].item().astype(np.float32)
#         fts = torch.Tensor(fts).to(device)
#     else:
#         fts1 = data['X'][0].item().astype(np.float32)
#         fts2 = data['X'][1].item().astype(np.float32)
#         fts1 = torch.Tensor(fts1).to(device)
#         fts2 = torch.Tensor(fts2).to(device)
#         dim1 = fts1.shape[1]
#         dim2 = fts2.shape[1]
#         dim = dim1 + dim2
#         print(f"Multi-modal features shape:  dim1: {dim1}, dim2: {dim2}, total_dim: {dim}")
#
#         fts = torch.cat((fts1, fts2), dim=-1)
#
#     if args.split_ratio < 0:
#         train_idx = np.where(idx == 1)[0]
#         test_idx = np.where(idx == 0)[0]
#     else:
#         nums = lbls.shape[0]
#         num_train = int(nums * args.split_ratio)
#         idx_list = [i for i in range(nums)]
#
#         train_idx = random.sample(idx_list, num_train)
#         test_idx = [i for i in idx_list if i not in train_idx]
#
#     # train_idx = np.where(idx == 1)[0]
#     # test_idx = np.where(idx == 0)[0]
#
#     lbls = torch.Tensor(lbls).squeeze().long().to(device)
#     train_idx = torch.Tensor(train_idx).long().to(device)
#     test_idx = torch.Tensor(test_idx).long().to(device)
#
#     data = {
#         'fts': fts,
#         'lbls': lbls,
#         'train_idx': train_idx,
#         'test_idx': test_idx
#     }
#
#     return data


def load_ft(args):
    if args.dataset == '40':
        data_dir = './data/ModelNet40_mvcnn_gvcnn.mat'
    elif args.dataset == 'NTU':
        data_dir = './data/NTU2012_mvcnn_gvcnn.mat'

    device = torch.device(args.device)
    feature_name = args.fts

    data = scio.loadmat(data_dir)
    lbls = data['Y'].astype(np.int64)

    torch.save(lbls, 'lbls.pt')
    if lbls.min() == 1:
        lbls = lbls - 1
    idx = data['indices'].item()

    if feature_name == 'MVCNN':
        fts = data['X'][0].item().astype(np.float32)
        fts = torch.Tensor(fts).to(device)
    elif feature_name == 'GVCNN':
        fts = data['X'][1].item().astype(np.float32)
        fts = torch.Tensor(fts).to(device)
    else:
        fts1 = data['X'][0].item().astype(np.float32)
        fts2 = data['X'][1].item().astype(np.float32)
        fts1 = torch.Tensor(fts1).to(device)
        fts2 = torch.Tensor(fts2).to(device)
        dim1 = fts1.shape[1]
        dim2 = fts2.shape[1]
        dim = dim1 + dim2
        print(f"Multi-modal features shape:  dim1: {dim1}, dim2: {dim2}, total_dim: {dim}")
        fts = torch.cat((fts1, fts2), dim=-1)

    nums = lbls.shape[0]
    idx_list = list(range(nums))
    random.shuffle(idx_list)

    train_size = int(0.5 * nums)
    val_size = int(0.25 * nums)
    test_size = nums - train_size - val_size

    train_idx = idx_list[:train_size]
    val_idx = idx_list[train_size:train_size + val_size]
    test_idx = idx_list[train_size + val_size:]

    lbls = torch.Tensor(lbls).squeeze().long().to(device)
    train_idx = torch.tensor(train_idx).long().to(device)
    val_idx = torch.tensor(val_idx).long().to(device)
    test_idx = torch.tensor(test_idx).long().to(device)

    data = {
        'fts': fts,
        'lbls': lbls,
        'train_idx': train_idx,
        'val_idx': val_idx,
        'test_idx': test_idx
    }

    return data



def load_data(args):
    if args.dataset in ['40', 'NTU']:
        return load_ft(args)
    elif args.dataset in ['Cora', 'Citeseer', 'Pubmed']:
        return load_cite(args)
    elif args.dataset in ['chameleon', 'crocodile', 'squirrel']:
        return load_node_regress(args)
    elif args.dataset in ['cora']:
        return load_citation_data()
    elif args.dataset in ['node2vec_PPI']:
        return load_bionev_data(args)
    elif args.dataset in ['airport']:
        return load_airport_data(args)
    elif args.dataset in ['diabet']:
        return load_diabet_data(args)
    elif args.dataset in ['disease_nc']:
        return load_synthetic_data(args)
    elif args.dataset in ['ogbn-products', 'ogbn-arxiv', 'ogbn-proteins']:
        return load_ogbn_dataset(args)
    elif args.dataset in ['WikiCs']:
        return load_wikics_dataset(args)
    elif args.dataset in ('coauthor-cs', 'coauthor-physics'):
        return load_coauthor_dataset(args)
    elif args.dataset in ('reddit'):
        return load_reddit(args)
    elif args.dataset in ('20news'):
        return load_20news(args)
    elif args.dataset in ['arxiv']:
        return load_ogbn_arxiv(args)
    elif args.dataset in ['github']:
        return load_github(args)
    elif args.dataset in ['DE','EN','ES','FR','PT','RU']:
        return load_twitch(args)
    elif args.dataset in ['Lung','Brain','Breast','Stomach','Artery_Coronary','Artery_Aorta','Leukemia','Lung_cancer','Stomach_cancer','Kidney_renal_papillary_cell_carcinoma']:
        return load_Grand(args)
    elif args.dataset in ['computers','photo']:
        return load_Amazon(args)
    elif args.dataset in ['DBLP']:
        return load_dblp_homo(args)






def parse_index_file(filename):
    """
    Copied from gcn
    Parse index file.
    """
    index = []
    for line in open(filename):
        index.append(int(line.strip()))
    return index


def preprocess_features(features):
    """Row-normalize feature matrix and convert to tuple representation"""
    rowsum = np.array(features.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    features = r_mat_inv.dot(features)
    return features


def load_citation_data():
    """
    Copied from gcn
    citeseer/cora/pubmed with gcn split
    Loads input data from gcn/data directory

    ind.dataset_str.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;
    ind.dataset_str.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object;
    ind.dataset_str.allx => the feature vectors of both labeled and unlabeled training instances
        (a superset of ind.dataset_str.x) as scipy.sparse.csr.csr_matrix object;
    ind.dataset_str.y => the one-hot labels of the labeled training instances as numpy.ndarray object;
    ind.dataset_str.ty => the one-hot labels of the test instances as numpy.ndarray object;
    ind.dataset_str.ally => the labels for instances in ind.dataset_str.allx as numpy.ndarray object;
    ind.dataset_str.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict
        object;
    ind.dataset_str.test.index => the indices of test instances in graph, for the inductive setting as list object.

    All objects above must be saved using python pickle module.

    :param dataset_str: Dataset name
    :return: All data input files loaded (as well the training/test data).
    """
    cfg = {
        'citation_root': './data/Cora/raw',
        'activate_dataset': 'cora',
        'add_self_loop': True
    }

    names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
    objects = []
    for i in range(len(names)):
        with open("{}/ind.{}.{}".format(cfg['citation_root'], cfg['activate_dataset'], names[i]), 'rb') as f:
            objects.append(pkl.load(f, encoding='latin1'))

    x, y, tx, ty, allx, ally, graph = tuple(objects)
    test_idx_reorder = parse_index_file("{}/ind.{}.test.index".format(cfg['citation_root'], cfg['activate_dataset']))
    test_idx_range = np.sort(test_idx_reorder)

    if cfg['activate_dataset'] == 'citeseer':
        # Fix citeseer dataset (there are some isolated nodes in the graph)
        # Find isolated nodes, add them as zero-vecs into the right position
        test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder) + 1)
        tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
        tx_extended[test_idx_range - min(test_idx_range), :] = tx
        tx = tx_extended
        ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
        ty_extended[test_idx_range - min(test_idx_range), :] = ty
        ty = ty_extended

    features = sp.vstack((allx, tx)).tolil()
    features[test_idx_reorder, :] = features[test_idx_range, :]
    features = preprocess_features(features)
    features = features.todense()

    G = nx.from_dict_of_lists(graph)
    # print("=====> ", G)
    # edge_list = G.adjacency_list()
    adjacency = G.adjacency()
    edge_list = []
    for item in adjacency:
        # print(list(item[1].keys()))
        edge_list.append(list(item[1].keys()))

    degree = [0] * len(edge_list)
    if cfg['add_self_loop']:
        for i in range(len(edge_list)):
            edge_list[i].append(i)
            degree[i] = len(edge_list[i])
    max_deg = max(degree)
    mean_deg = sum(degree) / len(degree)
    print(f'max degree: {max_deg}, mean degree:{mean_deg}')

    labels = np.vstack((ally, ty))
    labels[test_idx_reorder, :] = labels[test_idx_range, :]  # one-hot labels
    n_sample = labels.shape[0]
    n_category = labels.shape[1]
    lbls = np.zeros((n_sample,))
    if cfg['activate_dataset'] == 'citeseer':
        n_category += 1  # one-hot labels all zero: new category
        for i in range(n_sample):
            try:
                lbls[i] = np.where(labels[i] == 1)[0]  # numerical labels
            except ValueError:  # labels[i] all zeros
                lbls[i] = n_category + 1  # new category
    else:
        for i in range(n_sample):
            lbls[i] = np.where(labels[i] == 1)[0]  # numerical labels

    idx_test = test_idx_range.tolist()
    idx_train = list(range(len(y)))
    idx_val = list(range(len(y), len(y) + 500))

    features = torch.Tensor(features)
    lbls = torch.LongTensor(lbls)

    data = {
        'fts': features,
        'lbls': lbls,
        'train_idx': idx_val,
        'test_idx': idx_test
    }

    return data

    # return features, lbls, idx_train, idx_val, idx_test, n_category, edge_list, edge_list


def load_bionev_data(args):
    dataset_str = args.dataset

    # 当前文件所在路径下的 data 文件夹
    path = osp.abspath(__file__)
    d_path = osp.dirname(path)
    data_path = osp.join(d_path, 'data')
    device = torch.device(args.device)
    use_feats = args.use_feats
    split_ratio = args.split_ratio  # 训练集比例，剩下的将平分成 val/test
    seed = args.seed if hasattr(args, 'seed') else 42

    # 设置随机种子
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    # 构造路径
    edgelist_file_path = os.path.join(data_path, dataset_str, f'{dataset_str}.edgelist')
    labels_file_path = os.path.join(data_path, dataset_str, f'{dataset_str}_labels.txt')

    # 加载边列表并构建邻接矩阵
    edges = np.loadtxt(edgelist_file_path, dtype=np.int32)
    num_nodes = np.max(edges) + 1
    adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                        shape=(num_nodes, num_nodes))

    # 特征矩阵构建
    if use_feats:
        # features = sp.identity(num_nodes)
        degrees = adj.sum(axis=1).A1  # 计算每个节点的度
        features = torch.tensor(degrees[:, None], dtype=torch.float)
    else:
        features = np.ones((num_nodes, 1))

    if sp.issparse(features):
        features = features.todense()
    fts = torch.tensor(features, dtype=torch.float)

    # 读取标签
    def read_node_labels(filename):
        node_list = []
        labels = []
        with open(filename, 'r') as f:
            for line in f:
                parts = line.strip().split()
                node_list.append(int(parts[0]))
                labels.append(int(parts[1]))
        return node_list, labels

    node_list, labels_list = read_node_labels(labels_file_path)
    lbls = torch.full((num_nodes,), -1, dtype=torch.long)
    lbls[node_list] = torch.tensor(labels_list, dtype=torch.long)

    # 邻接转换为 edge_index
    edge_index, _ = from_scipy_sparse_matrix(adj)

    # 有标签的节点
    valid_nodes = [i for i in node_list if lbls[i] >= 0]
    random.shuffle(valid_nodes)

    num_total = len(valid_nodes)
    num_train = int(num_total * split_ratio)
    num_val = int((num_total - num_train) / 2)
    num_test = num_total - num_train - num_val

    train_idx = torch.tensor(valid_nodes[:num_train], dtype=torch.long)
    val_idx = torch.tensor(valid_nodes[num_train:num_train + num_val], dtype=torch.long)
    test_idx = torch.tensor(valid_nodes[num_train + num_val:], dtype=torch.long)

    # 返回格式一致
    data = {
        'fts': fts.to(device),
        'edge_index': edge_index.to(device),
        'lbls': lbls.to(device),
        'train_idx': train_idx.to(device),
        'val_idx': val_idx.to(device),
        'test_idx': test_idx.to(device)
    }

    return data


def bin_feat(feat, bins):
    digitized = np.digitize(feat, bins)
    return digitized - digitized.min()


def load_airport_data(args):
    dataset_str = args.dataset
    # 当前文件所在路径下的 data 文件夹
    path = osp.abspath(__file__)
    d_path = osp.dirname(path)
    data_path = osp.join(d_path, 'data')
    device = torch.device(args.device)
    split_ratio = args.split_ratio
    seed = args.seed if hasattr(args, 'seed') else 42

    # 加载图结构
    graph = pkl.load(open(os.path.join(data_path, dataset_str, dataset_str + '.p'), 'rb'))
    adj = nx.adjacency_matrix(graph)
    features = np.array([graph.nodes[u]['feat'] for u in graph.nodes()])
    features = torch.tensor(features[:, :4], dtype=torch.float)

    # 获取标签（第5列）
    raw_labels = features[:, 3].numpy()
    labels = bin_feat(raw_labels, bins=[7.0 / 7, 8.0 / 7, 9.0 / 7])
    labels = torch.tensor(labels, dtype=torch.long)

    # 构建 PyG 格式边索引
    edge_index, _ = from_scipy_sparse_matrix(sp.csr_matrix(adj))

    # 节点总数
    num_nodes = features.shape[0]
    num_edges = edge_index.shape[1]
    num_classes = labels.max().item() + 1
    print(features.shape[0], features.shape[1] )
    print(num_edges)
    print(num_classes)

    idx_list = list(range(num_nodes))

    # 固定随机种子划分训练集、验证集和测试集
    random.seed(seed)
    random.shuffle(idx_list)

    # 划分为 70% 训练集、15% 验证集、15% 测试集
    num_train = int(num_nodes * 0.7)
    num_val = int(num_nodes * 0.15)
    train_idx = torch.tensor(idx_list[:num_train], dtype=torch.long)
    val_idx = torch.tensor(idx_list[num_train:num_train + num_val], dtype=torch.long)
    test_idx = torch.tensor(idx_list[num_train + num_val:], dtype=torch.long)

    return {
        'fts': features.to(device),
        'edge_index': edge_index.to(device),
        'lbls': labels.to(device),
        'train_idx': train_idx.to(device),
        'val_idx': val_idx.to(device),
        'test_idx': test_idx.to(device)
    }


def load_diabet_data(args):
    dataset = args.dataset
    # 当前文件所在路径下的 data 文件夹
    path = osp.abspath(__file__)
    d_path = osp.dirname(path)
    data_path = osp.join(d_path, 'data')
    use_feats = args.use_feats
    split_ratio = args.split_ratio  # 应该是训练集比例 0.7
    device = torch.device(args.device)
    seed = args.seed if hasattr(args, 'seed') else 42

    # 设置随机种子
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    # 构造路径
    data_file_path = os.path.join(data_path, dataset, f'{dataset}.csv')

    # 读取数据
    data = pd.read_csv(data_file_path)

    # 特征构造
    if use_feats:
        feature_values = data.iloc[:, :-1].values
        # ======== 归一化处理：Min-Max Scaling 到 [0, 1] ========
        feature_min = feature_values.min(axis=0)
        feature_max = feature_values.max(axis=0)
        feature_values = (feature_values - feature_min) / (feature_max - feature_min + 1e-8)
    else:
        feature_values = np.ones((data.shape[0], 1))

    fts = torch.tensor(feature_values, dtype=torch.float)

    # 标签
    labels = data.iloc[:, -1].values
    lbls = torch.tensor(labels, dtype=torch.long)

    # 用 KNN 构建图结构（邻接矩阵）
    knn = NearestNeighbors(n_neighbors=5)
    knn.fit(data.iloc[:, :-1].values)  # 始终用真实特征构建邻接
    knn_indices = knn.kneighbors(return_distance=False)

    n = data.shape[0]
    adj = sp.lil_matrix((n, n))
    for i in range(n):
        for j in knn_indices[i]:
            adj[i, j] = 1
            adj[j, i] = 1
    adj = adj.tocsr()

    edge_index, _ = from_scipy_sparse_matrix(adj)

    # 数据划分：train / val / test = 0.7 / 0.15 / 0.15
    idx_list = list(range(n))
    random.shuffle(idx_list)
    num_train = int(n * split_ratio)
    num_val = int((n - num_train) / 2)
    num_test = n - num_train - num_val

    train_idx = torch.tensor(idx_list[:num_train], dtype=torch.long)
    val_idx = torch.tensor(idx_list[num_train:num_train + num_val], dtype=torch.long)
    test_idx = torch.tensor(idx_list[num_train + num_val:], dtype=torch.long)

    return {
        'fts': fts.to(device),
        'edge_index': edge_index.to(device),
        'lbls': lbls.to(device),
        'train_idx': train_idx.to(device),
        'val_idx': val_idx.to(device),
        'test_idx': test_idx.to(device)
    }


def load_synthetic_data(args):
    dataset_str = args.dataset
    # 当前文件所在路径下的 data 文件夹
    path = osp.abspath(__file__)
    d_path = osp.dirname(path)
    data_path = osp.join(d_path, 'data')
    use_feats = args.use_feats
    split_ratio = args.split_ratio
    seed = args.seed if hasattr(args, 'seed') else 42
    device = torch.device(args.device)

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Step 1: 读取边数据
    object_to_idx = {}
    idx_counter = 0
    edges = []
    edge_file = os.path.join(data_path, dataset_str, f"{dataset_str}.edges.csv")

    with open(edge_file, 'r') as f:
        for line in f:
            n1, n2 = line.strip().split(',')
            for node in (n1, n2):
                if node not in object_to_idx:
                    object_to_idx[node] = idx_counter
                    idx_counter += 1
            edges.append((object_to_idx[n1], object_to_idx[n2]))

    # Step 2: 构建邻接矩阵
    n = len(object_to_idx)
    adj = np.zeros((n, n), dtype=np.float32)
    for i, j in edges:
        adj[i, j] = 1.0
        adj[j, i] = 1.0  # 无向图

    adj = sp.csr_matrix(adj)
    edge_index, _ = from_scipy_sparse_matrix(adj)

    # Step 3: 构建特征矩阵
    if use_feats:
        feature_path = os.path.join(data_path, dataset_str, f"{dataset_str}.feats.npz")
        features = sp.load_npz(feature_path).toarray()
    else:
        features = np.eye(n)

    fts = torch.tensor(features, dtype=torch.float)

    # Step 4: 加载标签
    label_path = os.path.join(data_path, dataset_str, f"{dataset_str}.labels.npy")
    labels = np.load(label_path)
    lbls = torch.tensor(labels, dtype=torch.long)

    # Step 5: 数据划分（train/val/test）0.7, 0.15, 0.15
    idx_list = list(range(n))
    random.shuffle(idx_list)

    num_train = int(n * 0.7)
    num_val = int(n * 0.15)
    num_test = n - num_train - num_val  # 剩下的部分作为测试集

    # 划分训练集、验证集和测试集
    train_idx = torch.tensor(idx_list[:num_train], dtype=torch.long)
    val_idx = torch.tensor(idx_list[num_train:num_train + num_val], dtype=torch.long)
    test_idx = torch.tensor(idx_list[num_train + num_val:], dtype=torch.long)

    num_nodes = features.shape[0]
    num_edges = edge_index.shape[1]
    feat_dim = features.shape[1]
    num_classes = len(set(labels.tolist()))
    print("Nodes:", num_nodes)
    print("Edges:", edge_index.shape[1])
    print("Feature Dim:", feat_dim)
    print("Classes:", len(set(labels.tolist())))
    return {
        'fts': fts.to(device),
        'edge_index': edge_index.to(device),
        'lbls': lbls.to(device),
        'train_idx': train_idx.to(device),
        'val_idx': val_idx.to(device),  # 返回验证集索引
        'test_idx': test_idx.to(device)
    }

def load_node_regress(args):
    # 选择数据集名
    name = args.dataset  # 只能是 'chameleon'、'crocodile' 或 'squirrel'
    device = torch.device(args.device)

    # 加载数据集
    dataset = WikipediaNetwork(root='/home/projects/TDHNN/data/Wikipedia', name=name,geom_gcn_preprocess=True)
    data = dataset[0].to(device)

    labels = data.y
    num_nodes = data.num_nodes

    np.random.seed(args.seed)
    idx = np.random.permutation(num_nodes)

    n_train = int(0.3 * num_nodes)
    n_val = int(0.2 * num_nodes)
    n_test = num_nodes - n_train - n_val  # 0.5 * num_nodes

    train_idx = torch.tensor(idx[:n_train], dtype=torch.long)
    val_idx = torch.tensor(idx[n_train:n_train + n_val], dtype=torch.long)
    test_idx = torch.tensor(idx[n_train + n_val:], dtype=torch.long)
    data.x = torch.nn.functional.normalize(data.x, p=1)
    return {
        'fts': data.x,
        'edge_index': data.edge_index,
        'lbls': labels,
        'train_idx': train_idx,
        'val_idx': val_idx,
        'test_idx': test_idx
    }
def parse_index_file(filename):
    index = []
    for line in open(filename):
        index.append(int(line.strip()))
    return index


def load_ogbn_dataset(args):
    # 设置设备
    device = torch.device(args.device if hasattr(args, 'device') else ('cuda' if torch.cuda.is_available() else 'cpu'))
    # device = torch.device('cpu')
    print(device)
    dataset_name = args.dataset  # 例如 'ogbn-products'

    # 设置数据路径
    path = osp.abspath(__file__)
    d_path = osp.dirname(path)
    data_path = osp.join(d_path, 'data')

    # 加载数据
    dataset = PygNodePropPredDataset(name=dataset_name, root=data_path)
    data = dataset[0]

    features = data.x
    edge_index = data.edge_index
    labels = data.y.squeeze()
    split_idx = dataset.get_idx_split()
    # 构建 PyG 的 Data 对象
    pyg_data = Data(x=features.to(torch.float),
                    edge_index=edge_index.to(torch.long),
                    y=labels.to(torch.long))

    print("加载的数据类型:", type(pyg_data))

    n = pyg_data.num_nodes
    print('N',n)
    idx_list = list(range(n))
    random.shuffle(idx_list)  # 打乱索引

    num_train = int(n * 0.7)
    num_val = int(n * 0.15)
    num_test = n - num_train - num_val

    # train_idx = torch.tensor(idx_list[:num_train], dtype=torch.long)
    # val_idx = torch.tensor(idx_list[num_train:num_train + num_val], dtype=torch.long)
    # test_idx = torch.tensor(idx_list[num_train + num_val:], dtype=torch.long)
    split_idx = dataset.get_idx_split()

    return {
        'fts': pyg_data.x.to(device),
        'edge_index': pyg_data.edge_index.to(device),
        'lbls': pyg_data.y.to(device),
        'split_idx': split_idx
    }


from torch_geometric.utils import subgraph
from torch_geometric.utils import add_self_loops


def load_ogbn_arxiv(args):
    # 设置设备
    device = torch.device(args.device if hasattr(args, 'device') else ('cuda' if torch.cuda.is_available() else 'cpu'))

    dataset_name = 'ogbn-products'  # 例如 'ogbn-products'

    # 设置数据路径
    path = osp.abspath(__file__)
    d_path = osp.dirname(path)
    data_path = osp.join(d_path, 'data')

    # 加载数据
    dataset = PygNodePropPredDataset(name=dataset_name, root=data_path)
    data = dataset[0]
    num_total_nodes = data.num_nodes
    features = data.x
    edge_index = data.edge_index
    edge_index, _ = add_self_loops(edge_index, num_nodes=num_total_nodes)
    labels = data.y.squeeze()
    # Step 1: 定义前 10000 个节点索引
    # 随机选取 10000 个节点
    num_sampled = 200000
    subset_nodes = torch.randperm(num_total_nodes)[:num_sampled]

    # Step 2: 利用 PyG 提供的 subgraph 工具构建诱导子图
    sub_edge_index, _ = subgraph(subset_nodes, edge_index, relabel_nodes=True)

    # Step 3: 提取对应的特征和标签
    sub_features = features[subset_nodes]
    sub_labels = labels[subset_nodes]
    # 构建 PyG 的 Data 对象
    pyg_data = Data(x=sub_features.to(torch.float),
                    edge_index=sub_edge_index.to(torch.long),
                    y=sub_labels.to(torch.long))

    print("加载的数据类型:", type(pyg_data))

    n = pyg_data.num_nodes
    idx_list = list(range(n))
    random.shuffle(idx_list)  # 打乱索引

    num_train = int(n * 0.1)
    num_val = int(n * 0.1)
    num_test = n - num_train - num_val

    train_idx = torch.tensor(idx_list[:num_train], dtype=torch.long)
    val_idx = torch.tensor(idx_list[num_train:num_train + num_val], dtype=torch.long)
    test_idx = torch.tensor(idx_list[num_train + num_val:], dtype=torch.long)
    # split_idx = dataset.get_idx_split()
    # train_idx = split_idx['train']
    # val_idx = split_idx['valid']
    # test_idx = split_idx['test']

    return {
        'fts': pyg_data.x.to(device),
        'edge_index': pyg_data.edge_index.to(device),
        'lbls': pyg_data.y.to(device),
        'train_idx': train_idx.to(device),
        'val_idx': val_idx.to(device),
        'test_idx': test_idx.to(device)
    }


from torch_geometric.datasets import WikiCS


def load_wikics_dataset(args):
    device = torch.device(args.device)
    path = osp.abspath(__file__)
    d_path = osp.dirname(path)
    f_path = osp.join(d_path, 'data')

    dataset = WikiCS(root=osp.join(f_path, 'wikics'))
    tmp = dataset[0].to(device)
    fts = tmp.x
    lbls = tmp.y

    if args.split_ratio < 0:
        # 使用 WikiCS 自带的标准划分（20个）
        train_idx = tmp.train_mask[:, 0].nonzero(as_tuple=False).view(-1)
        val_idx = tmp.val_mask.nonzero(as_tuple=False).view(-1)
        test_idx = tmp.test_mask.nonzero(as_tuple=False).view(-1)
    else:
        # 自定义划分（例如 0.7/0.1/0.2）
        nums = lbls.shape[0]
        idx_list = list(range(nums))
        random.seed(args.seed)
        random.shuffle(idx_list)

        num_train = int(nums * args.split_ratio)
        num_val = int(nums * 0.1)
        num_test = nums - num_train - num_val

        train_idx = torch.tensor(idx_list[:num_train], device=device)
        val_idx = torch.tensor(idx_list[num_train:num_train + num_val], device=device)
        test_idx = torch.tensor(idx_list[num_train + num_val:], device=device)

    data = {
        'fts': fts,
        'edge_index': tmp.edge_index,
        'lbls': lbls,
        'train_idx': train_idx,
        'val_idx': val_idx,
        'test_idx': test_idx
    }

    return data


import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import Coauthor
import os.path as osp
import random


def load_coauthor_dataset(args):
    device = torch.device(args.device)
    transform = T.NormalizeFeatures()
    data_dir = osp.join(osp.abspath(osp.dirname(__file__)), 'data')
    name = args.dataset

    if name == 'coauthor-cs':
        torch_dataset = Coauthor(root=osp.join(data_dir, 'Coauthor'), name='CS')
    elif name == 'coauthor-physics':
        torch_dataset = Coauthor(root=osp.join(data_dir, 'Coauthor'), name='Physics')
    else:
        raise ValueError(f"Unsupported dataset: {name}")

    data = torch_dataset[0].to(device)
    fts = data.x
    lbls = data.y
    edge_index = data.edge_index

    # 自定义划分（如 0.1/0.1/0.8）
    nums = lbls.shape[0]
    idx_list = list(range(nums))
    random.seed(args.seed)
    random.shuffle(idx_list)

    num_train = int(nums * 0.6)
    num_val = int(nums * 0.2)
    num_test = nums - num_train - num_val

    train_idx = torch.tensor(idx_list[:num_train], device=device)
    val_idx = torch.tensor(idx_list[num_train:num_train + num_val], device=device)
    test_idx = torch.tensor(idx_list[num_train + num_val:], device=device)

    return {
        'fts': fts,
        'edge_index': edge_index,
        'lbls': lbls,
        'train_idx': train_idx,
        'val_idx': val_idx,
        'test_idx': test_idx
    }


import torch
import random
import os.path as osp
from torch_geometric.datasets import Reddit
from torch_geometric.data import Data


def load_reddit(args):
    dname = args.dataset
    device = torch.device(args.device)
    path = osp.abspath(__file__)
    d_path = osp.dirname(path)
    f_path = osp.join(d_path, 'data')

    # 加载 Reddit 数据集
    dataset = Reddit(f_path)
    tmp = dataset[0].to(device)
    fts = tmp.x  # 特征
    lbls = tmp.y  # 标签

    # 标准划分（使用 Planetoid 自带的 mask）
    if args.split_ratio < 0:
        train_idx = tmp.train_mask.nonzero(as_tuple=False).view(-1)
        val_idx = tmp.val_mask.nonzero(as_tuple=False).view(-1)
        test_idx = tmp.test_mask.nonzero(as_tuple=False).view(-1)
    else:
        # 自定义划分（例如 0.7/0.15/0.15）
        nums = lbls.shape[0]
        idx_list = list(range(nums))
        random.seed(args.seed)
        random.shuffle(idx_list)

        num_train = int(nums * args.split_ratio)
        num_val = int(nums * 0.1)
        num_test = nums - num_train - num_val

        train_idx = torch.tensor(idx_list[:num_train])
        val_idx = torch.tensor(idx_list[num_train:num_train + num_val])
        test_idx = torch.tensor(idx_list[num_train + num_val:])
        print('label', lbls.shape)
    data = {
        'fts': fts,
        'edge_index': tmp.edge_index,
        'lbls': lbls,
        'train_idx': train_idx,
        'val_idx': val_idx if args.split_ratio >= 0 else val_idx,  # 标准划分时 val_idx 是存在的
        'test_idx': test_idx
    }

    return data


import torch
import numpy as np
import random
from sklearn.datasets import fetch_20newsgroups_vectorized

import os.path as osp
import torch
import numpy as np
import random
from sklearn.datasets import fetch_20newsgroups_vectorized
import os


def load_20news(args):
    # 确定路径
    path = osp.abspath(__file__)
    d_path = osp.dirname(path)
    f_path = osp.join(d_path, 'data')
    os.makedirs(f_path, exist_ok=True)
    save_path = osp.join(f_path, '20news.npz')

    device = torch.device(args.device)

    # 如果本地存在就加载，否则下载并保存
    if osp.exists(save_path):
        data_npz = np.load(save_path)
        fts = torch.tensor(data_npz['fts'], dtype=torch.float32).to(device)
        lbls = torch.tensor(data_npz['lbls'], dtype=torch.long).to(device)
    else:
        print("Downloading 20Newsgroups dataset...")
        dataset = fetch_20newsgroups_vectorized(subset='all', remove=('headers', 'footers', 'quotes'))
        fts_np = dataset.data.toarray().astype(np.float32)
        lbls_np = dataset.target.astype(np.int64)
        np.savez_compressed(save_path, fts=fts_np, lbls=lbls_np)
        fts = torch.tensor(fts_np).to(device)
        lbls = torch.tensor(lbls_np).long().to(device)

    # 划分数据集
    nums = lbls.shape[0]
    idx_list = list(range(nums))
    random.seed(args.seed)
    random.shuffle(idx_list)

    if args.split_ratio < 0:
        num_train = int(nums * 0.6)
    else:
        num_train = int(nums * args.split_ratio)

    train_idx = idx_list[:num_train]
    test_idx = idx_list[num_train:]

    train_idx = torch.tensor(train_idx).long().to(device)
    test_idx = torch.tensor(test_idx).long().to(device)

    data = {
        'fts': fts,
        'lbls': lbls,
        'train_idx': train_idx,
        'test_idx': test_idx
    }

    return data



def load_github(args):
    device = torch.device(args.device)
    path = osp.abspath(__file__)
    d_path = osp.dirname(path)
    f_path = osp.join(d_path, 'data','github')

    dataset = GitHub(root=f_path)
    data = dataset[0].to(device)

    fts = data.x
    lbls = data.y

    # 随机划分索引：60% 训练，20% 验证，20% 测试
    labels = data.y
    num_nodes = data.num_nodes

    np.random.seed(args.seed)
    idx = np.random.permutation(num_nodes)

    n_train = int(0.6 * num_nodes)
    n_val = int(0.2 * num_nodes)
    n_test = num_nodes - n_train - n_val

    train_idx = torch.tensor(idx[:n_train], dtype=torch.long)
    val_idx = torch.tensor(idx[n_train:n_train + n_val], dtype=torch.long)
    test_idx = torch.tensor(idx[n_train + n_val:], dtype=torch.long)

    return {
        'fts': fts,
        'edge_index': data.edge_index,
        'lbls': lbls,
        'train_idx': train_idx,
        'val_idx':val_idx,
        'test_idx': test_idx
    }


import numpy as np

def load_twitch(args):
    # 加载数据
    name=args.dataset
    device = torch.device(args.device)
    dataset = Twitch(root='/home/projects/TDHNN/data/twitch', name=name)
    data = dataset[0].to(device)  # Twitch-DE 只有一个图

    # 随机划分索引：60% 训练，20% 验证，20% 测试
    labels = data.y
    num_nodes = data.num_nodes

    np.random.seed(args.seed)
    idx = np.random.permutation(num_nodes)

    n_train = int(0.6 * num_nodes)
    n_val = int(0.2 * num_nodes)
    n_test = num_nodes - n_train - n_val

    train_idx = torch.tensor(idx[:n_train], dtype=torch.long)
    val_idx = torch.tensor(idx[n_train:n_train + n_val], dtype=torch.long)
    test_idx = torch.tensor(idx[n_train + n_val:], dtype=torch.long)


    # 特征、边、标签
    fts = data.x
    edge_index = data.edge_index
    lbls = data.y

    return {
        'fts': fts,
        'edge_index': edge_index,
        'lbls': lbls,
        'train_idx': train_idx,
        'val_idx': val_idx,
        'test_idx': test_idx
    }





def load_Grand(args):
    # 加载数据
    device = torch.device(args.device)
    name=args.dataset
    dataset = Grand(root='/home/projects/TDHNN/data/grand', name=name)
    data = dataset[0].to(device)  # Twitch-DE 只有一个图

    # 随机划分索引：60% 训练，20% 验证，20% 测试
    labels = data.y
    num_nodes = data.num_nodes

    np.random.seed(args.seed)
    idx = np.random.permutation(num_nodes)

    n_train = int(0.6 * num_nodes)
    n_val = int(0.2 * num_nodes)
    n_test = num_nodes - n_train - n_val

    train_idx = torch.tensor(idx[:n_train], dtype=torch.long)
    val_idx = torch.tensor(idx[n_train:n_train + n_val], dtype=torch.long)
    test_idx = torch.tensor(idx[n_train + n_val:], dtype=torch.long)


    # 特征、边、标签
    fts = data.x
    edge_index = data.edge_index
    lbls = data.y

    return {
        'fts': fts,
        'edge_index': edge_index,
        'lbls': lbls,
        'train_idx': train_idx,
        'val_idx': val_idx,
        'test_idx': test_idx
    }


from hg.datasets import GitHub,Facebook,Twitch,Grand
def load_Amazon(args):
    # 加载数据
    device = torch.device(args.device)
    name=args.dataset
    dataset = Amazon(root='/home/projects/TDHNN/data', name=name)
    data = dataset[0].to(device)  # Twitch-DE 只有一个图

    # 随机划分索引：60% 训练，20% 验证，20% 测试
    labels = data.y
    num_nodes = data.num_nodes

    np.random.seed(args.seed)
    idx = np.random.permutation(num_nodes)

    n_train = int(0.6 * num_nodes)
    n_val = int(0.2 * num_nodes)
    n_test = num_nodes - n_train - n_val

    train_idx = torch.tensor(idx[:n_train], dtype=torch.long)
    val_idx = torch.tensor(idx[n_train:n_train + n_val], dtype=torch.long)
    test_idx = torch.tensor(idx[n_train + n_val:], dtype=torch.long)


    # 特征、边、标签
    fts = data.x
    edge_index = data.edge_index
    lbls = data.y

    return {
        'fts': fts,
        'edge_index': edge_index,
        'lbls': lbls,
        'train_idx': train_idx,
        'val_idx': val_idx,
        'test_idx': test_idx
    }



from torch_geometric.datasets import DBLP
from torch_geometric.utils import to_undirected
import torch
import torch_geometric.transforms as T
from torch_geometric.data import Data
import os.path as osp
def load_dblp_homo(args):
    path = osp.join(osp.dirname(osp.abspath(__file__)), 'data', 'DBLP')
    dataset = DBLP(path)

    data = dataset[0]
    print(data)

    # 将异构图转为同构图：只保留 A-P-A 元路径
    edge_index_ap = data['author', 'writes', 'paper'].edge_index
    edge_index_pa = data['paper', 'written_by', 'author'].edge_index

    edge_index_apa = torch.cat([edge_index_ap, edge_index_pa.flip(0)], dim=1)
    edge_index_apa = to_undirected(edge_index_apa)

    num_authors = data['author'].num_nodes
    x = data['author'].x
    y = data['author'].y

    # 加载或生成训练/验证/测试划分（标准划分）
    split = dataset.get_idx_split()
    train_idx = split['train']
    val_idx = split['valid']
    test_idx = split['test']

    homo_data = Data(x=x, edge_index=edge_index_apa, y=y)
    homo_data.train_mask = torch.zeros(num_authors, dtype=torch.bool)
    homo_data.train_mask[train_idx] = True
    homo_data.val_mask = torch.zeros(num_authors, dtype=torch.bool)
    homo_data.val_mask[val_idx] = True
    homo_data.test_mask = torch.zeros(num_authors, dtype=torch.bool)
    homo_data.test_mask[test_idx] = True

    return homo_data
