{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "faf3a2ad-b356-413e-a15c-48cebdcbed30",
   "metadata": {},
   "source": [
    "# Table 4: Coherence and Seed Set Semantic Similarity"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c470422e-c38d-4447-a042-5a47df79c32c",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Imports and Constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "627e7f7a-1ef0-4c6e-9c95-4ea890eadc6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fac2c694-fecb-4a4b-9feb-c18f45f1266a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, string, random\n",
    "import numpy as np\n",
    "\n",
    "if os.path.isdir(\"../notebooks/\"):\n",
    "    os.chdir(\"../badseeds/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "59aca587-5c4e-41f5-bd62-e5ac5660ea74",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import pandas as pd\n",
    "import gensim.models as gm\n",
    "from tqdm import tqdm\n",
    "\n",
    "import seedbank, tab_a2\n",
    "from utils import generate_seed_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "63a73473-4ead-4ec1-bc7c-8416225d0196",
   "metadata": {},
   "outputs": [],
   "source": [
    "# constants\n",
    "SEED = 42\n",
    "CONFIG_PATH = \"../config.json\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5113184a-800e-4289-b986-b2537549cdaa",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Loading Files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d2755c20-1ff0-463b-a216-32229e5df20c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# path to config json file containing paths to datasets. change if necessary\n",
    "with open(CONFIG_PATH, \"r\") as f:\n",
    "    config = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fe30d1c9-26f1-4b61-adbf-19010527c125",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load our gathered seeds\n",
    "seeds = seedbank.seedbanking(config[\"seeds\"][\"dir_path\"] + \"seeds.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e959702c-54d9-4416-85f3-5bf424c6c6f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load embeddings\n",
    "models = []\n",
    "models_dir = os.path.join(\n",
    "    config[\"models\"][\"dir_path\"], config[\"models\"][\"nyt_subpath\"][\"10\"]\n",
    ")\n",
    "for file in os.listdir(models_dir):\n",
    "    if file.endswith(\".kv\"):\n",
    "        models.append(gm.KeyedVectors.load(os.path.join(models_dir, file)))\n",
    "\n",
    "if len(models) == 0:\n",
    "    raise ValueError(\"No embeddings found in directory.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15be8519-36d6-4690-a101-15cbdc9817d8",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Generating Table"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f593f25-f34d-4a1e-9667-bf8c86ea651d",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### Gathered Seeds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4dc3be0c-e85c-47f0-a189-1e09334d06d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:32<00:00,  1.65s/model]\n"
     ]
    }
   ],
   "source": [
    "all_coherence = []\n",
    "for model in tqdm(models, unit=\"model\"):\n",
    "    coh = tab_a2.build_row_table_a2(\n",
    "        model,\n",
    "        seeds,\n",
    "        pairing_method=\"file\",\n",
    "        pair_path=\"../seed_set_pairings.csv\",\n",
    "    )\n",
    "    all_coherence.append(coh)\n",
    "\n",
    "# aggregate\n",
    "coh_avg_gath = tab_a2.agg_coherence(all_coherence)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d440423-a868-470b-9b9d-2e314f7161c0",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### Generated Seeds"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e3a9ffd-ab36-4b23-8818-15aa691c827e",
   "metadata": {},
   "source": [
    "For generated seeds, first we have to generate _nice enough_ seed sets (i.e., seed sets that use Latin characters). We set the seed as specified in the first section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "825d5bd8-b9f6-44d3-9057-2acd79f3c2d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "check = string.printable\n",
    "np.random.seed(SEED)\n",
    "random.seed(SEED)\n",
    "\n",
    "# generate random seeds, ignore non-alpha characters\n",
    "sampled = []\n",
    "for model in random.choices(models, k=50):\n",
    "    while True:\n",
    "        s = generate_seed_set(model)\n",
    "        if 0 not in [c in check for w in s for c in w]:\n",
    "            sampled.append(s)\n",
    "            break\n",
    "g_seeds = pd.DataFrame(data=pd.Series(sampled), columns=[\"Seeds\"])\n",
    "\n",
    "# check for duplicates in seeds\n",
    "if 0 in g_seeds.apply(str).duplicated():\n",
    "    raise ValueError(\"Duplicate seeds found.\")\n",
    "\n",
    "# uncomment below to visualize seeds\n",
    "# g_seeds"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54b34ff2-2db0-4929-80c8-6cb1a5f92257",
   "metadata": {},
   "source": [
    "Next, we generate the table in a similar manner as with gathered seeds."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "aaa6a087-334d-4b3a-8d13-b52a7000cc92",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [08:17<00:00, 24.86s/model]\n"
     ]
    }
   ],
   "source": [
    "prefix = \"Generated\"\n",
    "\n",
    "# do coherence\n",
    "all_coherence = []\n",
    "for model in tqdm(models, unit=\"model\"):\n",
    "    coh = tab_a2.build_row_table_a2(model, g_seeds, pairing_method=\"all\")\n",
    "    all_coherence.append(coh)\n",
    "\n",
    "coh_avg_gen = tab_a2.agg_coherence(all_coherence)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f57898c-c67c-44ed-af75-1e18af51b523",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Displaying Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1de2af93-6e17-4e56-9eca-6e2b159460dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    Coherence                 Gathered Set A                 Gathered Set B\n",
      "2       0.999  CAREER: executive, managem...  FAMILY: home, parents, chi...\n",
      "3       0.999  CAREER: executive, managem...  FAMILY: home, parents, chi...\n",
      "15      0.968  MALE: brother, father, unc...  FEMALE: sister, mother, au...\n",
      "32      0.942  TERRORISM: terror, terrori...  OCCUPATIONS: banker, carpe...\n",
      "..        ...                            ...                            ...\n",
      "29      0.100  NAMES HISPANIC: ruiz, alva...  NAMES WHITE: harris, nelso...\n",
      "16      0.093  MALE NAMES: john, paul, mi...  FEMALE NAMES: amy, joan, l...\n",
      "26      0.053  NAMES BLACK: harris, robin...  NAMES WHITE: harris, nelso...\n",
      "21      0.026  NAMES ASIAN: cho, wong, ta...  NAMES CHINESE: chung, liu,...\n",
      "\n",
      "[34 rows x 3 columns]\n",
      "     Coherence                Generated Set A                Generated Set B\n",
      "603      1.000  know, believe, think, gues...  governor, mayor, legislatu...\n",
      "460      1.000  foot-8, foot-7, foot-3, fo...  rousteing, atkins, cornejo...\n",
      "91       0.999  associate, assistant, econ...  heels, shoes, pants, legs,...\n",
      "96       0.999  associate, assistant, econ...  know, believe, think, gues...\n",
      "..         ...                            ...                            ...\n",
      "559      0.073  heels, shoes, pants, legs,...  paws, nausea, wares, insid...\n",
      "567      0.062  hertl, agnieszka, goran, b...  bases, wings, outs, scorel...\n",
      "676      0.059  molina, glasser, pitney, d...  carver, mina, boyce, curat...\n",
      "641      0.053  lime, juice, lemon, potato...  combo, bodysuit, raisin, k...\n",
      "\n",
      "[1046 rows x 3 columns]\n"
     ]
    }
   ],
   "source": [
    "# clean up for display\n",
    "coh_avg_gath = tab_a2.clean_tab_a2(coh_avg_gath, seeds, \"Gathered\")\n",
    "coh_avg_gen = tab_a2.clean_tab_a2(coh_avg_gen, seeds, \"Generated\")\n",
    "\n",
    "with pd.option_context(\"display.max_rows\", 9, \"display.max_colwidth\", 30):\n",
    "    print(coh_avg_gath)\n",
    "    print(coh_avg_gen)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
