import torch
import numpy as np
from pcst_fast import pcst_fast
import torch_geometric
from torch_geometric.data.data import Data
from torch_geometric.utils import to_networkx, k_hop_subgraph
import networkx as nx
from datasketch import MinHash, MinHashLSH
from datetime import datetime
import time

def retrieval_graphpack(
    graph,
    q_emb,
    root=-1,
    textual_nodes=None,
    textual_edges=None,
    topk=3, # top-k
    n=2, # n-hop
    load=20
):

    has_edge_attr = hasattr(graph, 'edge_attr') and graph.edge_attr is not None

    edge_index = graph.edge_index.cpu().numpy()

    if not textual_nodes is None:
        if len(textual_nodes) <= topk or len(textual_edges) == 0:
            desc = textual_nodes.to_csv(index=False) + '\n' + textual_edges.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])
            graph = Data(x=graph.x, edge_index=graph.edge_index, edge_attr=graph.edge_attr, num_nodes=graph.num_nodes)
            return graph, desc

    similarity = torch.nn.CosineSimilarity(dim=-1)(q_emb, graph.x) # 相似度计算
    if has_edge_attr:
        e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, graph.edge_attr)

    top_k_scores, top_k_indices = torch.topk(similarity, topk)

    candidate_nodes = []
    candidate_edges = []
    k = 5
    for center_node in top_k_indices:   

        center_node = int(center_node)
        # 获取子图
        sub_edge_index, subgraph_nodes = extract_n_hop_subgraph(graph.edge_index, center_node, n)

        # 获取子图节点的相似度作为奖励
        n_sorted_indices = np.argsort(similarity[subgraph_nodes].detach().cpu().numpy())[::-1]  # 降序排序索引

        if has_edge_attr:
            sub_edge = [tuple(i) for i in sub_edge_index.t().tolist()]
            edge_to_idx = {tuple(edge):idx for idx, edge in enumerate(graph.edge_index.t().tolist())}
            sub_edge_idx = [edge_to_idx[i] for i in sub_edge] # [2, 4, 5, 8, 9,...]
            subgraph_e_similarity = e_prizes[sub_edge_idx].tolist()
            e_sorted_indices = np.argsort(e_prizes[sub_edge_idx].detach().cpu().numpy())[::-1]  # 降序排序索引

        top_k_indices = n_sorted_indices[:k]  # 取前 k 个索引
        rewards_node = [0] * len(subgraph_nodes)
        for rank, idx in enumerate(top_k_indices):
            rewards_node[idx] = k - rank  # 奖励从 k 到 1

        if has_edge_attr:
            top_k_indices = e_sorted_indices[:k]  # 取前 k 个索引
            rewards_edge = [0] * len(sub_edge_idx)
            for rank, idx in enumerate(top_k_indices):
                rewards_edge[idx] = k - rank  # 奖励从 k 到 1

        # 获取子图节点的距离作为重量
        node_distances, edge_distances = shortest_hops_k_hop(graph, center_node, n)
        node_distances = [node_distances[i] for i in subgraph_nodes]
        if has_edge_attr:
            edge_distances = [edge_distances[i] for i in sub_edge]

        # 背包优化
        if has_edge_attr:
            selected_indices = knapsack_optimization(rewards_node+rewards_edge, node_distances+edge_distances, load)
        else:
            selected_indices = knapsack_optimization(rewards_node, node_distances, load)

        choiced_nodes = []
        choiced_edges = []
        for i in selected_indices:
            if i < len(subgraph_nodes):
                choiced_nodes.append(int(subgraph_nodes[i]))
            else:
                choiced_edges.append(sub_edge_idx[i-len(subgraph_nodes)])

        candidate_nodes = candidate_nodes + choiced_nodes
        candidate_edges = candidate_edges + choiced_edges


    for i in candidate_edges:
        edge = graph.edge_index.t().tolist()[i]
        candidate_nodes += list(edge)
    final_node_set = np.array(list(set(candidate_nodes)))
    final_node_set.sort()

    # 再次筛选边：在原图中，只保留 src, dst 都在 final_node_set 之内的边
    # 假设 edge_index 是一个 numpy 数组，形状为 (2, num_edges)
    s = edge_index[0]
    d = edge_index[1]

    # 确保 final_node_set 是集合或 numpy 数组
    final_node_set = np.array(final_node_set)

    # 向量化判断
    mask = np.isin(s, final_node_set) & np.isin(d, final_node_set)

    # 获取满足条件的边的索引
    adj_mask = np.where(mask)[0]

    # 构造最终的子图
    final_edge_index = graph.edge_index[:, adj_mask]
    final_edge_attr = graph.edge_attr[adj_mask] if has_edge_attr else None
    # 收集节点
    final_nodes_sorted = np.unique(np.concatenate([final_edge_index[0].cpu().numpy(), final_edge_index[1].cpu().numpy()]))
    final_nodes_sorted = np.array(sorted(final_nodes_sorted))

    # 映射到 0..len-1
    map_final = {n: i for i, n in enumerate(final_nodes_sorted.tolist())}
    src_final = [map_final[i] for i in final_edge_index[0].tolist()]
    dst_final = [map_final[i] for i in final_edge_index[1].tolist()]

    final_data = Data(
        x=graph.x[final_node_set],
        edge_index=torch.LongTensor([src_final, dst_final]),
        edge_attr=final_edge_attr,
        num_nodes=len(final_node_set)
    )

    # 组合新的 textual 信息
    if not textual_nodes is None:
        textual_n_final = textual_nodes.iloc[final_node_set]
        textual_n_final = textual_n_final[~textual_n_final['node_attr'].str.startswith('m.', na=False)]

        textual_e_final = textual_edges.iloc[adj_mask]

        desc_final = textual_n_final.to_csv(index=False) + '\n' + \
                    textual_e_final.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])
    else:
        desc_final = None

    return final_data, desc_final


# Step 1: 抽取 n-hop 子图
def extract_n_hop_subgraph(edge_index, center_node, n):
    """
    使用 PyG 的 k_hop_subgraph 抽取以 center_node 为中心的 n-hop 子图。
    :param edge_index: 图的边索引 (2, num_edges)
    :param center_node: 中心节点
    :param n: 跳数
    :return: 子图的边索引和子图节点列表
    """
    # 使用 k_hop_subgraph 提取 n-hop 子图
    subset, sub_edge_index, mapping, _ = torch_geometric.utils.k_hop_subgraph(
        node_idx=center_node,
        num_hops=n,
        edge_index=edge_index,
        relabel_nodes=0
    )

    return sub_edge_index, subset

def shortest_hops_k_hop(graph, center_node, n):
    """
    使用 k_hop_subgraph 计算其他节点和边到中心节点的最短跳数。
    
    参数:
    - graph: PyG Data 对象，包含图的边信息 (edge_index)。
    - center_node: 中心节点的索引。
    - n: 最大跳数。
    
    返回:
    - node_distances: 一个列表，其中第 i 个元素表示节点 i 到中心节点的最短跳数。
    - edge_distances: 一个字典，键为边的元组 (u, v)，值为该边的跳数。
    """
    edge_index = graph.edge_index
    num_nodes = graph.num_nodes if hasattr(graph, 'num_nodes') else edge_index.max().item() + 1
    
    # 初始化节点距离数组，默认为无穷大
    node_distances = [float('inf')] * num_nodes
    node_distances[center_node] = 0
    
    # 初始化边距离字典
    edge_distances = {}
    
    # 逐步增加跳数，直到覆盖所有节点或达到最大跳数
    for k in range(1, n + 1):
        subset, sub_edge_index, _, _ = k_hop_subgraph(center_node, k, edge_index, num_nodes=num_nodes)
        
        # 更新节点跳数
        for node in subset.tolist():
            if node_distances[node] == float('inf'):  # 如果节点尚未被访问
                node_distances[node] = k
        
        # 更新边跳数
        for u, v in sub_edge_index.t().tolist():
            if (u, v) not in edge_distances:
                edge_distances[(u, v)] = k
        
        # 如果已经覆盖所有节点，停止
        if len(subset) == num_nodes:
            break
    
    return node_distances, edge_distances

# Step 3: 定义背包优化算法
def knapsack_optimization(values, weights, capacity):
    """
    使用动态规划解决 0-1 背包问题。
    :param values: 节点的价值列表
    :param weights: 节点的重量列表
    :param capacity: 背包容量
    :return: 选中的节点索引列表
    """
    n = len(values)
    dp = [[0] * (capacity + 1) for _ in range(n + 1)]
    keep = [[False] * (capacity + 1) for _ in range(n + 1)]

    # 动态规划填充表格
    for i in range(1, n + 1):
        for w in range(capacity + 1):
            if weights[i - 1] <= w:
                if values[i - 1] + dp[i - 1][w - weights[i - 1]] > dp[i - 1][w]:
                    dp[i][w] = values[i - 1] + dp[i - 1][w - weights[i - 1]]
                    keep[i][w] = True
                else:
                    dp[i][w] = dp[i - 1][w]
            else:
                dp[i][w] = dp[i - 1][w]

    # 回溯选中的节点
    selected_items = []
    k = capacity
    for i in range(n, 0, -1):
        if keep[i][k]:
            selected_items.append(i - 1)
            k -= weights[i - 1]

    return selected_items