import os
import sys
sys.path.append(os.getcwd())
import numpy as np
import torch
import scipy
from torch_geometric.utils import from_scipy_sparse_matrix
from torch_geometric.data import Data

def convert_to_pyg_data(adj_matrix, node_features, edge_features_norm):
    """
    将单个场景的图数据转换为 PyG 的 Data 对象。

    参数:
    - adj_matrix (numpy.ndarray): 邻接矩阵 (N x N)
    - node_features (numpy.ndarray): 节点特征矩阵 (N x F)
    - edge_features_norm (numpy.ndarray): 边特征矩阵 (E x D)

    返回:
    - data (torch_geometric.data.Data): PyG 的 Data 对象
    """
    # 转换邻接矩阵为稀疏矩阵
    adj_sparse = scipy.sparse.csr_matrix(adj_matrix)
    
    # 从稀疏矩阵获取 edge_index 和 edge_attr
    edge_index, _ = from_scipy_sparse_matrix(adj_sparse)
    
    # 边特征
    if edge_features_norm.size > 0:
        edge_attr = torch.tensor(edge_features_norm, dtype=torch.float)
    else:
        edge_attr = None
    
    # 节点特征
    x = torch.tensor(node_features, dtype=torch.float)
    
    # 创建 Data 对象
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    
    return data
