{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-05-17T11:24:33.590565700Z",
     "start_time": "2024-05-17T11:24:29.372008500Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "KeyboardInterrupt\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "nx_graph_loader.py\n",
    "\"\"\"\n",
    "import json\n",
    "import pickle\n",
    "from collections import defaultdict, Counter\n",
    "\n",
    "import networkx as nx\n",
    "from tqdm import tqdm\n",
    "\n",
    "# 初始化存储结构\n",
    "reviews_by_user = defaultdict(list)\n",
    "genres_by_book = {}\n",
    "authors_by_book = {}\n",
    "descriptions_by_book = {}\n",
    "\n",
    "# 先加载所有评论\n",
    "review_info_path = 'children_genre/raw/goodreads_reviews_children.json'\n",
    "with open(review_info_path, 'r') as f:\n",
    "    for line in f:\n",
    "        review = json.loads(line)\n",
    "        user_id = review['user_id']\n",
    "        reviews_by_user[user_id].append(review)\n",
    "\n",
    "# 然后确定前十个不同的user_id(scale up 时删除即可）\n",
    "user_counter = Counter([review['user_id'] for reviews in reviews_by_user.values() for review in reviews])\n",
    "\n",
    "# 处理题材数据\n",
    "genre_info_path = 'children_genre/raw/goodreads_book_genres_initial.json'\n",
    "with open(genre_info_path, 'r') as f:\n",
    "    for line in f:\n",
    "        genre_info = json.loads(line)\n",
    "        book_id = genre_info['book_id']\n",
    "        genres = genre_info['genres']\n",
    "        main_genre = max(genres, key=genres.get) if genres else None\n",
    "        if main_genre:\n",
    "            genres_by_book[book_id] = main_genre\n",
    "\n",
    "# 加载书籍信息并为选中的评论添加作者信息和描述信息\n",
    "book_info_path = 'children_genre/raw/goodreads_books_children.json'\n",
    "with open(book_info_path, 'r') as f:\n",
    "    for line in f:\n",
    "        book = json.loads(line)\n",
    "        book_id = book['book_id']\n",
    "        authors_by_book[book_id] = [author['author_id'] for author in book.get('authors', [])]\n",
    "        descriptions_by_book[book_id] = book.get('description', '')\n",
    "\n",
    "# 准备最终数据集，只包括前十个user的评论\n",
    "final_data = []\n",
    "for user_id in list(user_counter.keys()):\n",
    "    for review in reviews_by_user[user_id]:\n",
    "        book_id = review['book_id']\n",
    "        record = {\n",
    "            'user_id': user_id,\n",
    "            'book_id': book_id,\n",
    "            'review_text': review['review_text'],\n",
    "            'genre': genres_by_book.get(book_id, None),  # 添加题材信息\n",
    "            'description': descriptions_by_book.get(book_id, '')  # 添加书籍描述信息\n",
    "        }\n",
    "        final_data.append(record)\n",
    "\n",
    "userbook2review = {}\n",
    "bookgenre2review = {}\n",
    "for item in final_data:\n",
    "    user_id = item['user_id']\n",
    "    book_id = item['book_id']\n",
    "    genre = item['genre']\n",
    "    userbook2review[user_id + '|' + book_id] = item['review_text']\n",
    "    bookgenre2review[book_id + '|' + genre] = item['description']\n",
    "\n",
    "G = nx.Graph()\n",
    "genres = {'history, historical fiction, biography': 0,\n",
    "          'children': 1,\n",
    "          'romance': 2,\n",
    "          'comics, graphic': 3,\n",
    "          'non-fiction': 4,\n",
    "          'mystery, thriller, crime': 5,\n",
    "          'poetry': 6,\n",
    "          'young-adult': 7,\n",
    "          'fiction': 8,\n",
    "          'fantasy, paranormal': 9,\n",
    "          'None': 10}\n",
    "user_id2idx = {}\n",
    "book_id2idx = {}\n",
    "\n",
    "# 添加节点和边\n",
    "for item in tqdm(final_data):\n",
    "    user_id = item['user_id']\n",
    "    book_id = item['book_id']\n",
    "    genre = item['genre']\n",
    "\n",
    "    if user_id not in user_id2idx:\n",
    "        user_id2idx[user_id] = len(user_id2idx)\n",
    "    if book_id not in book_id2idx:\n",
    "        book_id2idx[book_id] = len(book_id2idx)\n",
    "\n",
    "    user_id = \"user_\" + str(user_id2idx[user_id])\n",
    "    book_id = \"book_\" + str(book_id2idx[book_id])\n",
    "    genre = \"genre_\" + str(genres[genre])\n",
    "\n",
    "    # TODO: You should add nodes texts in final_data in advance and then add them to the graph\n",
    "    G.add_node(user_id, type='user', color='blue', label=user_id, text=f'This is user {user_id}')\n",
    "    G.add_node(book_id, type='book', color='red', label=book_id, text=f'This is book {book_id}')\n",
    "    G.add_node(genre, type='genre', color='yellow', label=genre, text=f'This is genre {genre}')\n",
    "\n",
    "    G.add_edge(user_id, book_id, color='magenta', text=userbook2review[item['user_id'] + '|' + item['book_id']])\n",
    "    G.add_edge(book_id, genre, color='black', text=bookgenre2review[item['book_id'] + '|' + item['genre']])\n",
    "\n",
    "pickle_file_path = \"children_genre/nx_graph.pkl\"\n",
    "with open(pickle_file_path, 'wb') as f:\n",
    "    pickle.dump(G, f)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "text_graph.py\n",
    "\"\"\"\n",
    "def get_graph_prompt(user, book, G, start, path_length, k_hop, sample_edge_ratio=1):\n",
    "    if start == 'user':\n",
    "        source = user\n",
    "        target = book\n",
    "    elif start == 'book':\n",
    "        source = book\n",
    "        target = user\n",
    "\n",
    "    nodes = []\n",
    "    edges = []\n",
    "\n",
    "    for path in nx.all_simple_paths(G, source=source, target=target, cutoff=path_length):\n",
    "        if len(path) == 2:\n",
    "            continue\n",
    "        # 添加路径中的节点\n",
    "        nodes.extend(path)\n",
    "        # 添加路径中的边\n",
    "        edges.extend([(path[i], path[i + 1], {'text': G.get_edge_data(path[i], path[i + 1])['text']}) for i in\n",
    "                      range(len(path) - 1)])\n",
    "\n",
    "    if len(edges) == 0:\n",
    "        neighbors = list(G.neighbors(source))\n",
    "        paths_graph = G.subgraph(neighbors + [source])\n",
    "        if paths_graph.number_of_edges() <= 1:\n",
    "            return f'[ROOT] {source}.'\n",
    "    else:\n",
    "        paths_graph = nx.Graph()\n",
    "        paths_graph.add_nodes_from(nodes)\n",
    "        num_sampled_edges = math.ceil(len(edges) * sample_edge_ratio)\n",
    "        sampled_edges = random.sample(edges, num_sampled_edges)\n",
    "        paths_graph.add_edges_from(sampled_edges)\n",
    "\n",
    "    data = torch_geometric.utils.from_networkx(paths_graph)\n",
    "    # TODO: add node texts to the graph\n",
    "    node_texts = [train_graph.nodes[node]['text'] for node in paths_graph.nodes()]\n",
    "    node_mapping = dict(zip(paths_graph.nodes(), range(paths_graph.number_of_nodes())))\n",
    "    mapping = dict(zip(node_texts, range(paths_graph.number_of_nodes())))\n",
    "    \n",
    "    print(mapping)\n",
    "\n",
    "    # node df\n",
    "    node_df = pd.DataFrame(list(mapping.items()), columns=['node_attr', 'node_id'])\n",
    "    node_df = node_df[['node_id', 'node_attr']]\n",
    "\n",
    "    # edge df\n",
    "    src = data.edge_index[0].tolist()\n",
    "    dst = data.edge_index[1].tolist()\n",
    "    edge_attr = data.text\n",
    "    edge_df = pd.DataFrame({'src': src, 'edge_attr': edge_attr, 'dst': dst})\n",
    "    \n",
    "    # TODO: use node mapping\n",
    "    start = node_mapping[source]\n",
    "    prompt, _ = hard_prompt(data, node_df, edge_df, start, k_hop)\n",
    "\n",
    "    return prompt\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-17T11:51:26.002795Z",
     "start_time": "2024-05-17T11:51:25.978867900Z"
    }
   },
   "id": "f47a608d58a8c2f3"
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "def process_edge(supervision_edge, train_graph):\n",
    "    user = 'user_' + str(supervision_edge[0].item())\n",
    "    book = 'book_' + str(supervision_edge[1].item())\n",
    "\n",
    "    print(user, book)\n",
    "\n",
    "    user_id = user.lstrip('user_')\n",
    "    book_id = book.lstrip('book_')\n",
    "    key = user_id + '|' + book_id\n",
    "    value = get_graph_prompt(user, book, train_graph, 'user', 3, 2)\n",
    "\n",
    "    return key, value"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-17T11:43:52.277306300Z",
     "start_time": "2024-05-17T11:43:52.266555100Z"
    }
   },
   "id": "64cd69eef12337ff"
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "document_embedding.py\n",
    "\"\"\"\n",
    "import pickle\n",
    "from multiprocessing import Pool\n",
    "\n",
    "import networkx as nx\n",
    "import torch\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric import seed_everything\n",
    "from torch_geometric.utils import degree\n",
    "from tqdm import tqdm\n",
    "\n",
    "from children_genre.goodreads_children_genre import Goodreads_children_genre\n",
    "# from text_graph import get_graph_prompt\n",
    "\n",
    "prefix = \"As an AI language model, we want to predict whether there will be a link between [node_a] and [node b]. We provide the neighborhood information of [node_a] and [node_b] in the following two paragraphs.\"\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "seed_everything(66)\n",
    "Dataset = Goodreads_children_genre(root='.')\n",
    "data = Dataset[0]\n",
    "\n",
    "num_users = data['user'].num_nodes\n",
    "num_books = data['book'].num_nodes\n",
    "num_reviews = data['user', 'review', 'book'].num_edges\n",
    "num_descriptions = data['book', 'description', 'genre'].num_edges\n",
    "\n",
    "data['user', 'review', 'book'].edge_attr = torch.ones(num_reviews, 64)  # TODO\n",
    "data['book', 'description', 'genre'].edge_attr = torch.ones(num_descriptions, 64)  # TODO\n",
    "\n",
    "# select 4-star or 5-star review as positive edge\n",
    "positive_edges_mask = (data['user', 'review', 'book'].edge_label == 5) | (\n",
    "        data['user', 'review', 'book'].edge_label == 4)\n",
    "data['user', 'review', 'book'].edge_index = data['user', 'review', 'book'].edge_index[:, positive_edges_mask]\n",
    "data['user', 'review', 'book'].edge_attr = data['user', 'review', 'book'].edge_attr[positive_edges_mask]\n",
    "\n",
    "# Add a reverse ('book', 'rev_review', 'user') relation for message passing:\n",
    "data = T.ToUndirected()(data)\n",
    "del data['book', 'rev_review', 'user'].edge_label  # Remove \"reverse\" label.\n",
    "del data['user', 'review', 'book'].edge_label  # Remove \"reverse\" label.\n",
    "\n",
    "# Perform a link-level split into training, validation, and test edges:\n",
    "train_data, val_data, test_data = T.RandomLinkSplit(\n",
    "    num_val=0.85,\n",
    "    num_test=0.1,\n",
    "    disjoint_train_ratio=0.3,\n",
    "    neg_sampling_ratio=2.0,\n",
    "    edge_types=[('user', 'review', 'book')],\n",
    "    rev_edge_types=[('book', 'rev_review', 'user')],\n",
    ")(data)\n",
    "\n",
    "# TODO: use val_data as training data\n",
    "data = train_data\n",
    "print(data)\n",
    "# assert len(data['user', 'review', 'book'].edge_label_index[0]) == 18795\n",
    "\n",
    "# message passing edges + positive supervision edges\n",
    "review_message_edge_index = data['user', 'review', 'book'].edge_index\n",
    "edge_mask = data['user', 'review', 'book'].edge_label.long() == 1\n",
    "review_supervision_edge_index = data['user', 'review', 'book'].edge_label_index[:, edge_mask]  # only label == 1\n",
    "review_edge_index = torch.concat((review_message_edge_index, review_supervision_edge_index), dim=1)\n",
    "\n",
    "review_edge_index = review_edge_index.tolist()\n",
    "user_edges = ['user_' + str(idx) for idx in review_edge_index[0]]\n",
    "book_edges = ['book_' + str(idx) for idx in review_edge_index[1]]\n",
    "review_edge_list = list(zip(user_edges, book_edges))\n",
    "\n",
    "description_edge_list = data['book', 'description', 'genre'].edge_index.tolist()\n",
    "book_edges = ['book_' + str(idx) for idx in description_edge_list[0]]\n",
    "genre_edges = ['genre_' + str(idx) for idx in description_edge_list[1]]\n",
    "description_edge_list = list(zip(book_edges, genre_edges))\n",
    "\n",
    "edge_list = review_edge_list + description_edge_list\n",
    "\n",
    "# negative supervision edges\n",
    "negative_edge_mask = data['user', 'review', 'book'].edge_label.long() == 0\n",
    "negative_edge_index = data['user', 'review', 'book'].edge_label_index[:, negative_edge_mask]  # only label == 0\n",
    "\n",
    "with open('children_genre/raw/nx_graph.pkl', 'rb') as f:\n",
    "    G = pickle.load(f)\n",
    "\n",
    "edge_texts = {edge: G[edge[0]][edge[1]]['text'] for edge in edge_list}\n",
    "\n",
    "document = {}\n",
    "\n",
    "# message passing edges + positive supervision edges as train_graph edges\n",
    "train_graph = nx.Graph()\n",
    "for edge, text in tqdm(edge_texts.items()):\n",
    "    train_graph.add_edge(*edge, text=text)\n",
    "    print(G.nodes[edge[0]]['text'])\n",
    "    train_graph.add_node(edge[0], text=G.nodes[edge[0]]['text'])  # TODO: add node text to the graph\n",
    "    train_graph.add_node(edge[1], text=G.nodes[edge[1]]['text'])  # TODO: add node text to the graph\n",
    "\n",
    "# negative edge as train_graph nodes\n",
    "for edge in negative_edge_index.numpy().T:\n",
    "    user_node = 'user_' + str(edge[0])\n",
    "    book_node = 'book_' + str(edge[1])\n",
    "    train_graph.add_node(user_node)\n",
    "    train_graph.add_node(book_node)\n",
    "\n",
    "NP_review_supervision_edge_index = data['user', 'review', 'book'].edge_label_index\n",
    "\n",
    "# \n",
    "# with Pool() as p:\n",
    "#     args = [(edge, train_graph) for edge in NP_review_supervision_edge_index.T]\n",
    "#     results = list(tqdm(p.starmap(process_edge, args), total=len(args)))\n",
    "# \n",
    "# document = dict(results)\n",
    "# \n",
    "# with open('0.15_train_text', 'wb') as f:\n",
    "#     pickle.dump(document, f)\n",
    "\n",
    "print(len(document))\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "3241bca6bdf1b7f"
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "user_18476 book_29378\n",
      "{'This is user user_18476': 0, 'This is book book_3768': 1, 'This is genre genre_1': 2, 'This is book book_29378': 3}\n"
     ]
    },
    {
     "data": {
      "text/plain": "('18476|29378',\n \"[ROOT] This is user user_18476.\\nThis is user user_18476 is connected to: \\n1 This is book book_3768 via Molly's Surprise: A Christmas Story is the third installment of her American Girl series. I found this one to be my favorite involving Molly. Its so sweet and cute. Christmas time has come and in America during that time Christmas was rough on families. Molly's family is greatly impacted by the fact that her father is still over with the war effort and will not be able to return home for the holidays. Eventually Molly and her siblings work through their sorrows and pull off a great Christmas for their family. Molly has a surprise up her sleeve as well. It was nice to read such a short, sweet story..\\nThis is book book_3768 is connected to: This is user user_18476 via Molly's Surprise: A Christmas Story is the third installment of her American Girl series. I found this one to be my favorite involving Molly. Its so sweet and cute. Christmas time has come and in America during that time Christmas was rough on families. Molly's family is greatly impacted by the fact that her father is still over with the war effort and will not be able to return home for the holidays. Eventually Molly and her siblings work through their sorrows and pull off a great Christmas for their family. Molly has a surprise up her sleeve as well. It was nice to read such a short, sweet story.;\\nThis is book book_3768 is connected to:\\n  1.1 This is genre genre_1 via Molly is a lively, lovable schemer and dreamer growing up in 1944. Her stories describe her life on the home front during World War Two. Molly doesn't like many of the changes the war has brought, and she especially misses her father, who is away caring for wounded soldiers. But Molly learns the importance of getting along and pulling together -- just as her country has to do to win the war As the McIntires face a cheerless holiday, Molly decides to make some merriment of her own -- complete with unexpected surprises..\")"
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "process_edge(NP_review_supervision_edge_index.T[1], train_graph)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-17T11:51:29.485627500Z",
     "start_time": "2024-05-17T11:51:29.440628200Z"
    }
   },
   "id": "3fcac05fc1973d5d"
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "data": {
      "text/plain": "<networkx.classes.graph.Graph at 0x2115a636df0>"
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-17T11:44:00.153811700Z",
     "start_time": "2024-05-17T11:43:59.865447400Z"
    }
   },
   "id": "2f7c855396f8953d"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
