import json
import numpy as np
import networkx as nx
from collections import deque
import re

import torch
import torch.nn.functional as F

from src.utils import logger

ENTITY_NAME_MAP = json.load(open("data/entities_names.json"))

# 添加BM25实现
class BM25:
    def __init__(self, corpus, k1=1.5, b=0.75, epsilon=0.25):
        self.k1 = k1
        self.b = b
        self.epsilon = epsilon
        self.corpus_size = len(corpus)
        self.corpus = corpus
        self.avg_doc_len = sum([len(doc.split()) for doc in corpus]) / self.corpus_size
        self.doc_freqs = []
        self.idf = {}
        self.doc_len = []
        self._initialize()

    def _initialize(self):
        for document in self.corpus:
            terms = self._tokenize(document)
            self.doc_len.append(len(terms))
            term_freqs = {}

            for term in terms:
                if term not in term_freqs:
                    term_freqs[term] = 0
                term_freqs[term] += 1

            self.doc_freqs.append(term_freqs)

            for term, freq in term_freqs.items():
                if term not in self.idf:
                    self.idf[term] = 0
                self.idf[term] += 1

        for term, doc_freq in self.idf.items():
            self.idf[term] = np.log((self.corpus_size - doc_freq + 0.5) / (doc_freq + 0.5) + 1.0)

    def _tokenize(self, text):
        # 简单的分词处理，可以根据需要使用更复杂的分词器
        return re.findall(r'\w+', text.lower())

    def get_scores(self, query):
        query_terms = self._tokenize(query)
        scores = np.zeros(self.corpus_size)

        for term in query_terms:
            if term not in self.idf:
                continue

            for doc_id, doc_freqs in enumerate(self.doc_freqs):
                if term not in doc_freqs:
                    continue

                doc_term_freq = doc_freqs[term]
                doc_len = self.doc_len[doc_id]

                # BM25计算公式
                numerator = self.idf[term] * doc_term_freq * (self.k1 + 1)
                denominator = doc_term_freq + self.k1 * (1 - self.b + self.b * doc_len / self.avg_doc_len)
                scores[doc_id] += numerator / denominator

        return scores

    def get_batch_scores(self, queries):
        """计算多个查询与语料库的BM25分数"""
        return np.array([self.get_scores(query) for query in queries])

def build_graph(graph: list) -> nx.DiGraph:
    G = nx.DiGraph()
    for triplet in graph:
        h, r, t = triplet
        h = ENTITY_NAME_MAP.get(h, h)
        t = ENTITY_NAME_MAP.get(t, t)
        G.add_edge(h, t, relation=f"--> {r} -->")  # 正向关系
        G.add_edge(t, h, relation=f"<-- {r} <--")
    return G

def build_vec_graph(graph: list, embed_model, use_agg=True, agg_method="mean") -> nx.DiGraph:
    G = build_graph(graph)

    nodes = list(G.nodes())
    node_embeddings = embed_model.encode(nodes)

    for node, embedding in zip(nodes, node_embeddings):
        G.nodes[node]['embedding'] = embedding

    # 处理边嵌入
    edges = list(G.edges(data=True))

    if use_agg:

        if agg_method == "mean":
            edge_keys = [edge[0] + edge[2]['relation'] + edge[1] for edge in edges]
            edge_embeddings = embed_model.encode(edge_keys)

            for edge, embedding in zip(edges, edge_embeddings):
                G[edge[0]][edge[1]]['embedding'] = embedding

            G_copy = G.copy()
            for node in G.nodes():
                neighbors_embedding = [G_copy.nodes[neighbor]['embedding'] for neighbor in G_copy.neighbors(node)]
                if neighbors_embedding:  # 确保有邻居
                    G.nodes[node]['embedding'] = G.nodes[node]['embedding'] + np.mean(neighbors_embedding, axis=0)

        elif agg_method == "shallow":
            edge_keys = [edge[0] + edge[2]['relation'] + edge[1] for edge in edges]
            edge_embeddings = embed_model.encode(edge_keys)

            for edge, embedding in zip(edges, edge_embeddings):
                G[edge[0]][edge[1]]['embedding'] = embedding

        elif agg_method == "transE":
            # TransE: h + r ≈ t 的思想
            # 首先获取所有关系的embedding
            relations = list(set([edge[2]['relation'] for edge in edges]))
            relation_embeddings = embed_model.encode(relations)
            relation_dict = {rel: emb for rel, emb in zip(relations, relation_embeddings)}

            # 为每条边存储关系embedding
            for edge in edges:
                G[edge[0]][edge[1]]['embedding'] = relation_dict[edge[2]['relation']]

            # TransE风格的节点embedding更新
            G_copy = G.copy()
            learning_rate = 0.01
            iterations = 5

            for _ in range(iterations):
                for edge in edges:
                    h, t, data = edge
                    rel_emb = relation_dict[data['relation']]

                    # 获取当前的头实体和尾实体embedding
                    h_emb = G.nodes[h]['embedding'].copy()
                    t_emb = G.nodes[t]['embedding'].copy()

                    # TransE损失: ||h + r - t||
                    predicted_t = h_emb + rel_emb
                    error = predicted_t - t_emb

                    # 梯度更新（简化版）
                    G.nodes[h]['embedding'] = h_emb - learning_rate * error
                    G.nodes[t]['embedding'] = t_emb + learning_rate * error

        elif agg_method == "RotatE":
            # RotatE: 在复数空间中，h * r = t
            # 将embedding转换为复数表示
            embed_dim = node_embeddings.shape[1]

            # 确保embedding维度是偶数，以便转换为复数
            if embed_dim % 2 != 0:
                # 如果是奇数维度，添加一维
                for node in G.nodes():
                    G.nodes[node]['embedding'] = np.concatenate([G.nodes[node]['embedding'], [0.0]])
                embed_dim += 1

            # 获取关系embedding
            relations = list(set([edge[2]['relation'] for edge in edges]))
            relation_embeddings = embed_model.encode(relations)

            # 如果关系embedding维度不匹配，进行调整
            if relation_embeddings.shape[1] != embed_dim:
                if relation_embeddings.shape[1] < embed_dim:
                    # 补齐维度
                    padding = np.zeros((relation_embeddings.shape[0], embed_dim - relation_embeddings.shape[1]))
                    relation_embeddings = np.concatenate([relation_embeddings, padding], axis=1)
                else:
                    # 截断维度
                    relation_embeddings = relation_embeddings[:, :embed_dim]

            relation_dict = {rel: emb for rel, emb in zip(relations, relation_embeddings)}

            # 为每条边存储关系embedding
            for edge in edges:
                G[edge[0]][edge[1]]['embedding'] = relation_dict[edge[2]['relation']]

            # 将实数embedding转换为复数
            def to_complex(embedding):
                half_dim = len(embedding) // 2
                real_part = embedding[:half_dim]
                imag_part = embedding[half_dim:]
                return real_part + 1j * imag_part

            def from_complex(complex_emb):
                return np.concatenate([complex_emb.real, complex_emb.imag])

            # RotatE风格的embedding更新
            learning_rate = 0.01
            iterations = 3

            for _ in range(iterations):
                for edge in edges:
                    h, t, data = edge

                    # 转换为复数
                    h_complex = to_complex(G.nodes[h]['embedding'])
                    t_complex = to_complex(G.nodes[t]['embedding'])
                    rel_complex = to_complex(relation_dict[data['relation']])

                    # RotatE: h * r = t，其中*是复数乘法
                    predicted_t = h_complex * rel_complex
                    error_complex = predicted_t - t_complex

                    # 简化的梯度更新
                    h_update = learning_rate * error_complex * np.conj(rel_complex)
                    t_update = learning_rate * error_complex

                    # 转换回实数并更新
                    G.nodes[h]['embedding'] = from_complex(h_complex - h_update)
                    G.nodes[t]['embedding'] = from_complex(t_complex + t_update)

        elif agg_method == "attention":
            # 基于注意力机制的聚合
            edge_keys = [edge[0] + edge[2]['relation'] + edge[1] for edge in edges]
            edge_embeddings = embed_model.encode(edge_keys)

            for edge, embedding in zip(edges, edge_embeddings):
                G[edge[0]][edge[1]]['embedding'] = embedding

            # 计算注意力权重并聚合
            G_copy = G.copy()
            for node in G.nodes():
                neighbors = list(G_copy.neighbors(node))
                if not neighbors:
                    continue

                node_emb = G_copy.nodes[node]['embedding']
                neighbor_embs = []
                attention_scores = []

                for neighbor in neighbors:
                    neighbor_emb = G_copy.nodes[neighbor]['embedding']
                    edge_emb = G_copy[node][neighbor]['embedding']

                    # 简单的注意力计算：使用点积
                    attention_score = np.dot(node_emb, neighbor_emb + edge_emb)
                    attention_scores.append(attention_score)
                    neighbor_embs.append(neighbor_emb)

                # Softmax归一化
                attention_scores = np.array(attention_scores)
                attention_weights = F.softmax(torch.tensor(attention_scores), dim=0).numpy()

                # 加权聚合
                aggregated = np.sum([w * emb for w, emb in zip(attention_weights, neighbor_embs)], axis=0)
                G.nodes[node]['embedding'] = G.nodes[node]['embedding'] + aggregated

        elif agg_method == "graphsage":
            # GraphSAGE风格的采样和聚合
            edge_keys = [edge[0] + edge[2]['relation'] + edge[1] for edge in edges]
            edge_embeddings = embed_model.encode(edge_keys)

            for edge, embedding in zip(edges, edge_embeddings):
                G[edge[0]][edge[1]]['embedding'] = embedding

            G_copy = G.copy()
            max_neighbors = 10  # 最大邻居采样数量

            for node in G.nodes():
                neighbors = list(G_copy.neighbors(node))
                if not neighbors:
                    continue

                # 采样邻居
                if len(neighbors) > max_neighbors:
                    sampled_neighbors = np.random.choice(neighbors, max_neighbors, replace=False)
                else:
                    sampled_neighbors = neighbors

                # 聚合邻居信息
                neighbor_embs = [G_copy.nodes[neighbor]['embedding'] for neighbor in sampled_neighbors]

                # 使用均值聚合 (也可以用LSTM或其他聚合函数)
                if neighbor_embs:
                    aggregated = np.mean(neighbor_embs, axis=0)
                    # 修改: 使用线性变换而不是拼接，保持维度一致
                    node_emb = G.nodes[node]['embedding']
                    # 简单的线性组合，保持原始维度
                    G.nodes[node]['embedding'] = 0.7 * node_emb + 0.3 * aggregated

        elif agg_method == "gcn":
            # GCN风格的聚合
            edge_keys = [edge[0] + edge[2]['relation'] + edge[1] for edge in edges]
            edge_embeddings = embed_model.encode(edge_keys)

            for edge, embedding in zip(edges, edge_embeddings):
                G[edge[0]][edge[1]]['embedding'] = embedding

            G_copy = G.copy()
            for node in G.nodes():
                neighbors = list(G_copy.neighbors(node))
                if not neighbors:
                    continue

                # GCN的聚合方式：包括自身
                all_nodes = [node] + neighbors
                node_embs = [G_copy.nodes[n]['embedding'] for n in all_nodes]

                # 度归一化
                degree = len(neighbors) + 1  # +1 for self-loop
                normalized_embs = [emb / np.sqrt(degree) for emb in node_embs]

                # 聚合
                aggregated = np.sum(normalized_embs, axis=0)
                G.nodes[node]['embedding'] = aggregated

        elif agg_method == "rotate_agg":
            # RotatE + 层级聚合：结合RotatE的复数空间优化和多层聚合
            # 在复数空间中，h * r ≈ t 的思想 + 多层邻居聚合

            # 首先获取所有关系的embedding
            relations = list(set([edge[2]['relation'] for edge in edges]))
            relation_embeddings = embed_model.encode(relations)
            relation_dict = {rel: emb for rel, emb in zip(relations, relation_embeddings)}

            # 为每条边存储关系embedding
            for edge in edges:
                G[edge[0]][edge[1]]['embedding'] = relation_dict[edge[2]['relation']]

            # 确保embedding维度是偶数，以便转换为复数
            embed_dim = node_embeddings.shape[1]
            if embed_dim % 2 != 0:
                # 如果是奇数维度，添加一维
                for node in G.nodes():
                    G.nodes[node]['embedding'] = np.concatenate([G.nodes[node]['embedding'], [0.0]])
                embed_dim += 1

            # 调整关系embedding维度
            if relation_embeddings.shape[1] != embed_dim:
                if relation_embeddings.shape[1] < embed_dim:
                    padding = np.zeros((relation_embeddings.shape[0], embed_dim - relation_embeddings.shape[1]))
                    relation_embeddings = np.concatenate([relation_embeddings, padding], axis=1)
                else:
                    relation_embeddings = relation_embeddings[:, :embed_dim]

            relation_dict = {rel: emb for rel, emb in zip(relations, relation_embeddings)}

            # 更新边的embedding
            for edge in edges:
                G[edge[0]][edge[1]]['embedding'] = relation_dict[edge[2]['relation']]

            # 定义复数转换函数
            def to_complex(embedding):
                half_dim = len(embedding) // 2
                real_part = embedding[:half_dim]
                imag_part = embedding[half_dim:]
                return real_part + 1j * imag_part

            def from_complex(complex_emb):
                return np.concatenate([complex_emb.real, complex_emb.imag])

            # 层级聚合参数
            num_layers = 3  # 聚合层数
            learning_rate = 0.01
            layer_iterations = 2  # 每层的RotatE迭代次数

            # 保存原始嵌入作为残差连接
            original_embeddings = {}
            for node in G.nodes():
                original_embeddings[node] = G.nodes[node]['embedding'].copy()

            # 多层聚合
            for layer in range(num_layers):
                logger.info(f"Processing RotatE aggregation layer {layer + 1}/{num_layers}")

                # 当前层的学习率衰减
                current_lr = learning_rate * (0.8 ** layer)

                # Step 1: RotatE风格的embedding优化
                for iteration in range(layer_iterations):
                    for edge in edges:
                        h, t, data = edge

                        # 转换为复数
                        h_complex = to_complex(G.nodes[h]['embedding'])
                        t_complex = to_complex(G.nodes[t]['embedding'])
                        rel_complex = to_complex(relation_dict[data['relation']])

                        # RotatE: h * r = t
                        predicted_t = h_complex * rel_complex
                        error_complex = predicted_t - t_complex

                        # 梯度更新
                        h_update = current_lr * error_complex * np.conj(rel_complex)
                        t_update = current_lr * error_complex

                        # 转换回实数并更新
                        G.nodes[h]['embedding'] = from_complex(h_complex - h_update)
                        G.nodes[t]['embedding'] = from_complex(t_complex + t_update)

                # Step 2: 基于注意力的邻居聚合
                G_snapshot = G.copy()  # 保存当前状态用于聚合

                for node in G.nodes():
                    neighbors = list(G_snapshot.neighbors(node))
                    if not neighbors:
                        continue

                    node_emb = G_snapshot.nodes[node]['embedding']
                    neighbor_info = []
                    attention_scores = []

                    # 收集邻居信息
                    for neighbor in neighbors:
                        neighbor_emb = G_snapshot.nodes[neighbor]['embedding']
                        edge_emb = G_snapshot[node][neighbor]['embedding']

                        # 结合邻居和关系信息
                        combined_emb = neighbor_emb + edge_emb
                        neighbor_info.append(combined_emb)

                        # 计算注意力分数（使用复数域的相似度）
                        node_complex = to_complex(node_emb)
                        neighbor_complex = to_complex(combined_emb)

                        # 复数域的相似度计算
                        attention = np.real(np.sum(node_complex * np.conj(neighbor_complex)))
                        attention_scores.append(attention)

                    # Softmax归一化注意力权重
                    if attention_scores:
                        attention_scores = np.array(attention_scores)
                        # 添加温度参数来控制注意力的锐度
                        temperature = 0.1 / (layer + 1)  # 随层数递减
                        attention_weights = F.softmax(torch.tensor(attention_scores / temperature), dim=0).numpy()

                        # 加权聚合邻居信息
                        aggregated = np.sum([w * info for w, info in zip(attention_weights, neighbor_info)], axis=0)

                        # 层级融合：结合原始embedding、当前embedding和聚合信息
                        alpha = 0.3  # 原始信息权重
                        beta = 0.4   # 当前信息权重
                        gamma = 0.3  # 聚合信息权重

                        # 确保权重和为1
                        total_weight = alpha + beta + gamma
                        alpha, beta, gamma = alpha/total_weight, beta/total_weight, gamma/total_weight

                        G.nodes[node]['embedding'] = (
                            alpha * original_embeddings[node] +
                            beta * node_emb +
                            gamma * aggregated
                        )

                # Step 3: 层间归一化（可选）
                if layer < num_layers - 1:  # 最后一层不进行归一化
                    for node in G.nodes():
                        # L2归一化
                        norm = np.linalg.norm(G.nodes[node]['embedding'])
                        if norm > 0:
                            G.nodes[node]['embedding'] = G.nodes[node]['embedding'] / norm * np.linalg.norm(original_embeddings[node])

        elif agg_method == "weighted_mean":
            # 加权平均聚合，根据度数来分配权重
            edge_keys = [edge[0] + edge[2]['relation'] + edge[1] for edge in edges]
            edge_embeddings = embed_model.encode(edge_keys)

            for edge, embedding in zip(edges, edge_embeddings):
                G[edge[0]][edge[1]]['embedding'] = embedding

            G_copy = G.copy()
            for node in G.nodes():
                neighbors = list(G_copy.neighbors(node))
                if not neighbors:
                    continue

                node_emb = G_copy.nodes[node]['embedding']
                neighbor_embs = [G_copy.nodes[neighbor]['embedding'] for neighbor in neighbors]

                # 根据度数计算权重
                degrees = [len(list(G_copy.neighbors(neighbor))) + 1 for neighbor in neighbors]
                weights = [1.0 / np.sqrt(degree) for degree in degrees]
                total_weight = sum(weights)
                normalized_weights = [w / total_weight for w in weights]

                # 加权聚合
                aggregated = np.sum([w * emb for w, emb in zip(normalized_weights, neighbor_embs)], axis=0)
                G.nodes[node]['embedding'] = 0.6 * node_emb + 0.4 * aggregated

    else:
        edge_keys = [edge[2]['relation'] for edge in edges]
        edge_embeddings = embed_model.encode(edge_keys)

        for edge, embedding in zip(edges, edge_embeddings):
            G[edge[0]][edge[1]]['embedding'] = embedding

    return G


def summary(ins, with_graph=False):
    G = build_graph(ins["graph"])
    q_entity = [ENTITY_NAME_MAP.get(q, q) for q in ins["q_entity"]]
    a_entity = [ENTITY_NAME_MAP.get(a, a) for a in ins["a_entity"]]
    truth_paths = ins["valid_paths"] if "valid_paths" in ins else get_truth_paths(q_entity, a_entity, G)
    mean_path_len = np.mean([len(p) for p in truth_paths]) if truth_paths else 0

    return {
        "id": ins["id"],
        "question": ins["question"],
        "q_entity": q_entity,
        "a_entity": a_entity,
        "choices": ins["choices"],
        "truth_paths": truth_paths,
        "answer": [ENTITY_NAME_MAP.get(a, a) for a in ins["answer"]],
        "nodes": G.number_of_nodes(),
        "edges": G.number_of_edges(),
        "paths": len(truth_paths),
        "mean_path_len": mean_path_len,
        "graph": G if with_graph else None,
        "answer_not_in_graph": any(a not in G for a in ins["answer"]),
    }


def get_relevants(wholeG,
                    nodes=None,
                    edges=None,
                    embeddings=None,
                    node_embeds=None,
                    edge_embeds=None,
                    node_count:float=0.1,
                    edge_count:float=0.1,
                    use_bm25:bool=False,
                    bm25_params:dict=None,
                    bm25_supplement_ratio:float=0.2):
    # 断言参数有效性 - cosine方法总是必需的
    assert embeddings is not None or (node_embeds is not None and edge_embeds is not None), \
        "For cosine method, either embeddings or both node_embeds and edge_embeds must be provided."

    # 如果使用BM25，确保节点或边的文本表示可用
    if use_bm25:
        if nodes is None and edges is None:
            logger.warning("BM25 is enabled but no nodes or edges text provided. Will only use cosine similarity.")
            use_bm25 = False

    graph_nodes = list(wholeG.nodes(data=True))
    graph_edges = list(wholeG.edges(data=True))

    # 计算需要选取的节点和边的数量
    if isinstance(node_count, float) and node_count < 1:
        node_count = int(node_count * len(graph_nodes))
    if isinstance(edge_count, float) and edge_count < 1:
        edge_count = int(edge_count * len(graph_edges))

    # 余弦相似度方法（必选）
    # 如果embeddings给定，则用它们作为节点和边的嵌入
    if embeddings is not None:
        node_embeds = embeddings
        edge_embeds = embeddings
        if node_embeds is not None or edge_embeds is not None:
            logger.warning("node_embeds or edge_embeds is not None, but embeddings is also provided.")

    # 确保node_embeds 和 edge_embeds 已是张量，若没有则转换为tensor
    node_embeds = torch.tensor(node_embeds, dtype=torch.float32) if not isinstance(node_embeds, torch.Tensor) else node_embeds
    edge_embeds = torch.tensor(edge_embeds, dtype=torch.float32) if not isinstance(edge_embeds, torch.Tensor) else edge_embeds

    # 取图中节点和边的嵌入，使用numpy.array优化性能
    node_graph_embeds = torch.tensor(
        np.array([v[1]['embedding'] for v in graph_nodes]),
        dtype=torch.float32
    )
    edge_graph_embeds = torch.tensor(
        np.array([v[2]['embedding'] for v in graph_edges]),
        dtype=torch.float32
    )

    # 检查并处理维度不匹配的问题
    query_node_dim = node_embeds.shape[1]
    graph_node_dim = node_graph_embeds.shape[1]

    if query_node_dim != graph_node_dim:
        logger.warning(f"Dimension mismatch: query nodes {query_node_dim}D vs graph nodes {graph_node_dim}D")

        if query_node_dim < graph_node_dim:
            # 查询embedding维度较小，需要扩展
            if graph_node_dim == 2 * query_node_dim:
                # 很可能是GraphSAGE等方法导致的维度翻倍，取前半部分
                node_graph_embeds = node_graph_embeds[:, :query_node_dim]
                logger.info(f"Truncated graph node embeddings to {query_node_dim}D")
            else:
                # 使用零填充
                padding = torch.zeros(node_embeds.shape[0], graph_node_dim - query_node_dim)
                node_embeds = torch.cat([node_embeds, padding], dim=1)
                logger.info(f"Padded query node embeddings to {graph_node_dim}D")
        else:
            # 查询embedding维度较大，截断
            node_embeds = node_embeds[:, :graph_node_dim]
            logger.info(f"Truncated query node embeddings to {graph_node_dim}D")

    # 同样处理边embedding的维度不匹配
    query_edge_dim = edge_embeds.shape[1]
    graph_edge_dim = edge_graph_embeds.shape[1]

    if query_edge_dim != graph_edge_dim:
        logger.warning(f"Dimension mismatch: query edges {query_edge_dim}D vs graph edges {graph_edge_dim}D")

        if query_edge_dim < graph_edge_dim:
            if graph_edge_dim == 2 * query_edge_dim:
                # 很可能是GraphSAGE等方法导致的维度翻倍，取前半部分
                edge_graph_embeds = edge_graph_embeds[:, :query_edge_dim]
                logger.info(f"Truncated graph edge embeddings to {query_edge_dim}D")
            else:
                # 使用零填充
                padding = torch.zeros(edge_embeds.shape[0], graph_edge_dim - query_edge_dim)
                edge_embeds = torch.cat([edge_embeds, padding], dim=1)
                logger.info(f"Padded query edge embeddings to {graph_edge_dim}D")
        else:
            # 查询embedding维度较大，截断
            edge_embeds = edge_embeds[:, :graph_edge_dim]
            logger.info(f"Truncated query edge embeddings to {graph_edge_dim}D")

    # 对所有向量进行归一化
    node_graph_embeds = F.normalize(node_graph_embeds, p=2, dim=1)
    edge_graph_embeds = F.normalize(edge_graph_embeds, p=2, dim=1)
    node_embeds = F.normalize(node_embeds, p=2, dim=1)
    edge_embeds = F.normalize(edge_embeds, p=2, dim=1)

    # 计算节点和边的余弦相似度
    node_similarities = torch.matmul(node_embeds, node_graph_embeds.T)  # (num_nodes, num_graph_nodes)
    edge_similarities = torch.matmul(edge_embeds, edge_graph_embeds.T)  # (num_edges, num_graph_edges)

    node_scores = node_similarities.cpu().numpy() if isinstance(node_similarities, torch.Tensor) else node_similarities
    edge_scores = edge_similarities.cpu().numpy() if isinstance(edge_similarities, torch.Tensor) else edge_similarities

    # Step 1: 首先用余弦相似度选择基础节点和边
    most_relevant_nodes = []
    most_relevant_edges = []

    # 计算基础节点数量（如果使用BM25，留出一部分给BM25补充）
    if use_bm25:
        cosine_node_count = int(node_count * (1 - bm25_supplement_ratio))
        cosine_edge_count = int(edge_count * (1 - bm25_supplement_ratio))
        bm25_node_count = node_count - cosine_node_count
        bm25_edge_count = edge_count - cosine_edge_count
    else:
        cosine_node_count = node_count
        cosine_edge_count = edge_count

    # 用余弦相似度选择基础节点
    cosine_selected_node_indices = set()
    for i in range(len(node_scores)):
        # 获取前cosine_node_count个最高得分的索引
        top_indices = np.argsort(node_scores[i])[-cosine_node_count:][::-1]
        for idx in top_indices:
            if node_scores[i][idx] > 0:  # 只添加有相关性的节点
                most_relevant_nodes.append(graph_nodes[idx])
                cosine_selected_node_indices.add(idx)

    # 用余弦相似度选择基础边
    cosine_selected_edge_indices = set()
    for i in range(len(edge_scores)):
        # 获取前cosine_edge_count个最高得分的索引
        top_indices = np.argsort(edge_scores[i])[-cosine_edge_count:][::-1]
        for idx in top_indices:
            if edge_scores[i][idx] > 0:  # 只添加有相关性的边
                most_relevant_edges.append(graph_edges[idx])
                cosine_selected_edge_indices.add(idx)

    # Step 2: 如果使用BM25，作为补充选择额外的节点和边
    if use_bm25:
        # 设置BM25参数
        bm25_params = bm25_params or {"k1": 1.5, "b": 0.75}

        # 处理节点
        if nodes is not None and bm25_node_count > 0:
            # 准备语料库 - 所有图节点的文本表示
            node_corpus = [node_name for node_name, _ in graph_nodes]

            # 初始化BM25
            node_bm25 = BM25(node_corpus, **bm25_params)

            # 计算BM25得分
            bm25_node_scores = node_bm25.get_batch_scores(nodes)

            # 为每个查询选择BM25补充节点
            for i in range(len(bm25_node_scores)):
                # 获取BM25分数排序的索引
                bm25_indices = np.argsort(bm25_node_scores[i])[::-1]

                # 选择BM25补充节点，跳过已被余弦相似度选择的节点
                bm25_added_count = 0
                for idx in bm25_indices:
                    if idx not in cosine_selected_node_indices and bm25_node_scores[i][idx] > 0:
                        most_relevant_nodes.append(graph_nodes[idx])
                        bm25_added_count += 1
                        if bm25_added_count >= bm25_node_count:
                            break

        # 处理边
        if edges is not None and bm25_edge_count > 0:
            # 准备语料库 - 所有图边的文本表示
            edge_corpus = []
            for src, dst, data in graph_edges:
                edge_text = f"{src} {data['relation']} {dst}"
                edge_corpus.append(edge_text)

            # 初始化BM25
            edge_bm25 = BM25(edge_corpus, **bm25_params)

            # 计算BM25得分
            bm25_edge_scores = edge_bm25.get_batch_scores(edges)

            # 为每个查询选择BM25补充边
            for i in range(len(bm25_edge_scores)):
                # 获取BM25分数排序的索引
                bm25_indices = np.argsort(bm25_edge_scores[i])[::-1]

                # 选择BM25补充边，跳过已被余弦相似度选择的边
                bm25_added_count = 0
                for idx in bm25_indices:
                    if idx not in cosine_selected_edge_indices and bm25_edge_scores[i][idx] > 0:
                        most_relevant_edges.append(graph_edges[idx])
                        bm25_added_count += 1
                        if bm25_added_count >= bm25_edge_count:
                            break

    return most_relevant_nodes, most_relevant_edges


# 根据这些节点和边构建子图
def build_subgraph(wholeG, most_relevant_nodes, most_relevant_edges):
    # 创建一个可修改的有向子图
    subgraph = nx.DiGraph()
    # 添加节点
    for node in most_relevant_nodes:
        if node[0] in wholeG:
            subgraph.add_node(node[0], **node[1])
    # 添加边
    for edge in most_relevant_edges:
        subgraph.add_edge(edge[0], edge[1], **edge[2])

    return subgraph

# 填充子图，已知 wholeG 中，所有的节点都与 ins["q_entity"] 相连，将子图中与ins["q_entity"]不相连的点以及路径中的边，从wholeG中找到并添加到子图中
def fill_subgraph(subgraph, wholeG, ins):
    q_entity = ENTITY_NAME_MAP.get(ins["q_entity"][0], ins["q_entity"][0])

    # 找到 q_entity 在 wholeG 中的节点
    q_entity_node = [node for node in wholeG.nodes if node == q_entity][0]

    # 找到子图中所有游离节点
    isolated_nodes = [node for node in subgraph.nodes if not nx.has_path(subgraph, node, q_entity_node)]

    # 对于每个游离节点，找到从该节点到 q_entity 的路径
    for node in isolated_nodes:
        try:
            # 在 wholeG 中找到从游离节点到 q_entity 的路径，并将路径中的节点和边添加到子图中
            paths = nx.all_shortest_paths(wholeG, source=q_entity_node, target=node)
            for path in paths:
                for i in range(len(path) - 1):
                    subgraph.add_edge(path[i], path[i + 1], **wholeG[path[i]][path[i + 1]])

        except nx.NetworkXNoPath:
            logger.warning(f"Warning: No path found from node {node} to q_entity {q_entity_node}")

    return subgraph


# 定义一个函数来进行宽度优先搜索
def bfs_with_rule(graph, start_node, target_rule, max_p = 10):
    result_paths = []
    queue = deque([(start_node, [])])  # 使用队列存储待探索节点和对应路径
    while queue:
        current_node, current_path = queue.popleft()

        # 如果当前路径符合规则，将其添加到结果列表中
        if len(current_path) == len(target_rule):
            result_paths.append(current_path)
            # if len(result_paths) >= max_p:
            #     break

        # 如果当前路径长度小于规则长度，继续探索
        if len(current_path) < len(target_rule):
            if current_node not in graph:
                continue
            for neighbor in graph.neighbors(current_node):
                # 剪枝：如果当前边类型与规则中的对应位置不匹配，不继续探索该路径
                rel = graph[current_node][neighbor]['relation']
                if rel != target_rule[len(current_path)] or len(current_path) > len(target_rule):
                    continue
                queue.append((neighbor, current_path + [(current_node, rel,neighbor)]))

    return result_paths

def merge_similar_paths(triples_list):
    """Merge similar paths in the triples list to reduce redundancy."""
    # Group paths by their pattern (head relation and tail relation structure)
    path_groups = {}

    for path in triples_list:
        if not path:  # Skip empty paths
            continue

        # Create a key based on the path structure (head entity and relations)
        path_key = (path[0][0], path[0][1])  # (head, first_relation)
        if len(path) > 1:
            path_key += (path[1][1],)  # Add second_relation if it exists

        if path_key not in path_groups:
            path_groups[path_key] = []
        path_groups[path_key].append(path)

    # Merge paths within each group
    merged_paths = []
    for paths in path_groups.values():
        if len(paths) == 1:
            merged_paths.append(paths[0])
            continue

        # Merge paths with the same structure
        merged_path = list(paths[0])  # Start with the first path
        if len(merged_path) == 1:
            # Merge single-triple paths
            all_tails = " / ".join(sorted(set(p[0][2] for p in paths)))
            merged_path[0] = (merged_path[0][0], merged_path[0][1], all_tails)
        else:
            # Merge two-triple paths
            all_mids = " / ".join(sorted(set(p[0][2] for p in paths)))
            all_tails = " / ".join(sorted(set(p[1][2] for p in paths)))
            merged_path[0] = (merged_path[0][0], merged_path[0][1], all_mids)
            merged_path[1] = (all_mids, merged_path[1][1], all_tails)

        merged_paths.append(merged_path)

    return merged_paths

def get_truth_paths(q_entity: list, a_entity: list, graph: nx.Graph) -> list:
    '''
    Get shortest paths connecting question and answer entities.
    '''
    # Select paths
    paths = []
    for h in q_entity:
        if h not in graph:
            continue
        for t in a_entity:
            if t not in graph:
                continue
            try:
                for p in nx.all_shortest_paths(graph, h, t):
                    paths.append(p)
            except:
                pass
    # Add relation to paths
    result_paths = []
    for p in paths:
        tmp = []
        for i in range(len(p)-1):
            u = p[i]
            v = p[i+1]
            tmp.append((u, graph[u][v]['relation'], v))
        result_paths.append(tmp)
    return result_paths

def path_to_str(q_entity, paths, with_deco=True):
    if isinstance(q_entity, list):
        q_entity = q_entity[0]  # 如果q_entity是列表，则取第一个

    paths_str = []

    def process_one_path(q_entity, path):
        _path = q_entity
        for idx, triple in enumerate(path):
            if with_deco:
                assert "-->" not in triple[1] and "<--" not in triple[1]
                # 判断关系方向
                if triple[0] == _path[-len(triple[0]):]:
                    triple_str = f" --> {triple[1]} --> {triple[2]}"
                elif triple[2] == _path[-len(triple[2]):]:
                    triple_str = f" <-- {triple[1]} <-- {triple[0]}"
                else:
                    raise ValueError(f"Path: {_path}, Triple: {triple}")
            else:
                triple_str = f" {triple[1]} {triple[2]}"
            _path += triple_str
        return _path

    for path in paths:
        _path = process_one_path(q_entity, path)
        paths_str.append(_path)

    return paths_str

def path_to_str_with_graph(q_entities, paths, G):
    paths_str = []

    def process_one_path(q_entity, path):
        # TODO 总感觉有点问题，如果是多个q_entity，那么每个q_entity的path_str会不一样
        _path = q_entity
        for idx, triple in enumerate(path):
            # assert "-->" not in triple[1] and "<--" not in triple[1]
            # 判断关系方向
            if triple[0] == _path[-len(triple[0]):]:
                relation = G[triple[0]][triple[2]]['relation']
                triple_str = f" {relation} {triple[2]}"
            elif triple[2] == _path[-len(triple[2]):]:
                relation = G[triple[2]][triple[0]]['relation']
                triple_str = f" {relation} {triple[0]}"
            else:
                raise ValueError(f"Path: {_path}, Triple: {triple}")
            _path += triple_str
        return _path

    for path in paths:
        for q_entity in q_entities:
            try:
                _path = process_one_path(q_entity, path)
                paths_str.append(_path)
                break
            except Exception as e:
                print(e)
                continue

    assert paths_str
    return paths_str

def get_simple_paths(q_entity: list, a_entity: list, graph: nx.Graph, hop=2) -> list:
    '''
    Get all simple paths connecting question and answer entities within given hop
    '''
    # Select paths
    paths = []
    for h in q_entity:
        if h not in graph:
            continue
        for t in a_entity:
            if t not in graph:
                continue
            try:
                for p in nx.all_simple_edge_paths(graph, h, t, cutoff=hop):
                    paths.append(p)
            except:
                pass
    # Add relation to paths
    result_paths = []
    for p in paths:
        result_paths.append([(e[0], graph[e[0]][e[1]]['relation'], e[1]) for e in p])
    return result_paths

def get_negative_paths(q_entity: list, a_entity: list, graph: nx.Graph, n_neg: int, hop=2) -> list:
    '''
    Get negative paths for question witin hop
    '''
    # sample paths
    start_nodes = []
    end_nodes = []
    node_idx = list(graph.nodes())
    for h in q_entity:
        if h in graph:
            start_nodes.append(node_idx.index(h))
    for t in a_entity:
        if t in graph:
            end_nodes.append(node_idx.index(t))

    import walker
    paths = walker.random_walks(graph, n_walks=n_neg, walk_len=hop, start_nodes=start_nodes, verbose=False)
    # Add relation to paths
    result_paths = []
    for p in paths:
        tmp = []
        # remove paths that end with answer entity
        if p[-1] in end_nodes:
            continue
        for i in range(len(p)-1):
            u = node_idx[p[i]]
            v = node_idx[p[i+1]]
            tmp.append((u, graph[u][v]['relation'], v))
        result_paths.append(tmp)
    return result_paths

def get_random_paths(q_entity: list, graph: nx.Graph, n=3, hop=2) -> tuple [list, list]:
    '''
    Get negative paths for question witin hop
    '''
    # sample paths
    start_nodes = []
    node_idx = list(graph.nodes())
    for h in q_entity:
        if h in graph:
            start_nodes.append(node_idx.index(h))

    import walker
    paths = walker.random_walks(graph, n_walks=n, walk_len=hop, start_nodes=start_nodes, verbose=False)
    # Add relation to paths
    result_paths = []
    rules = []
    for p in paths:
        tmp = []
        tmp_rule = []
        for i in range(len(p)-1):
            u = node_idx[p[i]]
            v = node_idx[p[i+1]]
            tmp.append((u, graph[u][v]['relation'], v))
            tmp_rule.append(graph[u][v]['relation'])
        result_paths.append(tmp)
        rules.append(tmp_rule)
    return result_paths, rules
