{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6018f18e-78b8-42e4-8d33-265b5beafde7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import openai\n",
    "\n",
    "import data_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0c9685d3-f88b-4aaa-9f0a-e31504910400",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"cifar10\"\n",
    "prompt_type = \"important\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "aa4309fd-ee0b-4c84-9706-4c5a3e47654d",
   "metadata": {},
   "outputs": [],
   "source": [
    "openai.api_key = open(os.path.join(os.path.expanduser(\"~\"), \".openai_api_key\"), \"r\").read()[:-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a6fd600a-a7f0-4488-9360-09238b1a9305",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompts = {\n",
    "    \"important\" : \"List the most important features for recognizing something as a \\\"goldfish\\\":\\n\\n-bright orange color\\n-a small, round body\\n-a long, flowing tail\\n-a small mouth\\n-orange fins\\n\\nList the most important features for recognizing something as a \\\"beerglass\\\":\\n\\n-a tall, cylindrical shape\\n-clear or translucent color\\n-opening at the top\\n-a sturdy base\\n-a handle\\n\\nList the most important features for recognizing something as a \\\"{}\\\":\",\n",
    "    \"superclass\" : \"Give superclasses for the word \\\"tench\\\":\\n\\n-fish\\n-vertebrate\\n-animal\\n\\nGive superclasses for the word \\\"beer glass\\\":\\n\\n-glass\\n-container\\n-object\\n\\nGive superclasses for the word \\\"{}\\\":\",\n",
    "    \"around\" : \"List the things most commonly seen around a \\\"tench\\\":\\n\\n- a pond\\n-fish\\n-a net\\n-a rod\\n-a reel\\n-a hook\\n-bait\\n\\nList the things most commonly seen around a \\\"beer glass\\\":\\n\\n- beer\\n-a bar\\n-a coaster\\n-a napkin\\n-a straw\\n-a lime\\n-a person\\n\\nList the things most commonly seen around a \\\"{}\\\":\"\n",
    "}\n",
    "\n",
    "base_prompt = prompts[prompt_type]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "57ac8873-5b7f-4458-aef0-09278664b05a",
   "metadata": {},
   "outputs": [],
   "source": [
    "cls_file = data_utils.LABEL_FILES[dataset]\n",
    "with open(cls_file, \"r\") as f:\n",
    "    classes = f.read().split(\"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "05f19306-ecdd-4d5b-9556-d8b8b7930eac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " 0 airplane\n",
      "\n",
      " 1 automobile\n",
      "\n",
      " 2 bird\n",
      "\n",
      " 3 cat\n",
      "\n",
      " 4 deer\n",
      "\n",
      " 5 dog\n",
      "\n",
      " 6 frog\n",
      "\n",
      " 7 horse\n",
      "\n",
      " 8 ship\n",
      "\n",
      " 9 truck\n"
     ]
    }
   ],
   "source": [
    "feature_dict = {}\n",
    "\n",
    "for i, label in enumerate(classes):\n",
    "    feature_dict[label] = set()\n",
    "    print(\"\\n\", i, label)\n",
    "    for _ in range(2):\n",
    "        response = openai.completions.create(\n",
    "              model=\"text-davinci-002\",\n",
    "              prompt=base_prompt.format(label),\n",
    "              temperature=0.7,\n",
    "              max_tokens=256,\n",
    "              top_p=1,\n",
    "              frequency_penalty=0,\n",
    "              presence_penalty=0\n",
    "            )\n",
    "        #clean up responses\n",
    "        features = response.choices[0].text\n",
    "        features = features.split(\"\\n-\")\n",
    "        features = [feat.replace(\"\\n\", \"\") for feat in features]\n",
    "        features = [feat.strip() for feat in features]\n",
    "        features = [feat for feat in features if len(feat)>0]\n",
    "        features = set(features)\n",
    "        feature_dict[label].update(features)\n",
    "    feature_dict[label] = sorted(list(feature_dict[label]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c033c81e-530b-4b62-967c-722f90636187",
   "metadata": {},
   "outputs": [],
   "source": [
    "json_object = json.dumps(feature_dict, indent=4)\n",
    "with open(\"data/concept_sets/gpt3_init/gpt3_{}_{}_new.json\".format(dataset, prompt_type), \"w\") as outfile:\n",
    "    outfile.write(json_object)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch_1_10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
