{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import display, Image, Audio\n",
    "import json\n",
    "import cv2  # We're using OpenCV to read video, to install !pip install opencv-python\n",
    "import base64\n",
    "import time\n",
    "import random\n",
    "from openai import OpenAI\n",
    "import os\n",
    "import requests\n",
    "from utils import chat_vision\n",
    "from constants import *\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from openai import AzureOpenAI\n",
    "from utils import *\n",
    "\n",
    "\n",
    "MODEL = \"gpt-4-turbo-2024-04-09\"\n",
    "REGION = \"eastus2\"\n",
    "API_KEY = \"YOUR_API_KEY\"\n",
    "API_BASE = \"https://api.openai.com\"\n",
    "ENDPOINT = f\"{API_BASE}/{REGION}\"\n",
    "\n",
    "client = AzureOpenAI(\n",
    "    api_key=API_KEY,\n",
    "    api_version=\"2024-02-01\",\n",
    "    azure_endpoint=ENDPOINT,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tokencost import count_message_tokens, count_string_tokens\n",
    "\n",
    "message_prompt = [{ \"role\": \"user\", \"content\": \"Hello world\"}]\n",
    "# Counting tokens in prompts formatted as message lists\n",
    "print(count_message_tokens(message_prompt, model=\"gpt-4-turbo\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_video(path):\n",
    "    video = cv2.VideoCapture(path)\n",
    "    base64Frames = []\n",
    "    buffers = []\n",
    "    while video.isOpened():\n",
    "        success, frame = video.read()\n",
    "        if not success:\n",
    "            break\n",
    "        buffers.append(frame)\n",
    "    video.release() \n",
    "\n",
    "    if len(buffers) > 8:\n",
    "        buffers = buffers[::int(len(buffers)/8)]\n",
    "    else:\n",
    "        buffers = buffers\n",
    "    # concat 4 to one, as the max number of gpt4v frames is 10\n",
    "    for i in range(0, len(buffers)-4, 4):\n",
    "        frame = cv2.vconcat([buffers[i], buffers[i+1], buffers[i+2], buffers[i+3]])\n",
    "        _, buffer = cv2.imencode(\".jpg\", frame)\n",
    "        base64Frames.append(base64.b64encode(buffer).decode(\"utf-8\"))\n",
    "\n",
    "    # print(len(base64Frames), \"frames read.\")\n",
    "    # resample to 50 frames\n",
    "    return base64Frames"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import chat_vision\n",
    "\n",
    "data_dir = \"/dataset\"\n",
    "train_file = f\"{data_dir}train.json\"\n",
    "\n",
    "total_cost = 0\n",
    "false = 0\n",
    "\n",
    "with open(train_file, \"r\") as f:\n",
    "    lines = f.readlines()\n",
    "    test_files = [json.loads(line) for line in lines]\n",
    "    # shuffle\n",
    "    random.shuffle(test_files)\n",
    "\n",
    "    for n, test_file in enumerate(tqdm(test_files[:100])):\n",
    "\n",
    "        with open(test_file[\"text\"], 'r') as f:\n",
    "            text_log = f.readlines()\n",
    "            \n",
    "        base64Frames = read_video(test_file[\"camera\"])\n",
    "        \n",
    "        preference = test_file[\"preference\"]\n",
    "            \n",
    "        video_path_1 = test_files[n+1][\"camera\"]\n",
    "        preference_1 = test_files[n+1][\"preference\"]\n",
    "        \n",
    "        video_path_2 = test_files[n+2][\"camera\"]\n",
    "        preference_2 = test_files[n+2][\"preference\"]\n",
    "        \n",
    "        video_path_3, preference_3 = get_same_demo_video(test_file[\"preference\"], test_files)\n",
    "        \n",
    "        # shuffle the video logs and corresponding preferences\n",
    "        demos = [[video_path_1, preference_1], [video_path_2, preference_2], [video_path_3, preference_3]]\n",
    "        random.shuffle(demos)\n",
    "        video_path_1, preference_1 = demos[0]\n",
    "        video_path_2, preference_2 = demos[1]\n",
    "        video_path_3, preference_3 = demos[2]\n",
    "\n",
    "        base64Frames_1 = read_video(video_path_1)\n",
    "        base64Frames_2 = read_video(video_path_2)\n",
    "        base64Frames_3 = read_video(video_path_3)\n",
    "                \n",
    "        # print(\"len(base64Frames):\", len(base64Frames))\n",
    "        messages = []\n",
    "        instructions = \"You are a robot assistant that can help summarize the host's preference. Please read the text log file and summarize the user's preference.\"\n",
    "        # possible_preferences = Rearrangement[0]['Level0'] + Rearrangement[2]['Level2']\n",
    "        possible_preferences = Sequence_Preferences['name']\n",
    "\n",
    "        instructions += f\"Choose from following preference: \\n{parse_concat(possible_preferences, replace=', ')}.\\n\"\n",
    "\n",
    "        PROMPT_MESSAGES = [\n",
    "            {\n",
    "                \"role\": \"user\",\n",
    "                \"content\": [\n",
    "                    instructions,\n",
    "                    \"Quesiton: What's the user's preference? Choose from the preference listed before:\",\n",
    "                ],\n",
    "            },\n",
    "            {\n",
    "                \"role\": \"user\",\n",
    "                \"content\": [\n",
    "                    *map(lambda x: {\"image\": x, \"resize\": 512}, base64Frames_1[:2]),\n",
    "                    f\"the preference is {preference_1}\",\n",
    "                    *map(lambda x: {\"image\": x, \"resize\": 512}, base64Frames_2[:2]),\n",
    "                    f\"the preference is {preference_2}\",\n",
    "                    *map(lambda x: {\"image\": x, \"resize\": 512}, base64Frames_3[:2]),\n",
    "                    f\"the preference is {preference_3}\",\n",
    "                    *map(lambda x: {\"image\": x, \"resize\": 512}, base64Frames[:2]),\n",
    "                    f\"the preference is?\",\n",
    "                ],\n",
    "            },\n",
    "        ]\n",
    "\n",
    "        while(True):\n",
    "            try:\n",
    "                answer, messages, cost = chat_vision(client, MODEL, None, PROMPT_MESSAGES, role=\"assistant\")\n",
    "                break\n",
    "            except Exception as e:\n",
    "                print(e)\n",
    "                time.sleep(1)\n",
    "                \n",
    "        gt = test_file[\"preference\"]\n",
    "\n",
    "        answer = answer.lower()\n",
    "\n",
    "        for keyword in gt.split(\" \")[:]:\n",
    "            if keyword.lower() not in answer:\n",
    "                false += 1\n",
    "                print(f\"False: {answer} vs {gt}\")\n",
    "                break\n",
    "\n",
    "        print(f\"True: {n+1-false}/{n+1}, total cost: {total_cost}\")\n",
    "        \n",
    "    print(f\"True: {n+1-false}/{len(test_files)}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "openai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
