{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e22ee0dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import Ipynb_importer\n",
    "import os\n",
    "from utils import *\n",
    "import re\n",
    "import json\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68448aab",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "class Exploration_record:\n",
    "\n",
    "    def __init__(self, init_nodes, triples, unknown_entity, all_clues):\n",
    "\n",
    "        ### global\n",
    "        self.expected_structure = triples\n",
    "        self.initial_nodes = {}\n",
    "        for node in init_nodes:\n",
    "            self.initial_nodes[node.clue] = self.initial_nodes.get(node.clue, []) + [node]\n",
    "        self.unexplored_initial_clues = set(list(self.initial_nodes.keys())[1:])\n",
    "        self.unknown_entity = unknown_entity\n",
    "        self.all_clues = set(all_clues)\n",
    "        self.connection = {}\n",
    "        self.results = []\n",
    "\n",
    "        ### path exploration from the initial nodes\n",
    "        self.explored_clues = []\n",
    "        self.unfound_clues = []\n",
    "        self.left_clues = set(all_clues) - set([init_nodes[0].clue])\n",
    "        self.beyond_expectation = {}  ###{clue: [(r, former_id, score)]}\n",
    "        \n",
    "        ### one round exploration\n",
    "        self.current_clue2nodes = {init_nodes[0].clue: [init_nodes[0]]}\n",
    "        self.expected_clues = self.get_expected_clues([init_nodes[0].clue])\n",
    "\n",
    "        \n",
    "\n",
    "    ### call after each round of exploration\n",
    "    def early_stop_check(self):\n",
    "        sum_clues = set()\n",
    "        for couple in self.connection.keys():\n",
    "            for clue in couple:\n",
    "                sum_clues.add(clue)\n",
    "        if sum_clues != self.all_clues:\n",
    "            return False\n",
    "        return self.update_path() \n",
    "\n",
    "    ### call only when self.expected_clues == {}\n",
    "    def success_check(self):\n",
    "        return self.update_path()\n",
    "        \n",
    "\n",
    "\n",
    "\n",
    "class Node:\n",
    "    def __init__(self, name, id, clue, label=None, relations: dict[str, str] = None, graph=True):\n",
    "        self.name = name\n",
    "        self.id = id\n",
    "        self.clue = clue\n",
    "        self.label = label\n",
    "        self.clue_in_graph = graph\n",
    "        if relations is not None:\n",
    "            self.relations = relations ### {related_id: relation}\n",
    "        else:\n",
    "            self.relations = {}\n",
    "        self.related_id = []\n",
    "        self.relation = []\n",
    "        if relations is not None:\n",
    "            for related_id, relation in relations.items():\n",
    "                self.related_id.append(related_id)\n",
    "                self.relation.append(relation)\n",
    "        \n",
    "\n",
    "\n",
    "class NodeChain:\n",
    "    def __init__(self, init_node, pairs, specific, set_clue=None):\n",
    "        self.batches = [[init_node]]  # 存储所有批次的节点\n",
    "        self.clues = [init_node.clue]    # 存储每个批次的 clue\n",
    "        self.pairs = pairs    # 存储头尾实体的关联关系\n",
    "        self.checked_batches = set()  # 缓存已经检查过的批次索引 \n",
    "        self.explored_clues = set(init_node.clue)  # 表明这个clue我已经找过了\n",
    "        self.next_clue, self.current_clue, self.current_pair = self._get_next_clue(set_clue=set_clue)\n",
    "        self.potential_clues = self._get_potential_clues()  # 初始化 potential_clues\n",
    "        self.in_graph = init_node.clue_in_graph\n",
    "        self.specific = specific  \n",
    "        self.last_generic_in_graph = False\n",
    "        \n",
    "\n",
    "    def add_batch(self, new_nodes):\n",
    "        if not new_nodes:\n",
    "            self.explored_clues.add(self.next_clue)\n",
    "            self.next_clue, self.current_clue, self.current_pair = self._get_next_clue()\n",
    "            return  # 如果新批次为空，直接返回\n",
    "\n",
    "        if new_nodes[0].clue_in_graph:\n",
    "            self.in_graph = True\n",
    "            if new_nodes[0].clue not in self.specific:\n",
    "                self.last_generic_in_graph = True\n",
    "        elif new_nodes[0].clue not in self.specific:\n",
    "            self.last_generic_in_graph = False\n",
    "\n",
    "\n",
    "        # for node in new_nodes:\n",
    "        #     print('new node:', node.name, node.id, node.clue, node.relations, node.clue_in_graph)\n",
    "\n",
    "        # 检查新批次的 clue 是否与之前批次的 clue 重复\n",
    "        new_clue = new_nodes[0].clue  # 同一批次的节点有同一个 clue\n",
    "        if new_clue in self.clues:\n",
    "            raise ValueError(f\"Clue '{new_clue}' already exists in a previous batch.\")\n",
    "        \n",
    "        self.explored_clues.add(self.next_clue)\n",
    "        # 将新批次节点加入总批次列表\n",
    "        self.batches.append(new_nodes)\n",
    "        self.clues.append(new_clue)\n",
    "\n",
    "        # 重置缓存，因为新批次可能引入新的关联关系\n",
    "        self.checked_batches = set()\n",
    "\n",
    "        # 检查并删除无法关联到最新批次的节点\n",
    "        self._prune_unrelated_nodes()\n",
    "\n",
    "        self.next_clue, self.current_clue, self.current_pair = self._get_next_clue()\n",
    "        \n",
    "        # self.potential_clues = self._get_potential_clues()\n",
    "\n",
    "    \n",
    "\n",
    "    def generate_triples(self):\n",
    "        triples = set()\n",
    "        # 遍历所有 clue 对\n",
    "        print('generate triples------------')\n",
    "        for clue1, clue2 in self.pairs:\n",
    "            print('pairs:', clue1, clue2)\n",
    "            # 找到与 clue1 对应的批次\n",
    "            batch1 = self._find_batch_by_clue(clue1)\n",
    "            if batch1 is None:\n",
    "                continue\n",
    "            # 找到与 clue2 对应的批次\n",
    "            batch2 = self._find_batch_by_clue(clue2)\n",
    "            if batch2 is None:\n",
    "                continue\n",
    "            # 遍历 batch1 和 batch2 中的节点，生成三元组\n",
    "            for node1 in batch1:\n",
    "                for node2 in batch2:\n",
    "                    # 检查 node1.id 是否在 node2.relations 中\n",
    "                    if node1.id in node2.relations:\n",
    "                        relation = node2.relations[node1.id]\n",
    "                        triples.add((node1.name, relation, node2.name))\n",
    "                    # 检查 node2.id 是否在 node1.relations 中\n",
    "                    elif node2.id in node1.relations:\n",
    "                        relation = node1.relations[node2.id]\n",
    "                        triples.add((node2.name, relation, node1.name))\n",
    "        return list(triples)\n",
    "\n",
    "    def _get_next_clue(self, set_clue=None):\n",
    "        \"\"\"获取下一个要探索的 clue\"\"\"\n",
    "        potential_clues = []\n",
    "        # # 检查当前批次的 clue 是否在 pairs 中作为头或尾实体\n",
    "        # for pair in self.pairs:\n",
    "        #     if set_clue is not None:\n",
    "        #         if set_clue not in pair:\n",
    "        #             continue\n",
    "        #     if pair[0] in self.clues and pair[1] not in self.clues and pair[1] not in self.explored_clues:\n",
    "        #         potential_clues.append((pair[1], pair[0], pair))\n",
    "        #     elif pair[1] in self.clues and pair[0] not in self.clues and pair[0] not in self.explored_clues:\n",
    "        #         potential_clues.append((pair[0], pair[1], pair))\n",
    "        for pair in self.pairs:\n",
    "            if set_clue is not None:\n",
    "                if set_clue not in pair:\n",
    "                    continue\n",
    "            if pair[0] in self.clues and pair[1] not in self.clues and pair[1] not in self.explored_clues:\n",
    "                return pair[1], pair[0], pair\n",
    "            elif pair[1] in self.clues and pair[0] not in self.clues and pair[0] not in self.explored_clues:\n",
    "                return pair[0], pair[1], pair\n",
    "\n",
    "        # # 优先返回clue_in_graph属性不为False的 next_clue\n",
    "        # for next_clue, current_clue, pair in potential_clues:\n",
    "        #     batch = self._find_batch_by_clue(current_clue)\n",
    "        #     if batch:\n",
    "        #         for node in batch:\n",
    "        #             if node.clue_in_graph:\n",
    "        #                 return next_clue, current_clue, pair\n",
    "\n",
    "\n",
    "        # 如果没有找到 id 不为 None 的，返回第一个 potential_clue\n",
    "        if potential_clues:\n",
    "            return potential_clues[0]\n",
    "\n",
    "        return None, None, None\n",
    "    \n",
    "    def _get_potential_clues(self):\n",
    "        \"\"\"获取与 next_clue 相关联且不在 clues 中的潜在线索\"\"\"\n",
    "        potential_clues = set()\n",
    "        # if self.next_clue is not None:\n",
    "        #     for pair in self.pairs:\n",
    "        #         if pair[0] == self.next_clue and pair[1] not in self.clues:\n",
    "        #             potential_clues.add(pair[1])\n",
    "        #         elif pair[1] == self.next_clue and pair[0] not in self.clues:\n",
    "        #             potential_clues.add(pair[0])\n",
    "        if self.next_clue is not None:\n",
    "            all_clues = set()\n",
    "            for pair in self.pairs:\n",
    "                all_clues.add(pair[0])\n",
    "                all_clues.add(pair[1])\n",
    "            potential_clues = all_clues - set(self.clues) - set([self.next_clue])\n",
    "        return list(potential_clues)  # 返回一个列表\n",
    "    \n",
    "    def check_multi_branch(self):\n",
    "        clues = set()\n",
    "        clue2nodes = {}\n",
    "       \n",
    "        for pair in self.pairs:\n",
    "            if self.next_clue == pair[0] and pair[1] in self.clues:\n",
    "                clues.add(pair[1])\n",
    "               \n",
    "            elif self.next_clue == pair[1] and pair[0] in self.clues:\n",
    "                clues.add(pair[0])\n",
    "               \n",
    "        if len(clues) > 1:\n",
    "            for clue in clues:\n",
    "                clue2nodes[clue] = copy.deepcopy(self._find_batch_by_clue(clue))\n",
    "            return clue2nodes\n",
    "        return None\n",
    "            \n",
    "\n",
    "\n",
    "    def _prune_unrelated_nodes(self):\n",
    "        \"\"\"从后往前检查每一批次节点，删除缺失了关联的节点\"\"\"\n",
    "        if not self.pairs:\n",
    "            return  # 如果没有关联关系，无需处理\n",
    "\n",
    "        # 从后往前遍历批次\n",
    "        for i in range(len(self.batches) - 1, -1, -1):\n",
    "            if i in self.checked_batches:\n",
    "                continue  # 如果批次已经检查过，跳过\n",
    "\n",
    "            batch = self.batches[i]\n",
    "            clue = self.clues[i]\n",
    "\n",
    "            # 检查当前批次的 clue 是否在 pairs 中作为头或尾实体\n",
    "            related_batches = []\n",
    "            for pair in self.pairs:\n",
    "                if clue == pair[0]:  # clue 是头实体\n",
    "                    related_batches.append(self._find_batch_by_clue(pair[1]))\n",
    "                elif clue == pair[1]:  # clue 是尾实体\n",
    "                    related_batches.append(self._find_batch_by_clue(pair[0]))\n",
    "\n",
    "            # 如果没有关联的批次，则删除当前批次\n",
    "            if not related_batches:\n",
    "                # print(f\"Removing batch with clue '{clue}' due to missing association.\")\n",
    "                # del self.batches[i]\n",
    "                # del self.clues[i]\n",
    "                print('no related batch, clue:', clue)\n",
    "                continue\n",
    "\n",
    "            # 检查当前批次与关联批次之间的节点是否通过 id 和 related_id 关联\n",
    "            for related_batch in related_batches:\n",
    "                if related_batch is not None:\n",
    "                    self._check_and_prune_nodes(batch, related_batch)\n",
    "\n",
    "            # 标记当前批次为已检查\n",
    "            self.checked_batches.add(i)\n",
    "\n",
    "    def _find_batch_by_clue(self, clue):\n",
    "        \"\"\"根据 clue 查找对应的批次\"\"\"\n",
    "        for i, c in enumerate(self.clues):\n",
    "            if c == clue:\n",
    "                return self.batches[i]\n",
    "        return None\n",
    "\n",
    "    def _check_and_prune_nodes(self, batch1, batch2):\n",
    "        \"\"\"检查两个批次之间的节点是否通过 id 和 related_id 关联，删除无关联的节点\"\"\"\n",
    "        # 遍历 batch1 中的节点\n",
    "        # for node1 in batch1[:]:  # 使用切片复制以避免修改迭代中的列表\n",
    "        #     has_relation = False\n",
    "        #     for node2 in batch2:\n",
    "        #         if node2.id in node1.related_id or node1.id in node2.related_id:\n",
    "        #             has_relation = True\n",
    "        #             break\n",
    "\n",
    "        #     # 如果 node1 与 batch2 中的任何节点都没有关联，则删除 node1\n",
    "        #     if not has_relation:\n",
    "        #         print(f\"Removing node {node1} due to missing association.\")\n",
    "        #         batch1.remove(node1)\n",
    "\n",
    "        # 遍历 batch2 中的节点\n",
    "        clue = batch2[0].clue\n",
    "        for node2 in batch2[:]:  # 使用切片复制以避免修改迭代中的列表\n",
    "            has_relation = False\n",
    "            for node1 in batch1:\n",
    "                # print('check relation between', node1.name, 'and', node2.name)\n",
    "                # print('node1.related_id:', node1.related_id)\n",
    "                # print('node2.related_id:', node2.related_id)\n",
    "                # print('node1.id:', node1.id)\n",
    "                # print('node2.id:', node2.id)\n",
    "                if node2.id in node1.related_id or node1.id in node2.related_id:\n",
    "                    has_relation = True\n",
    "                    break\n",
    "\n",
    "            # 如果 node2 与 batch1 中的任何节点都没有关联，则删除 node2\n",
    "            if not has_relation:\n",
    "                print(f\"Removing node {node2} due to missing association.\")\n",
    "                batch2.remove(node2)\n",
    "        \n",
    "        if batch2 == []:\n",
    "            print(f'all nodes of {clue} are deleted!!!!!!')\n",
    "            batch2_index = self.batches.index(batch2)\n",
    "            del self.batches[batch2_index]\n",
    "            del self.clues[batch2_index]\n",
    "            return\n",
    "\n",
    "    def __repr__(self):\n",
    "        return f\"NodeChain(batches={self.batches}, clues={self.clues}, pairs={self.pairs})\"\n",
    "    \n",
    "class Exploration:\n",
    "    def __init__(self, init_nodes, pairs, triples, target, specific, statement):\n",
    "        self.init_clue2node = {node.clue: node for node in init_nodes}  # 初始化 clue 到 node 的映射\n",
    "        self.chain = NodeChain(init_nodes[0], pairs, specific)  # 以第一个 init_node 作为起始，创建主要 chain\n",
    "        self.explored_init = [init_nodes[0].clue]  # 存储已经探索过的 init_clue\n",
    "        self.pairs = pairs  # 存储头尾实体的关联关系\n",
    "        self.triples = triples\n",
    "        self.target = target  # 目标 clue\n",
    "        # self.explored_clues = set()\n",
    "        self.specific = specific\n",
    "        self.statement = statement\n",
    "        self.llm_call = 0\n",
    "        self.next_flag = 0     ## 弃用，本来是用于查看映射是否指向next_clue\n",
    "        self.exceed = 0\n",
    "        self.relation_bridge = False\n",
    "        self.entity_bridge = False\n",
    "\n",
    "    def llm_internal(self):\n",
    "        triples = []\n",
    "        clue2nodes = self.chain.check_multi_branch()\n",
    "        ### 看clue是否是specific再决定用internal1还是2\n",
    "        ### \n",
    "        dic = dict()\n",
    "        if clue2nodes and len(clue2nodes.keys()) > 1:\n",
    "            for clue, nodes in clue2nodes.items():\n",
    "                dic[clue] = [node.name for node in nodes]\n",
    "                for triple in self.triples:\n",
    "                    if clue in triple and self.chain.next_clue in triple:\n",
    "                        triples.append(triple)\n",
    "        else:\n",
    "            for triple in self.triples:\n",
    "                if self.chain.current_clue in triple and self.chain.next_clue in triple:\n",
    "                    triples.append(triple)\n",
    "            dic[self.chain.current_clue] = [node.name for node in self.chain._find_batch_by_clue(self.chain.current_clue)]\n",
    "        possible_nodes, token_num = internal(triples, self.chain.next_clue, dic)\n",
    "        \n",
    "        return possible_nodes, token_num, triples\n",
    "        ## 只返回结果就行，至于结果可不可用，那是其他函数的事。因为有多个地方需要用结果，它们的处理逻辑不同\n",
    "        ## 一个是起始点，这个可能需要能检索到“唯一”的节点\n",
    "        ## 一个是探索过程，作为加速用，能检索到是最好的，没有也没办法\n",
    "\n",
    "    def explore(self):\n",
    "            \"\"\"探索 chain 的下一个 clue\n",
    "            - 如果 next_clue 在 init_clue 中，调用 intersection 方法，如果没有交集，考虑1. 结构错误 2. init_node是抽象节点，如person\n",
    "            - 如果 next_clue 与 init_clue 在 pairs 中相连，调用 intersection 方法，如果没有交集，同上\n",
    "            - 以上是与init_clue夹逼，最后判断非init_clue，chain自身分支夹逼，如果没有交集，考虑 结构错误\n",
    "            - 无交集 则调用 llm_prune 方法\n",
    "            - 情况二和三，如果交集大于1，还应llm_prune\n",
    "           \n",
    "            \"\"\"\n",
    "            input_token = 0\n",
    "            output_token = 0\n",
    "            total_token = 0\n",
    "            next_clue = self.chain.next_clue\n",
    "            current_clue = self.chain.current_clue\n",
    "            current_batch = self.chain._find_batch_by_clue(current_clue)\n",
    "            nodes_in_graph = [node.name for node in current_batch if node.clue_in_graph]\n",
    "            nodes_not_in_graph = [node.name for node in current_batch if not node.clue_in_graph]\n",
    "            node_bridge_batch = []\n",
    "            possible_nodes = []\n",
    "            if next_clue in self.explored_init:\n",
    "                return [] , {'input': input_token, 'output': output_token, 'total': total_token} # 如果 clue 已经被探索过，则停止\n",
    "\n",
    "            # 情况1: 如果 next_clue 是 init_clue 中的另外一个\n",
    "            if next_clue in self.init_clue2node.keys() and next_clue not in self.explored_init and current_batch[0].clue_in_graph:\n",
    "                print('------  next_clue is init_clue --- before ------')\n",
    "                # 调用 intersection 方法\n",
    "                new_batch, relations_return = self.intersection([current_batch], next_clue)\n",
    "                if new_batch:\n",
    "                    return new_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "                triples = []\n",
    "                if not current_batch[0].clue_in_graph:\n",
    "                    for triple in self.triples:\n",
    "                            if self.chain.current_clue in triple and self.chain.next_clue in triple:\n",
    "                                triples.append(triple)\n",
    "                    names = [node.name for node in current_batch]\n",
    "                    if len(names) > 1:\n",
    "                        names, token_num = specific_internal({self.chain.current_clue:[node.name for node in current_batch]}, triples)\n",
    "                        input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                        self.llm_call += 1\n",
    "                    relations_dict = {}\n",
    "                    for name in names:\n",
    "                            for node in current_batch:\n",
    "                                if node.name == name:\n",
    "                                    relations_dict[node.id] = triples[0][1]  # triple[1] is the relation\n",
    "                    new_batch = [Node(name = self.chain.next_clue, id = f\"{self.chain.next_clue}_0\", clue = self.chain.next_clue, relations=relations_dict, graph=False)]\n",
    "                    return new_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "            back_up = []\n",
    "            if self.chain.next_clue in self.specific:\n",
    "                possible_nodes = []\n",
    "                if  not current_batch[0].clue_in_graph:\n",
    "                    triples = []\n",
    "                    for triple in self.triples:\n",
    "                        if self.chain.current_clue in triple and self.chain.next_clue in triple:\n",
    "                            triples.append(triple)\n",
    "                    names = [node.name for node in current_batch]\n",
    "                    if len(names) > 1:\n",
    "                        names, token_num = specific_internal({self.chain.current_clue:[node.name for node in current_batch]}, triples)\n",
    "                        input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                        self.llm_call += 1\n",
    "                    relations_dict = {}\n",
    "                    for name in names:\n",
    "                        for node in current_batch:\n",
    "                            if node.name == name:\n",
    "                                relations_dict[node.id] = triples[0][1]  # triple[1] is the relation\n",
    "                    new_batch = [Node(name = self.chain.next_clue, id = f\"{self.chain.next_clue}_0\", clue = self.chain.next_clue, relations=relations_dict, graph=False)]\n",
    "                    return new_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "            else:\n",
    "                possible_nodes, token_num, triples = self.llm_internal()\n",
    "                self.llm_call += 1\n",
    "                input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "\n",
    "            possible_batch = []\n",
    "\n",
    "            ###   以下为了测试纯用 llm_internal的效果，不测试时应全部注释\n",
    "            # for idx, p_node in enumerate(possible_nodes):\n",
    "            #                 relations_dict = {}\n",
    "            #                 for triple in self.triples:\n",
    "            #                     if current_clue in triple and next_clue in triple:\n",
    "            #                         for node in current_batch:\n",
    "            #                                 relations_dict[node.id] = triple[1]  # triple[1] is the relation\n",
    "            #                 possible_batch.append(Node(\n",
    "            #                     name=p_node,\n",
    "            #                     id=f\"{self.chain.next_clue}_{idx}\",\n",
    "            #                     clue=self.chain.next_clue,\n",
    "            #                     relations=relations_dict,\n",
    "            #                     graph=False\n",
    "            #                 ))\n",
    "            # return possible_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "            ### 以上为了测试纯用 llm_internal的效果，不测试时应全部注释\n",
    "            found_nodes = []\n",
    "            not_found_nodes = []\n",
    "            back_up = []  ### 在kg中但没关联上\n",
    "            if current_batch[0].clue_in_graph:\n",
    "                if possible_nodes:\n",
    "                    for p_node in possible_nodes:\n",
    "                        # print('possible node from internal:', p_node)\n",
    "                        queried_nodes = query_node_by_name(p_node)  ## (name, id)\n",
    "                        \n",
    "                        if queried_nodes:\n",
    "                            # print('queried_nodes:', queried_nodes[0])\n",
    "                            found_nodes.append(queried_nodes[0])  ## [0]表示第一个返回结果，不是name\n",
    "                        else:\n",
    "                            not_found_nodes.append(p_node)\n",
    "                \n",
    "                    if found_nodes:\n",
    "                        found_node_ids = [node[1] for node in found_nodes]\n",
    "                        current_batch_ids = [node.id for node in current_batch]\n",
    "                        # print(found_node_ids)\n",
    "                        # print(current_batch_ids)\n",
    "                        if len(current_batch_ids) > 50:\n",
    "                            current_batch_ids = current_batch_ids[:50]\n",
    "                        connections = find_and_expand_connections(found_node_ids, current_batch_ids)  ### (related_name, related_id, existing_id, relation)\n",
    "                        print('connections:', connections)\n",
    "                        related_nodes_data = {}\n",
    "                        for connection in connections:\n",
    "                            related_name, related_id, existing_id, relation = connection\n",
    "                            if related_id not in related_nodes_data:\n",
    "                                related_nodes_data[related_id] = {\n",
    "                                    \"name\": related_name,\n",
    "                                    \"relations\": {}\n",
    "                                }\n",
    "                            related_nodes_data[related_id][\"relations\"][existing_id] = relation\n",
    "\n",
    "                        for related_id, data in related_nodes_data.items():\n",
    "                            related_name = data[\"name\"]\n",
    "                            relations_dict = data[\"relations\"]\n",
    "                            possible_batch.append(Node(\n",
    "                                    name=related_name,\n",
    "                                    id=related_id,\n",
    "                                    clue=self.chain.next_clue,\n",
    "                                    relations=relations_dict,\n",
    "                                    graph=True\n",
    "                                ))\n",
    "\n",
    "                            \n",
    "                        connected_names = [c[0] for c in connections]\n",
    "                            # 关系桥接\n",
    "                        for idx, p_node in enumerate(found_nodes):\n",
    "                                if p_node in connected_names:\n",
    "                                    continue\n",
    "                                relations_dict = {}\n",
    "                                for triple in self.triples:\n",
    "                                    if current_clue in triple and next_clue in triple:\n",
    "                                        for node in current_batch:\n",
    "                                            relations_dict[node.id] = triple[1]  # triple[1] is the relation\n",
    "                                back_up.append(Node(\n",
    "                                    name=p_node[0],\n",
    "                                    id=p_node[1],\n",
    "                                    clue=self.chain.next_clue,\n",
    "                                    relations=relations_dict,\n",
    "                                    graph=True\n",
    "                                ))\n",
    "                                self.relation_bridge = True\n",
    "                    else:\n",
    "                        ### 实体桥接\n",
    "                        for idx, p_node in enumerate(not_found_nodes):\n",
    "                            relations_dict = {}\n",
    "                            for triple in self.triples:\n",
    "                                if current_clue in triple and next_clue in triple:\n",
    "                                    for node in current_batch:\n",
    "                                        relations_dict[node.id] = triple[1]  # triple[1] is the relation\n",
    "                            node_bridge_batch.append(Node(\n",
    "                                name=p_node,\n",
    "                                id=f\"{self.chain.next_clue}_{idx}\",\n",
    "                                clue=self.chain.next_clue,\n",
    "                                relations=relations_dict,\n",
    "                                graph=False\n",
    "                            ))\n",
    "\n",
    "                ### 没有图推理结果\n",
    "                # else:\n",
    "                #     直接llmprune\n",
    "            else:\n",
    "                if possible_nodes:\n",
    "                    for p_node in possible_nodes:\n",
    "                        queried_nodes = query_node_by_name(p_node)  ## (name, id)\n",
    "                        if queried_nodes:\n",
    "                            found_nodes.append(queried_nodes[0])\n",
    "                        else:\n",
    "                            not_found_nodes.append(p_node)\n",
    "\n",
    "                    if found_nodes:\n",
    "                        for node in found_nodes:\n",
    "                            relations_dict = {}\n",
    "                            for triple in self.triples:\n",
    "                                if current_clue in triple and next_clue in triple:\n",
    "                                    for c_node in current_batch:\n",
    "                                        relations_dict[c_node.id] = triple[1]  # triple[1] is the relation\n",
    "                            ### 关系桥接\n",
    "                            possible_batch.append(Node(\n",
    "                                name=node[0],\n",
    "                                id=node[1],\n",
    "                                clue=self.chain.next_clue,\n",
    "                                relations=relations_dict,\n",
    "                                graph=True\n",
    "                            ))\n",
    "                    else:\n",
    "                        for idx, p_node in enumerate(not_found_nodes):\n",
    "                            relations_dict = {}\n",
    "                            for triple in self.triples:\n",
    "                                if current_clue in triple and next_clue in triple:\n",
    "                                    for node in current_batch:\n",
    "                                        relations_dict[node.id] = triple[1]  # triple[1] is the relation\n",
    "                            ## 实体桥接\n",
    "                            possible_batch.append(Node(\n",
    "                                name=p_node,\n",
    "                                id=f\"{self.chain.next_clue}_{idx}\",\n",
    "                                clue=self.chain.next_clue,\n",
    "                                relations=relations_dict,\n",
    "                                graph=False\n",
    "                            ))\n",
    "\n",
    "                    return possible_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "                else:\n",
    "                    return [], {'input': input_token, 'output': output_token, 'total': total_token}    \n",
    "                    \n",
    "\n",
    "                \n",
    "\n",
    "            possible_relations = []\n",
    "            new_batch = []\n",
    "            if possible_batch:\n",
    "                possible_relations = [(v, k, node.id) for node in possible_batch for k, v in node.relations.items()]\n",
    "                possible_relations = list(set(possible_relations))\n",
    "            r_type = set([rel[0] for rel in possible_relations])\n",
    "            if len(r_type) == 1:\n",
    "                return possible_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "            # 情况1: 如果 next_clue 是 init_clue 中的另外一个\n",
    "            if next_clue in self.init_clue2node.keys() and next_clue not in self.explored_init:\n",
    "                print('------  next_clue is init_clue  ------')\n",
    "                # \n",
    "                connections = find_and_expand_connections([node.id for node in current_batch][:50], [self.init_clue2node[next_clue].id]) ### (related_name, related_id, existing_id, relation)\n",
    "                relations = dict()\n",
    "                for connection in connections:\n",
    "                    related_name, related_id, existing_id, relation = connection\n",
    "                    relations[existing_id] = relation\n",
    "                if relations:\n",
    "                    new_batch = [Node(\n",
    "                        name=self.init_clue2node[next_clue].name,\n",
    "                        id=self.init_clue2node[next_clue].id,\n",
    "                        clue=next_clue,\n",
    "                        relations=relations,\n",
    "                        graph=True\n",
    "                    )]\n",
    "                    \n",
    "                else:\n",
    "                    new_batch, token_num = self.llm_prune(self.chain)  ###  无论如何，将next_clue加入explored_clues，不用判断是否potential\n",
    "                    input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                    new_batch, token_num = self.check_unnamed_entity(new_batch)\n",
    "                    input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                if back_up:\n",
    "                    new_batch += back_up\n",
    "                if not new_batch:\n",
    "                    new_batch = node_bridge_batch\n",
    "                    self.node_bridge = True\n",
    "                return new_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "\n",
    "            # 情况2: 如果 next_clue 与 init_clue（不在 chain.clues 里）在 pairs 中相连\n",
    "            for init_node in self.init_clue2node.values():\n",
    "                if init_node.clue not in self.chain.clues and init_node.clue not in self.chain.explored_clues and init_node.clue not in self.explored_init and ((next_clue, init_node.clue) in self.pairs or (init_node.clue, next_clue) in self.pairs):\n",
    "                    print('------  next_clue and init_clue are in pairs  ------')\n",
    "                    if possible_batch:\n",
    "                        new_batch = possible_batch\n",
    "                        relations_return = possible_relations\n",
    "                    else:\n",
    "                        batches = [current_batch, [init_node]]\n",
    "                        new_batch, relations_return = self.intersection(batches, next_clue)\n",
    "                    batch1 = []\n",
    "                    if new_batch:\n",
    "                        ## 如果任一边的关系数大于1，还得剪枝一次\n",
    "                        r_current = []\n",
    "                        r_init = []\n",
    "                        current_ids = [node.id for node in current_batch]\n",
    "                        for relation in relations_return:\n",
    "                            if relation[1] == init_node.id:\n",
    "                                r_init.append(relation)\n",
    "                            elif relation[1] in current_ids:\n",
    "                                r_current.append(relation)\n",
    "                        if len(r_current) > 1:\n",
    "                            batch1, token_num = self.llm_prune(current_batch, r_current, self.chain)\n",
    "                            input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                            batch1, token_num = self.check_unnamed_entity(batch1)\n",
    "                            input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)                    \n",
    "                            new_batch = batch1\n",
    "                    else:\n",
    "                        new_batch = []\n",
    "                    ### 不考虑两边通过语义对齐向中间夹逼，结果不一致，需要choose，的情况了\n",
    "                    ### 只考虑一边语义对齐，另一边或多边连通，兼顾速度和精准\n",
    "                        \n",
    "                    if  back_up:\n",
    "                        new_batch += back_up \n",
    "                    if not new_batch:\n",
    "                        new_batch = node_bridge_batch\n",
    "                        self.node_bridge = True\n",
    "                    return new_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "                    \n",
    "                    \n",
    "            \n",
    "            clue2nodes = self.chain.check_multi_branch()\n",
    "            if clue2nodes:\n",
    "                print('------  multi branch  ------')\n",
    "                if possible_batch:\n",
    "                    new_batch = possible_batch\n",
    "                    relations_return = possible_relations\n",
    "                else:\n",
    "                    batches = list(clue2nodes.values())\n",
    "                    new_batch, relations_return = self.intersection(batches, next_clue)\n",
    "                if new_batch:\n",
    "                    if len(new_batch) == 1:\n",
    "                        return new_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "                    else:\n",
    "                        new_batch, token_num = self.llm_prune(new_batch, relations_return, self.chain)\n",
    "                        input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                        new_batch, token_num = self.check_unnamed_entity(new_batch)\n",
    "                        input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                        if  back_up:\n",
    "                            new_batch += back_up\n",
    "                        if not new_batch:\n",
    "                            new_batch = node_bridge_batch\n",
    "                            self.node_bridge = True\n",
    "                        return new_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "                else:\n",
    "                    new_batch, token_num = self.llm_prune(self.chain)\n",
    "                    input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                    new_batch, token_num = self.check_unnamed_entity(new_batch)\n",
    "                    input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                    if back_up:\n",
    "                        new_batch += back_up\n",
    "                    if not new_batch:\n",
    "                        new_batch = node_bridge_batch\n",
    "                        self.node_bridge = True\n",
    "                    return new_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "            print('------  normal exploration  ------')\n",
    "            if possible_batch:\n",
    "                new_batch = possible_batch\n",
    "                for node in new_batch:\n",
    "                    print('possible node:', node.name, node.id)\n",
    "                relations_return = possible_relations\n",
    "                if len(new_batch) == 1:\n",
    "                    return new_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "                else:\n",
    "                    new_batch, token_num = self.llm_prune(new_batch, relations_return, self.chain)\n",
    "                    input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                    new_batch, token_num = self.check_unnamed_entity(new_batch)\n",
    "                    input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                    if back_up:\n",
    "                        new_batch += back_up\n",
    "                    if not new_batch:\n",
    "                        new_batch = node_bridge_batch\n",
    "                        self.node_bridge = True\n",
    "                    return new_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "            else:\n",
    "                new_batch, token_num = self.llm_prune(self.chain)\n",
    "                input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                new_batch, token_num = self.check_unnamed_entity(new_batch)\n",
    "                input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                if back_up:\n",
    "                    new_batch += back_up\n",
    "                if not new_batch:\n",
    "                    new_batch = node_bridge_batch\n",
    "                    self.node_bridge = True\n",
    "                return new_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "\n",
    "\n",
    "\n",
    "    def check_unnamed_entity(self, batch):\n",
    "        \"\"\"\n",
    "        检查是否存在未命名的实体\n",
    "        \"\"\"\n",
    "        print('check unnamed entity--------------')\n",
    "        related_nodes = []\n",
    "        new_batch = []\n",
    "        token_num = {'input': 0, 'output': 0, 'total': 0}\n",
    "        if batch == []:\n",
    "            return [], token_num\n",
    "        for node in batch:\n",
    "            if node.name == 'UnName_Entity':\n",
    "                tmp = get_related_nodes_by_id(node.id) ### (r, name, id)\n",
    "                tmp = [(t[0], t[1], t[2], node.relations) for t in tmp if t[2] not in node.related_id]\n",
    "                for t in tmp:\n",
    "                    if t[0] in node.relation:\n",
    "                        # print('unName_entity:', t[1])\n",
    "                        new_batch.append(Node(name=t[1], id=t[2], clue=node.clue, relations=t[3]))\n",
    "                    else:\n",
    "                        related_nodes.append(t)\n",
    "            else:\n",
    "                new_batch.append(node)\n",
    "                # print('named entity:', node.name)\n",
    "\n",
    "        # 进一步过滤 related_nodes\n",
    "        related_nodes = [node for node in related_nodes if node[2] not in [n.id for n in new_batch]]\n",
    "        print('related_nodes of UnNamed:', related_nodes)\n",
    "        \n",
    "        if related_nodes:    \n",
    "            UnName_relations = [node[0] for node in related_nodes]\n",
    "            UnName_relations = list(set(UnName_relations))\n",
    "            clue = batch[0].clue\n",
    "            matched_clues = []\n",
    "            pair = self.chain.current_pair\n",
    "            for triple in self.triples:\n",
    "                if pair[0] == triple[0] and pair[1] == triple[2]:\n",
    "                    if clue == pair[0]:\n",
    "                        matched_clues.append(f\"the target {clue}_{triple[1]}_{triple[2]}\")  # clue + \"_\" + triple[2]\n",
    "                    elif clue == pair[1]:\n",
    "                        matched_clues.append(f\"{triple[0]}_{triple[1]}_the target {clue}\")  # triple[0] + \"_\" + clue\n",
    "                    break\n",
    "            \n",
    "            UnName_relations_result, token_num = prune_node(matched_clues, UnName_relations, {})\n",
    "            UnName_relations_pruned = UnName_relations_result.get(matched_clues[0], [])\n",
    "            print('UnName_relations_pruned:', UnName_relations_pruned)\n",
    "            self.llm_call += 1\n",
    "            if UnName_relations_pruned != []:\n",
    "                for UnName_relation in UnName_relations_pruned:\n",
    "                    for node in related_nodes:\n",
    "                        if node[0] == UnName_relation[0]:\n",
    "                            # print('unName_entity:', node[1])\n",
    "                            new_batch.append(Node(name=node[1], id=node[2], clue=clue, relations=node[3]))\n",
    "        return new_batch, token_num\n",
    "\n",
    "    def llm_prune(self, *args):\n",
    "        \"\"\"\n",
    "        根据输入参数的数量执行不同的操作逻辑。\n",
    "        - 如果传入一个参数（chains），说明是无限制探索，返回值应该是batch_of_nodes，记得把next_clue加入explored_clues\n",
    "        - 如果传入三个参数，则要基于所给节点进行剪枝\n",
    "        \"\"\"\n",
    "        relations_info = []\n",
    "        batch = []\n",
    "        input_token = 0\n",
    "        output_token = 0\n",
    "        total_token = 0\n",
    "        if len(args) == 1:\n",
    "            # 逻辑 A：传入 chains\n",
    "            chain = args[0]\n",
    "            batch = chain._find_batch_by_clue(chain.current_clue)\n",
    "            \n",
    "            for node in batch:\n",
    "                tmp = get_relations(node.id)   # [(r, id)]\n",
    "                print('current_pair', chain.current_pair)\n",
    "                print('get relations of ', node.name)\n",
    "                if node.related_id != None: ## not the initial node\n",
    "                    tmp = [t for t in tmp if t[0] != node.relation]\n",
    "                relations_info.extend(tmp) # [(r, id)]\n",
    "                relations_info = list(set(relations_info))\n",
    "        elif len(args) == 3:\n",
    "            # 逻辑 B：传入 batches_of_nodes 和 relations\n",
    "            batch, relations_info, chain = args\n",
    "        else:\n",
    "            raise ValueError(\"Invalid number of arguments. Expected 1 or 2 arguments.\")\n",
    "             \n",
    "        relations = [info[0] for info in relations_info]\n",
    "        if relations == []:\n",
    "            print(chain.current_clue, 'has no relations')\n",
    "            return [], {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "        relations = list(set(relations))\n",
    "        candidate_clues = []\n",
    "        if len(args) == 3:\n",
    "            candidate_clues = [chain.next_clue]\n",
    "        else:\n",
    "            candidate_clues = [chain.next_clue]\n",
    "            # candidate_clues = [chain.next_clue] + chain.potential_clues\n",
    "        matched_clues = []\n",
    "        \n",
    "       \n",
    "\n",
    "        clue = chain.next_clue\n",
    "        pair = chain.current_pair\n",
    "        for triple in self.triples:\n",
    "            if pair[0] == triple[0] and pair[1] == triple[2]:\n",
    "                if clue == pair[0]:\n",
    "                    matched_clues.append(f\"the target {clue}_{triple[1]}_{triple[2]}\")  # clue + \"_\" + triple[2]\n",
    "                elif clue == pair[1]:\n",
    "                    matched_clues.append(f\"{triple[0]}_{triple[1]}_the target {clue}\")  # triple[0] + \"_\" + clue\n",
    "                break\n",
    "        \n",
    "        print('relations:', relations)\n",
    "        pruned_relations = {}\n",
    "        prune_once_more = 0 ### 如果多次剪枝了，需要再综合判断一次\n",
    "        if len(relations)/20 + self.llm_call > 30:\n",
    "            self.exceed = 1\n",
    "            return [], {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "        while len(relations) > 0:\n",
    "            prune_flag = 0\n",
    "            if len(relations) > 20:\n",
    "                prune_once_more = 1\n",
    "                relations_for_match = relations[:20]\n",
    "                relations = relations[20:]\n",
    "            else:\n",
    "                relations_for_match = relations\n",
    "                prune_flag = 1                  \n",
    "            pruned_relations, token_num = prune_node(matched_clues, relations_for_match, pruned_relations) # {clue: [(r, label, score)]}\n",
    "            input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "            self.llm_call += 1\n",
    "            print('pruned_info:', pruned_relations)\n",
    "\n",
    "            if prune_flag == 1:\n",
    "                break\n",
    "        if pruned_relations == {}:\n",
    "            print('no relations can be matched')\n",
    "            return [], {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        if prune_once_more == 1:\n",
    "            relations_for_match = []\n",
    "            matched_clues = []\n",
    "            for clue in pruned_relations.keys():\n",
    "                if 'the target ' + chain.next_clue in clue:\n",
    "                # 如果 next_clue 在 pruned_relations 中，取得分前三的 relation\n",
    "                    relations_for_match = pruned_relations[clue]\n",
    "                    relations_for_match = [rel[0] for rel in relations_for_match]\n",
    "                    matched_clues.append(clue)\n",
    "                    break\n",
    "            # if relations_for_match == []:\n",
    "            #     # 如果 next_clue 不在，取得分最高的 relation 以及同一个 clue 且得分仅次于它的两个 relation\n",
    "            #     self.next_flag = 1\n",
    "            #     all_relations = []\n",
    "            #     for clue, relations in pruned_relations.items():\n",
    "            #         all_relations.extend([(rel[0], clue, rel[2]) for rel in relations])\n",
    "                \n",
    "            #     # 按得分从高到低排序\n",
    "            #     sorted_all_relations = sorted(all_relations, key=lambda x: x[2], reverse=True)\n",
    "                \n",
    "            #     if not sorted_all_relations:\n",
    "            #         return []  # 如果没有 relations，返回空列表\n",
    "                \n",
    "            #     # 取得分最高的 relation\n",
    "            #     top_relation = sorted_all_relations[0] \n",
    "            #     top_clue = top_relation[1]\n",
    "            #     matched_clues.append(top_clue)\n",
    "                \n",
    "            #     # 找到同一个 clue 且得分仅次于它的 relation\n",
    "            #     relations_for_match = [rel[0] for rel in sorted_all_relations if rel[1] == top_clue]\n",
    "            print('select top relation')\n",
    "            pruned_relations = {}\n",
    "            if relations_for_match:\n",
    "                pruned_relations, token_num = prune_node(matched_clues, relations_for_match, pruned_relations) # {clue: [(r, label, score)]}\n",
    "                input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                self.llm_call += 1\n",
    "            print('pruned_info:', pruned_relations)\n",
    "        top_relations = []\n",
    "        top_clue = ''\n",
    "        for clue_r, relations in pruned_relations.items():\n",
    "            clues = clue_r.split('_')\n",
    "            if len(clues) != 3:\n",
    "                print('strange len of clue_r:', clue_r)\n",
    "                continue\n",
    "            if 'the target' in clues[0]:\n",
    "                top_clue = clues[0].replace(\"the target \", \"\")\n",
    "                if top_clue in candidate_clues:\n",
    "                    max_value = max(relations, key=lambda x: x[2])[2]  # 先找到最大的x[2]值\n",
    "                    top_relations = [rel[0] for rel in relations if rel[2] == max_value] \n",
    "                else:\n",
    "                    print('strange clue[0]:', clue_r)\n",
    "            elif 'the target' in clues[2]:\n",
    "                top_clue = clues[2].replace(\"the target \", \"\")\n",
    "                if top_clue in candidate_clues:\n",
    "                    max_value = max(relations, key=lambda x: x[2])[2]  # 先找到最大的x[2]值\n",
    "                    top_relations = [rel[0] for rel in relations if rel[2] == max_value]\n",
    "                else:\n",
    "                    print('strange clue[2]:', clue_r)\n",
    "            else:\n",
    "                print('strange clue:', clue_r)\n",
    "\n",
    "\n",
    "        new_batch = []\n",
    "        if len(args) == 3:\n",
    "            for node in batch:\n",
    "                for top_relation in top_relations:\n",
    "                    if top_relation in node.relation:\n",
    "                        new_batch.append(node)\n",
    "        else:\n",
    "                for info in relations_info:\n",
    "                    if info[0] in top_relations:\n",
    "                        print('get related nodes of ', info[1], info[0])\n",
    "                        try:\n",
    "                            nodes = get_related_nodes(info[1], info[0]) # [(label, name, id, relation, former_id)]\n",
    "                        except Exception as e:\n",
    "                            print('error:', e)\n",
    "                            continue\n",
    "                        for node in nodes:\n",
    "                            new_batch.append(Node(name=node[1], id=node[2], clue=top_clue, label=node[0], relations={node[4]: node[3]}))\n",
    "        return new_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "        \n",
    "\n",
    "    def intersection(self, batches, next_clue):\n",
    "        \"\"\"\n",
    "        1个或多个分支的交集，第一个返回值是交集里的所有Node，第二个返回值没啥用\n",
    "        \"\"\"\n",
    "        new_batch = []\n",
    "        relations_return = set()\n",
    "         \n",
    "        batch_ids = [set(node.id for node in batch) for batch in batches]\n",
    "        if next_clue in self.init_clue2node.keys():\n",
    "            related_nodes_in_batches = []\n",
    "            init_node = self.init_clue2node[next_clue]\n",
    "            related_nodes = get_related_nodes_by_id(init_node.id)  ## (r, name, id)\n",
    "            \n",
    "            for batch_id in batch_ids:\n",
    "                related_nodes_in_batch = [node for node in related_nodes if node[2] in batch_id]\n",
    "                if related_nodes_in_batch == []:\n",
    "                    print(f'clue = {next_clue} has no related nodes in former batch')\n",
    "                    return None, None\n",
    "                related_nodes_in_batches.append(related_nodes_in_batch)\n",
    "            new_batch.append(Node(name=init_node.name, id=init_node.id, clue=init_node.clue, relations= {node[2]: node[1] for nodes in related_nodes_in_batches for node in nodes}))\n",
    "            return new_batch, None\n",
    "\n",
    "        base_batch = batches[0]\n",
    "        min_num = 10000\n",
    "        for batch in batches:\n",
    "            entities_num = 0\n",
    "            for node in batch:\n",
    "                entities_num += int(count_nodes(node.id))\n",
    "            if entities_num < min_num:\n",
    "                min_num = entities_num\n",
    "                base_batch = batch\n",
    "\n",
    "        if min_num > 5000:\n",
    "            print('batch size is too large, use llm_prune')\n",
    "            return None, None\n",
    "\n",
    "\n",
    "        related_nodes_cache = {}\n",
    "\n",
    "        def get_related_nodes_cached(id, name, related_id):\n",
    "            if id not in related_nodes_cache:\n",
    "                # 获取关联节点，并剔除node.related_id中的节点\n",
    "                print('intersection: get related nodes of ', name, id)\n",
    "                related_nodes = get_related_nodes_by_id(id)\n",
    "                related_nodes_filtered = [\n",
    "                    related_node for related_node in related_nodes\n",
    "                    if related_node[2] not in related_id\n",
    "                ]\n",
    "                rst = []\n",
    "                for r_node in related_nodes_filtered:\n",
    "                    if r_node[1] == 'UnName_Entity':\n",
    "                        # print('getting related nodes of UnName_Entity:', r_node[2])\n",
    "                        r_related_nodes = get_related_nodes_by_id(r_node[2])\n",
    "                        r_related_nodes = [\n",
    "                            r_related_node for r_related_node in r_related_nodes\n",
    "                            if r_related_node[2] not in [id]\n",
    "                        ]\n",
    "                        rst.extend(r_related_nodes)\n",
    "                    else:\n",
    "                        rst.append(r_node)\n",
    "                related_nodes_cache[id] = rst\n",
    "            return related_nodes_cache[id]\n",
    "\n",
    "        # 收集所有batch中节点的关联节点\n",
    "        related_nodes = set()\n",
    "        \n",
    "        for node in base_batch:\n",
    "                for related_node in get_related_nodes_cached(node.id, node.name, node.related_id):\n",
    "                    related_nodes.add((related_node[1], related_node[2]))  # (name, id)\n",
    "        \n",
    "        for node in related_nodes:\n",
    "            entities_num = 0\n",
    "            entities_num += int(count_nodes(node[1]))\n",
    "            if entities_num > 10000:\n",
    "                print('batch size is too large, use llm_prune')\n",
    "                return None, None\n",
    "\n",
    "        for related_node in related_nodes:\n",
    "            nodes = get_related_nodes_cached(related_node[1], related_node[0], []) # (r, name, id)\n",
    "            \n",
    "            matched_nodes = {}\n",
    "            is_candidate =  all(\n",
    "                (match := next((node for node in nodes if node[2] in id_set), None)) is not None\n",
    "                and (matched_nodes.update({match[2]: match[0]}) or True)\n",
    "                for id_set in batch_ids\n",
    "            ) or (matched_nodes.clear() and False)\n",
    "\n",
    "            if is_candidate:\n",
    "                for id, r in matched_nodes.items():\n",
    "                    relations_return.add((r, id, related_node[1])) # (r, original_node_id, related_node_id)\n",
    "               \n",
    "                # # 创建新的Node对象\n",
    "                new_node = Node(\n",
    "                    name=related_node[0],\n",
    "                    id=related_node[1],\n",
    "                    clue=next_clue,\n",
    "                    relations=matched_nodes\n",
    "                )\n",
    "                new_batch.append(new_node)\n",
    "                \n",
    "        \n",
    "        print('----------------- intersection ends ---------------')\n",
    "        return new_batch, relations_return\n",
    "                    \n",
    "def add_tokens(input_tok, output_tok, total_tok, token_num):\n",
    "    input_tok += token_num['input']\n",
    "    output_tok += token_num['output']\n",
    "    total_tok += token_num['total']\n",
    "    return input_tok, output_tok, total_tok"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1433c3f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import time\n",
    "from collections import Counter\n",
    "\n",
    "data_path = ''\n",
    "\n",
    "datas = prepare_dataset(data_path)\n",
    "\n",
    "count = 0\n",
    "for i, data in enumerate(datas):\n",
    "    start_time = time.time()\n",
    "    input_token = 0\n",
    "    output_token = 0\n",
    "    total_token = 0\n",
    "    print('the', i, 'question')\n",
    "    llm_call = 2\n",
    "    print('data:', data)\n",
    "    question = data['question']\n",
    "    reext_flag = 1   ### 重新抽取的机会，1代表不重新抽取，2代表重新抽取一次\n",
    "    # question = data['RawQuestion']\n",
    "    # question = \"what ocean is around hawaii\"\n",
    "    try:\n",
    "        # statement = 'Identify the country that borders France and contains an airport serving Nijmegen.'\n",
    "        # specific = ['france', 'nijmegen']\n",
    "        # generic = ['country', 'airport']\n",
    "        # unknown_entity= 'country'\n",
    "        while reext_flag:\n",
    "            statement, specific, generic, token_num = complete_info(question)\n",
    "            input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "            initial_nodes = []\n",
    "            for clue in specific:\n",
    "                if re.match(r'^[\"\\'].*[\"\\']$', clue):\n",
    "                    print('clue with quotes:', clue)\n",
    "                    nodes = query_node_by_name(clue[1:-1])\n",
    "                else:\n",
    "                    print('clue:', clue)\n",
    "                    nodes = query_node_by_name(clue)  # [(name, id)]\n",
    "                if len(nodes) == 1:\n",
    "                    node = nodes[0]\n",
    "                    initial_nodes.append(Node(node[0], node[1], clue))\n",
    "                elif len(nodes) > 1: ### freebase 实体经常包含大量同名的歌曲、书籍等\n",
    "                    node = nodes[0]\n",
    "                    initial_nodes.append(Node(node[0], node[1], clue))\n",
    "                else:\n",
    "                    print('no node found for clue:', clue)\n",
    "\n",
    "            if initial_nodes == []: ###  如果找不到起始点，那么重新抽一遍，再找不到就算了\n",
    "                reext_flag -= 1           ###  可以看看效果好不好，不好的话，设计一个重新抽取的prompt\n",
    "            else:                   ###  又或者，如果发生次数都很少，那就不管了\n",
    "                break\n",
    "        \n",
    "        print('initial_nodes:', initial_nodes)\n",
    "        triples, unknown_entity, token_num = structure(statement, generic, specific)\n",
    "        input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "            \n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "        continue\n",
    "    \n",
    "    entity_counter = Counter()\n",
    "    all_clues = []\n",
    "    try:\n",
    "        updated_triples = []\n",
    "        for head, relation, tail in triples:\n",
    "            # Remove double quotes from head and tail if present\n",
    "            head = head.strip('\"')\n",
    "            tail = tail.strip('\"')\n",
    "            \n",
    "            if head.lower() == tail.lower():\n",
    "                tail = f\"{tail}_modified\"\n",
    "            if head.lower() not in all_clues:\n",
    "                all_clues.append(head.lower())\n",
    "            if tail.lower() not in all_clues:\n",
    "                all_clues.append(tail.lower())\n",
    "            updated_triples.append((head.lower(), relation, tail.lower()))\n",
    "        triples = updated_triples\n",
    "    except:\n",
    "        continue\n",
    "\n",
    "    print('statement:', statement)\n",
    "    print('specific:', specific)\n",
    "    print('generic:', generic)\n",
    "    print('unknown_entity:', unknown_entity)\n",
    "    print('triples:', triples)\n",
    "    \n",
    "    print('all_clues:', all_clues)\n",
    "    \n",
    "                    \n",
    "\n",
    "    \n",
    "    cot_answer = False\n",
    "    pairs = [(head, tail) for (head, _, tail) in triples]\n",
    "    if initial_nodes != []:\n",
    "        exploration = Exploration(initial_nodes, pairs, triples, unknown_entity, specific, statement)\n",
    "        print('init clue:', exploration.chain.clues)\n",
    "        print('exploration begins......')\n",
    "    else:\n",
    "        if specific:\n",
    "            first_specific = specific[0]\n",
    "            exploration = Exploration([Node(first_specific, first_specific+'_0', first_specific,graph=False)], pairs, triples, unknown_entity, specific, statement)\n",
    "        else:\n",
    "            exploartion = Exploration([Node(unknown_entity, unknown_entity+'_0', unknown_entity,graph=False)], pairs, [], unknown_entity, specific, statement)\n",
    "    while True:\n",
    "            \"\"\"\n",
    "            next_clue还有，就接着explore\n",
    "            explore能返回batch_of_nodes，就接着下一循环\n",
    "            \"\"\"\n",
    "            if exploration.exceed == 1:\n",
    "                print('exceed the limit')\n",
    "                break\n",
    "            if exploration.chain.next_clue is None:\n",
    "                break\n",
    "\n",
    "                \n",
    "            batch, token_num = exploration.explore()\n",
    "            input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "            ### 如果 batch为空，在add时会将next_clue加入explored_clues，并继续找新的next_clue\n",
    "            exploration.chain.add_batch(batch)\n",
    "            print('clues done:', exploration.chain.clues)\n",
    "            # print(exploration.chain.generate_triples())\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    llm_call += exploration.llm_call\n",
    "    result_triplets = exploration.chain.generate_triples()\n",
    "    in_graph = exploration.chain.in_graph\n",
    "    last_generic_in_graph = exploration.chain.last_generic_in_graph\n",
    "    if len(result_triplets) > 30:\n",
    "            result_triplets = result_triplets[:30]\n",
    "    relation_bridge = exploration.relation_bridge\n",
    "    entity_bridge = exploration.entity_bridge\n",
    "\n",
    "    print('result_triplets:', result_triplets)\n",
    "    success_flag = 0\n",
    "\n",
    "\n",
    "    try:\n",
    "        ### 如果在图里，就基于triples回答，cot回答实际为triples回答,cot_answer=False\n",
    "        ### 如果不在图里，记录triples回答和基于cot回答，cot_answer=True\n",
    "        ### 如果针对triples回答进行evaluate，则测试GGoT得到三元组的作用\n",
    "        ### 如果针对cot回答进行evaluate，则测试动态选择是否依据三元组回答的作用\n",
    "        ### 如果筛选出所有cot_answer为False的样本，则测试对象为最后一个generic在图中的问题\n",
    "        ### 如果筛选出所有cot_answer为True的样本，则测试对象为最后一个generic不在图中的问题\n",
    "        answer_w_triples, token_num = answer_question(question, result_triplets, success_flag)  ### 总是基于result_triplets回答问题\n",
    "        input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "        llm_call += 1\n",
    "        if last_generic_in_graph:\n",
    "            answer_w_cot = answer_w_triples\n",
    "            cot_answer = False\n",
    "        else:\n",
    "            answer_w_cot, token_num = chat('please answer the question step by step: ' + question)\n",
    "            # answer_w_cot, token_num = chat_llama('please answer the question step by step: ' + question)\n",
    "            cot_answer = True\n",
    "            llm_call += 1\n",
    "    except Exception as e:\n",
    "            print(e)\n",
    "\n",
    "    # print('input_token:', input_token)\n",
    "    # print('output_token:', output_token)\n",
    "    # print('total_token:', total_token)\n",
    "    time_used = time.time() - start_time\n",
    "    result = {\n",
    "            'data': data,\n",
    "            'structure': triples,\n",
    "            'answer_w_triples': answer_w_triples,\n",
    "            'answer_w_cot': answer_w_cot,\n",
    "            'llm_call': llm_call,\n",
    "            'triples': result_triplets,\n",
    "            'input_token': input_token,\n",
    "            'output_token': output_token,\n",
    "            'total_token': total_token,\n",
    "            'time_used': time_used,\n",
    "            'cot_answer': cot_answer,\n",
    "            'relation_bridge': relation_bridge,\n",
    "            'entity_bridge': entity_bridge,\n",
    "            'in_graph': in_graph,\n",
    "            'last_generic_in_graph': last_generic_in_graph\n",
    "        }\n",
    "\n",
    "    with open('', 'a', encoding='utf-8') as file:\n",
    "            json.dump(result, file, ensure_ascii=False)\n",
    "            file.write('\\n')\n",
    "    count+=1\n",
    "    print(count, 'done')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myjupyter",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
