{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba601ed9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5707d78",
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers\n",
    "\n",
    "config = \"EleutherAI/gpt-j-6B\"\n",
    "device = \"cuda:1\"\n",
    "model = transformers.AutoModelForCausalLM.from_pretrained(\n",
    "    config,\n",
    "    low_cpu_mem_usage=True,\n",
    "    revision=\"float16\").to(device)\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(config)\n",
    "tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ecee9018",
   "metadata": {},
   "source": [
    "# Model Confidence?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71ea9e13",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "# subject = \"surgeon\"\n",
    "# prompt = \"farmer: barn\\ncar mechanic: garage\\nchef: kitchen\\nteacher: school\\n{}:\"\n",
    "\n",
    "# subject = \"Saudi Arabia\"\n",
    "# delim = \" shares its northern border with\"\n",
    "# prompt = f\"USA{delim} Canada\\Mexico{delim} USA\\nSudan{delim} Egypt\\n\" + \"{}\"+ delim\n",
    "\n",
    "# subject = \"The actor Neil Patrick Harris\"\n",
    "# prompt = \"{} is married to a man named\"\n",
    "\n",
    "\n",
    "# subject = \"Gengar\"\n",
    "# prompt = \"Pikachu: electric\\nSquirtle: water\\nCharizard: fire\\nShroomish: grass\\n{}:\"\n",
    "\n",
    "subject = \"Bagon\"\n",
    "prompt = \"Pikachu: Raichu\\nCharmander: Charmeleon\\nShroomish: Breloom\\n{}:\"\n",
    "\n",
    "inputs = tokenizer(prompt.format(subject), return_tensors=\"pt\").to(device)\n",
    "with torch.inference_mode():\n",
    "    outputs = model(**inputs)\n",
    "topk = torch.softmax(outputs.logits[:, -1].float(), dim=-1).topk(dim=-1, k=5)\n",
    "words = [tokenizer.decode(token_id) for token_id in topk.indices.squeeze()]\n",
    "probs = topk.values.squeeze().tolist()\n",
    "\n",
    "print(prompt)\n",
    "for word, prob in zip(words, probs):\n",
    "    print(f\"{word} ({prob:.2f})\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c67d1736",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import corner, estimate\n",
    "import importlib\n",
    "importlib.reload(corner)\n",
    "importlib.reload(estimate)\n",
    "\n",
    "operator, _ = estimate.relation_operator_from_sample(\n",
    "    model,\n",
    "    tokenizer,\n",
    "    subject,\n",
    "    prompt,\n",
    "    device=device,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2010ed28",
   "metadata": {},
   "outputs": [],
   "source": [
    "def logit_lens(h, k=10):\n",
    "    h = h.view(1, model.config.hidden_size)\n",
    "    dist = torch.softmax(model.lm_head(model.transformer.ln_f(h)), dim=-1)\n",
    "    topk = dist.topk(dim=-1, k=k)\n",
    "    words = [\n",
    "        tokenizer.decode(token_id)\n",
    "        for token_id in topk.indices.squeeze()\n",
    "    ]\n",
    "    probs = topk.values.squeeze().tolist()\n",
    "    return tuple(zip(words, probs))\n",
    "\n",
    "logit_lens(operator.bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b070c7b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "operator(\"blueberries\", device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce132f27",
   "metadata": {},
   "outputs": [],
   "source": [
    "corner_estimator = corner.CornerEstimator(model, tokenizer)\n",
    "c = corner_estimator.estimate_simple_corner([\n",
    "    \"black\",\n",
    "    \"white\",\n",
    "    \"brown\",\n",
    "    \"green\",\n",
    "    \"blue\",\n",
    "    \"orange\",\n",
    "    \"yellow\",\n",
    "    \"purple\",\n",
    "    \"red\",\n",
    "    \"pink\",\n",
    "    \"grey\",\n",
    "])\n",
    "corner_operator = operator.overwrite(bias=c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d411199",
   "metadata": {},
   "outputs": [],
   "source": [
    "# corner_operator = operator.overwrite(bias=c, weight=torch.eye(4096).to(device))\n",
    "corner_operator(\"sweet potatoes\", device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2506adb4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "sns.heatmap(corner_operator.weight.data.cpu()[:100, :100], vmin=-.02, vmax=.02)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "107d2b04",
   "metadata": {},
   "source": [
    "# Vignette"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "538b8259",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from itertools import chain\n",
    "from pathlib import Path\n",
    "from typing import NamedTuple\n",
    "\n",
    "from src import corner, estimate\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "import importlib\n",
    "importlib.reload(corner)\n",
    "importlib.reload(estimate)\n",
    "\n",
    "RESULTS_DIR = Path(\"vignette_results/\")\n",
    "\n",
    "\n",
    "class Setting(NamedTuple):\n",
    "    \n",
    "    relation: str\n",
    "    train: tuple[tuple[str, str], ...]\n",
    "    test: tuple[tuple[str, str], ...]\n",
    "    rng: tuple[str, ...] | None = None\n",
    "    k: int = 2\n",
    "    prompt: str | None = None\n",
    "\n",
    "        \n",
    "corner_estimator = corner.CornerEstimator(model, tokenizer)\n",
    "\n",
    "\n",
    "def pprint_predictions(subject, predictions):\n",
    "    print(f\"* {subject}\")\n",
    "    for word, prob in predictions:\n",
    "        if isinstance(word, list):\n",
    "            word = word[0]\n",
    "        if isinstance(prob, list):\n",
    "            prob = prob[0]\n",
    "        print(f\"{word} ({prob:.2f})\")\n",
    "\n",
    "\n",
    "def test(operator, samples, k=2, pprint=False):\n",
    "    total_correct = 0\n",
    "    sample_summaries = []\n",
    "    for subject, expected in samples:\n",
    "        predictions = operator(subject, device=device)\n",
    "        predictions = [\n",
    "            (w, p)\n",
    "            for w, p in predictions\n",
    "            if w.strip()\n",
    "            and (\n",
    "                len(w.strip()) <= 2\n",
    "                or (\n",
    "                    w.strip().lower() != subject.strip().lower()\n",
    "                    and not subject.strip().lower().startswith(w.strip().lower())\n",
    "                )\n",
    "            )\n",
    "        ]\n",
    "        if pprint:\n",
    "            pprint_predictions(subject, predictions)\n",
    "        is_correct = any(expected.lower().startswith(x[0].strip().lower()) for x in predictions[:k])\n",
    "        \n",
    "        sample_summaries.append({\n",
    "            \"subject\": subject,\n",
    "            \"predictions\": predictions,\n",
    "            \"expected\": expected,\n",
    "            \"is_correct\": is_correct,\n",
    "        })\n",
    "        \n",
    "        total_correct += int(is_correct)\n",
    "\n",
    "    accuracy = total_correct / len(samples)\n",
    "    return accuracy, {\n",
    "        \"k\": k,\n",
    "        \"accuracy\": accuracy,\n",
    "        \"samples\": sample_summaries,\n",
    "    }\n",
    "\n",
    "\n",
    "def evaluate(settings, layer=12, plot=False, results_dir=None):\n",
    "    if results_dir is not None and Path(results_dir).name != str(layer):\n",
    "        results_dir = Path(results_dir) / str(layer)\n",
    "    if results_dir is not None:\n",
    "        results_dir.mkdir(exist_ok=True, parents=True)\n",
    "\n",
    "    summaries = []\n",
    "    for relation, trains, tests, codomain, k, prompt in settings:\n",
    "        print(f\"---- {relation} ----\")\n",
    "\n",
    "        prompt = prompt if prompt is not None else f\" {relation} \"\n",
    "        prompt_template = \"\\n\".join(f\"{subj}{prompt}{obj}\" for subj, obj in trains[:-1])\n",
    "        prompt_template += \"\\n\" + \"{} \" + relation\n",
    "        train_subject = trains[-1][0]\n",
    "\n",
    "        print(\"estimating J...\")\n",
    "        operator, _ = estimate.relation_operator_from_sample(\n",
    "            model=model,\n",
    "            tokenizer=tokenizer,\n",
    "            subject=train_subject,\n",
    "            relation=prompt_template,\n",
    "            device=device,\n",
    "            layer=layer,\n",
    "        )\n",
    "\n",
    "        print(\"estimating corner...\")\n",
    "        if codomain:\n",
    "            c = corner_estimator.estimate_average_corner_with_gradient_descent(codomain)\n",
    "        else:\n",
    "            c = corner_estimator.estimate_average_corner_with_gradient_descent([\n",
    "                x[-1] for x in chain(trains, tests)\n",
    "            ])\n",
    "\n",
    "        print(\"logit lens on corner:\")\n",
    "        lens = logit_lens(c)\n",
    "#         for word, prob in lens:\n",
    "#             print(f\"* {word} ({prob:.2f})\")\n",
    "\n",
    "        print(\"\\nJ, bias\")\n",
    "        acc, jb_summary = test(operator, tests, k=k)\n",
    "        print(f\"layer={layer} {k}-acc={acc:.2f}\")\n",
    "\n",
    "        print(\"\\nI, corner\")\n",
    "        acc, ic_summary = test(operator.overwrite(weight=torch.eye(model.config.hidden_size).to(device), bias=c), tests, k=k)\n",
    "        print(f\"layer={layer} {k}-acc={acc:.2f}\")\n",
    "\n",
    "        print(\"\\nJ, corner\")\n",
    "        acc, jc_summary = test(operator.overwrite(bias=c), tests, k=k)\n",
    "        print(f\"layer={layer} {k}-acc={acc:.2f}\")\n",
    "\n",
    "        if plot:\n",
    "            plt.figure()\n",
    "            sns.heatmap(operator.weight.data.cpu()[:25, :25], cmap=\"PiYG\")\n",
    "\n",
    "            plot_file = Path(relation.strip().replace(\" \", \"_\") + \".png\")\n",
    "            if results_dir is not None:\n",
    "                plot_file = Path(results_dir) / plot_file\n",
    "            plt.savefig(str(plot_file))\n",
    "\n",
    "        summary = {\n",
    "            \"relation\": relation,\n",
    "            \"c_logit_lens\": lens,\n",
    "            \"jb\": jb_summary,\n",
    "            \"ic\": ic_summary,\n",
    "            \"jc\": jc_summary,\n",
    "        }\n",
    "        summaries.append(summary)\n",
    "\n",
    "    if results_dir is not None:\n",
    "        with Path(results_dir, \"summaries.json\").open(\"w\") as handle:\n",
    "            json.dump({\n",
    "                \"summaries\": summaries,\n",
    "            }, handle)\n",
    "\n",
    "    return summaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43f8a621",
   "metadata": {},
   "outputs": [],
   "source": [
    "LETTERS = \"abcdefghijklmnopqrstuvwxyz\".upper()\n",
    "\n",
    "SPOUSE_PAIRS = (\n",
    "    (\"Beyonce\", \"Jay-Z\"),\n",
    "    (\"George Bush\", \"Laura Bush\"),\n",
    "    (\"Ariana Grande\", \"Pete Davidson\"),\n",
    "    (\"Barack Obama\", \"Michelle Obama\"),\n",
    "    (\"Michelle Obama\", \"Barack Obama\"),\n",
    "    (\"Jay-Z\", \"Beyonce\"),\n",
    "    (\"Beyonce\", \"Jay-Z\"),\n",
    "    (\"John Lennon\", \"Yoko Ono\"),\n",
    "    (\"Yoko Ono\", \"John Lennon\"),\n",
    "    (\"Forrest Gump\", \"Jenny\"),\n",
    "    (\"George Bush\", \"Laura Bush\"),\n",
    "    (\"Laura Bush\", \"George Bush\"),\n",
    "    (\"Marie Curie\", \"Pierre Curie\"),\n",
    "    (\"Pierre Curie\", \"Marie Curie\"),\n",
    "    (\"Mark Antony\", \"Cleopatra\"),\n",
    "    (\"Cleopatra\", \"Mark Antony\"),\n",
    "    (\"Ashton Kutcher\", \"Mila Kunis\"),\n",
    "    (\"Mila Kunis\", \"Ashton Kutcher\"),\n",
    ")\n",
    "\n",
    "COUNTRY_PAIRS = (\n",
    "    (\"USA\", \"Canada\"),\n",
    "    (\"Sudan\", \"Egypt\"),\n",
    "    (\"Ukraine\", \"Belarus\"),\n",
    "\n",
    "    (\"Mexico\", \"USA\"),\n",
    "    (\"Saudi Arabia\", \"Jordan\"),\n",
    "    (\"Spain\", \"France\"),\n",
    "    (\"Syria\", \"Turkey\"),\n",
    "    (\"Jordan\", \"Syria\"),\n",
    "    (\"Ecuador\", \"Colombia\"),\n",
    "    (\"France\", \"Belgium\"),\n",
    ")\n",
    "\n",
    "PRESIDENT_PAIRS = (\n",
    "    (\"Barack Obama\", \"Donald Trump\"),\n",
    "    (\"George Washington\", \"John Adams\"),\n",
    "    (\"Abraham Lincoln\", \"Andrew Johnson\"),\n",
    "\n",
    "    (\"John Tyler\", \"James Polk\"),\n",
    "    (\"Warren Harding\", \"Calvin Coolidge\"),\n",
    "    (\"Calvin Coolidge\", \"Herbert Hoover\"),\n",
    "    (\"Herbert Hoover\", \"Franklin Roosevelt\"),\n",
    "    (\"James Carter\", \"Ronald Reagan\"),\n",
    "    (\"Harry Truman\", \"Dwight Eisenhower\"),\n",
    "    (\"Teddy Roosevelt\", \"William Howard Taft\"),\n",
    ")\n",
    "\n",
    "\n",
    "SETTINGS = (\n",
    "    Setting(\n",
    "        relation=\"is partnered to\",\n",
    "        train=SPOUSE_PAIRS[:3],    \n",
    "        test=SPOUSE_PAIRS[3:],\n",
    "    ),\n",
    "\n",
    "    Setting(\n",
    "        relation=\"is the opposite of\",\n",
    "        train=(\n",
    "            (\"dark\", \"light\"),\n",
    "            (\"good\", \"evil\"),\n",
    "            (\"up\", \"down\"),\n",
    "        ),\n",
    "        test=(\n",
    "            (\"left\", \"right\"),\n",
    "            (\"right\", \"left\"),\n",
    "            (\"down\", \"up\"),\n",
    "            (\"evil\", \"good\"),\n",
    "            (\"light\", \"dark\"),\n",
    "            (\"open\", \"closed\"),\n",
    "        ),\n",
    "        rng=(\n",
    "            \"light\", \"evil\", \"down\"\n",
    "        ),\n",
    "    ),\n",
    "\n",
    "    Setting(\n",
    "        relation=\"preceeded the President\",\n",
    "        train=PRESIDENT_PAIRS[:3],\n",
    "        test=PRESIDENT_PAIRS[3:],\n",
    "        prompt=\": \",\n",
    "    ),\n",
    "    \n",
    "    Setting(\n",
    "        relation=\"succeeded the President\",\n",
    "        train=[(y, x) for x, y in PRESIDENT_PAIRS[:3]],\n",
    "        test=[(y, x) for x, y in PRESIDENT_PAIRS[3:]],\n",
    "        prompt=\": \",\n",
    "    ),\n",
    "\n",
    "    Setting(\n",
    "        relation=\"shares its northern border with\",\n",
    "        train=COUNTRY_PAIRS[:3],\n",
    "        test=COUNTRY_PAIRS[3:],\n",
    "    ),\n",
    "\n",
    "    Setting(\n",
    "        relation=\"shares its southern border with\",\n",
    "        train=[(y, x) for x, y in COUNTRY_PAIRS[:3]],\n",
    "        test=[(y, x) for x, y in COUNTRY_PAIRS[3:]],\n",
    "    ),\n",
    "    \n",
    "    Setting(\n",
    "        relation=\"has the color of\",\n",
    "        train=(\n",
    "            (\"bananas\", \"yellow\"),\n",
    "            (\"blueberries\", \"blue\"),\n",
    "            (\"kiwis\", \"green\"),\n",
    "        ),\n",
    "        test=(\n",
    "            (\"broccoli\", \"green\"),\n",
    "            (\"tangerines\", \"orange\"),\n",
    "            (\"apples\", \"red\"),\n",
    "            (\"sweet potatoes\", \"orange\"),\n",
    "            (\"carrots\", \"orange\"),\n",
    "            (\"milk\", \"white\"),\n",
    "            (\"cauliflower\", \"white\"),\n",
    "            (\"kale\", \"green\"),\n",
    "            (\"chocolate\", \"brown\"),\n",
    "            (\"water\", \"blue\"),\n",
    "            (\"plum\", \"purple\"),\n",
    "        ),\n",
    "        rng=(\n",
    "            \"black\",\n",
    "            \"white\",\n",
    "            \"brown\",\n",
    "            \"green\",\n",
    "            \"blue\",\n",
    "            \"orange\",\n",
    "            \"yellow\",\n",
    "            \"purple\",\n",
    "            \"red\",\n",
    "            \"pink\",\n",
    "            \"grey\",\n",
    "        ),\n",
    "        prompt=\": \",\n",
    "    ),\n",
    "\n",
    "    Setting(\n",
    "        relation=\"evolves into\",\n",
    "        train=(\n",
    "            (\"The pokemon Pikachu\", \"Raichu\"),\n",
    "            (\"The pokemon Shroomish\", \"Breloom\"),\n",
    "            (\"The pokemon Charmander\", \"Charizard\"),\n",
    "            (\"The pokemon Munchlax\", \"Snorlax\"),\n",
    "        ),\n",
    "        test=(\n",
    "            (\"The pokemon Squirtle\", \"Blastoise\"),\n",
    "            (\"The pokemon Mudkip\", \"Swampert\"),\n",
    "            (\"The pokemon Grimer\", \"Muk\"),\n",
    "            (\"The pokemon Abra\", \"Alakazam\"),\n",
    "            (\"The pokemon Bulbasaur\", \"Venusaur\"),\n",
    "            (\"The pokemon Geodude\", \"Golem\"),\n",
    "            (\"The pokemon Dratini\", \"Dragonite\"),\n",
    "            (\"The pokemon Pichu\", \"Raichu\"),\n",
    "            (\"The pokemon Charmander\", \"Charmeleon\"),\n",
    "        )\n",
    "    ),\n",
    "\n",
    "    Setting(\n",
    "        relation=\"is abbreviated as\",\n",
    "        train=(\n",
    "            (\"Connecticut\", \"CT\"),\n",
    "            (\"Oregon\", \"OR\"),\n",
    "            (\"Colorado\", \"CO\"),\n",
    "        ),\n",
    "        test=(\n",
    "            (\"New York\", \"NY\"),\n",
    "            (\"Illinois\", \"IL\"),\n",
    "            (\"Massachusetts\", \"MA\"),\n",
    "            (\"Utah\", \"UT\"),\n",
    "            (\"Nevada\", \"NV\"),\n",
    "            (\"Washington\", \"WA\"),\n",
    "            (\"Wisconsin\", \"WI\"),\n",
    "            (\"Maryland\", \"MD\"),\n",
    "            (\"Alabama\", \"AL\")\n",
    "        ),\n",
    "        rng=tuple(\n",
    "            f\"{a}{b}\"\n",
    "            for a in LETTERS\n",
    "            for b in LETTERS\n",
    "        ),\n",
    "    ),\n",
    ")\n",
    "\n",
    "RESULTS_DIR = Path(\"./vignette_results\")\n",
    "LAYERS = list(range(28))\n",
    "\n",
    "summaries_by_layer = {}\n",
    "for layer in LAYERS:\n",
    "    summaries_by_layer[layer] = evaluate(SETTINGS, layer=layer, plot=True, results_dir=RESULTS_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "338b1592",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "METHOD_TO_PRETTY = {\n",
    "    \"jb\": \"J, bias\",\n",
    "    \"ic\": \"Corner\",\n",
    "    \"jc\": \"J, corner\"\n",
    "    \n",
    "}\n",
    "\n",
    "\n",
    "def generate_html(summaries, layer):\n",
    "    html = [\n",
    "        \"<html>\",\n",
    "        \"<style>\",\n",
    "    \"\"\"\n",
    "    th {\n",
    "        font-weight: bold;\n",
    "    }\n",
    "\n",
    "    table {\n",
    "        text-align: left;\n",
    "        border-collapse: collapse;\n",
    "    }\n",
    "\n",
    "    th {\n",
    "        border-top: 2px solid black;\n",
    "        border-bottom: 1px solid black;\n",
    "    }\n",
    "\n",
    "    .qualitative th {\n",
    "        padding-right: 5em;\n",
    "    }\n",
    "\n",
    "    .quantitative th {\n",
    "        padding-right: 2em;\n",
    "    }\n",
    "\n",
    "    tr:last-of-type {\n",
    "        border-bottom: 2px solid black;\n",
    "    }\n",
    "\n",
    "    h2 {\n",
    "        margin-top: 2em;\n",
    "    }\n",
    "\n",
    "    h4 {\n",
    "        font-weight: normal;\n",
    "        text-decoration: underline;\n",
    "    }\n",
    "\n",
    "    \"\"\"\n",
    "        \"</style>\",\n",
    "        \"<body>\",\n",
    "        f\"<h1>Results for GPT-J/layer {layer}</h1>\",\n",
    "    ]\n",
    "\n",
    "    for summary in summaries:\n",
    "        html += [\n",
    "            f\"<h2>{summary['relation']}</h2>\",\n",
    "            \"<table class='quantitative'>\",\n",
    "            \"<thead>\",\n",
    "            \"<th>Method</th>\",\n",
    "            \"<th>Recall@2</th>\",\n",
    "            \"</thead>\",\n",
    "            \"<tbody>\",\n",
    "            *[\n",
    "                f\"<tr><td>{METHOD_TO_PRETTY[method]}</td><td>{summary[method]['accuracy']:.2f}</td></tr>\"\n",
    "                for method in (\"jb\", \"ic\", \"jc\")\n",
    "            ],\n",
    "            \"</tbody>\",\n",
    "            \"</table>\",\n",
    "        ]\n",
    "\n",
    "        for method in (\"jb\", \"ic\", \"jc\"):\n",
    "            html += [\n",
    "                f\"<h4>{METHOD_TO_PRETTY[method]} Outputs</h4>\",\n",
    "                \"<table class='qualitative'>\",\n",
    "                \"<thead>\",\n",
    "                \"<tr>\",\n",
    "                \"<th>subject</th>\",\n",
    "                \"<th>object</th>\",\n",
    "                *[f\"<th>prediction {i}</th>\" for i in range(1, 6)],\n",
    "                \"</tr>\",\n",
    "                \"</thead>\",\n",
    "                \"<tbody>\",\n",
    "            ]\n",
    "            for sample in summary[method][\"samples\"]:\n",
    "                expected = sample['expected']\n",
    "                html += [\n",
    "                    \"<tr>\"\n",
    "                    f\"<td>{sample['subject']}</td>\",\n",
    "                    f\"<td>{expected}</td>\",\n",
    "                ]\n",
    "                for word, prob in sample[\"predictions\"]:\n",
    "                    is_correct = expected.lower().strip().startswith(word.lower().strip())\n",
    "\n",
    "                    word_html = f\"{word} ({prob:.2f})\"\n",
    "                    if is_correct:\n",
    "                        word_html = f\"<span style='color: blue'>{word_html}</span>\"\n",
    "                    html += [f\"<td>{word_html}</td>\"]\n",
    "\n",
    "                for _ in range(5 - len(sample[\"predictions\"])):\n",
    "                    html += [\"<td></td>\"]\n",
    "\n",
    "                html += [\"</tr>\"]\n",
    "            html += [\"</tbody>\", \"</table>\"]\n",
    "\n",
    "    html += [\"</body>\", \"</html>\"]\n",
    "\n",
    "    with (RESULTS_DIR / str(layer) / \"viz.html\").open(\"w\") as handle:\n",
    "        handle.write(\"\\n\".join(html))\n",
    "\n",
    "for layer, summaries in summaries_by_layer.items():\n",
    "    generate_html(summaries, layer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f61e7ea9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
