{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b345c118-d983-4234-bc85-f11e41c31485",
   "metadata": {},
   "outputs": [],
   "source": [
    "import data_utils\n",
    "import conceptset_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f569c19e-d3b1-428d-a8e5-64e71fd5391c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "ConceptNet params:\n",
    "LIMIT:how many relations to look up, higher limit -> larger concept set\n",
    "RELATIONS: which relations to include in search \n",
    "\n",
    "filters:\n",
    "CLASS_SIM_CUTOFF: Concenpts with cos similarity higher than this to any class will be removed\n",
    "OTHER_SIM_CUTOFF: Concenpts with cos similarity higher than this to another concept will be removed\n",
    "MAX_LEN: max number of characters in a concept\n",
    "\n",
    "PRINT_PROB: what percentage of filtered concepts will be printed\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "LIMIT = 200\n",
    "RELATIONS = [\"HasA\", \"IsA\", \"PartOf\", \"HasProperty\", \"MadeOf\"]#, \"AtLocation\"]\n",
    "\n",
    "CLASS_SIM_CUTOFF = 0.85\n",
    "OTHER_SIM_CUTOFF = 0.9 \n",
    "MAX_LEN = 30\n",
    "\n",
    "PRINT_PROB = 0.2\n",
    "\n",
    "dataset = \"imagenet\"\n",
    "save_name = 'data/concept_sets/conceptnet_{}_filtered_new.txt'.format(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3bb1903-e514-48ed-9449-c3f1a5239881",
   "metadata": {},
   "outputs": [],
   "source": [
    "cls_file = data_utils.LABEL_FILES[dataset]\n",
    "\n",
    "with open(cls_file, 'r') as f:\n",
    "    classes = f.read().split('\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "619ea6f9-d084-42c7-a09f-a45cb2359849",
   "metadata": {},
   "source": [
    "# Collect initial concepts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01d0c02f-bf5b-44d9-aa58-5e1c76b1c278",
   "metadata": {},
   "outputs": [],
   "source": [
    "concepts = conceptset_utils.get_init_conceptnet(classes, LIMIT, RELATIONS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f17bbfef-1a85-4d94-9319-c1d2af8b0af9",
   "metadata": {},
   "outputs": [],
   "source": [
    "concepts = conceptset_utils.remove_too_long(concepts, MAX_LEN, PRINT_PROB)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59600ce4-67e6-4dc7-835f-98250048a697",
   "metadata": {},
   "source": [
    "# Filter out concepts too similar to class names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "913874cf-2000-4de7-9108-1dd4a7a1b520",
   "metadata": {},
   "outputs": [],
   "source": [
    "concepts = conceptset_utils.filter_too_similar_to_cls(concepts, classes, CLASS_SIM_CUTOFF, print_prob=PRINT_PROB)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c398239-e2c8-42a0-aaf7-93810c389a72",
   "metadata": {},
   "source": [
    "# Filter out concepts too similar to each other"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94ee717e-066e-4916-a30a-dbb0c7df688f",
   "metadata": {},
   "outputs": [],
   "source": [
    "concepts = conceptset_utils.filter_too_similar(concepts, OTHER_SIM_CUTOFF, print_prob=PRINT_PROB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "061eca82-5705-4839-a6cd-b9eef0b3d593",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(save_name, \"w\") as f:\n",
    "    f.write(concepts[0])\n",
    "    for concept in concepts[1:]:\n",
    "        f.write(\"\\n\" + concept)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch_1_10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
