import torch
import networkx as nx
from torch.nn.functional import cosine_similarity
from torch_geometric.data import Data

def retrieve_subgraph_via_personalized_pagerank(
    graph: Data,
    q_emb: torch.Tensor,
    top_k_seeds: int = 10,
    top_m_nodes: int = 10,
    alpha: float = 0.85,
    textual_nodes = None,
    textual_edges = None
) -> Data:
    """
    使用基于 Personalized PageRank (PPR) 的方法，从给定的 PyG 图结构中检索一个与 q_emb 
    更相关的小子图。

    参数:
    -------
    graph:       PyG 的 Data 对象，至少应包含:
                   - x:         [num_nodes, d] 节点特征/embedding (例如 SBERT embedding)
                   - edge_index:[2, num_edges] 图的边索引
                   - edge_attr: (可选) 若需要对边进行加权，可在此基础上扩展
                   - num_nodes: 节点数
    q_emb:       [d] 的查询向量 (SBERT embedding)，与 graph.x 维度对应
    top_k_seeds: 选取与查询最相似的前 K 个节点作为 PPR 的种子节点
    top_m_nodes: 最终想保留的节点数量 (按 PPR 分数从高到低选取)
    alpha:       PageRank 中的随机游走保留概率 (1-alpha 为跳回概率)
                  典型默认值 0.85

    返回:
    -------
    subgraph:    PyG Data 对象，仅包含 PPR 分数排名最高的 top_m_nodes 以及它们之间的边。
    """
    if len(textual_nodes) == 0 or len(textual_edges) == 0:
        desc = textual_nodes.to_csv(index=False) + '\n' + textual_edges.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])
        graph = Data(x=graph.x, edge_index=graph.edge_index, edge_attr=graph.edge_attr, num_nodes=graph.num_nodes)
        return graph, desc

    if len(textual_nodes) == 1 or len(textual_edges) == 1:
        desc = textual_nodes.to_csv(index=False) + '\n' + textual_edges.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])
        graph = Data(x=graph.x, edge_index=graph.edge_index, edge_attr=graph.edge_attr, num_nodes=graph.num_nodes)
        return graph, desc
    
    # =========== Step 0: 设备与查询向量处理 ===========
    device = graph.x.device if graph.x.is_cuda else torch.device("cpu")
    q_emb = q_emb.to(device)
    if q_emb.dim() == 1:
        q_emb = q_emb.unsqueeze(0)  # shape: [1, d]

    # =========== Step 1: 计算每个节点与 q_emb 的相似度 ===========
    with torch.no_grad():
        node_scores = cosine_similarity(q_emb, graph.x, dim=-1).squeeze(0)

    # =========== Step 2: 选取 top-K 节点作为 seeds ===========
    num_nodes = graph.num_nodes
    if num_nodes is None:
        num_nodes = graph.x.shape[0]  # 兜底
    top_k_seeds = min(top_k_seeds, num_nodes)
    _, seed_indices = torch.topk(node_scores, k=top_k_seeds, largest=True)
    seed_indices = seed_indices.tolist()

    # =========== Step 3: 将 PyG 图转换为 networkx Graph ===========
    G = nx.Graph()
    for i in range(num_nodes):
        G.add_node(i)

    edge_index_np = graph.edge_index.cpu().numpy()  # shape: [2, num_edges]
    num_edges = edge_index_np.shape[1]
    for e_idx in range(num_edges):
        src = edge_index_np[0, e_idx]
        dst = edge_index_np[1, e_idx]
        G.add_edge(src, dst)

    # =========== Step 4: 构建个性化向量 (personalization) ===========
    personalization_dict = {}
    if len(seed_indices) > 0:
        seed_val = 1.0 / len(seed_indices)
    else:
        seed_val = 1.0  # 避免分母为0, 仅在 top_k_seeds=0 时会发生

    for i in range(num_nodes):
        personalization_dict[i] = 0.0
    for s in seed_indices:
        personalization_dict[s] = seed_val

    # =========== Step 5: 运行 Personalized PageRank ===========
    ppr_scores = nx.pagerank(G, alpha=alpha, personalization=personalization_dict)

    # =========== Step 6: 选取 PPR 分数最高的 top_m_nodes 个节点 ===========
    sorted_by_ppr = sorted(ppr_scores.items(), key=lambda x: x[1], reverse=True)
    top_m_nodes = min(top_m_nodes, num_nodes)
    candidate_nodes = [x[0] for x in sorted_by_ppr[:top_m_nodes]]

    # =========== Step 7: 在这 top_m_nodes 上找出 “最大连通分量” ===========
    # 先构造临时子图(仅包含 candidate_nodes)
    G_sub = G.subgraph(candidate_nodes)
    print(graph, G_sub)
    exit()
    # 获取所有连通分量(每个分量是一个节点集合)
    connected_comps = list(nx.connected_components(G_sub))
    # 找到最大连通分量(按长度排序)
    largest_comp = max(connected_comps, key=len) if len(connected_comps) > 0 else set()
    final_nodes = sorted(largest_comp)

    # =========== Step 8: 基于 final_nodes 构造真正的子图，并记录原始边索引 ===========
    set_final_nodes = set(final_nodes)
    node_old2new = {old_idx: i for i, old_idx in enumerate(final_nodes)}

    selected_edges = []
    selected_edge_mask = []
    for e_idx in range(num_edges):
        src = edge_index_np[0, e_idx]
        dst = edge_index_np[1, e_idx]
        # 仅当 src 和 dst 都在 largest_comp 才保留
        if (src in set_final_nodes) and (dst in set_final_nodes):
            selected_edges.append((node_old2new[src], node_old2new[dst]))
            selected_edge_mask.append(e_idx)

    if len(selected_edges) > 0:
        new_edge_index = torch.LongTensor(list(zip(*selected_edges)))
    else:
        new_edge_index = torch.empty((2, 0), dtype=torch.long)

    # 子图的节点特征
    new_x = graph.x[final_nodes]
    # 子图的边特征(若存在 edge_attr)
    if graph.edge_attr is not None and len(selected_edge_mask) > 0:
        new_edge_attr = graph.edge_attr[selected_edge_mask]
    else:
        new_edge_attr = None

    subgraph = Data(
        x=new_x,
        edge_index=new_edge_index,
        edge_attr=new_edge_attr,
        num_nodes=len(final_nodes)
    )

    n = textual_nodes.iloc[final_nodes]
    e = textual_edges.iloc[selected_edge_mask]
    desc = n.to_csv(index=False)+'\n'+e.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])

    return subgraph, desc