{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cfb8e77",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "353d1d6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e273376",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import models\n",
    "\n",
    "device = \"cuda:5\"\n",
    "mt = models.load_model(\"gptj\", device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d54ee92d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def comma_sep_lines_to_pairs(string):\n",
    "    return [\n",
    "        [x.strip() for x in line.split(\",\")]\n",
    "        for line in string.split(\"\\n\")\n",
    "    ]\n",
    "\n",
    "CAPITOLS = comma_sep_lines_to_pairs(\n",
    "    \"\"\"\\\n",
    "United States, Washington D.C.\n",
    "Canada, Ottawa\n",
    "Mexico, Mexico City\n",
    "Brazil, Brasília\n",
    "Argentina, Buenos Aires\n",
    "Chile, Santiago\n",
    "Peru, Lima\n",
    "Colombia, Bogotá\n",
    "Venezuela, Caracas\n",
    "Spain, Madrid\n",
    "France, Paris\n",
    "Germany, Berlin\n",
    "Italy, Rome\n",
    "Russia, Moscow\n",
    "China, Beijing\n",
    "Japan, Tokyo\n",
    "South Korea, Seoul\n",
    "India, New Delhi\n",
    "Pakistan, Islamabad\n",
    "Nigeria, Abuja\n",
    "Egypt, Cairo\n",
    "Saudi Arabia, Riyadh\n",
    "Turkey, Ankara\n",
    "Australia, Canberra\"\"\")\n",
    "\n",
    "LANGUAGES = comma_sep_lines_to_pairs(\"\"\"\\\n",
    "United States, English\n",
    "Canada, English and French\n",
    "Mexico, Spanish\n",
    "Brazil, Portuguese\n",
    "Argentina, Spanish\n",
    "Chile, Spanish\n",
    "Peru, Spanish\n",
    "Colombia, Spanish\n",
    "Venezuela, Spanish\n",
    "Spain, Spanish\n",
    "France, French\n",
    "Germany, German\n",
    "Italy, Italian\n",
    "Russia, Russian\n",
    "China, Mandarin Chinese\n",
    "Japan, Japanese\n",
    "South Korea, Korean\n",
    "India, Hindi\n",
    "Pakistan, Urdu\n",
    "Nigeria, English\n",
    "Egypt, Arabic\n",
    "Saudi Arabia, Arabic\n",
    "Turkey, Turkish\n",
    "Australia, English\"\"\")\n",
    "\n",
    "BORDER_NORTH = comma_sep_lines_to_pairs(\"\"\"\\\n",
    "United States, Canada\n",
    "Mexico, United States\n",
    "Brazil, Venezuela\n",
    "Argentina, Bolivia\n",
    "Chile, Peru\n",
    "Peru, Ecuador\n",
    "Colombia, Venezuela\n",
    "Venezuela, Colombia\n",
    "Spain, France\n",
    "France, Germany\n",
    "Germany, Denmark\n",
    "Italy, Switzerland\n",
    "Russia, Kazakhstan\n",
    "China, Russia\n",
    "South Korea, North Korea\n",
    "India, China\n",
    "Pakistan, Afghanistan\n",
    "South Africa, Namibia\n",
    "Egypt, Libya\n",
    "Saudi Arabia, Iraq\n",
    "Turkey, Bulgaria\"\"\")\n",
    "\n",
    "BORDER_SOUTH = comma_sep_lines_to_pairs(\"\"\"\\\n",
    "United States, Mexico\n",
    "Canada, United States\n",
    "Mexico, Guatemala\n",
    "Brazil, Bolivia\n",
    "Argentina, Chile\n",
    "Chile, Argentina\n",
    "Peru, Chile\n",
    "Colombia, Ecuador\n",
    "Venezuela, Brazil\n",
    "France, Spain\n",
    "Germany, Switzerland\n",
    "Russia, Georgia\n",
    "Nigeria, Cameroon\n",
    "South Africa, Lesotho\n",
    "Egypt, Sudan\n",
    "Saudi Arabia, Yemen\n",
    "Turkey, Syria\"\"\")\n",
    "\n",
    "BORDER_NORTH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b729ee55",
   "metadata": {},
   "outputs": [],
   "source": [
    "def line_sep_prompts(string):\n",
    "    return [line.strip().replace(\"[\", \"{\").replace(\"]\", \"}\") for line in string.split(\"\\n\")]\n",
    "\n",
    "LANGUAGE_PROMPTS = line_sep_prompts(\"\"\"\\\n",
    "[country] is a country where the language of [language] is spoken.\n",
    "The people of [country] communicate in [language].\n",
    "[country] is home to speakers of [language].\n",
    "The people in [country] converse using the language of [language].\n",
    "The inhabitants of [country] use [language] to communicate.\n",
    "In [country], the language primarily spoken is [language].\"\"\")\n",
    "\n",
    "CAPITOL_PROMPTS = line_sep_prompts(\"\"\"\\\n",
    "The capital city of [country] is [city].\n",
    "[country] is home to the capital city of [city].\n",
    "The political capital of [country] is [city].\n",
    "The seat of government for [country] is [city].\n",
    "The government of [country] is centered in [city].\"\"\")\n",
    "\n",
    "NORTH_PROMPTS = line_sep_prompts(\"\"\"\\\n",
    "The northern frontier of [country] meets that of [other].\n",
    "[country] lies to the north of [other].\n",
    "The northerly boundary of [country] is shared with [other].\n",
    "[country]'s northern flank abuts [other].\n",
    "[country]'s northernmost point touches [other].\n",
    "The northernmost part of [country] adjoins [other].\n",
    "To the north, [country] is contiguous with [other].\n",
    "The northern edge of [country] meets [other].\n",
    "[country]'s northern line of demarcation is with [other].\n",
    "The northern boundary of [country] is contiguous with [other].\"\"\")\n",
    "\n",
    "SOUTH_PROMPTS = line_sep_prompts(\"\"\"\\\n",
    "The southern frontier of [country] meets that of [country].\n",
    "[country]'s southern border abuts [country].\n",
    "[country] lies to the south of [country].\n",
    "[country]'s southern flank meets [country].\n",
    "The southernmost point of [country] borders [country].\n",
    "The southern edge of [country] meets [country].\n",
    "The southern line of demarcation of [country] is shared with [country].\n",
    "[country]'s southern boundary is contiguous with [country].\"\"\")\n",
    "\n",
    "NORTH_PROMPTS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42e0d728",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import estimate\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "ops_capitols = {}\n",
    "ops_languages = {}\n",
    "ops_north = {}\n",
    "\n",
    "for ops, prompts, samples in (\n",
    "    (ops_capitols, CAPITOL_PROMPTS, CAPITOLS),\n",
    "    (ops_languages, LANGUAGE_PROMPTS, LANGUAGES),\n",
    "    (ops_north, NORTH_PROMPTS, BORDER_NORTH),\n",
    "):\n",
    "    for prompt in prompts:\n",
    "        for subject, _ in tqdm(samples, desc=prompt):\n",
    "            prompt = (\n",
    "                prompt\n",
    "                    .split(\"{city}\")[0]\n",
    "                    .split(\"{language}\")[0]\n",
    "                    .split(\"{other}\")[0]\n",
    "                    .replace(\"{country}\", \"{}\")\n",
    "                    .rstrip(\". \")\n",
    "            )\n",
    "            operator = estimate.relation_operator_from_sample(\n",
    "                mt.model,\n",
    "                mt.tokenizer,\n",
    "                subject,\n",
    "                prompt,\n",
    "                device=device,\n",
    "            )\n",
    "            ops[prompt, subject] = operator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aeaa4c54",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "CATEGORIES = {\n",
    "    \"language\": ops_languages,\n",
    "    \"capitol\": ops_capitols,\n",
    "    \"north\": ops_north,\n",
    "}\n",
    "\n",
    "dists = defaultdict(lambda: defaultdict(list))\n",
    "for c1, ops1 in tqdm(CATEGORIES.items()):\n",
    "    for c2, ops2 in CATEGORIES.items():\n",
    "        for (p1, s1), (o1, m1) in ops1.items():\n",
    "            for (p2, s2), (o2, m2) in ops2.items():\n",
    "                if p1 == p2:\n",
    "                    continue\n",
    "                if s1 != s2:\n",
    "                    continue\n",
    "                dist = o1.weight.sub(o2.weight).norm().item()\n",
    "#                 dists[c1][c2].append(dist)\n",
    "                dists[c1][c2].append((p1, p2, s1, dist))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd1419e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "\n",
    "\n",
    "data = np.array([\n",
    "    [np.mean([x[-1] for x in dists[c1][c2]]) for c2 in CATEGORIES]\n",
    "    for c1 in CATEGORIES\n",
    "])\n",
    "\n",
    "\n",
    "sns.heatmap(\n",
    "    data=data,\n",
    "    xticklabels=list(CATEGORIES),\n",
    "    yticklabels=list(CATEGORIES),\n",
    "    vmin=0,\n",
    "    vmax=data.max(),\n",
    "    annot=True,\n",
    "    fmt=\".2f\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e79f73bd",
   "metadata": {},
   "source": [
    "Above shows average distances, but let's show classification accuracy instead."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7edae0a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "dists = defaultdict(lambda: defaultdict(list))\n",
    "for c1, ops1 in tqdm(CATEGORIES.items()):\n",
    "    for c2, ops2 in CATEGORIES.items():\n",
    "        for (p1, s1), (o1, m1) in ops1.items():\n",
    "            for (p2, s2), (o2, m2) in ops2.items():\n",
    "                if p1 == p2:\n",
    "                    continue\n",
    "                dist = o1.weight.sub(o2.weight).pow(2).sum().item()\n",
    "                dists[c1][p1].append((c2, p2, s2, dist))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95f79d44",
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = defaultdict(lambda: defaultdict(int))\n",
    "for c1 in dists:\n",
    "    for p1, ds in dists[c1].items():\n",
    "        ordered = sorted(ds, key=lambda x: x[-1])\n",
    "        best = ordered[0][0]\n",
    "        scores[c1][best] += 1\n",
    "\n",
    "accuracies = {\n",
    "    c1: {\n",
    "        c2: count / sum(counts.values())\n",
    "        for c2, count in counts.items()\n",
    "    }\n",
    "    for c1, counts in scores.items()\n",
    "}\n",
    "\n",
    "accuracies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26e5656e",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = np.array([\n",
    "    [accuracies.get(c1, {}).get(c2, 0) for c2 in CATEGORIES]\n",
    "    for c1 in CATEGORIES\n",
    "])\n",
    "\n",
    "plt.title(\"Classification Accuracy\")\n",
    "sns.heatmap(\n",
    "    data=data,\n",
    "    xticklabels=list(CATEGORIES),\n",
    "    yticklabels=list(CATEGORIES),\n",
    "    vmin=0.0,\n",
    "    vmax=1.0,\n",
    "    fmt=\".2f\",\n",
    "    annot=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc5d1ddf",
   "metadata": {},
   "source": [
    "What if we condition on J's accuracy?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "736f6f6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "PAIRS_BY_CATEGORY = {\n",
    "    \"language\": LANGUAGES,\n",
    "    \"capitol\": CAPITOLS,\n",
    "    \"north\": BORDER_NORTH,\n",
    "}\n",
    "\n",
    "def compute_accuracy(category, prompt, subject):\n",
    "    operator = CATEGORIES[category][prompt, subject]\n",
    "    dataset = [x for x in PAIRS_BY_CATEGORY[category] if x[0] != subject]\n",
    "    n_correct = 0\n",
    "    for s, t in dataset:\n",
    "        os = [o[0].lower().strip() for o in operator(s, device=device)]\n",
    "        n_correct += any(t.lower().strip().startswith(o) for o in os)\n",
    "    return n_correct / len(dataset)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
