{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "G 中点的数目: 48126\n",
      "G 中来自 G1 的点的数目: 43005\n",
      "G 中来自 G2 的点的数目: 6221\n",
      "G 中来自 G2 的点中与 G1 有连边的点的数目: 5068\n",
      "新增点的数目5121\n",
      "新增点的数目与 G1 有连边的点的数目: 3968\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'\\n# 可视化图\\nplt.figure(figsize=(12, 8))\\npos = nx.spring_layout(G, seed=42)  # 使用 spring 布局\\nnx.draw(G, pos, with_labels=True, node_size=300, node_color=color_map, font_size=10, font_weight=\"bold\")\\nplt.title(\"Merged Graph G with Different Edge Labels\")\\nplt.savefig(\"merged_graph.png\")  # 保存图形到本地文件\\nplt.close()  # 关闭图形\\n'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import json\n",
    "import re\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# 定义正则表达式模式\n",
    "patterns = {\n",
    "    'P17': r'^Who is (.+) married to\\?$',\n",
    "    'P19': r'^Where was (.+) born\\?$',\n",
    "    'P20': r'^Where did (.+) die\\?$',\n",
    "    'P26': r'^Who is (.+) married to\\?$',\n",
    "    'P36': r'^What is the capital of (.+)\\?$',\n",
    "    'P40': r'^Who is (.+)\\'s child\\?$',\n",
    "    'P50': r'^Who is the author of (.+)\\?$',\n",
    "    'P69': r'^Where was (.+) educated\\?$',\n",
    "    'P106': r'^What kind of work does (.+) do\\?$',\n",
    "    'P112': r'^Who founded (.+)\\?$',\n",
    "    'P127': r'^Who owns (.+)\\?$',\n",
    "    'P131': r'^Where is (.+) located\\?$',\n",
    "    'P136': r'^What type of music does (.+) play\\?$',\n",
    "    'P159': r'^Where is the headquarter of (.+)\\?$',\n",
    "    'P170': r'^Who was (.+) created by\\?$',\n",
    "    'P175': r'^Who performed (.+)\\?$',\n",
    "    'P176': r'^Which company is (.+) produced by\\?$',\n",
    "    'P264': r'^What music label is (.+) represented by\\?$',\n",
    "    'P276': r'^Where is (.+) located\\?$',\n",
    "    'P407': r'^Which language was (.+) written in\\?$',\n",
    "    'P413': r'^What position does (.+) play\\?$',\n",
    "    'P495': r'^Which country was (.+) created in\\?$',\n",
    "    'P740': r'^Where was (.+) founded\\?$',\n",
    "    'P800': r'^What is (.+) famous for\\?$',\n",
    "}\n",
    "\n",
    "def extract_replaceable_part(pattern_key, sentence):\n",
    "    pattern = patterns[pattern_key]\n",
    "    match = re.match(pattern, sentence)\n",
    "    if match:\n",
    "        return match.group(1)\n",
    "    else:\n",
    "        return None\n",
    "\n",
    "# 创建图\n",
    "G1 = nx.Graph()\n",
    "\n",
    "# 读取 JSON Lines 文件并解析数据\n",
    "with open('', 'r') as f:\n",
    "    for line in f:\n",
    "        data = json.loads(line.strip())\n",
    "        question = data['messages'][1]['content']\n",
    "        answer = data['messages'][2]['content']\n",
    "    \n",
    "        # 匹配问题并添加边\n",
    "        for pattern_key in patterns.keys():\n",
    "            replaceable_part = extract_replaceable_part(pattern_key, question)\n",
    "            if replaceable_part:\n",
    "                if not G1.has_node(replaceable_part):\n",
    "                    G1.add_node(replaceable_part, label='origin')\n",
    "                if not G1.has_node(answer):\n",
    "                    G1.add_node(answer, label='origin')\n",
    "                G1.add_edge(replaceable_part, answer)\n",
    "                break\n",
    "\n",
    "G2 = nx.Graph()\n",
    "\n",
    "with open('','r') as f:\n",
    "    for line in f:\n",
    "        data = json.loads(line.strip())\n",
    "        question = data['messages'][1]['content']\n",
    "        answer = data['messages'][2]['content']\n",
    "        # 匹配问题并添加边\n",
    "        for pattern_key in patterns.keys():\n",
    "            replaceable_part = extract_replaceable_part(pattern_key, question)\n",
    "            if replaceable_part:\n",
    "                if not G2.has_node(replaceable_part):\n",
    "                    G2.add_node(replaceable_part, label='changed')\n",
    "                if not G2.has_node(answer):\n",
    "                    G2.add_node(answer, label='changed')\n",
    "                G2.add_edge(replaceable_part, answer)\n",
    "                break\n",
    "\n",
    "\n",
    "G = nx.Graph()\n",
    "\n",
    "# 合并图\n",
    "G = nx.Graph()\n",
    "for node, data in G1.nodes(data=True):\n",
    "    G.add_node(node, label=data['label'])\n",
    "for node, data in G2.nodes(data=True):\n",
    "    if not G.has_node(node):\n",
    "        G.add_node(node, label=data['label'])\n",
    "\n",
    "# 添加边\n",
    "for edge in G1.edges():\n",
    "    G.add_edge(*edge)\n",
    "for edge in G2.edges():\n",
    "    G.add_edge(*edge)\n",
    "\n",
    "    # 统计来自 G2 的节点数目\n",
    "g2_nodes_count = sum(1 for node, data in G.nodes(data=True) if data['label'] in ['changed', 'both'])\n",
    "g1_nodes_count = sum(1 for node, data in G.nodes(data=True) if data['label'] in ['origin', 'both'])\n",
    "g_nodes_count=sum(1 for node, data in G.nodes(data=True))\n",
    "g2_changed_count=sum(1 for node, data in G.nodes(data=True) if data['label'] in ['changed'])\n",
    "# 统计来自 G2 的节点中与 G1 有连边的节点数目\n",
    "g2_nodes_with_g1_edges_count = 0\n",
    "for node, data in G.nodes(data=True):\n",
    "    if data['label'] in ['changed', 'both']:\n",
    "        for neighbor in G.neighbors(node):\n",
    "            if G.nodes[neighbor]['label'] in ['origin', 'both']:\n",
    "                g2_nodes_with_g1_edges_count += 1\n",
    "                break\n",
    "\n",
    "g2_nodes_changed_with_g1_edges_count = 0\n",
    "for node, data in G.nodes(data=True):\n",
    "    if data['label'] in ['changed']:\n",
    "        for neighbor in G.neighbors(node):\n",
    "            if G.nodes[neighbor]['label'] in ['origin','both']:\n",
    "                g2_nodes_changed_with_g1_edges_count += 1\n",
    "                break\n",
    "\n",
    "print(f\"G 中点的数目: {g_nodes_count}\")\n",
    "print(f\"G 中来自 G1 的点的数目: {g1_nodes_count}\")\n",
    "print(f\"G 中来自 G2 的点的数目: {g2_nodes_count}\")\n",
    "print(f\"G 中来自 G2 的点中与 G1 有连边的点的数目: {g2_nodes_with_g1_edges_count}\")\n",
    "print(f\"新增点的数目{g2_changed_count}\")\n",
    "print(f\"新增点的数目与 G1 有连边的点的数目: {g2_nodes_changed_with_g1_edges_count}\")\n",
    "# 为不同标签的节点着色\n",
    "\n",
    "color_map = []\n",
    "for node in G:\n",
    "    if G.nodes[node]['label'] == 'origin':\n",
    "        color_map.append('red')\n",
    "    elif G.nodes[node]['label'] == 'changed':\n",
    "        color_map.append('blue')\n",
    "    elif G.nodes[node]['label'] == 'both':\n",
    "        color_map.append('purple')\n",
    "\n",
    "# 为不同标签的边着色\n",
    "'''\n",
    "# 可视化图\n",
    "plt.figure(figsize=(12, 8))\n",
    "pos = nx.spring_layout(G, seed=42)  # 使用 spring 布局\n",
    "nx.draw(G, pos, with_labels=True, node_size=300, node_color=color_map, font_size=10, font_weight=\"bold\")\n",
    "plt.title(\"Merged Graph G with Different Edge Labels\")\n",
    "plt.savefig(\"merged_graph.png\")  # 保存图形到本地文件\n",
    "plt.close()  # 关闭图形\n",
    "'''\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "度最大的十个节点:\n",
      "节点: United States of America, 度: 4436\n",
      "节点: midfielder, 度: 2195\n",
      "节点: politician, 度: 1432\n",
      "节点: actor, 度: 643\n",
      "节点: English, 度: 523\n",
      "节点: Los Angeles, 度: 423\n",
      "节点: London, 度: 407\n",
      "节点: New York City, 度: 267\n",
      "节点: Harvard University, 度: 224\n",
      "节点: jazz, 度: 223\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import re\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# 定义正则表达式模式\n",
    "patterns = {\n",
    "    'P17': r'^Who is (.+) married to\\?$',\n",
    "    'P19': r'^Where was (.+) born\\?$',\n",
    "    'P20': r'^Where did (.+) die\\?$',\n",
    "    'P26': r'^Who is (.+) married to\\?$',\n",
    "    'P36': r'^What is the capital of (.+)\\?$',\n",
    "    'P40': r'^Who is (.+)\\'s child\\?$',\n",
    "    'P50': r'^Who is the author of (.+)\\?$',\n",
    "    'P69': r'^Where was (.+) educated\\?$',\n",
    "    'P106': r'^What kind of work does (.+) do\\?$',\n",
    "    'P112': r'^Who founded (.+)\\?$',\n",
    "    'P127': r'^Who owns (.+)\\?$',\n",
    "    'P131': r'^Where is (.+) located\\?$',\n",
    "    'P136': r'^What type of music does (.+) play\\?$',\n",
    "    'P159': r'^Where is the headquarter of (.+)\\?$',\n",
    "    'P170': r'^Who was (.+) created by\\?$',\n",
    "    'P175': r'^Who performed (.+)\\?$',\n",
    "    'P176': r'^Which company is (.+) produced by\\?$',\n",
    "    'P264': r'^What music label is (.+) represented by\\?$',\n",
    "    'P276': r'^Where is (.+) located\\?$',\n",
    "    'P407': r'^Which language was (.+) written in\\?$',\n",
    "    'P413': r'^What position does (.+) play\\?$',\n",
    "    'P495': r'^Which country was (.+) created in\\?$',\n",
    "    'P740': r'^Where was (.+) founded\\?$',\n",
    "    'P800': r'^What is (.+) famous for\\?$',\n",
    "}\n",
    "\n",
    "def extract_replaceable_part(pattern_key, sentence):\n",
    "    pattern = patterns[pattern_key]\n",
    "    match = re.match(pattern, sentence)\n",
    "    if match:\n",
    "        return match.group(1)\n",
    "    else:\n",
    "        return None\n",
    "\n",
    "# 创建二部图\n",
    "B = nx.Graph()\n",
    "\n",
    "# 读取 JSON Lines 文件并解析数据\n",
    "def process_file(filename, label, bi=0):\n",
    "    with open(filename, 'r') as f:\n",
    "        for line in f:\n",
    "            data = json.loads(line.strip())\n",
    "            question = data['messages'][1]['content']\n",
    "            answer = data['messages'][2]['content']\n",
    "        \n",
    "            for pattern_key in patterns.keys():\n",
    "                replaceable_part = extract_replaceable_part(pattern_key, question)\n",
    "                if replaceable_part:\n",
    "                    if not B.has_node(replaceable_part):\n",
    "                        B.add_node(replaceable_part, bipartite=bi, label=label)  # 第一组节点\n",
    "                    if not B.has_node(answer):\n",
    "                        B.add_node(answer, bipartite=bi, label=label)  # 第二组节点\n",
    "                    B.add_edge(replaceable_part, answer)\n",
    "                    break\n",
    "\n",
    "# 处理两个文件\n",
    "process_file('', 'origin', 0)\n",
    "process_file('', 'changed', 1)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# 为不同标签的节点着色和大小\n",
    "color_map = []\n",
    "fixed_node_size = 50  # 固定节点大小\n",
    "for node in B:\n",
    "    if B.nodes[node]['label'] == 'origin':\n",
    "        color_map.append('red')\n",
    "    elif B.nodes[node]['label'] == 'changed':\n",
    "        color_map.append('blue')\n",
    "\n",
    "# 使用 bipartite_layout 布局算法\n",
    "top_nodes = {n for n, d in B.nodes(data=True) if d['bipartite'] == 1}\n",
    "pos = nx.bipartite_layout(B, top_nodes)\n",
    "\n",
    "# 绘制图形\n",
    "plt.figure(figsize=(10, 10))  # 调整图形大小\n",
    "nx.draw(\n",
    "    B, pos, \n",
    "    with_labels=False, \n",
    "    node_size=fixed_node_size,  # 固定节点大小\n",
    "    node_color=color_map, \n",
    "    width=0.2, \n",
    "    edge_color='lightgray',  # 边的颜色设置为浅灰色\n",
    "    alpha=0.5  # 增加透明度\n",
    ")\n",
    "\n",
    "# 添加图例\n",
    "from matplotlib.lines import Line2D\n",
    "\n",
    "legend_elements = [\n",
    "    Line2D([0], [0], marker='o', color='w', label='Origin', markersize=20, markerfacecolor='red'),\n",
    "    Line2D([0], [0], marker='o', color='w', label='Changed', markersize=20, markerfacecolor='blue'),\n",
    "]\n",
    "plt.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(0.05, 1), fontsize=14)  # 将图例移动到左上角\n",
    "\n",
    "plt.title(\"Bipartite Graph of Questions and Answers\", fontsize=16)\n",
    "\n",
    "# 调整边距，以确保图例和图形都在页面范围内\n",
    "plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)\n",
    "\n",
    "# 保存为PDF格式的矢量图\n",
    "plt.savefig(\"bipartite_graph.pdf\", format='pdf', bbox_inches='tight')  # 使用bbox_inches='tight'来自动调整边距\n",
    "plt.close()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "qwen-sft",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
