{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f044edc2-b257-47e1-959b-c10a0f2a46e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import data_utils\n",
    "import conceptset_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ec09e8bc-a9ca-4821-98e6-b1994493da22",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "CLASS_SIM_CUTOFF: Concenpts with cos similarity higher than this to any class will be removed\n",
    "OTHER_SIM_CUTOFF: Concenpts with cos similarity higher than this to another concept will be removed\n",
    "MAX_LEN: max number of characters in a concept\n",
    "\n",
    "PRINT_PROB: what percentage of filtered concepts will be printed\n",
    "\"\"\"\n",
    "\n",
    "CLASS_SIM_CUTOFF = 0.85\n",
    "OTHER_SIM_CUTOFF = 0.9\n",
    "MAX_LEN = 30\n",
    "PRINT_PROB = 1\n",
    "\n",
    "dataset = \"cifar10\"\n",
    "device = \"cuda\"\n",
    "\n",
    "save_name = \"data/concept_sets/{}_filtered_new.txt\".format(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8f98f951-32e8-40af-90d9-a06042126810",
   "metadata": {},
   "outputs": [],
   "source": [
    "#EDIT these to use the initial concept sets you want\n",
    "\n",
    "with open(\"data/concept_sets/gpt3_init/gpt3_{}_important.json\".format(dataset), \"r\") as f:\n",
    "    important_dict = json.load(f)\n",
    "with open(\"data/concept_sets/gpt3_init/gpt3_{}_superclass.json\".format(dataset), \"r\") as f:\n",
    "    superclass_dict = json.load(f)\n",
    "with open(\"data/concept_sets/gpt3_init/gpt3_{}_around.json\".format(dataset), \"r\") as f:\n",
    "    around_dict = json.load(f)\n",
    "    \n",
    "with open(data_utils.LABEL_FILES[dataset], \"r\") as f:\n",
    "    classes = f.read().split(\"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8da88093-b11f-4273-9599-5c986063e869",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "170\n"
     ]
    }
   ],
   "source": [
    "concepts = set()\n",
    "\n",
    "for values in important_dict.values():\n",
    "    concepts.update(set(values))\n",
    "\n",
    "for values in superclass_dict.values():\n",
    "    concepts.update(set(values))\n",
    "    \n",
    "for values in around_dict.values():\n",
    "    concepts.update(set(values))\n",
    "\n",
    "print(len(concepts))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b14c878e-d6f8-47fd-9322-85d14dfbdaa8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "170 170\n"
     ]
    }
   ],
   "source": [
    "concepts = conceptset_utils.remove_too_long(concepts, MAX_LEN, PRINT_PROB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d27e7a88-bdd4-49cb-b74d-fd16866eb5e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "170\n",
      "Class:bird - Deleting bird\n",
      "Class:cat - Deleting a cat\n",
      "Class:deer - Deleting a deer\n",
      "167\n",
      "Class:airplane - Concept:a plane, sim:0.876 - Deleting a plane\n",
      "\n",
      "Class:airplane - Concept:aircraft, sim:0.887 - Deleting aircraft\n",
      "\n",
      "Class:automobile - Concept:a car, sim:0.881 - Deleting a car\n",
      "\n",
      "Class:automobile - Concept:car, sim:0.921 - Deleting car\n",
      "\n",
      "Class:automobile - Concept:vehicle, sim:0.904 - Deleting vehicle\n",
      "\n",
      "Class:dog - Concept:animal, sim:0.855 - Deleting animal\n",
      "\n",
      "Class:truck - Concept:a truck driver, sim:0.863 - Deleting a truck driver\n",
      "\n",
      "160\n"
     ]
    }
   ],
   "source": [
    "concepts = conceptset_utils.filter_too_similar_to_cls(concepts, classes, CLASS_SIM_CUTOFF, device, PRINT_PROB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3161dcfb-f9ec-4775-9b80-d8042c4f9f32",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "a bird house - a birdhouse , sim:0.9706 - Deleting a bird house\n",
      "a bowl of food - a food bowl , sim:0.9462 - Deleting a bowl of food\n",
      "a bowl of water - a water bowl , sim:0.9393 - Deleting a bowl of water\n",
      "a engine - an engine , sim:0.9968 - Deleting a engine\n",
      "a food bowl - a water bowl , sim:0.9172 - Deleting a water bowl\n",
      "a grille - a grille mouth , sim:0.9119 - Deleting a grille mouth\n",
      "a large body - a large, boxy body , sim:0.9017 - Deleting a large, boxy body\n",
      "a large body - a large, muscular body , sim:0.9227 - Deleting a large, muscular body\n",
      "a long tail - a short tail , sim:0.9107 - Deleting a long tail\n",
      "a short tail - a tail , sim:0.9235 - Deleting a short tail\n",
      "ability to fly - the ability to fly , sim:0.9764 - Deleting ability to fly\n",
      "bulging eyes - large, bulging eyes , sim:0.9604 - Deleting bulging eyes\n",
      "four large wheels - four wheels , sim:0.9352 - Deleting four large wheels\n",
      "four legs - two legs , sim:0.9147 - Deleting two legs\n",
      "headlights - two headlights , sim:0.9236 - Deleting headlights\n",
      "long legs - long, thin legs , sim:0.9055 - Deleting long legs\n",
      "two headlight eyes - two headlights , sim:0.9112 - Deleting two headlight eyes\n",
      "two wings - wings , sim:0.9054 - Deleting two wings\n",
      "142\n"
     ]
    }
   ],
   "source": [
    "concepts = conceptset_utils.filter_too_similar(concepts, OTHER_SIM_CUTOFF, device, PRINT_PROB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6d7db95f-3391-4609-b3b4-96f234e972b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(save_name, \"w\") as f:\n",
    "    f.write(concepts[0])\n",
    "    for concept in concepts[1:]:\n",
    "        f.write(\"\\n\" + concept)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:jovyan-lf_cbm]",
   "language": "python",
   "name": "conda-env-jovyan-lf_cbm-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
