{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6018f18e-78b8-42e4-8d33-265b5beafde7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import openai\n",
    "import time\n",
    "\n",
    "import components.data_utils as data_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0c9685d3-f88b-4aaa-9f0a-e31504910400",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"cifar10\"\n",
    "prompt_type = \"important\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "aa4309fd-ee0b-4c84-9706-4c5a3e47654d",
   "metadata": {},
   "outputs": [],
   "source": [
    "openai.api_key = '' # please add key here.\n",
    "# or load from path.\n",
    "# openai.api_key = open(os.path.join(os.path.expanduser(\"~\"), \".openai_api_key\"), \"r\").read()[:-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a6fd600a-a7f0-4488-9360-09238b1a9305",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompts = {\n",
    "    \"important\" : \"List the most important features for recognizing something as a \\\"goldfish\\\":\\n\\n-bright orange color\\n-a small, round body\\n-a long, flowing tail\\n-a small mouth\\n-orange fins\\n\\nList the most important features for recognizing something as a \\\"beerglass\\\":\\n\\n-a tall, cylindrical shape\\n-clear or translucent color\\n-opening at the top\\n-a sturdy base\\n-a handle\\n\\nList the most important features for recognizing something as a \\\"{}\\\":\",\n",
    "    \"superclass\" : \"Give superclasses for the word \\\"tench\\\":\\n\\n-fish\\n-vertebrate\\n-animal\\n\\nGive superclasses for the word \\\"beer glass\\\":\\n\\n-glass\\n-container\\n-object\\n\\nGive superclasses for the word \\\"{}\\\":\",\n",
    "    \"around\" : \"List the things most commonly seen around a \\\"tench\\\":\\n\\n- a pond\\n-fish\\n-a net\\n-a rod\\n-a reel\\n-a hook\\n-bait\\n\\nList the things most commonly seen around a \\\"beer glass\\\":\\n\\n- beer\\n-a bar\\n-a coaster\\n-a napkin\\n-a straw\\n-a lime\\n-a person\\n\\nList the things most commonly seen around a \\\"{}\\\":\"\n",
    "}\n",
    "\n",
    "base_prompt = prompts[prompt_type]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "19fc160f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'List the most important features for recognizing something as a \"goldfish\":\\n\\n-bright orange color\\n-a small, round body\\n-a long, flowing tail\\n-a small mouth\\n-orange fins\\n\\nList the most important features for recognizing something as a \"beerglass\":\\n\\n-a tall, cylindrical shape\\n-clear or translucent color\\n-opening at the top\\n-a sturdy base\\n-a handle\\n\\nList the most important features for recognizing something as a \"{}\":'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "base_prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "57ac8873-5b7f-4458-aef0-09278664b05a",
   "metadata": {},
   "outputs": [],
   "source": [
    "cls_file = data_utils.LABEL_FILES[dataset]\n",
    "with open(cls_file, \"r\") as f:\n",
    "    classes = f.read().split(\"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "25302273",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['airplane',\n",
       " 'automobile',\n",
       " 'bird',\n",
       " 'cat',\n",
       " 'deer',\n",
       " 'dog',\n",
       " 'frog',\n",
       " 'horse',\n",
       " 'ship',\n",
       " 'truck']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "05f19306-ecdd-4d5b-9556-d8b8b7930eac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " 0 airplane\n"
     ]
    }
   ],
   "source": [
    "MAX_TOKEN_PER_MIN = 40\n",
    "INTERVAL = 60.0 / MAX_TOKEN_PER_MIN\n",
    "feature_dict = {}\n",
    "\n",
    "for i, label in enumerate(classes):\n",
    "    feature_dict[label] = set()\n",
    "    print(\"\\n\", i, label)\n",
    "    for _ in range(2):\n",
    "        response = openai.ChatCompletion.create(\n",
    "            model=\"gpt-4\",\n",
    "            messages=[\n",
    "                {\"role\": \"system\", \"content\": \"You are an assistant that provides visual descriptions of objects. Use only adjectives and nouns in your description. Ensure each description is unique, short, and direct. Do not use qualifiers like 'typically', 'generally', or similar words.\"},\n",
    "                {\"role\": \"user\", \"content\": base_prompt.format(label)}],\n",
    "            temperature=0.7,\n",
    "            max_tokens=256,\n",
    "            top_p=1,\n",
    "            frequency_penalty=0,\n",
    "            presence_penalty=0\n",
    "        )\n",
    "        #clean up responses\n",
    "        features = response[\"choices\"][0][\"message\"][\"content\"]\n",
    "        features = features.split(\"\\n-\")\n",
    "        features = [feat.replace(\"\\n\", \"\") for feat in features]\n",
    "        features = [feat.strip() for feat in features]\n",
    "        features = [feat for feat in features if len(feat)>0]\n",
    "        features = [feat[1:] if feat[0] == '-' else feat for feat in features]\n",
    "        features = set(features)\n",
    "        feature_dict[label].update(features)\n",
    "    time.sleep(INTERVAL)\n",
    "    feature_dict[label] = sorted(list(feature_dict[label]) + [label])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "c033c81e-530b-4b62-967c-722f90636187",
   "metadata": {},
   "outputs": [],
   "source": [
    "json_object = json.dumps(feature_dict, indent=4)\n",
    "with open(\"data/concept_sets/gpt4_init/gpt4_{}_{}.json\".format(dataset, prompt_type), \"w\") as outfile:\n",
    "    outfile.write(json_object)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "78a3dfb0",
   "metadata": {},
   "source": [
    "### Test GPT3 & GPT4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "822dee2f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'- water', 'anchor', 'cargo', 'dock', 'lifebuoy', 'rope', 'sailors'}"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "response = openai.ChatCompletion.create(\n",
    "    model=\"gpt-4\",\n",
    "    messages=[\n",
    "        {\"role\": \"system\", \"content\": \"You are an assistant that provides visual descriptions of objects. Use only adjectives and nouns in your description. Ensure each description is unique, short, and direct. Do not use qualifiers like 'typically', 'generally', or similar words.\"},\n",
    "        {\"role\": \"user\", \"content\": prompts[\"around\"].format('ship')}],\n",
    "    temperature=0.7,\n",
    "    max_tokens=256,\n",
    "    top_p=1,\n",
    "    frequency_penalty=0,\n",
    "    presence_penalty=0\n",
    ")\n",
    "#clean up responses\n",
    "features = response[\"choices\"][0][\"message\"][\"content\"]\n",
    "features = features.split(\"\\n-\")\n",
    "features = [feat.replace(\"\\n\", \"\") for feat in features]\n",
    "features = [feat.strip() for feat in features]\n",
    "features = [feat for feat in features if len(feat)>0]\n",
    "features = set(features)\n",
    "features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "c12b5eee",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'a boat', 'a captain', 'a crew', 'a dock', 'a flag', 'a mast', 'the sea'}"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "response = openai.Completion.create(\n",
    "        model=\"text-davinci-002\",\n",
    "        prompt=prompts[\"around\"].format('ship'),\n",
    "        temperature=0.7,\n",
    "        max_tokens=256,\n",
    "        top_p=1,\n",
    "        frequency_penalty=0,\n",
    "        presence_penalty=0\n",
    "    )\n",
    "#clean up responses\n",
    "features = response[\"choices\"][0][\"text\"]\n",
    "features = features.split(\"\\n-\")\n",
    "features = [feat.replace(\"\\n\", \"\") for feat in features]\n",
    "features = [feat.strip() for feat in features]\n",
    "features = [feat for feat in features if len(feat)>0]\n",
    "features = set(features)\n",
    "features"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch_1_10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
