import random
import heapq
import networkx as nx
from typing import List, Tuple, Callable
from src.graph_utils import path_to_str
def beam_search(graph: nx.DiGraph,
                start_node: str,
                query: str,
                score_model,
                beam_width: int = 10,
                max_depth: int = 3,
                threshold: float = 0.9,
                max_results: int = 100) -> List[Tuple[List[Tuple[str, str]], float]]:
    beam = [(0, [])]
    results = []

    for _ in range(max_depth):
        cand_paths = []
        for score, path in beam:
            if score >= threshold:
                results.append((score, path))
                if len(results) >= max_results:
                    return results

            last_node = path[-1][2] if path else start_node
            for neighbor in graph.successors(last_node):
                if neighbor == start_node:
                    continue
                edge_data = graph.get_edge_data(last_node, neighbor)
                new_path = path + [(last_node, edge_data["relation"], neighbor)]
                cand_paths.append(new_path)

        paths_str = path_to_str(start_node, cand_paths)
        scores = score_model.score_batch(query, paths_str)
        candidates = [(score, path) for score, path in zip(scores, cand_paths)]

        beam = heapq.nlargest(beam_width, candidates)
        if not beam:
            break

    return results

# 示例用法
def example_scorer(query: str, path: List[Tuple[str, str, str]]) -> float:
    return random.random() / 2 + 0.5


if __name__ == "__main__":

    # 创建示例NetworkX有向图
    import datasets
    from src.graph_utils import summary
    dataset = datasets.load_dataset("data/RoG-webqsp", split="test")

    ins = dataset[10]
    ins = summary(ins, with_graph=True)
    from src.beam_search import beam_search
    from src.models import select_rerank_model

    rerank_model = select_rerank_model("bge-reranker-v2-m3")

    result = beam_search(graph=ins["graph"],
                         start_node=ins["q_entity"][0],
                         query=ins["question"],
                         score_model=rerank_model,
                         threshold=0.5, beam_width=10, max_depth=3)
    print(result)