import torch
from collections import defaultdict, deque

def find_selected_nodes_with_center(data, target_nodes, center_node):
    # Step 1: 构建邻接表
    edge_index = data.edge_index
    adj = defaultdict(list)
    for src, dst in edge_index.t().tolist():
        adj[src].append(dst)
        adj[dst].append(src)  # 假设是无向图
    
    # Step 2: 初始化选中的节点集合
    selected_nodes = set([center_node])
    
    # Step 3: 对每个目标节点计算最短路径
    def bfs_shortest_path(start, target):
        """使用 BFS 计算从 start 到 target 的最短路径"""
        queue = deque([(start, [start])])  # (当前节点, 路径)
        visited = set([start])
        
        while queue:
            current, path = queue.popleft()
            if current == target:
                return path  # 找到目标节点，返回路径
            
            for neighbor in adj[current]:
                if neighbor not in visited:
                    visited.add(neighbor)
                    queue.append((neighbor, path + [neighbor]))
        
        return None  # 如果没有找到路径
    
    for node in target_nodes:
        if node == center_node:
            continue
        
        # 计算中心节点到目标节点的最短路径
        shortest_path = bfs_shortest_path(center_node, node)
        if shortest_path is None:
            raise ValueError(f"目标节点 {node} 与中心节点 {center_node} 不连通")
        
        # 将路径上的节点加入选中集合
        selected_nodes.update(shortest_path)
    
    # Step 4: 返回选中的节点索引（排序后）
    return sorted(selected_nodes)

# 示例用法
if __name__ == "__main__":
    # 创建一个简单的图
    edge_index = torch.tensor([
        [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6],
        [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5]
    ], dtype=torch.long)
    
    # 目标节点和中心节点
    target_nodes = [0, 3, 6]
    center_node = 3
    
    # 找到选中的节点索引
    selected_nodes = find_selected_nodes_with_center_pyg(
        data=Data(edge_index=edge_index), 
        target_nodes=target_nodes, 
        center_node=center_node
    )
    
    print("选中的节点索引:", selected_nodes)