{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import random\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "random.seed(1234)\n",
    "import sys\n",
    "import time\n",
    "rootpath=\"./\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from einops import rearrange, repeat\n",
    "import random"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Create Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read\n",
    "print(\"Loading big matrix...\")\n",
    "big_matrix = pd.read_csv(rootpath + \"KuaiRec 2.0/data/big_matrix.csv\")\n",
    "print(\"Loading small matrix...\")\n",
    "small_matrix = pd.read_csv(rootpath + \"KuaiRec 2.0/data/small_matrix.csv\")\n",
    "print(\"Loading item features...\")\n",
    "item_feat = pd.read_csv(rootpath + \"KuaiRec 2.0/data/item_categories.csv\")\n",
    "item_feat[\"feat\"] = item_feat[\"feat\"].map(eval)\n",
    "print(\"All data loaded.\")\n",
    "sys.stdout.flush()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#user: 7176\n",
      "#item: 10728\n",
      "#cate: 31\n"
     ]
    }
   ],
   "source": [
    "# count number\n",
    "user_count = big_matrix['user_id'].nunique()\n",
    "item_count = item_feat['video_id'].nunique()\n",
    "temp_item = pd.DataFrame()\n",
    "temp_item['feat'] = item_feat['feat'].map(lambda x:x[-1])\n",
    "cate_count = temp_item['feat'].nunique()\n",
    "print('#user:',user_count)\n",
    "print('#item:',item_count)\n",
    "print('#cate:',cate_count)\n",
    "sys.stdout.flush()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get need data\n",
    "def getRatingInFo(df,threshold):\n",
    "    user_video = df[['user_id','video_id','watch_ratio','timestamp']]\n",
    "    user_video.loc[user_video['watch_ratio']<threshold,['reward']] = 0\n",
    "    user_video.loc[user_video['watch_ratio']>=threshold,['reward']] = 1\n",
    "    user_video = user_video[['user_id','video_id','reward','timestamp']]\n",
    "    return user_video\n",
    "threshold = 0.7\n",
    "big_train = getRatingInFo(big_matrix,threshold)\n",
    "small_test = getRatingInFo(small_matrix,threshold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "start build train dataset\n"
     ]
    }
   ],
   "source": [
    "# build dataset\n",
    "start_time = time.time()\n",
    "print('start build train dataset')\n",
    "sys.stdout.flush()\n",
    "user_in_small = small_test['user_id'].value_counts().index.tolist()\n",
    "user_in_small.sort()\n",
    "num_user_in_small = len(small_test['user_id'].value_counts())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "end build train dataset,cost 182.0302233695984\n"
     ]
    }
   ],
   "source": [
    "all_old_hist = {}\n",
    "train_set = []\n",
    "for user,hist in big_train.groupby('user_id'):\n",
    "    pos_list = hist.loc[hist['reward'] > 0, 'video_id'].tolist()\n",
    "    click_list = hist['video_id'].tolist()\n",
    "    time_list = hist['timestamp'].tolist()\n",
    "    rating_list = hist['reward'].tolist()\n",
    "    item_in_small_list = list(small_test.loc[small_test['user_id']==user]['video_id'])\n",
    "    def gen_neg():\n",
    "        neg = pos_list[0]\n",
    "        while neg in pos_list or neg in item_in_small_list:\n",
    "            neg = random.randint(0, item_count-1)\n",
    "        return neg\n",
    "    neg_list = [gen_neg() for i in range(len(click_list))]\n",
    "\n",
    "    hist_l = []\n",
    "    pos_index = 0\n",
    "    train_set.append((user, [], click_list[0], rating_list[0], 1,time_list[0]))\n",
    "    train_set.append((user, [], neg_list[0], 0, 0,time_list[0]))\n",
    "    for i in range(1,len(click_list)):\n",
    "        this_time_click_list = click_list[:i]\n",
    "        if pos_index< len(pos_list) and this_time_click_list[-1] == pos_list[pos_index]:\n",
    "            pos_index += 1\n",
    "            hist_l.append(this_time_click_list[-1])\n",
    "        temp_hist = [item for item in hist_l]\n",
    "        if len(temp_hist)>20:\n",
    "            temp_hist = temp_hist[-20:]\n",
    "        train_set.append((user, temp_hist, click_list[i], rating_list[i], 1,time_list[i]))\n",
    "        train_set.append((user, temp_hist, neg_list[i], 0, 0,time_list[i]))\n",
    "    all_old_hist[user]=temp_hist\n",
    "\n",
    "print('end build train dataset,cost {}'.format(time.time()-start_time))\n",
    "sys.stdout.flush()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "start build test dataset\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "end build test dataset,cost 3.25437331199646\n"
     ]
    }
   ],
   "source": [
    "start_time = time.time()\n",
    "print('start build test dataset')\n",
    "sys.stdout.flush()\n",
    "num_of_valid_set = 200\n",
    "userin_small_id = small_test['user_id'].value_counts().keys().tolist()\n",
    "valid_index = random.sample(userin_small_id,num_of_valid_set)\n",
    "valid_index.sort()\n",
    "\n",
    "mask = []\n",
    "item_in_small = small_matrix.loc[small_matrix['user_id']==4681,'video_id'].tolist()\n",
    "mask_in_small = []\n",
    "for i in range(item_count):\n",
    "    if i not in item_in_small:\n",
    "        mask_in_small.append(i)\n",
    "for u in range(user_count):\n",
    "    if u not in userin_small_id:\n",
    "        mask.append(range(item_count))\n",
    "    else:\n",
    "        mask.append(mask_in_small)\n",
    "\n",
    "\n",
    "index = 0\n",
    "valid_set = []\n",
    "test_set = []\n",
    "for user, hist in small_test.groupby('user_id'):\n",
    "    old_hist = all_old_hist[user]\n",
    "    click_item = hist['video_id'].tolist()\n",
    "    pos_item = hist['video_id'].loc[hist['reward']==1].tolist()\n",
    "    if index<num_of_valid_set and user==valid_index[index]:\n",
    "        valid_set.append([user,old_hist,click_item,pos_item])\n",
    "        index+=1\n",
    "    else:\n",
    "        test_set.append([user,old_hist,click_item,pos_item])\n",
    "print('end build test dataset,cost {}'.format(time.time()-start_time))\n",
    "sys.stdout.flush()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def MultiCate2(l,num_cate):\n",
    "    newl = [0 for i in range(num_cate)]\n",
    "    for i in l:\n",
    "        newl[i] = 1/len(l)\n",
    "    return newl\n",
    "item_feat['feat'] = item_feat[\"feat\"].map(lambda x:MultiCate2(x,cate_count))\n",
    "cate_list = [item_feat['feat'][item] for item in item_feat['video_id']]\n",
    "\n",
    "interaction = big_matrix['video_id'].value_counts()\n",
    "all_item = item_feat.values  # 没什么用，后面就用到了他的长度，即cate_count\n",
    "with open('../data/Kuai_dataset.pkl', 'wb') as f:\n",
    "  pickle.dump(train_set, f, pickle.HIGHEST_PROTOCOL)\n",
    "  pickle.dump(valid_set, f, pickle.HIGHEST_PROTOCOL)\n",
    "  pickle.dump(test_set, f, pickle.HIGHEST_PROTOCOL)\n",
    "  pickle.dump(cate_list, f, pickle.HIGHEST_PROTOCOL)\n",
    "  pickle.dump((user_count, item_count, cate_count), f, pickle.HIGHEST_PROTOCOL)\n",
    "  pickle.dump(interaction, f, pickle.HIGHEST_PROTOCOL)\n",
    "  pickle.dump(mask, f, pickle.HIGHEST_PROTOCOL)\n",
    "  pickle.dump(all_item, f, pickle.HIGHEST_PROTOCOL)\n",
    "\n",
    "with open('../data/Kuai_dataset_info.txt', 'w') as f:\n",
    "  f.write(\"num_user: \"+str(user_count)+'\\n')\n",
    "  f.write(\"num_item: \" + str(item_count)+'\\n')\n",
    "  f.write(\"num_train_set: \"+str(len(train_set))+'\\n')\n",
    "  f.write(\"num_valid_set: \" + str(len(valid_set)) + '\\n')\n",
    "  f.write(\"num_test_set: \" + str(len(test_set))+'\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nlp_env",
   "language": "python",
   "name": "nlp_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
