{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f044edc2-b257-47e1-959b-c10a0f2a46e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import components.data_utils as data_utils\n",
    "import conceptset_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "ec09e8bc-a9ca-4821-98e6-b1994493da22",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "CLASS_SIM_CUTOFF: Concenpts with cos similarity higher than this to any class will be removed\n",
    "OTHER_SIM_CUTOFF: Concenpts with cos similarity higher than this to another concept will be removed\n",
    "MAX_LEN: max number of characters in a concept\n",
    "\n",
    "PRINT_PROB: what percentage of filtered concepts will be printed\n",
    "\"\"\"\n",
    "\n",
    "CLASS_SIM_CUTOFF = 0.85\n",
    "OTHER_SIM_CUTOFF = 0.9\n",
    "MAX_LEN = 30\n",
    "PRINT_PROB = 1\n",
    "\n",
    "dataset = \"stl10\"\n",
    "device = \"cuda:1\"\n",
    "gpt_model = \"gpt4\"\n",
    "\n",
    "save_name = \"data/concept_sets/{}_filtered_remove_target_gpt4.txt\".format(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "8f98f951-32e8-40af-90d9-a06042126810",
   "metadata": {},
   "outputs": [],
   "source": [
    "#EDIT these to use the initial concept sets you want\n",
    "\n",
    "with open(\"data/concept_sets/{}_init/{}_{}_important.json\".format(gpt_model, gpt_model, dataset), \"r\") as f:\n",
    "    important_dict = json.load(f)\n",
    "with open(\"data/concept_sets/{}_init/{}_{}_superclass.json\".format(gpt_model, gpt_model, dataset), \"r\") as f:\n",
    "    superclass_dict = json.load(f)\n",
    "with open(\"data/concept_sets/{}_init/{}_{}_around.json\".format(gpt_model, gpt_model, dataset), \"r\") as f:\n",
    "    around_dict = json.load(f)\n",
    "    \n",
    "with open(data_utils.LABEL_FILES[dataset], \"r\") as f:\n",
    "    classes = f.read().split(\"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "ba3f0ba4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "76763234",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['airplane',\n",
       " 'bird',\n",
       " 'car',\n",
       " 'cat',\n",
       " 'deer',\n",
       " 'dog',\n",
       " 'horse',\n",
       " 'monkey',\n",
       " 'ship',\n",
       " 'truck']"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "8da88093-b11f-4273-9599-5c986063e869",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "207\n"
     ]
    }
   ],
   "source": [
    "concepts = set()\n",
    "\n",
    "for values in important_dict.values():\n",
    "    concepts.update(set(values))\n",
    "\n",
    "for values in superclass_dict.values():\n",
    "    concepts.update(set(values))\n",
    "    \n",
    "for values in around_dict.values():\n",
    "    concepts.update(set(values))\n",
    "\n",
    "print(len(concepts))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "b14c878e-d6f8-47fd-9322-85d14dfbdaa8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "77 doorsList the most important features for recognizing something as a \"piano\":\n",
      "207 206\n"
     ]
    }
   ],
   "source": [
    "concepts = conceptset_utils.remove_too_long(concepts, MAX_LEN, PRINT_PROB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "3950a55f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "206\n",
      "Class:airplane - Deleting airplane\n",
      "Class:bird - Deleting bird\n",
      "Class:car - Deleting car\n",
      "Class:cat - Deleting cat\n",
      "Class:deer - Deleting deer\n",
      "Class:dog - Deleting dog\n",
      "Class:horse - Deleting horse\n",
      "Class:monkey - Deleting monkey\n",
      "Class:ship - Deleting ship\n",
      "Class:truck - Deleting truck\n",
      "196\n",
      "Class:airplane - Concept:aircraft, sim:0.887 - Deleting aircraft\n",
      "\n",
      "Class:bird - Concept:avian, sim:0.881 - Deleting avian\n",
      "\n",
      "Class:car - Concept:vehicle, sim:0.909 - Deleting vehicle\n",
      "\n",
      "Class:cat - Concept:feline, sim:0.883 - Deleting feline\n",
      "\n",
      "Class:deer - Concept:other deer, sim:0.934 - Deleting other deer\n",
      "\n",
      "Class:dog - Concept:animal, sim:0.855 - Deleting animal\n",
      "\n",
      "Class:dog - Concept:pet, sim:0.871 - Deleting pet\n",
      "\n",
      "Class:monkey - Concept:other monkeys, sim:0.891 - Deleting other monkeys\n",
      "\n",
      "Class:monkey - Concept:primate, sim:0.910 - Deleting primate\n",
      "\n",
      "187\n"
     ]
    }
   ],
   "source": [
    "concepts = conceptset_utils.filter_too_similar_to_cls(concepts, classes, CLASS_SIM_CUTOFF, device, PRINT_PROB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "3161dcfb-f9ec-4775-9b80-d8042c4f9f32",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " a collar - a collar , sim:1.0000 - Deleting  a collar\n",
      " a driver - a driver , sim:1.0000 - Deleting a driver\n",
      " a litter box - a litter box , sim:1.0000 - Deleting a litter box\n",
      " a road - a road , sim:1.0000 - Deleting a road\n",
      " a runway -  runway , sim:0.9388 - Deleting  runway\n",
      " feathers - feathers , sim:0.9995 - Deleting feathers\n",
      " feathers - tail feathers , sim:0.9061 - Deleting tail feathers\n",
      "a bowl - a food bowl , sim:0.9133 - Deleting a food bowl\n",
      "a branch - branches , sim:0.9068 - Deleting branches\n",
      "a dog bed - a pet bed , sim:0.9549 - Deleting a dog bed\n",
      "a flight attendant - flight attendants , sim:0.9036 - Deleting flight attendants\n",
      "a gas station - fuel station , sim:0.9022 - Deleting a gas station\n",
      "a long tail - long tail , sim:0.9705 - Deleting long tail\n",
      "a white tail - white tail , sim:0.9556 - Deleting white tail\n",
      "a zookeeper - zookeeper , sim:0.9486 - Deleting zookeeper\n",
      "airport - an airport , sim:0.9394 - Deleting airport\n",
      "antlers - antlers (for males) , sim:0.9277 - Deleting antlers (for males)\n",
      "four large wheels - four or more wheels , sim:0.9340 - Deleting four or more wheels\n",
      "four legs - four strong legs , sim:0.9032 - Deleting four strong legs\n",
      "large eyes - large, expressive eyes , sim:0.9126 - Deleting large, expressive eyes\n",
      "long mane - long, flowing mane , sim:0.9466 - Deleting long, flowing mane\n",
      "long, muscular body - long, muscular legs , sim:0.9062 - Deleting long, muscular legs\n",
      "machine - machinery , sim:0.9015 - Deleting machinery\n",
      "pointed ears - pointy ears , sim:0.9053 - Deleting pointed ears\n",
      "round eyes - small, round eyes , sim:0.9379 - Deleting round eyes\n",
      "sails - white sails , sim:0.9001 - Deleting white sails\n",
      "tall mast - tall masts , sim:0.9231 - Deleting tall masts\n",
      "160\n"
     ]
    }
   ],
   "source": [
    "concepts = conceptset_utils.filter_too_similar(concepts, OTHER_SIM_CUTOFF, device, PRINT_PROB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "6d7db95f-3391-4609-b3b4-96f234e972b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(save_name, \"w\") as f:\n",
    "    f.write(concepts[0])\n",
    "    for concept in concepts[1:]:\n",
    "        f.write(\"\\n\" + concept)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
