import torch
import networkx as nx
from torch_geometric.data import Data
from torch.nn.functional import cosine_similarity
import torch.nn.functional as F

def retrievel_via_mincut(
    graph: Data,
    q_emb: torch.Tensor,
    textual_nodes, 
    textual_edges,
    node_percentile, # 节点分数分位数，例如0.7表示取70%分位 (使相似度高于此的节点更偏向保留)
    edge_percentile # 边分数分位数
) -> Data:
    """
    使用最小割，在无向图上抽取与查询 q_emb 最相关的子图。
    1) 动态计算节点阈值 node_positive_bias 和边阈值 edge_cost_bias；
    2) 将高于节点阈值的节点连到 S(源)，低于阈值的连到 T(汇)；
    3) 根据边相似度与边阈值来计算割断代价；
    4) 通过最小割保留 S 侧的“高相关连通子图”。

    参数:
    -------
    graph:          PyG 的 Data 对象，包含 x, edge_index, edge_attr, num_nodes 等
    q_emb:          查询向量 (形如 [d])，与 graph.x, graph.edge_attr 维度对应
    node_percentile:用于动态计算节点阈值的分位数 (0~1 之间)
    edge_percentile:用于动态计算边阈值的分位数 (0~1 之间)

    返回:
    -------
    subgraph: 一个 PyG Data 子图, 其中包含最小割后位于 S 侧的节点及它们之间的边。
    """

    if len(textual_nodes) == 0 or len(textual_edges) == 0:
        desc = textual_nodes.to_csv(index=False) + '\n' + textual_edges.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])
        graph = Data(x=graph.x, edge_index=graph.edge_index, edge_attr=graph.edge_attr, num_nodes=graph.num_nodes)
        return graph, desc
    device = graph.x.device if graph.x.is_cuda else torch.device("cpu")
    q_emb = q_emb.to(device)
    if q_emb.dim() == 1:
        q_emb = q_emb.unsqueeze(0)  # [1, d]

    # 1) 计算节点与查询的余弦相似度
    with torch.no_grad():
        node_scores = cosine_similarity(q_emb, graph.x, dim=-1).squeeze(0)
        # node_scores 大小: [num_nodes]

    # 2) 计算边与查询的相似度 (若没有edge_attr, 则置0)
    num_edges = graph.edge_index.shape[1]
    if graph.edge_attr is not None and graph.edge_attr.shape[0] == num_edges:
        with torch.no_grad():
            edge_scores = torch.nn.CosineSimilarity(dim=-1)(q_emb, graph.edge_attr)
    else:
        edge_scores = torch.zeros(num_edges, device=device)

    # 3) 动态计算 node_positive_bias 和 edge_cost_bias
    #    这里示例：分别取对应分位数(如0.7分位)，你可自由调参/使用均值+标准差等方法
    node_positive_bias = float(torch.quantile(node_scores, node_percentile))
    edge_cost_bias = float(torch.quantile(edge_scores, edge_percentile))


    # ========== 说明：节点 / 边阈值的含义 =============
    # node_positive_bias: 分数≥这个阈值 => 更倾向保留(连 S)； <阈值 => 倾向割掉(连 T)。
    # edge_cost_bias: 边 cut 的基础代价。score 高 > 这个阈值 => cost 小；score 低 <阈值 => cost 大。
    # 你可以根据相似度分布或期望的子图大小，适当提升/降低分位数 (node_percentile / edge_percentile)。

    # 4) 构造无向图(带两个超级节点 S, T)
    G = nx.Graph()
    S, T = "super_source", "super_sink"
    G.add_node(S)
    G.add_node(T)

    # 把原图节点加进去
    for i in range(graph.num_nodes):
        G.add_node(i)

    # 5) 根据节点分数与 node_positive_bias ，把节点连到 S 或 T
    #    为保证数值>0，这里加一个小 eps 避免 capacity=0 的边
    eps = 1e-6
    for i in range(graph.num_nodes):
        score = node_scores[i].item()

        if score >= node_positive_bias:
            # 分数越高 => capacity越大 => 更难被割
            capacity_s = (score - node_positive_bias) + 0.1  # 0.1只是一个基准值
            if capacity_s < 0:
                capacity_s = eps
            G.add_edge(S, i, capacity=capacity_s)
        else:
            # 分数越低 => 更想割掉 => 连到 T
            capacity_t = (node_positive_bias - score) + 0.1
            if capacity_t < 0:
                capacity_t = eps
            G.add_edge(i, T, capacity=capacity_t)

    # 6) 根据 edge_scores 来计算割边 cost
    #    思路： cost_val = max( eps, ( edge_cost_bias - edge_score ) + base_value )
    #    若 edge_score大(>edge_cost_bias)，则 cost_val 小 => 更不想割
    #    若 edge_score小(<=edge_cost_bias)，则 cost_val 大 => 更容易割
    base_value = 0.1  # 避免出现 cost=0
    edge_index_np = graph.edge_index.cpu().numpy()
    for e_i in range(num_edges):
        src = edge_index_np[0, e_i]
        dst = edge_index_np[1, e_i]
        escore = edge_scores[e_i].item()

        cost_val = (edge_cost_bias - escore) + base_value
        if cost_val < eps:
            cost_val = eps  # 避免出现0甚至负值

        G.add_edge(src, dst, capacity=cost_val)

    # 7) 最小割
    cut_value, (set_s, set_t) = nx.minimum_cut(G, S, T, capacity='capacity')

    # 8) 划分结果：保留 set_s - {S} 范围内的节点 + 边
    set_s = set_s - {S}
    set_t = set_t - {T}
    selected_nodes = sorted(list(set_s))
    node_old_to_new = {old_idx: i for i, old_idx in enumerate(selected_nodes)}

    # 在原图 edge_index 里，只保留两端都在 set_s 的边
    selected_edges = []
    for e_i in range(num_edges):
        src = edge_index_np[0, e_i]
        dst = edge_index_np[1, e_i]
        if (src in set_s) and (dst in set_s):
            selected_edges.append(e_i)

    new_edge_index_list = []
    for e_i in selected_edges:
        src = edge_index_np[0, e_i]
        dst = edge_index_np[1, e_i]
        new_src = node_old_to_new[src]
        new_dst = node_old_to_new[dst]
        new_edge_index_list.append((new_src, new_dst))

    if len(new_edge_index_list) > 0:
        new_edge_index = torch.LongTensor(list(zip(*new_edge_index_list)))
    else:
        new_edge_index = torch.empty((2, 0), dtype=torch.long)

    new_x = graph.x[selected_nodes]
    if graph.edge_attr is not None:
        new_edge_attr = graph.edge_attr[selected_edges]
    else:
        new_edge_attr = None

    subgraph = Data(
        x=new_x,
        edge_index=new_edge_index,
        edge_attr=new_edge_attr,
        num_nodes=len(selected_nodes)
    )
    n = textual_nodes.iloc[selected_nodes]
    e = textual_edges.iloc[selected_edges]
    desc = n.to_csv(index=False)+'\n'+e.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])

    return subgraph, desc