import torch
import numpy as np
from pcst_fast import pcst_fast
from torch_geometric.data.data import Data
from torch_geometric.utils import to_networkx
import networkx as nx
from datasketch import MinHash, MinHashLSH
from datetime import datetime

def retrieval_via_pcst(graph, q_emb, textual_nodes=None, textual_edges=None, topk=3, topk_e=3, cost_e=0.5):
    c = 0.01

    if not textual_nodes is None:
        if len(textual_nodes) == 0 or len(textual_edges) == 0:
            desc = textual_nodes.to_csv(index=False) + '\n' + textual_edges.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])
            graph = Data(x=graph.x, edge_index=graph.edge_index, edge_attr=graph.edge_attr, num_nodes=graph.num_nodes)
            return graph, desc
    root = -1  # unrooted
    num_clusters = 1
    pruning = 'gw'
    verbosity_level = 0
    if topk > 0:
        n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, graph.x)
        topk = min(topk, graph.num_nodes)
        _, topk_n_indices = torch.topk(n_prizes, topk, largest=True)

        n_prizes = torch.zeros_like(n_prizes)
        n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
    else:
        n_prizes = torch.zeros(graph.num_nodes)

    if topk_e > 0:
        e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, graph.edge_attr)
        topk_e = min(topk_e, e_prizes.unique().size(0))

        topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True)
        e_prizes[e_prizes < topk_e_values[-1]] = 0.0
        last_topk_e_value = topk_e
        for k in range(topk_e):
            indices = e_prizes == topk_e_values[k]
            value = min((topk_e-k)/sum(indices), last_topk_e_value)
            e_prizes[indices] = value
            last_topk_e_value = value*(1-c)
        # reduce the cost of the edges such that at least one edge is selected
        cost_e = min(cost_e, e_prizes.max().item()*(1-c/2))
    else:
        e_prizes = torch.zeros(graph.num_edges)

    costs = []
    edges = []
    vritual_n_prizes = []
    virtual_edges = []
    virtual_costs = []
    mapping_n = {}
    mapping_e = {}
    for i, (src, dst) in enumerate(graph.edge_index.T.numpy()):
        prize_e = e_prizes[i]
        if prize_e <= cost_e:
            mapping_e[len(edges)] = i
            edges.append((src, dst))
            costs.append(cost_e - prize_e)
        else:
            virtual_node_id = graph.num_nodes + len(vritual_n_prizes)
            mapping_n[virtual_node_id] = i
            virtual_edges.append((src, virtual_node_id))
            virtual_edges.append((virtual_node_id, dst))
            virtual_costs.append(0)
            virtual_costs.append(0)
            vritual_n_prizes.append(prize_e - cost_e)

    prizes = np.concatenate([n_prizes, np.array(vritual_n_prizes)])
    num_edges = len(edges)
    if len(virtual_costs) > 0:
        costs = np.array(costs+virtual_costs)
        edges = np.array(edges+virtual_edges)

    vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters, pruning, verbosity_level)

    selected_nodes = vertices[vertices < graph.num_nodes]
    selected_edges = [mapping_e[e] for e in edges if e < num_edges]
    virtual_vertices = vertices[vertices >= graph.num_nodes]
    if len(virtual_vertices) > 0:
        virtual_vertices = vertices[vertices >= graph.num_nodes]
        virtual_edges = [mapping_n[i] for i in virtual_vertices]
        selected_edges = np.array(selected_edges+virtual_edges)

    edge_index = graph.edge_index[:, selected_edges]
    selected_nodes = np.unique(np.concatenate([selected_nodes, edge_index[0].numpy(), edge_index[1].numpy()]))

    if not textual_nodes is None:
        n = textual_nodes.iloc[selected_nodes]
        e = textual_edges.iloc[selected_edges]
        desc = n.to_csv(index=False)+'\n'+e.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])
    else:
        desc = None

    mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}

    x = graph.x[selected_nodes]
    edge_attr = graph.edge_attr[selected_edges]
    src = [mapping[i] for i in edge_index[0].tolist()]
    dst = [mapping[i] for i in edge_index[1].tolist()]
    edge_index = torch.LongTensor([src, dst])
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, num_nodes=len(selected_nodes))

    return data, desc

def retrieval_via_pcst_plus(
    graph,
    q_emb,
    root=-1,
    textual_nodes=None,
    textual_edges=None,
    topk=3,
    topk_e=3,
    cost_e=0.5,
    max_hops=2,
    min_gain=0.5
):
    """
    将 PCST 检索与多跳检索+阈值裁剪结合：
    1) 用 PCST 得到初步子图 (data)；
    2) 在这个子图或原图上，从 PCST 所选节点出发，进行多跳检索+阈值裁剪，得到最终子图。
    """
    c = 0.01

    if not textual_nodes is None:
        if len(textual_nodes) == 0 or len(textual_edges) == 0:
            desc = textual_nodes.to_csv(index=False) + '\n' + textual_edges.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])
            graph = Data(x=graph.x, edge_index=graph.edge_index, edge_attr=graph.edge_attr, num_nodes=graph.num_nodes)
            return graph, desc

    num_clusters = 1
    pruning = 'gw'
    verbosity_level = 0

    # =============== Step 1: 根据 query 的相似度，设置节点 prize = rank 值 = topk..1 ===============
    if topk > 0:
        if 0:
            
            query_m = MinHash(num_perm=20)
            for word in q_emb.split():
                query_m.update(word.encode('utf8'))
            results = graph.lsh.query(query_m)
            topk_n_indices = [int(result.split('_')[1]) for result in results]
            n_prizes_new = torch.zeros(graph.num_nodes)
            # 给 top-k 节点分配递减的 prize
            n_prizes_new[topk_n_indices] = torch.arange(len(topk_n_indices), 0, -1).float()
            n_prizes = n_prizes_new
        else:
            n_prizes = torch.nn.functional.cosine_similarity(q_emb, graph.x, dim=-1)
            # print(n_prizes)
            # exit()
            topk = min(topk, graph.num_nodes)
            _, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
            n_prizes_new = torch.zeros_like(n_prizes)
            # 给 top-k 节点分配递减的 prize
            n_prizes_new[topk_n_indices] = torch.arange(topk, 0, -1).float().to(n_prizes.device)
            n_prizes = n_prizes_new
    else:
        n_prizes = torch.zeros(graph.num_nodes, device=graph.x.device)

    # =============== Step 2: 根据 query 的相似度，给边分配 prize & 调整 cost_e ===============
    if topk_e > 0:
        e_prizes = torch.nn.functional.cosine_similarity(q_emb, graph.edge_attr, dim=-1)
        unique_vals = e_prizes.unique()
        topk_e = min(topk_e, unique_vals.size(0))

        topk_e_values, _ = torch.topk(unique_vals, topk_e, largest=True)
        # 将低于最后一个 topk_e_values[-1] 的全部置为 0
        e_prizes[e_prizes < topk_e_values[-1]] = 0.0

        # 递减赋值
        last_topk_e_value = float(topk_e)
        for k in range(topk_e):
            indices = (e_prizes == topk_e_values[k])
            value = min((topk_e - k) / indices.sum().item(), last_topk_e_value)
            e_prizes[indices] = value
            last_topk_e_value = value * (1 - c)

        # 保证 cost_e 不大于最优边奖赏
        cost_e = min(cost_e, e_prizes.max().item() * (1 - c / 2))
    else:
        e_prizes = torch.zeros(graph.num_edges, device=graph.x.device)

    # =============== Step 3: 整理 PCST 的输入 (edges, prizes, costs) ===============
    costs = []
    edges_list = []
    virtual_n_prizes = []
    virtual_edges = []
    virtual_costs = []
    mapping_n = {}
    mapping_e = {}

    edge_arr = graph.edge_index.cpu().numpy()  # [2, num_edges]
    e_prizes_cpu = e_prizes.cpu().numpy()

    for i, (src, dst) in enumerate(edge_arr.T):
        prize_e = e_prizes_cpu[i]
        if prize_e <= cost_e:
            # 直接是一条普通边
            mapping_e[len(edges_list)] = i
            edges_list.append((src, dst))
            costs.append(cost_e - prize_e)  # 实际cost = cost_e - e_prize
        else:
            # 变成一个虚节点
            virtual_node_id = graph.num_nodes + len(virtual_n_prizes)
            mapping_n[virtual_node_id] = i
            virtual_edges.append((src, virtual_node_id))
            virtual_edges.append((virtual_node_id, dst))
            virtual_costs.append(0)
            virtual_costs.append(0)
            virtual_n_prizes.append(prize_e - cost_e)

    prizes = np.concatenate([n_prizes.cpu().numpy(), np.array(virtual_n_prizes)])
    num_edges_normal = len(edges_list)
    if len(virtual_costs) > 0:
        costs = np.array(costs + virtual_costs)
        edges_final = np.array(edges_list + virtual_edges)
    else:
        costs = np.array(costs)
        edges_final = np.array(edges_list)

    # =============== Step 4: 调用 PCST ===============

    vertices, edges_sol = pcst_fast(
        edges_final, prizes, costs,
        root, num_clusters, pruning, verbosity_level
    )

    # 将 PCST 选出的内容还原到原图索引空间
    selected_nodes = vertices[vertices < graph.num_nodes]

    selected_edges = [mapping_e[e] for e in edges_sol if e < num_edges_normal]
    virtual_vertices = vertices[vertices >= graph.num_nodes]

    if len(virtual_vertices) > 0:
        virtual_sol_edges = [mapping_n[v_idx] for v_idx in virtual_vertices]
        selected_edges = np.array(list(selected_edges) + virtual_sol_edges)
    else:
        selected_edges = np.array(selected_edges)

    # 由边来再确认节点 (防止重复)
    edge_index_pcst = graph.edge_index[:, selected_edges]
    selected_nodes = np.unique(np.concatenate([selected_nodes, edge_index_pcst[0].cpu().numpy(), edge_index_pcst[1].cpu().numpy()]))

    # =============== Step 5: 先构造初步子图 (PCST 结果) ===============
    x_pcst = graph.x[selected_nodes]
    edge_attr_pcst = graph.edge_attr[selected_edges]
    mapping_idx = {n: i for i, n in enumerate(selected_nodes.tolist())}

    src_pcst = [mapping_idx[i] for i in edge_index_pcst[0].tolist()]
    dst_pcst = [mapping_idx[i] for i in edge_index_pcst[1].tolist()]
    edge_index_pcst = torch.LongTensor([src_pcst, dst_pcst])

    data_pcst = Data(
        x=x_pcst,
        edge_index=edge_index_pcst,
        edge_attr=edge_attr_pcst,
        num_nodes=len(selected_nodes)
    )



    # =============== Step 6: 在 PCST 结果基础上再做 “多跳检索 + 阈值裁剪” ===============
    # 你可以选择在“原图”上多跳检索，也可以只在 PCST 子图里多跳检索。
    # 通常为了再扩张节点，需要回到原图做 BFS，这里演示在“原图”进行 BFS。

    adj_list = build_adjacency_list(graph.num_nodes, graph.edge_index)  # 上例已有函数
    adjacency_map = build_edge_map(graph.edge_index)                   # 新增的edge_map生成函数
    if root == -1:
        final_node_set = multi_hop_search_and_prune_advanced(
            graph=graph,
            start_nodes=selected_nodes,  # 例如PCST选出的节点
            q_emb=q_emb,
            max_hops=4,
            top_n=2,
            gain_threshold=0.5,
            adjacency_map=adjacency_map,
            adj_list=adj_list,
            alpha=0.5,
            beta=0.5,
            gamma=0,
        )
    else:
        final_node_set = multi_hop_search_and_prune_advanced(
            graph=graph,
            start_nodes=selected_nodes,  # 例如PCST选出的节点
            q_emb=q_emb,
            max_hops=2,
            top_n=2,
            gain_threshold=0.35,
            adjacency_map=adjacency_map,
            adj_list=adj_list,
            alpha=0.5,
            beta=0.5,
            gamma=0,
        )

    final_node_set = np.array(list(final_node_set))
    final_node_set.sort()

    # 再次筛选边：在原图中，只保留 src, dst 都在 final_node_set 之内的边
    adj_mask = []
    for idx, (s, d) in enumerate(edge_arr.T):
        if s in final_node_set and d in final_node_set:
            adj_mask.append(idx)
    adj_mask = np.array(adj_mask)

    # 构造最终的子图
    final_edge_index = graph.edge_index[:, adj_mask]
    final_edge_attr = graph.edge_attr[adj_mask]
    # 收集节点
    final_nodes_sorted = np.unique(np.concatenate([final_edge_index[0].cpu().numpy(), final_edge_index[1].cpu().numpy()]))
    final_nodes_sorted = np.array(sorted(final_nodes_sorted))

    # 映射到 0..len-1
    map_final = {n: i for i, n in enumerate(final_nodes_sorted.tolist())}
    src_final = [map_final[i] for i in final_edge_index[0].tolist()]
    dst_final = [map_final[i] for i in final_edge_index[1].tolist()]

    final_data = Data(
        x=graph.x[final_nodes_sorted],
        edge_index=torch.LongTensor([src_final, dst_final]),
        edge_attr=final_edge_attr,
        num_nodes=len(final_nodes_sorted)
    )
    if len(final_nodes_sorted) == 0:
        final_data = Data(
            x=graph.x[final_node_set],
            edge_index=torch.LongTensor([src_final, dst_final]),
            edge_attr=final_edge_attr,
            num_nodes=len(final_node_set)
        )


    # 转换为 NetworkX 图
    G = to_networkx(final_data, to_undirected=True)

    # 检查连通性
    is_connected = nx.is_connected(G)
    if not final_data:
        print(f"图是否为联通图: {is_connected}")  # 输出: True

    # 组合新的 textual 信息
    if not textual_nodes is None:
        textual_n_final = textual_nodes.iloc[final_nodes_sorted]
        textual_e_final = textual_edges.iloc[adj_mask]

        desc_final = textual_n_final.to_csv(index=False) + '\n' + \
                    textual_e_final.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])
    else:
        desc_final = None
    if final_data.num_nodes == 0:
        print(data_pcst)

    return final_data, desc_final


def build_edge_map(edge_index):
    """
    edge_index: shape=[2, num_edges], 每列 (src, dst) 表示一条边.
    返回：
      adjacency_map: dict, 
                     adjacency_map[(src, dst)] = edge_index_in_PyG
                     adjacency_map[(dst, src)] = edge_index_in_PyG 
                     (因为是无向图, 方便双向查找)
    """
    adjacency_map = {}
    arr = edge_index.cpu().numpy()  # [2, num_edges]
    for i, (s, d) in enumerate(arr.T):
        adjacency_map[(s, d)] = i
        adjacency_map[(d, s)] = i
    return adjacency_map

def build_adjacency_list(num_nodes, edge_index):
    """
    将 PyG 的 edge_index 转化为 adjacency list。
    edge_index: shape=[2, num_edges]
    返回一个 list[set]，adj_list[u] 存储与节点 u 相连的所有邻居。
    """
    adj_list = [set() for _ in range(num_nodes)]
    edge_arr = edge_index.cpu().numpy()  # (2, num_edges)
    for src, dst in zip(edge_arr[0], edge_arr[1]):
        adj_list[src].add(dst)
        adj_list[dst].add(src)
    return adj_list

def advanced_gain_function(
    graph,
    node_idx: int,
    q_emb: torch.Tensor,
    selected_nodes: set,
    adjacency_map: dict,
    adj_list: list,
    alpha: float = 0.6,
    beta: float = 0.3,
    gamma: float = 0.3
):
    """
    更高级的增益算法:
      gain = alpha * node_sim
            + beta  * submodular_part
            + gamma * edge_sim_part

    1) node_sim: 节点嵌入与 query 的相似度
    2) submodular_part: 与已选节点集合 S 的"边际增量"(如上例中的平均相似度 + 次模变换)
    3) edge_sim_part: 该节点与其邻居之间的边(若存在)与 query 的相似度平均值/最大值等

    参数:
    ----
    graph:         PyG Data 对象, 包含 x[node_idx], edge_attr[edge_idx], ...
    node_idx:      候选节点ID
    q_emb:         查询向量 [1, d] or [d]
    selected_nodes: 已选节点集合
    adjacency_map:  (src, dst)->edge_idx 的映射
    adj_list:      adjacency list, adj_list[node] = set(邻居)
    alpha, beta, gamma: 各部分增益的权重

    返回:
      gain: float
    """
    device = graph.x.device
    if q_emb.dim() == 1:
        q_emb = q_emb.unsqueeze(0)  # [1, d]
    q_emb = q_emb.to(device)

    # (1) node_sim: 节点与 query 的余弦相似度
    node_emb = graph.x[node_idx].unsqueeze(0)  # [1, d]
    cos_node = torch.nn.functional.cosine_similarity(q_emb, node_emb, dim=-1)
    node_sim = cos_node.item()

    # (2) submodular_part: 与已选节点的平均相似度(示例) + 次模变换
    if len(selected_nodes) == 0:
        submod_part = 1.0
    else:
        nodes_emb = graph.x[list(selected_nodes)]
        # 扩展后与 node_emb 作 pairwise sim
        sim_vals = torch.nn.functional.cosine_similarity(
            node_emb.expand(len(selected_nodes), -1),
            nodes_emb,
            dim=-1
        )
        mean_sim = sim_vals.mean()
        # 举例: submod_part = sqrt(1 + max(0, mean_sim))
        submod_part = torch.sqrt(1.0 + torch.relu(mean_sim)).item()

    # (3) edge_sim_part: 该节点与所有邻居之间的边相似度(对 query)
    #     假定 graph.edge_attr[edge_idx] 存有 embedding, 跟 query 求余弦相似度.
    #     取平均值或最大值都可以. 此处以平均为例.
    neighbor_edge_sims = []
    for nbr in adj_list[node_idx]:
        # 找到 node_idx 与 neighbor 之间的边 index
        if (node_idx, nbr) in adjacency_map:
            e_idx = adjacency_map[(node_idx, nbr)]
            edge_emb = graph.edge_attr[e_idx].unsqueeze(0)  # [1, d]
            cos_edge = torch.nn.functional.cosine_similarity(q_emb, edge_emb, dim=-1)
            neighbor_edge_sims.append(cos_edge.item())

    if len(neighbor_edge_sims) > 0:
        edge_sim_part = float(np.mean(neighbor_edge_sims))
    else:
        edge_sim_part = 0.0

    gain = alpha * node_sim + beta * submod_part + gamma * edge_sim_part
    return gain

def multi_hop_search_and_prune_advanced(
    graph: Data,
    start_nodes: np.ndarray,
    q_emb: torch.Tensor,
    max_hops: int,
    top_n: int,
    gain_threshold: float,
    adjacency_map: dict,  # (src, dst) -> edge_idx
    adj_list: list,
    alpha: float = 0.6,
    beta: float = 0.3,
    gamma: float = 0.3
):
    """
    在多跳检索中使用我们自定义的 'advanced_gain_function'。
    """
    device = graph.x.device
    if q_emb.dim() == 1:
        q_emb = q_emb.unsqueeze(0)
    q_emb = q_emb.to(device)

    visited = set(start_nodes.tolist())
    queue = [(node, 0) for node in start_nodes]
    final_nodes = set(start_nodes.tolist())

    while queue:
        current_node, depth = queue.pop(0)
        if depth >= max_hops:
            continue
        nbr_num = 0
        for nbr in adj_list[current_node]:
            if nbr not in visited:
                # 计算增益
                g = advanced_gain_function(
                    graph=graph,
                    node_idx=nbr,
                    q_emb=q_emb,
                    selected_nodes=final_nodes,
                    adjacency_map=adjacency_map,
                    adj_list=adj_list,
                    alpha=alpha,
                    beta=beta,
                    gamma=gamma
                )
                if g >= gain_threshold and nbr_num < top_n:
                    visited.add(nbr)
                    final_nodes.add(nbr)
                    queue.append((nbr, depth + 1))
                    nbr_num += 1
    return final_nodes
