{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e22ee0dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "importing Jupyter notebook from utils.ipynb\n",
      "/home/taodehao/miniconda3/envs/myjupyter/bin/python\n"
     ]
    }
   ],
   "source": [
    "import Ipynb_importer\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\" \n",
    "from utils import *\n",
    "import re\n",
    "import json\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "68448aab",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "class Exploration_record:\n",
    "\n",
    "    def __init__(self, init_nodes, triples, unknown_entity, all_clues):\n",
    "\n",
    "        ### global\n",
    "        self.expected_structure = triples\n",
    "        self.initial_nodes = {}\n",
    "        for node in init_nodes:\n",
    "            self.initial_nodes[node.clue] = self.initial_nodes.get(node.clue, []) + [node]\n",
    "        self.unexplored_initial_clues = set(list(self.initial_nodes.keys())[1:])\n",
    "        self.unknown_entity = unknown_entity\n",
    "        self.all_clues = set(all_clues)\n",
    "        self.connection = {}\n",
    "        self.results = []\n",
    "\n",
    "        ### path exploration from the initial nodes\n",
    "        self.explored_clues = []\n",
    "        self.unfound_clues = []\n",
    "        self.left_clues = set(all_clues) - set([init_nodes[0].clue])\n",
    "        self.beyond_expectation = {}  ###{clue: [(r, former_id, score)]}\n",
    "        \n",
    "        ### one round exploration\n",
    "        self.current_clue2nodes = {init_nodes[0].clue: [init_nodes[0]]}\n",
    "        self.expected_clues = self.get_expected_clues([init_nodes[0].clue])\n",
    "\n",
    "        \n",
    "\n",
    "    ### call after each round of exploration\n",
    "    def early_stop_check(self):\n",
    "        sum_clues = set()\n",
    "        for couple in self.connection.keys():\n",
    "            for clue in couple:\n",
    "                sum_clues.add(clue)\n",
    "        if sum_clues != self.all_clues:\n",
    "            return False\n",
    "        return self.update_path() \n",
    "\n",
    "    ### call only when self.expected_clues == {}\n",
    "    def success_check(self):\n",
    "        return self.update_path()\n",
    "        \n",
    "\n",
    "\n",
    "\n",
    "class Node:\n",
    "    def __init__(self, name, id, clue, label=None, relations: dict[str, str] = None):\n",
    "        self.name = name\n",
    "        self.id = id\n",
    "        self.clue = clue\n",
    "        self.label = label\n",
    "        if relations is not None:\n",
    "            self.relations = relations ### {related_id: relation}\n",
    "        else:\n",
    "            self.relations = {}\n",
    "        self.related_id = []\n",
    "        self.relation = []\n",
    "        if relations is not None:\n",
    "            for related_id, relation in relations.items():\n",
    "                self.related_id.append(related_id)\n",
    "                self.relation.append(relation)\n",
    "        \n",
    "\n",
    "\n",
    "class NodeChain:\n",
    "    def __init__(self, init_node, pairs, set_clue=None):\n",
    "        self.batches = [[init_node]]  # 存储所有批次的节点\n",
    "        self.clues = [init_node.clue]    # 存储每个批次的 clue\n",
    "        self.pairs = pairs    # 存储头尾实体的关联关系\n",
    "        self.checked_batches = set()  # 缓存已经检查过的批次索引 \n",
    "        self.explored_clues = set(init_node.clue)  # 表明这个clue我已经找过了\n",
    "        self.next_clue, self.current_clue, self.current_pair = self._get_next_clue(set_clue=set_clue)\n",
    "        self.potential_clues = self._get_potential_clues()  # 初始化 potential_clues\n",
    "        \n",
    "\n",
    "    def add_batch(self, new_nodes):\n",
    "        if not new_nodes:\n",
    "            self.explored_clues.add(self.next_clue)\n",
    "            self.next_clue, self.current_clue, self.current_pair = self._get_next_clue()\n",
    "            return  # 如果新批次为空，直接返回\n",
    "\n",
    "        # 检查新批次的 clue 是否与之前批次的 clue 重复\n",
    "        new_clue = new_nodes[0].clue  # 同一批次的节点有同一个 clue\n",
    "        if new_clue in self.clues:\n",
    "            raise ValueError(f\"Clue '{new_clue}' already exists in a previous batch.\")\n",
    "        \n",
    "        self.explored_clues.add(self.next_clue)\n",
    "        # 将新批次节点加入总批次列表\n",
    "        self.batches.append(new_nodes)\n",
    "        self.clues.append(new_clue)\n",
    "\n",
    "        # 重置缓存，因为新批次可能引入新的关联关系\n",
    "        self.checked_batches = set()\n",
    "\n",
    "        # 检查并删除无法关联到最新批次的节点\n",
    "        self._prune_unrelated_nodes()\n",
    "\n",
    "        self.next_clue, self.current_clue, self.current_pair = self._get_next_clue()\n",
    "        \n",
    "        # self.potential_clues = self._get_potential_clues()\n",
    "\n",
    "    \n",
    "\n",
    "    def generate_triples(self):\n",
    "        triples = set()\n",
    "        # 遍历所有 clue 对\n",
    "        print('generate triples------------')\n",
    "        for clue1, clue2 in self.pairs:\n",
    "            print('pairs:', clue1, clue2)\n",
    "            # 找到与 clue1 对应的批次\n",
    "            batch1 = self._find_batch_by_clue(clue1)\n",
    "            if batch1 is None:\n",
    "                continue\n",
    "            # 找到与 clue2 对应的批次\n",
    "            batch2 = self._find_batch_by_clue(clue2)\n",
    "            if batch2 is None:\n",
    "                continue\n",
    "            # 遍历 batch1 和 batch2 中的节点，生成三元组\n",
    "            for node1 in batch1:\n",
    "                for node2 in batch2:\n",
    "                    # 检查 node1.id 是否在 node2.relations 中\n",
    "                    if node1.id in node2.relations:\n",
    "                        relation = node2.relations[node1.id]\n",
    "                        triples.add((node1.name, relation, node2.name))\n",
    "                    # 检查 node2.id 是否在 node1.relations 中\n",
    "                    elif node2.id in node1.relations:\n",
    "                        relation = node1.relations[node2.id]\n",
    "                        triples.add((node2.name, relation, node1.name))\n",
    "        return list(triples)\n",
    "\n",
    "    def _get_next_clue(self, set_clue=None):\n",
    "        \"\"\"获取下一个要探索的 clue\"\"\"\n",
    "        # 检查当前批次的 clue 是否在 pairs 中作为头或尾实体\n",
    "        for pair in self.pairs:\n",
    "            if set_clue is not None:\n",
    "                if set_clue not in pair:\n",
    "                    continue\n",
    "            if pair[0] in self.clues and pair[1] not in self.clues and pair[1] not in self.explored_clues:\n",
    "                return pair[1], pair[0], pair\n",
    "            elif pair[1] in self.clues and pair[0] not in self.clues and pair[0] not in self.explored_clues:\n",
    "                return pair[0], pair[1], pair\n",
    "        return None, None, None\n",
    "    \n",
    "    def _get_potential_clues(self):\n",
    "        \"\"\"获取与 next_clue 相关联且不在 clues 中的潜在线索\"\"\"\n",
    "        potential_clues = set()\n",
    "        # if self.next_clue is not None:\n",
    "        #     for pair in self.pairs:\n",
    "        #         if pair[0] == self.next_clue and pair[1] not in self.clues:\n",
    "        #             potential_clues.add(pair[1])\n",
    "        #         elif pair[1] == self.next_clue and pair[0] not in self.clues:\n",
    "        #             potential_clues.add(pair[0])\n",
    "        if self.next_clue is not None:\n",
    "            all_clues = set()\n",
    "            for pair in self.pairs:\n",
    "                all_clues.add(pair[0])\n",
    "                all_clues.add(pair[1])\n",
    "            potential_clues = all_clues - set(self.clues) - set([self.next_clue])\n",
    "        return list(potential_clues)  # 返回一个列表\n",
    "    \n",
    "    def check_multi_branch(self):\n",
    "        clues = set()\n",
    "        clue2nodes = {}\n",
    "        for pair in self.pairs:\n",
    "            if self.next_clue == pair[0] and pair[1] in self.clues:\n",
    "                clues.add(pair[1])\n",
    "            elif self.next_clue == pair[1] and pair[0] in self.clues:\n",
    "                clues.add(pair[0])\n",
    "        if len(clues) > 1:\n",
    "            for clue in clues:\n",
    "                clue2nodes[clue] = copy.deepcopy(self._find_batch_by_clue(clue))\n",
    "            return clue2nodes\n",
    "        return None\n",
    "            \n",
    "\n",
    "\n",
    "    def _prune_unrelated_nodes(self):\n",
    "        \"\"\"从后往前检查每一批次节点，删除缺失了关联的节点\"\"\"\n",
    "        if not self.pairs:\n",
    "            return  # 如果没有关联关系，无需处理\n",
    "\n",
    "        # 从后往前遍历批次\n",
    "        for i in range(len(self.batches) - 1, -1, -1):\n",
    "            if i in self.checked_batches:\n",
    "                continue  # 如果批次已经检查过，跳过\n",
    "\n",
    "            batch = self.batches[i]\n",
    "            clue = self.clues[i]\n",
    "\n",
    "            # 检查当前批次的 clue 是否在 pairs 中作为头或尾实体\n",
    "            related_batches = []\n",
    "            for pair in self.pairs:\n",
    "                if clue == pair[0]:  # clue 是头实体\n",
    "                    related_batches.append(self._find_batch_by_clue(pair[1]))\n",
    "                elif clue == pair[1]:  # clue 是尾实体\n",
    "                    related_batches.append(self._find_batch_by_clue(pair[0]))\n",
    "\n",
    "            # 如果没有关联的批次，则删除当前批次\n",
    "            if not related_batches:\n",
    "                # print(f\"Removing batch with clue '{clue}' due to missing association.\")\n",
    "                # del self.batches[i]\n",
    "                # del self.clues[i]\n",
    "                print('no related batch, clue:', clue)\n",
    "                continue\n",
    "\n",
    "            # 检查当前批次与关联批次之间的节点是否通过 id 和 related_id 关联\n",
    "            for related_batch in related_batches:\n",
    "                if related_batch is not None:\n",
    "                    self._check_and_prune_nodes(batch, related_batch)\n",
    "\n",
    "            # 标记当前批次为已检查\n",
    "            self.checked_batches.add(i)\n",
    "\n",
    "    def _find_batch_by_clue(self, clue):\n",
    "        \"\"\"根据 clue 查找对应的批次\"\"\"\n",
    "        for i, c in enumerate(self.clues):\n",
    "            if c == clue:\n",
    "                return self.batches[i]\n",
    "        return None\n",
    "\n",
    "    def _check_and_prune_nodes(self, batch1, batch2):\n",
    "        \"\"\"检查两个批次之间的节点是否通过 id 和 related_id 关联，删除无关联的节点\"\"\"\n",
    "        # 遍历 batch1 中的节点\n",
    "        # for node1 in batch1[:]:  # 使用切片复制以避免修改迭代中的列表\n",
    "        #     has_relation = False\n",
    "        #     for node2 in batch2:\n",
    "        #         if node2.id in node1.related_id or node1.id in node2.related_id:\n",
    "        #             has_relation = True\n",
    "        #             break\n",
    "\n",
    "        #     # 如果 node1 与 batch2 中的任何节点都没有关联，则删除 node1\n",
    "        #     if not has_relation:\n",
    "        #         print(f\"Removing node {node1} due to missing association.\")\n",
    "        #         batch1.remove(node1)\n",
    "\n",
    "        # 遍历 batch2 中的节点\n",
    "        clue = batch2[0].clue\n",
    "        for node2 in batch2[:]:  # 使用切片复制以避免修改迭代中的列表\n",
    "            has_relation = False\n",
    "            for node1 in batch1:\n",
    "                if node2.id in node1.related_id or node1.id in node2.related_id:\n",
    "                    has_relation = True\n",
    "                    break\n",
    "\n",
    "            # 如果 node2 与 batch1 中的任何节点都没有关联，则删除 node2\n",
    "            if not has_relation:\n",
    "                print(f\"Removing node {node2} due to missing association.\")\n",
    "                batch2.remove(node2)\n",
    "        \n",
    "        if batch2 == []:\n",
    "            print(f'all nodes of {clue} are deleted!!!!!!')\n",
    "            batch2_index = self.batches.index(batch2)\n",
    "            del self.batches[batch2_index]\n",
    "            del self.clues[batch2_index]\n",
    "            return\n",
    "\n",
    "    def __repr__(self):\n",
    "        return f\"NodeChain(batches={self.batches}, clues={self.clues}, pairs={self.pairs})\"\n",
    "    \n",
    "class Exploration:\n",
    "    def __init__(self, init_nodes, pairs, triples, target, generic, statement):\n",
    "        self.init_clue2node = {node.clue: node for node in init_nodes}  # 初始化 clue 到 node 的映射\n",
    "        self.chain = NodeChain(init_nodes[0], pairs)  # 以第一个 init_node 作为起始，创建主要 chain\n",
    "        self.explored_init = [init_nodes[0].clue]  # 存储已经探索过的 init_clue\n",
    "        self.pairs = pairs  # 存储头尾实体的关联关系\n",
    "        self.triples = triples\n",
    "        self.target = target  # 目标 clue\n",
    "        # self.explored_clues = set()\n",
    "        self.generic = generic\n",
    "        self.statement = statement\n",
    "        self.llm_call = 0\n",
    "        self.next_flag = 0     ## 弃用，本来是用于查看映射是否指向next_clue\n",
    "        self.exceed = 0\n",
    "\n",
    "    def explore(self):\n",
    "            \"\"\"探索 chain 的下一个 clue\n",
    "            - 如果 next_clue 在 init_clue 中，调用 intersection 方法，如果没有交集，考虑1. 结构错误 2. init_node是抽象节点，如person\n",
    "            - 如果 next_clue 与 init_clue 在 pairs 中相连，调用 intersection 方法，如果没有交集，同上\n",
    "            - 以上是与init_clue夹逼，最后判断非init_clue，chain自身分支夹逼，如果没有交集，考虑 结构错误\n",
    "            - 无交集 则调用 llm_prune 方法\n",
    "            - 情况二和三，如果交集大于1，还应llm_prune\n",
    "           \n",
    "            \"\"\"\n",
    "            input_token = 0\n",
    "            output_token = 0\n",
    "            total_token = 0\n",
    "            next_clue = self.chain.next_clue\n",
    "            current_clue = self.chain.current_clue\n",
    "            current_batch = self.chain._find_batch_by_clue(current_clue)\n",
    "            if next_clue in self.explored_init:\n",
    "                return [] , {'input': input_token, 'output': output_token, 'total': total_token} # 如果 clue 已经被探索过，则停止\n",
    "\n",
    "            # 情况1: 如果 next_clue 是 init_clue 中的另外一个\n",
    "            if next_clue in self.init_clue2node.keys() and next_clue not in self.explored_init:\n",
    "                print('------  next_clue is init_clue  ------')\n",
    "                # 调用 intersection 方法\n",
    "                new_batch, relations_return = self.intersection([current_batch], next_clue)\n",
    "                if new_batch:\n",
    "                    return new_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "                else:\n",
    "                    new_batch, token_num = self.llm_prune(self.chain)  ###  无论如何，将next_clue加入explored_clues，不用判断是否potential\n",
    "                    input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                    new_batch, token_num = self.check_unnamed_entity(new_batch)\n",
    "                    input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                    return new_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "\n",
    "            # 情况2: 如果 next_clue 与 init_clue（不在 chain.clues 里）在 pairs 中相连\n",
    "            for init_node in self.init_clue2node.values():\n",
    "                if init_node.clue not in self.chain.clues and init_node.clue not in self.chain.explored_clues and init_node.clue not in self.explored_init and ((next_clue, init_node.clue) in self.pairs or (init_node.clue, next_clue) in self.pairs):\n",
    "                    print('------  next_clue and init_clue are in pairs  ------')\n",
    "                    batches = [current_batch, [init_node]]\n",
    "                    new_batch, relations_return = self.intersection(batches, next_clue)\n",
    "                    batch1 = []\n",
    "                    batch2 = []\n",
    "                    if new_batch:\n",
    "                        ## 如果任一边的关系数大于1，还得剪枝一次\n",
    "                        r_current = []\n",
    "                        r_init = []\n",
    "                        current_ids = [node.id for node in current_batch]\n",
    "                        for relation in relations_return:\n",
    "                            if relation[1] == init_node.id:\n",
    "                                r_init.append(relation)\n",
    "                            elif relation[1] in current_ids:\n",
    "                                r_current.append(relation)\n",
    "                        if len(r_current) > 1:\n",
    "                            batch1, token_num = self.llm_prune(current_batch, r_current, self.chain)\n",
    "                            input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                            batch1, token_num = self.check_unnamed_entity(batch1)\n",
    "                            input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                        else:\n",
    "                            batch1 = new_batch.copy()\n",
    "                        if len(r_init) > 1:\n",
    "                            batch2, token_num = self.llm_prune([init_node], r_init, self.chain)\n",
    "                            input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                            batch2, token_num = self.check_unnamed_entity(batch2)\n",
    "                            input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                        else:\n",
    "                            batch2 = new_batch.copy()\n",
    "                        new_batch = []\n",
    "                        for node in batch1:\n",
    "                            for node2 in batch2:\n",
    "                                if node.id == node2.id:\n",
    "                                    new_batch.append(node)\n",
    "               \n",
    "                    \n",
    "                    # 不存在交集，则开始llm校验\n",
    "                    if new_batch == []:\n",
    "                        batch1, token_num = self.llm_prune(self.chain)\n",
    "                        input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                        batch1, token_num = self.check_unnamed_entity(batch1)\n",
    "                        input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                        tmp_chain = NodeChain(init_node, self.pairs, set_clue=next_clue)\n",
    "                        batch2, token_num = self.llm_prune(tmp_chain)\n",
    "                        input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                        batch2, token_num = self.check_unnamed_entity(batch2)\n",
    "                        input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                        str1 = self.chain.current_pair[0] + ' ' + batch1[0].relation[0] + ' ' + self.chain.current_pair[1]\n",
    "                        str2 = tmp_chain.current_pair[0] + ' ' + batch2[0].relation[0] + ' ' + tmp_chain.current_pair[1]\n",
    "                        choosen, token_num = self.choose_one(str1, str2, tmp_chain)\n",
    "                        input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                        if choosen == 1:\n",
    "                            new_batch = batch1\n",
    "                        elif choosen == 2:\n",
    "                            new_batch = batch2\n",
    "                        else:\n",
    "                            print('choose error')\n",
    "                            return [], {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "                        \n",
    "                        \n",
    "                    return new_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "                    \n",
    "                    \n",
    "            \n",
    "            clue2nodes = self.chain.check_multi_branch()\n",
    "            if clue2nodes:\n",
    "                print('------  multi branch  ------')\n",
    "                batches = list(clue2nodes.values())\n",
    "                new_batch, relations_return = self.intersection(batches, next_clue)\n",
    "                if new_batch:\n",
    "                    if len(new_batch) == 1:\n",
    "                        return new_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "                    else:\n",
    "                        new_batch, token_num = self.llm_prune(new_batch, relations_return, self.chain)\n",
    "                        input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                        new_batch, token_num = self.check_unnamed_entity(new_batch)\n",
    "                        input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                        return new_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "                else:\n",
    "                    print('真的有需要合并的情况！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！')\n",
    "                    new_batch, token_num = self.llm_prune(self.chain)\n",
    "                    input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                    new_batch, token_num = self.check_unnamed_entity(new_batch)\n",
    "                    input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                    return new_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "            print('------  normal exploration  ------')\n",
    "            new_batch, token_num = self.llm_prune(self.chain)\n",
    "            input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "            new_batch, token_num = self.check_unnamed_entity(new_batch)\n",
    "            input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "            return new_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "\n",
    "    def choose_one(self, batch1_clue_triples, batch2_clue_triples, tmp_chain):\n",
    "        choosen, token_num = choose(self.statement, batch1_clue_triples, batch2_clue_triples)\n",
    "        print('choosen:', choosen)\n",
    "        ### init是比较确定的，所以llm校验仍然不选择init，就放入explored_clues\n",
    "        ### 但如果选的是init，choose的另一分支可能出发点就错了，因此还可以再探索一遍\n",
    "        if choosen == '2':\n",
    "            self.explored_init.append(tmp_chain.current_clue)\n",
    "            self.chain = tmp_chain\n",
    "            return 2, token_num\n",
    "        elif choosen == '1':\n",
    "            self.chain.explored_clues.add(tmp_chain.current_clue)\n",
    "            self.explored_init.append(tmp_chain.current_clue)\n",
    "            return 1, token_num\n",
    "        else:\n",
    "            print('choose error')\n",
    "            return 0, token_num\n",
    "\n",
    "\n",
    "    def check_unnamed_entity(self, batch):\n",
    "        \"\"\"\n",
    "        检查是否存在未命名的实体\n",
    "        \"\"\"\n",
    "        print('check unnamed entity--------------')\n",
    "        related_nodes = []\n",
    "        new_batch = []\n",
    "        token_num = {'input': 0, 'output': 0, 'total': 0}\n",
    "        if batch == []:\n",
    "            return [], token_num\n",
    "        for node in batch:\n",
    "            if node.name == 'UnName_Entity':\n",
    "                tmp = get_related_nodes_by_id(node.id) ### (r, name, id)\n",
    "                tmp = [(t[0], t[1], t[2], node.relations) for t in tmp if t[2] not in node.related_id]\n",
    "                for t in tmp:\n",
    "                    if t[0] in node.relation:\n",
    "                        # print('unName_entity:', t[1])\n",
    "                        new_batch.append(Node(name=t[1], id=t[2], clue=node.clue, relations=t[3]))\n",
    "                    else:\n",
    "                        related_nodes.append(t)\n",
    "            else:\n",
    "                new_batch.append(node)\n",
    "                # print('named entity:', node.name)\n",
    "\n",
    "        # 进一步过滤 related_nodes\n",
    "        related_nodes = [node for node in related_nodes if node[2] not in [n.id for n in new_batch]]\n",
    "        print('related_nodes of UnNamed:', related_nodes)\n",
    "        \n",
    "        if related_nodes:    \n",
    "            UnName_relations = [node[0] for node in related_nodes]\n",
    "            UnName_relations = list(set(UnName_relations))\n",
    "            clue = batch[0].clue\n",
    "            matched_clues = []\n",
    "            pair = self.chain.current_pair\n",
    "            for triple in self.triples:\n",
    "                if pair[0] == triple[0] and pair[1] == triple[2]:\n",
    "                    if clue == pair[0]:\n",
    "                        matched_clues.append(f\"the target {clue}_{triple[1]}_{triple[2]}\")  # clue + \"_\" + triple[2]\n",
    "                    elif clue == pair[1]:\n",
    "                        matched_clues.append(f\"{triple[0]}_{triple[1]}_the target {clue}\")  # triple[0] + \"_\" + clue\n",
    "                    break\n",
    "            \n",
    "            UnName_relations_result, token_num = prune_node(matched_clues, UnName_relations, {})\n",
    "            UnName_relations_pruned = UnName_relations_result.get(matched_clues[0], [])\n",
    "            print('UnName_relations_pruned:', UnName_relations_pruned)\n",
    "            self.llm_call += 1\n",
    "            if UnName_relations_pruned != []:\n",
    "                for UnName_relation in UnName_relations_pruned:\n",
    "                    for node in related_nodes:\n",
    "                        if node[0] == UnName_relation[0]:\n",
    "                            print('unName_entity:', node[1])\n",
    "                            new_batch.append(Node(name=node[1], id=node[2], clue=clue, relations=node[3]))\n",
    "        return new_batch, token_num\n",
    "\n",
    "    def llm_prune(self, *args):\n",
    "        \"\"\"\n",
    "        根据输入参数的数量执行不同的操作逻辑。\n",
    "        - 如果传入一个参数（chains），说明是无限制探索，返回值应该是batch_of_nodes，记得把next_clue加入explored_clues\n",
    "        - 如果传入三个参数，则要基于所给节点进行剪枝\n",
    "        \"\"\"\n",
    "        relations_info = []\n",
    "        batch = []\n",
    "        input_token = 0\n",
    "        output_token = 0\n",
    "        total_token = 0\n",
    "        if len(args) == 1:\n",
    "            # 逻辑 A：传入 chains\n",
    "            chain = args[0]\n",
    "            batch = chain._find_batch_by_clue(chain.current_clue)\n",
    "            \n",
    "            for node in batch:\n",
    "                tmp = get_relations(node.id)   # [(r, id)]\n",
    "                print('current_pair', chain.current_pair)\n",
    "                print('get relations of ', node.name)\n",
    "                if node.related_id != None: ## not the initial node\n",
    "                    tmp = [t for t in tmp if t[0] != node.relation]\n",
    "                relations_info.extend(tmp) # [(r, id)]\n",
    "                relations_info = list(set(relations_info))\n",
    "        elif len(args) == 3:\n",
    "            # 逻辑 B：传入 batches_of_nodes 和 relations\n",
    "            batch, relations_info, chain = args\n",
    "        else:\n",
    "            raise ValueError(\"Invalid number of arguments. Expected 1 or 2 arguments.\")\n",
    "             \n",
    "        relations = [info[0] for info in relations_info]\n",
    "        if relations == []:\n",
    "            print(chain.current_clue, 'has no relations')\n",
    "            return [], {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "        relations = list(set(relations))\n",
    "        candidate_clues = []\n",
    "        if len(args) == 3:\n",
    "            candidate_clues = [chain.next_clue]\n",
    "        else:\n",
    "            candidate_clues = [chain.next_clue]\n",
    "            # candidate_clues = [chain.next_clue] + chain.potential_clues\n",
    "        matched_clues = []\n",
    "        \n",
    "       \n",
    "\n",
    "        \"\"\"\n",
    "        先不考虑结构会出错\n",
    "        \"\"\"\n",
    "        clue = chain.next_clue\n",
    "        pair = chain.current_pair\n",
    "        for triple in self.triples:\n",
    "            if pair[0] == triple[0] and pair[1] == triple[2]:\n",
    "                if clue == pair[0]:\n",
    "                    matched_clues.append(f\"the target {clue}_{triple[1]}_{triple[2]}\")  # clue + \"_\" + triple[2]\n",
    "                elif clue == pair[1]:\n",
    "                    matched_clues.append(f\"{triple[0]}_{triple[1]}_the target {clue}\")  # triple[0] + \"_\" + clue\n",
    "                break\n",
    "        \n",
    "        print('relations:', relations)\n",
    "        pruned_relations = {}\n",
    "        prune_once_more = 0 ### 如果多次剪枝了，需要再综合判断一次\n",
    "        if len(relations)/20 + self.llm_call > 30:\n",
    "            self.exceed = 1\n",
    "            return [], {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "        while len(relations) > 0:\n",
    "            prune_flag = 0\n",
    "            if len(relations) > 20:\n",
    "                prune_once_more = 1\n",
    "                relations_for_match = relations[:20]\n",
    "                relations = relations[20:]\n",
    "            else:\n",
    "                relations_for_match = relations\n",
    "                prune_flag = 1                  \n",
    "            pruned_relations, token_num = prune_node(matched_clues, relations_for_match, pruned_relations) # {clue: [(r, label, score)]}\n",
    "            input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "            self.llm_call += 1\n",
    "            print('pruned_info:', pruned_relations)\n",
    "\n",
    "            if prune_flag == 1:\n",
    "                break\n",
    "        if pruned_relations == {}:\n",
    "            print('no relations can be matched')\n",
    "            return [], {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        if prune_once_more == 1:\n",
    "            relations_for_match = []\n",
    "            matched_clues = []\n",
    "            for clue in pruned_relations.keys():\n",
    "                if 'the target ' + chain.next_clue in clue:\n",
    "                # 如果 next_clue 在 pruned_relations 中，取得分前三的 relation\n",
    "                    relations_for_match = pruned_relations[clue]\n",
    "                    relations_for_match = [rel[0] for rel in relations_for_match]\n",
    "                    matched_clues.append(clue)\n",
    "                    break\n",
    "            # if relations_for_match == []:\n",
    "            #     # 如果 next_clue 不在，取得分最高的 relation 以及同一个 clue 且得分仅次于它的两个 relation\n",
    "            #     self.next_flag = 1\n",
    "            #     all_relations = []\n",
    "            #     for clue, relations in pruned_relations.items():\n",
    "            #         all_relations.extend([(rel[0], clue, rel[2]) for rel in relations])\n",
    "                \n",
    "            #     # 按得分从高到低排序\n",
    "            #     sorted_all_relations = sorted(all_relations, key=lambda x: x[2], reverse=True)\n",
    "                \n",
    "            #     if not sorted_all_relations:\n",
    "            #         return []  # 如果没有 relations，返回空列表\n",
    "                \n",
    "            #     # 取得分最高的 relation\n",
    "            #     top_relation = sorted_all_relations[0] \n",
    "            #     top_clue = top_relation[1]\n",
    "            #     matched_clues.append(top_clue)\n",
    "                \n",
    "            #     # 找到同一个 clue 且得分仅次于它的 relation\n",
    "            #     relations_for_match = [rel[0] for rel in sorted_all_relations if rel[1] == top_clue]\n",
    "            print('select top relation')\n",
    "            pruned_relations = {}\n",
    "            if relations_for_match:\n",
    "                pruned_relations, token_num = prune_node(matched_clues, relations_for_match, pruned_relations) # {clue: [(r, label, score)]}\n",
    "                input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "                self.llm_call += 1\n",
    "            print('pruned_info:', pruned_relations)\n",
    "        top_relations = []\n",
    "        top_clue = ''\n",
    "        for clue_r, relations in pruned_relations.items():\n",
    "            clues = clue_r.split('_')\n",
    "            if len(clues) != 3:\n",
    "                print('strange len of clue_r:', clue_r)\n",
    "                continue\n",
    "            if 'the target' in clues[0]:\n",
    "                top_clue = clues[0].replace(\"the target \", \"\")\n",
    "                if top_clue in candidate_clues:\n",
    "                    max_value = max(relations, key=lambda x: x[2])[2]  # 先找到最大的x[2]值\n",
    "                    top_relations = [rel[0] for rel in relations if rel[2] == max_value] \n",
    "                else:\n",
    "                    print('strange clue[0]:', clue_r)\n",
    "            elif 'the target' in clues[2]:\n",
    "                top_clue = clues[2].replace(\"the target \", \"\")\n",
    "                if top_clue in candidate_clues:\n",
    "                    max_value = max(relations, key=lambda x: x[2])[2]  # 先找到最大的x[2]值\n",
    "                    top_relations = [rel[0] for rel in relations if rel[2] == max_value]\n",
    "                else:\n",
    "                    print('strange clue[2]:', clue_r)\n",
    "            else:\n",
    "                print('strange clue:', clue_r)\n",
    "\n",
    "\n",
    "        new_batch = []\n",
    "        if len(args) == 3:\n",
    "            for node in batch:\n",
    "                for top_relation in top_relations:\n",
    "                    if top_relation in node.relation:\n",
    "                        new_batch.append(node)\n",
    "        else:\n",
    "                for info in relations_info:\n",
    "                    if info[0] in top_relations:\n",
    "                        print('get related nodes of ', info[1], info[0])\n",
    "                        try:\n",
    "                            nodes = get_related_nodes(info[1], info[0]) # [(label, name, id, relation, former_id)]\n",
    "                        except Exception as e:\n",
    "                            print('error:', e)\n",
    "                            continue\n",
    "                        for node in nodes:\n",
    "                            new_batch.append(Node(name=node[1], id=node[2], clue=top_clue, label=node[0], relations={node[4]: node[3]}))\n",
    "        return new_batch, {'input': input_token, 'output': output_token, 'total': total_token}\n",
    "        \n",
    "\n",
    "    def intersection(self, batches, next_clue):\n",
    "        \"\"\"\n",
    "        1个或多个分支的交集，第一个返回值是交集里的所有Node，第二个返回值没啥用\n",
    "        \"\"\"\n",
    "        new_batch = []\n",
    "        relations_return = set()\n",
    "         \n",
    "        batch_ids = [set(node.id for node in batch) for batch in batches]\n",
    "        if next_clue in self.init_clue2node.keys():\n",
    "            related_nodes_in_batches = []\n",
    "            init_node = self.init_clue2node[next_clue]\n",
    "            related_nodes = get_related_nodes_by_id(init_node.id)  ## (r, name, id)\n",
    "            \n",
    "            for batch_id in batch_ids:\n",
    "                related_nodes_in_batch = [node for node in related_nodes if node[2] in batch_id]\n",
    "                if related_nodes_in_batch == []:\n",
    "                    print(f'clue = {next_clue} has no related nodes in former batch')\n",
    "                    return None, None\n",
    "                related_nodes_in_batches.append(related_nodes_in_batch)\n",
    "            new_batch.append(Node(name=init_node.name, id=init_node.id, clue=init_node.clue, relations= {node[2]: node[1] for nodes in related_nodes_in_batches for node in nodes}))\n",
    "            return new_batch, None\n",
    "\n",
    "        base_batch = batches[0]\n",
    "        min_num = 10000\n",
    "        for batch in batches:\n",
    "            entities_num = 0\n",
    "            for node in batch:\n",
    "                entities_num += int(count_nodes(node.id))\n",
    "            if entities_num < min_num:\n",
    "                min_num = entities_num\n",
    "                base_batch = batch\n",
    "\n",
    "        if min_num > 5000:\n",
    "            print('batch size is too large, use llm_prune')\n",
    "            return None, None\n",
    "\n",
    "\n",
    "        related_nodes_cache = {}\n",
    "\n",
    "        def get_related_nodes_cached(id, name, related_id):\n",
    "            if id not in related_nodes_cache:\n",
    "                # 获取关联节点，并剔除node.related_id中的节点\n",
    "                print('intersection: get related nodes of ', name, id)\n",
    "                related_nodes = get_related_nodes_by_id(id)\n",
    "                related_nodes_filtered = [\n",
    "                    related_node for related_node in related_nodes\n",
    "                    if related_node[2] not in related_id\n",
    "                ]\n",
    "                rst = []\n",
    "                for r_node in related_nodes_filtered:\n",
    "                    if r_node[1] == 'UnName_Entity':\n",
    "                        # print('getting related nodes of UnName_Entity:', r_node[2])\n",
    "                        r_related_nodes = get_related_nodes_by_id(r_node[2])\n",
    "                        r_related_nodes = [\n",
    "                            r_related_node for r_related_node in r_related_nodes\n",
    "                            if r_related_node[2] not in [id]\n",
    "                        ]\n",
    "                        rst.extend(r_related_nodes)\n",
    "                    else:\n",
    "                        rst.append(r_node)\n",
    "                related_nodes_cache[id] = rst\n",
    "            return related_nodes_cache[id]\n",
    "\n",
    "        # 收集所有batch中节点的关联节点\n",
    "        related_nodes = set()\n",
    "        \n",
    "        for node in base_batch:\n",
    "                for related_node in get_related_nodes_cached(node.id, node.name, node.related_id):\n",
    "                    related_nodes.add((related_node[1], related_node[2]))  # (name, id)\n",
    "        \n",
    "        for node in related_nodes:\n",
    "            entities_num = 0\n",
    "            entities_num += int(count_nodes(node[1]))\n",
    "            if entities_num > 10000:\n",
    "                print('batch size is too large, use llm_prune')\n",
    "                return None, None\n",
    "\n",
    "        for related_node in related_nodes:\n",
    "            nodes = get_related_nodes_cached(related_node[1], related_node[0], []) # (r, name, id)\n",
    "            \n",
    "            matched_nodes = {}\n",
    "            is_candidate =  all(\n",
    "                (match := next((node for node in nodes if node[2] in id_set), None)) is not None\n",
    "                and (matched_nodes.update({match[2]: match[0]}) or True)\n",
    "                for id_set in batch_ids\n",
    "            ) or (matched_nodes.clear() and False)\n",
    "\n",
    "            if is_candidate:\n",
    "                for id, r in matched_nodes.items():\n",
    "                    relations_return.add((r, id, related_node[1])) # (r, original_node_id, related_node_id)\n",
    "               \n",
    "                # # 创建新的Node对象\n",
    "                new_node = Node(\n",
    "                    name=related_node[0],\n",
    "                    id=related_node[1],\n",
    "                    clue=next_clue,\n",
    "                    relations=matched_nodes\n",
    "                )\n",
    "                new_batch.append(new_node)\n",
    "                \n",
    "        \n",
    "        print('----------------- intersection ends ---------------')\n",
    "        return new_batch, relations_return\n",
    "                    \n",
    "def add_tokens(input_tok, output_tok, total_tok, token_num):\n",
    "    input_tok += token_num['input']\n",
    "    output_tok += token_num['output']\n",
    "    total_tok += token_num['total']\n",
    "    return input_tok, output_tok, total_tok"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1433c3f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import time\n",
    "from collections import Counter\n",
    "\n",
    "data_path = ''\n",
    "datas = prepare_dataset(data_path)\n",
    "\n",
    "count = 0\n",
    "for i, data in enumerate(datas):\n",
    "    start_time = time.time()\n",
    "    input_token = 0\n",
    "    output_token = 0\n",
    "    total_token = 0\n",
    "    print('the', i, 'question')\n",
    "    llm_call = 2\n",
    "    print('data:', data)\n",
    "    question = data['question']\n",
    "\n",
    "    try:\n",
    "\n",
    "        statement, specific, generic, token_num = complete_info(question)\n",
    "        input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "        triples, unknown_entity, token_num = structure(statement, generic, specific)\n",
    "        input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "        continue\n",
    "    \n",
    "    entity_counter = Counter()\n",
    "    all_clues = []\n",
    "    try:\n",
    "        for head, relation, tail in triples:\n",
    "            entity_counter[head] += 1\n",
    "            entity_counter[tail] += 1\n",
    "            if head not in all_clues:\n",
    "                all_clues.append(head.lower())\n",
    "            if tail not in all_clues:\n",
    "                all_clues.append(tail.lower())\n",
    "    except:\n",
    "        continue\n",
    "    ###\n",
    "    ### 如果没有全部连通，则检查是否有一个三元组的尾实体出现在另一个三元组的头实体字符串中\n",
    "    ###\n",
    "    # 筛选出只出现了一次的实体\n",
    "    pairs = [(head.lower(), tail.lower()) for (head, _, tail) in triples]\n",
    "    entities_once = [entity for entity, count in entity_counter.items() if count == 1]\n",
    "    if unknown_entity not in all_clues:\n",
    "        print('unknown_entity not in all_clues')\n",
    "\n",
    "    print('statement:', statement)\n",
    "    print('specific:', specific)\n",
    "    print('generic:', generic)\n",
    "    print('unknown_entity:', unknown_entity)\n",
    "    print('triples:', triples)\n",
    "    initial_nodes = []\n",
    "    print('all_clues:', all_clues)\n",
    "    for clue in all_clues:\n",
    "        if clue.lower() in specific:\n",
    "            if re.match(r'^[\"\\'].*[\"\\']$', clue):\n",
    "                nodes = query_node_by_name(clue[1:-1])\n",
    "            else:\n",
    "                nodes = query_node_by_name(clue)  # [(name, id)]\n",
    "            if len(nodes) == 1:\n",
    "                node = nodes[0]\n",
    "                initial_nodes.append(Node(node[0], node[1], clue))\n",
    "            elif len(nodes) > 1: ### freebase 实体经常包含大量同名的歌曲、书籍等\n",
    "                topic_entity = data['topic_entity']\n",
    "                for node in nodes:\n",
    "                    if node[1] in topic_entity.keys():\n",
    "                        print(node[0], node[1])\n",
    "                        initial_nodes.append(Node(node[0], node[1], clue))\n",
    "                        break\n",
    "                    \n",
    "    \n",
    "\n",
    "    \n",
    "    cot_answer = False\n",
    "    if initial_nodes != []:\n",
    "        pairs = [(head, tail) for (head, _, tail) in triples]\n",
    "        exploration = Exploration(initial_nodes, pairs, triples, unknown_entity, generic, statement)\n",
    "        print('init clue:', exploration.chain.clues)\n",
    "        print('exploration begins......')\n",
    "        while True:\n",
    "            \"\"\"\n",
    "            next_clue还有，就接着explore\n",
    "            explore能返回batch_of_nodes，就接着下一循环\n",
    "            \"\"\"\n",
    "            if exploration.exceed == 1:\n",
    "                print('exceed the limit')\n",
    "                break\n",
    "            if exploration.chain.next_clue is None:\n",
    "                break\n",
    "\n",
    "                \n",
    "            batch, token_num = exploration.explore()\n",
    "            input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "            ### 如果 batch为空，在add时会将next_clue加入explored_clues，并继续找新的next_clue\n",
    "            exploration.chain.add_batch(batch)\n",
    "            print('clues done:', exploration.chain.clues)\n",
    "            # print(exploration.chain.generate_triples())\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "        llm_call += exploration.llm_call\n",
    "        result_triplets = exploration.chain.generate_triples()\n",
    "        if len(result_triplets) > 30:\n",
    "            result_triplets = result_triplets[:30]\n",
    "        next_flag = exploration.next_flag\n",
    "        print('next_flag:', next_flag)\n",
    "        print('result_triplets:', result_triplets)\n",
    "        success_flag = 0\n",
    "\n",
    "\n",
    "        try:\n",
    "            answer_w_triples, token_num = answer_question(question, result_triplets, success_flag)  ### 总是基于result_triplets回答问题\n",
    "            input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "            llm_call += 1\n",
    "            \n",
    "\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "    else:\n",
    "        print('no initial nodes')\n",
    "        result_triplets = []\n",
    "        answer_w_triples, token_num = chat('please answer the question step by step: ' + question)\n",
    "        # answer_w_triples, token_num = chat('please answer the question step by step: ' + question)\n",
    "        input_token, output_token, total_token = add_tokens(input_token, output_token, total_token, token_num)\n",
    "        cot_answer = True\n",
    "        llm_call += 1\n",
    "    print('input_token:', input_token)\n",
    "    print('output_token:', output_token)\n",
    "    print('total_token:', total_token)\n",
    "    time_used = time.time() - start_time\n",
    "    result = {\n",
    "            'data': data,\n",
    "            'structure': triples,\n",
    "            'answer_w_triples': answer_w_triples,\n",
    "\n",
    "            'llm_call': llm_call,\n",
    "            'triples': result_triplets,\n",
    "            'input_token': input_token,\n",
    "            'output_token': output_token,\n",
    "            'total_token': total_token,\n",
    "            'time_used': time_used,\n",
    "            'cot_answer': cot_answer\n",
    "        }\n",
    "    # print('result:', result)\n",
    "    # break\n",
    "    with open('', 'a', encoding='utf-8') as file:\n",
    "            json.dump(result, file, ensure_ascii=False)\n",
    "            file.write('\\n')\n",
    "    count+=1\n",
    "    print(count, 'done')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myjupyter",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
