import argparse
import os
import random
import json
import traceback
from rich.progress import track
import datasets
from src.graph_utils import summary


def process_edge_relation(edge, relation):
    """处理边的关系并返回合适的路径三元组

    Args:
        edge: 包含源节点和目标节点的边 (source, target, attrs)
        relation: 关系字符串

    Returns:
        tuple: (start_node, processed_relation, end_node)
    """
    if "<--" in relation:
        relation = relation.replace("<--", "").strip()
        return (edge[0], relation, edge[1])
    elif "-->" in relation:
        relation = relation.replace("-->", "").strip()
        return (edge[1], relation, edge[0])
    else:
        return (edge[0], relation, edge[1])

def load_data(file_path):
    with open(file_path, "r") as f:
        lines = f.readlines()
        return [json.loads(line) for line in lines if line]

def get_random_paths(q_entity, G, n=1):
    if q_entity not in G:
        return []  # 如果起始节点不在图中,返回空列表

    result_paths = []
    for _ in range(n):
        hop = random.choices([1, 2, 3], weights=[0.7, 0.25, 0.05])[0]
        path = []

        current_node = q_entity
        for _ in range(hop):
            neighbors = list(G.neighbors(current_node))

            if not neighbors:
                break  # 如果当前节点没有邻居,结束这条路径

            next_node = random.choice(neighbors)
            relation = G[current_node][next_node]['relation']
            path.append(process_edge_relation((current_node, next_node, relation), relation))
            assert q_entity in str(path)

            current_node = next_node

        if len(path) >= 1:  # 只有当路径长度大于1时才添加到结果中
            result_paths.append(path)  # 不包括起始节点

    return result_paths

def path_to_str_with_graph(q_entities, paths, G):
    paths_str = []

    def process_one_path(q_entity, path):
        _path = q_entity
        for idx, triple in enumerate(path):
            assert "-->" not in triple[1] and "<--" not in triple[1]
            # 判断关系方向
            if triple[0] == _path[-len(triple[0]):]:
                relation = G[triple[0]][triple[2]]['relation']
                triple_str = f" {relation} {triple[2]}"
            elif triple[2] == _path[-len(triple[2]):]:
                relation = G[triple[2]][triple[0]]['relation']
                triple_str = f" {relation} {triple[0]}"
            else:
                raise ValueError(f"Path: {_path}, Triple: {triple}, q_entity: {q_entity}, paths: {paths}")
            _path += triple_str
        return _path

    for path in paths:
        for q_entity in q_entities:
            try:
                _path = process_one_path(q_entity, path)
                paths_str.append(_path)
                break
            except Exception as e:
                print(e, traceback.format_exc())
                continue

    # assert paths_str or len(paths) == 0, f"{paths}"
    return paths_str

def get_last_node(q_entities, path, G):
    """获取路径的最后一个节点

    Args:
        q_entities: 起始节点
        path: 路径三元组列表 [(s1,r1,t1), (s2,r2,t2), ...]
        G: 图对象

    Returns:
        str: 路径的最后一个节点
    """
    if not path:
        return q_entities[0]

    current_node = None
    for idx, triple in enumerate(path):
        s, r, t = triple
        # 验证路径连接性
        if idx == 0 and current_node is None:
            if s in q_entities:
                current_node = t
            elif t in q_entities:
                current_node = s
            else:
                raise ValueError(f"路径不连续: 当前节点 {current_node}, 三元组 {triple}")
        else:
            if s != current_node and t != current_node:
                raise ValueError(f"路径不连续: 当前节点 {current_node}, 三元组 {triple}")

            # 确定下一个节点
            current_node = t if s == current_node else s

    return current_node

def generate_neg_samples(dataset, output_path):
    data = []

    # 如果当前有文件，则删除
    if os.path.exists(output_path):
        os.remove(output_path)

    f = open(output_path, "w+")

    for ins in track(dataset, description="Generating negative samples..."):
        ins = summary(ins, with_graph=True)
        G = ins["graph"]
        q_entity = ins["q_entity"][0]

        if len(ins["truth_paths"]) == 0:
            continue

        iter_data = {"query": ins["question"], "pos": ins["truth_paths"], "entity": q_entity}

        ratio = 3
        neg_samples = []
        for path in ins["truth_paths"][:300]:
            # 1) 替换一条边,生成负样本
            if random.random() < 0.8:
                if len(path) == 1:
                    q_entity_edges = list(G.edges(q_entity, data=True))
                    if q_entity_edges:
                        new_edge = random.choice(q_entity_edges)
                        relation = new_edge[2]['relation']
                        _path = [process_edge_relation(new_edge, relation)]
                    assert path_to_str_with_graph(ins["q_entity"], [_path], G), f"{_path}"
                    neg_samples.append(_path)
                else:
                    _path = path[:-1]
                    # 使用新函数获取最后一个节点
                    last_node = get_last_node(ins["q_entity"], _path, G)

                    # 从最后一个节点的邻居中随机选择一条不同的边
                    last_node_edges = list(G.edges(last_node, data=True))
                    if last_node_edges:
                        new_edge = random.choice(last_node_edges)
                        relation = new_edge[2]['relation']
                        _path.append(process_edge_relation(new_edge, relation))
                    assert path_to_str_with_graph(ins["q_entity"], [_path], G), f"{_path}"
                    neg_samples.append(_path)


            # 2) 增加一条边
            if random.random() < 0.6:
                last_node = get_last_node(ins["q_entity"], path, G)
                if neighbors := list(G.neighbors(last_node)):
                    next_node = random.choice(neighbors)
                    relation = G[last_node][next_node]['relation']
                    edge = (last_node, next_node, {'relation': relation})
                    _path = path + [process_edge_relation(edge, relation)]
                    neg_samples.append(_path)
                    assert path_to_str_with_graph(ins["q_entity"], [_path], G), f"{_path}"

            # 3) 删除一条边
            if random.random() < 0.6 and len(path) >= 2:
                _path = path[:-1]
                neg_samples.append(_path)
                assert path_to_str_with_graph(ins["q_entity"], [_path], G), f"{_path}"

            if random.random() < 1:
                random_path = get_random_paths(q_entity, G)
                assert path_to_str_with_graph(ins["q_entity"], random_path, G), f"{random_path}"
                neg_samples.extend(random_path)

        # 确保samples中的relation没有<--或-->
        filter_neg_samples = []
        for sample in neg_samples:
            _sample = []
            for triple in sample:
                if "<--" in triple[1]:
                    relation = triple[1].replace("<--", "").strip()
                elif "-->" in triple[1]:
                    relation = triple[1].replace("-->", "").strip()
                else:
                    relation = triple[1]
                _sample.append((triple[0], relation, triple[2]))
            filter_neg_samples.append(_sample)
        filter_neg_samples = random.sample(filter_neg_samples, min(len(filter_neg_samples), 100))

        iter_data["pos"] = path_to_str_with_graph(ins["q_entity"], iter_data["pos"][:100], G)
        iter_data["neg"] = path_to_str_with_graph(ins["q_entity"], filter_neg_samples, G)
        f.write(json.dumps(iter_data, ensure_ascii=False) + "\n")

        data.append(iter_data)

    return data

def main():
    parser = argparse.ArgumentParser(description="生成负样本数据")
    parser.add_argument("--input_file", type=str, default="data/RoG-webqsp.train.valid.jsonl",
                      help="输入文件路径")
    parser.add_argument("--output_file", type=str, default="data/RoG-webqsp.rerank.train.jsonl",
                      help="输出文件路径")
    args = parser.parse_args()

    # 加载数据
    print(f"正在从 {args.input_file} 加载数据...")
    dataset = load_data(args.input_file)

    # 生成负样本
    print("正在生成负样本...")
    data = generate_neg_samples(dataset, args.output_file + ".tmp")

    # 保存结果
    print(f"正在将结果保存到 {args.output_file}...")
    os.rename(args.output_file + ".tmp", args.output_file)
    print("完成！")

if __name__ == "__main__":
    main()