{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#!pip install openai==1.29.0\n",
    "import time\n",
    "from tokencost import calculate_prompt_cost, calculate_completion_cost\n",
    "from constants import *\n",
    "from utils import *\n",
    "from tqdm.auto import tqdm\n",
    "import random\n",
    "import json\n",
    "# test constant import\n",
    "\n",
    "from openai import AzureOpenAI\n",
    "MODEL = \"gpt-4-turbo-2024-04-09\"\n",
    "REGION = \"eastus2\"\n",
    "API_KEY = \"YOUR_API_KEY\"\n",
    "API_BASE = \"https://api.openai.com\"\n",
    "ENDPOINT = f\"{API_BASE}/{REGION}\"\n",
    "\n",
    "client = AzureOpenAI(\n",
    "    api_key=API_KEY,\n",
    "    api_version=\"2024-02-01\",\n",
    "    azure_endpoint=ENDPOINT,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def chat(client, model, content, messages, role=\"user\"):\n",
    "    messages.append({\"role\": role, \"content\": content})\n",
    "    prompt_cost = calculate_prompt_cost(messages, \"gpt-4-turbo\")\n",
    "    completion = client.chat.completions.create(\n",
    "        model=model,\n",
    "        max_tokens=30,\n",
    "        messages=messages\n",
    "    )\n",
    "    chat_response = completion\n",
    "    answer = chat_response.choices[0].message.content\n",
    "    completion_cost = calculate_completion_cost(answer, \"gpt-4-turbo\")\n",
    "    print(f'ChatGPT: {answer}')\n",
    "    messages.append({\"role\": \"assistant\", \"content\": answer})\n",
    "    return answer, messages, prompt_cost + completion_cost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = \"data/\"\n",
    "train_file = f\"{data_dir}train.json\"\n",
    "\n",
    "total_cost = 0\n",
    "false = 0\n",
    "\n",
    "\n",
    "with open(train_file, \"r\") as f:\n",
    "    lines = f.readlines()\n",
    "    test_files = [json.loads(line) for line in lines]\n",
    "    # shuffle\n",
    "    random.shuffle(test_files)\n",
    "\n",
    "    save_dict = {}\n",
    "    scene_list = {}\n",
    "\n",
    "    for n, test_file in enumerate(tqdm(test_files[:200])):\n",
    "\n",
    "        with open(test_file[\"text\"], 'r') as f:\n",
    "            text_log = f.readlines()\n",
    "        \n",
    "        scene = test_file[\"scene\"]\n",
    "\n",
    "        with open(test_files[n+1][\"text\"], 'r') as f:\n",
    "            text_log_1 = f.readlines()\n",
    "        preference_1 = test_files[n+1][\"preference\"]\n",
    "        \n",
    "        with open(test_files[n+2][\"text\"], 'r') as f:\n",
    "            text_log_2 = f.readlines()\n",
    "        preference_2 = test_files[n+2][\"preference\"]\n",
    "            \n",
    "        text_log_3, preference_3 = get_same_demo_text(test_file[\"preference\"], test_files)\n",
    "        # text_log_3, preference_3 = text_log, test_file[\"preference\"]\n",
    "        \n",
    "        # shuffle the text logs and corresponding preferences\n",
    "        text_logs = [[text_log_1, preference_1], [text_log_2, preference_2], [text_log_3, preference_3]]\n",
    "        random.shuffle(text_logs)\n",
    "        text_log_1, preference_1 = text_logs[0]\n",
    "        text_log_2, preference_2 = text_logs[1]\n",
    "        text_log_3, preference_3 = text_logs[2]\n",
    "        \n",
    "        messages = []\n",
    "\n",
    "        instructions = \"You are a robot assistant that can help summarize the host's preference.\"\n",
    "        possible_preferences = Rearrangement[0]['Level0'] + Rearrangement[2]['Level2']\n",
    "        # possible_preferences = Sequence_Preferences['name']\n",
    "\n",
    "        instructions += f\"Choose from following preference: \\n{parse_concat(possible_preferences, replace=', ')}.\\n\"\n",
    "        instructions += f\"Text log file: \\n {parse_concat(text_log_1, replace=', ')}.\\n\"\n",
    "        instructions += f\"Preference: {preference_1}.\\n\"\n",
    "        instructions += f\"Text log file: \\n {parse_concat(text_log_2, replace=', ')}.\\n\"\n",
    "        instructions += f\"Preference: {preference_2}.\\n\"\n",
    "        instructions += f\"Text log file: \\n {parse_concat(text_log_3, replace=', ')}.\\n\"\n",
    "        instructions += f\"Preference: {preference_3}.\\n\"\n",
    "        instructions += f\"Text log file: \\n {parse_concat(text_log, replace=', ')}.\\n\"\n",
    "        instructions += f\"The user's preference is \"\n",
    "\n",
    "        # print(instructions)\n",
    "        while(True):\n",
    "            try:\n",
    "                answer, messages, cost = chat(client, MODEL, instructions, messages, role=\"assistant\")\n",
    "                break\n",
    "            except:\n",
    "                time.sleep(1)\n",
    "\n",
    "        total_cost += cost\n",
    "\n",
    "        gt = test_file[\"preference\"]\n",
    "\n",
    "        answer = answer.lower().split(\",\")[0]\n",
    "\n",
    "        if scene not in scene_list:\n",
    "            scene_list[scene] = []\n",
    "\n",
    "        if not compare(answer, gt, in_sequence=True):\n",
    "            false += 1\n",
    "            scene_list[scene].append(-2)\n",
    "        else:\n",
    "            scene_list[scene].append(2)\n",
    "\n",
    "            \n",
    "        print(f\"True: {n+1-false}/{n+1}, total cost: {total_cost}\")\n",
    "        \n",
    "    print(f\"True: {n+1-false}/{len(test_files)}\")\n",
    "\n",
    "    for scene in scene_list:\n",
    "        print(f\"{scene_list[scene]}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for scene in scene_list:\n",
    "    print(str(scene_list[scene]).replace(\", \", \" \")[1:-1])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "openai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
