{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import jsonlines\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import copy\n",
    "from tqdm import tqdm\n",
    "from collections import defaultdict\n",
    "import requests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id_map = json.load(open(\"./handled/id_map.json\"))\n",
    "item_dict = json.load(open(\"./handled/item2attributes.json\", \"r\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_dataset():\n",
    "        '''Load train, validation, test dataset'''\n",
    "\n",
    "        usernum = 0\n",
    "        itemnum = 0\n",
    "        User = defaultdict(list)    # default value is a blank list\n",
    "        user_train = {}\n",
    "        user_valid = {}\n",
    "        user_test = {}\n",
    "        # assume user/item index starting from 1\n",
    "        f = open('./handled/inter.txt', 'r')\n",
    "        for line in f:  # use a dict to save all seqeuces of each user\n",
    "            u, i = line.rstrip().split(' ')\n",
    "            u = int(u)\n",
    "            i = int(i)\n",
    "            usernum = max(u, usernum)\n",
    "            itemnum = max(i, itemnum)\n",
    "            User[u].append(i)\n",
    "\n",
    "        for user in tqdm(User):\n",
    "            nfeedback = len(User[user])\n",
    "            #nfeedback = len(User[user])\n",
    "            if nfeedback < 3:\n",
    "                user_train[user] = User[user]\n",
    "                user_valid[user] = []\n",
    "                user_test[user] = []\n",
    "            else:\n",
    "                user_train[user] = User[user][:-2]\n",
    "                user_valid[user] = []\n",
    "                user_valid[user].append(User[user][-2])\n",
    "                user_test[user] = []\n",
    "                user_test[user].append(User[user][-1])\n",
    "        \n",
    "        return user_train\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inter = load_dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_template = \"The user has visited following fashions: \\n<HISTORY> \\nplease conclude the user's perference.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_user_prompt(history):\n",
    "\n",
    "    user_str = copy.deepcopy(prompt_template)\n",
    "    hist_str = \"\"\n",
    "    for item in history:\n",
    "        try:    # some item does not have title\n",
    "            item_str = item_dict[id_map[\"id2item\"][str(item)]][\"title\"]\n",
    "            hist_str = hist_str + item_str + \", \"\n",
    "        except:\n",
    "            continue\n",
    "\n",
    "    # limit the prompt length\n",
    "    if len(hist_str) > 8000:\n",
    "        hist_str = hist_str[-8000:]\n",
    "\n",
    "    user_str = user_str.replace(\"<HISTORY>\", hist_str)\n",
    "\n",
    "    return user_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "user_data = {}\n",
    "\n",
    "for user, history in tqdm(inter.items()):\n",
    "    user_data[user] = get_user_prompt(history)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# json.dump(user_data, open(\"./handled/user_str.json\", \"w\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "user_data = json.load(open(\"./handled/user_str.json\", \"r\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"\"\n",
    "\n",
    "payload = json.dumps({\n",
    "   \"model\": \"text-embedding-ada-002\",\n",
    "   \"input\": \"The food was delicious and the waiter...\"\n",
    "})\n",
    "headers = {\n",
    "   'Authorization': '',\n",
    "   'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',\n",
    "   'Content-Type': 'application/json'\n",
    "}\n",
    "\n",
    "response = requests.request(\"POST\", url, headers=headers, data=payload)\n",
    "\n",
    "print(response.text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_response(prompt):\n",
    "    url = \"\"\n",
    "\n",
    "    payload = json.dumps({\n",
    "    \"model\": \"text-embedding-ada-002\",\n",
    "    \"input\": prompt\n",
    "    })\n",
    "    headers = {\n",
    "    'Authorization': '',\n",
    "    'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',\n",
    "    'Content-Type': 'application/json'\n",
    "    }\n",
    "\n",
    "    response = requests.request(\"POST\", url, headers=headers, data=payload)\n",
    "    re_json = json.loads(response.text)\n",
    "\n",
    "    return re_json[\"data\"][0][\"embedding\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "user_emb = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.exists(\"./handled/user_emb.pkl\"):    # check whether some item emb exist in cache\n",
    "    user_emb = pickle.load(open(\"./handled/user_emb.pkl\", \"rb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(user_emb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count = 1\n",
    "while 1:    # avoid broken due to internet connection\n",
    "    if len(user_emb) == len(user_data):\n",
    "        break\n",
    "    try:\n",
    "        for key, value in tqdm(user_data.items()):\n",
    "            if key not in user_emb.keys():\n",
    "                if len(value) > 4096:\n",
    "                    value = value[:4095]\n",
    "                user_emb[key] = get_response(value)\n",
    "                count += 1\n",
    "    except:\n",
    "        pickle.dump(user_emb, open(\"./handled/user_emb.pkl\", \"wb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(user_emb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "emb_list = []\n",
    "for key, value in tqdm(user_emb.items()):\n",
    "    emb_list.append(value)\n",
    "\n",
    "emb_list = np.array(emb_list)\n",
    "pickle.dump(emb_list, open(\"./handled/usr_emb_np.pkl\", \"wb\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
