import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import json
import random
import heapq
import networkx as nx
from typing import List, Tuple
import datasets
from rich.progress import track
from src.graph_utils import summary, path_to_str
from src.models import select_rerank_model

def load_data(file_path):
    with open(file_path, "r") as f:
        lines = f.readlines()
        return [json.loads(line) for line in lines if line]

def generate_hard_negatives(
    graph: nx.DiGraph,  # 确保这里是 DiGraph
    start_node: str,
    query: str,
    positive_paths: List[List[Tuple[str, str, str]]],
    score_model,
    beam_width: int = 10,
    max_depth: int = 3,
    num_negatives: int = 10
) -> List[List[Tuple[str, str, str]]]:
    """
    使用beam search生成hard negative样本
    """

    beam = [(0.0, [])]
    all_candidates = []
    positive_str = set(path_to_str(start_node, positive_paths))

    for _ in range(max_depth):
        candidates = []
        for _, path in beam:
            last_node = path[-1][2] if path else start_node
            # 使用neighbors()方法替代successors()
            for neighbor in graph.neighbors(last_node):
                edge_data = graph.get_edge_data(last_node, neighbor)
                new_path = path + [(last_node, edge_data["relation"], neighbor)]
                candidates.append(new_path)

        if not candidates:
            break

        # 计算所有候选路径的分数
        paths_str = path_to_str(start_node, candidates)
        scores = score_model.score_batch(query, paths_str)

        # 过滤掉正样本路径
        scored_candidates = []
        for score, path, path_str in zip(scores, candidates, paths_str):
            if path_str not in positive_str:
                scored_candidates.append((score, path))
                all_candidates.append((score, path))

        beam = heapq.nlargest(beam_width, scored_candidates)

    # 从所有候选中选择分数最高的作为hard negative
    all_candidates.sort(reverse=True)
    return [path for _, path in all_candidates[:num_negatives]]

def main():
    # 加载数据集和模型
    split = "train"
    dataset_name = "RoG-webqsp"
    dataset = load_data(f"data/{dataset_name}.{split}.valid.jsonl")
    rerank_model = select_rerank_model("bge-reranker-v2-m3")

    data = []
    for ins in track(dataset, description="Generating hard negatives..."):
        ins = summary(ins, with_graph=True)
        G = ins["graph"]
        q_entity = ins["q_entity"][0]

        # 生成hard negative样本
        pos_cnt = len(ins["truth_paths"])
        hard_negs = generate_hard_negatives(
            graph=G,
            start_node=q_entity,
            query=ins["question"],
            positive_paths=ins["truth_paths"],
            score_model=rerank_model,
            num_negatives=pos_cnt * 5
        )

        hard_negs = random.sample(hard_negs[2:], min(pos_cnt * 3, len(hard_negs[2:])))

        iter_data = {
            "query": ins["question"],
            "pos": ins["truth_paths"],
            "neg": hard_negs,
            "entity": q_entity
        }
        data.append(iter_data)

    # 保存结果
    with open(f"data/{dataset_name}.rerank.{split}.valid.hard_neg.jsonl", "w") as f:
        for d in data:
            if len(d["pos"]) == 0:
                continue
            d["pos"] = path_to_str(d["entity"], d["pos"])
            d["neg"] = path_to_str(d["entity"], d["neg"])
            f.write(json.dumps(d, ensure_ascii=False) + "\n")

if __name__ == "__main__":
    main()