{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Code Intervention Tutorial\n",
    "\n",
    "This is a tutorial notebook on how to perform code interventions with the Codebook Features library. The goal of this tutorial is to steer a language model to generate text that follows a topic (by activating specific topic codes) and quantitatively evaluate how well the model was steered. We use the TinyStories 21M parameter model trained on synthetic children's stories (https://arxiv.org/abs/2305.07759). This small model typically produces grammatical but incoherent stories; nevertheless we can use it to see how different topics are woven into the network. For example, activating 'baby' codes causes the model to introduce topics such as babies, bathtubs, and baby birds, rather than simply outputting 'baby baby baby' repeatedly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    import codebook_features\n",
    "except:\n",
    "    ! pip install codebook-features\n",
    "    # restart runtime after installing if running on Colab\n",
    "    from IPython import get_ipython\n",
    "\n",
    "    if \"google.colab\" in str(get_ipython()):\n",
    "        exit(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.grad_mode.set_grad_enabled at 0x10d2c1050>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "from codebook_features import models\n",
    "from codebook_features import utils as cb_utils\n",
    "import torch\n",
    "import re\n",
    "\n",
    "# We turn automatic differentiation off, to save GPU memory,\n",
    "# as this tutorial focuses only on model inference\n",
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name_or_path = \"roneneldan/TinyStories-1Layer-21M\"\n",
    "pretrained_path = \"../models/TinyStories-1Layer-21M-Codebook/\"\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = \"cuda\"\n",
    "else:\n",
    "    device = \"cpu\"\n",
    "    print(\n",
    "        \"No GPU found, using CPU instead. If running on Colab, \"\n",
    "        \"make sure to enable GPU acceleration under Runtime -> Change runtime type\"\n",
    "    )\n",
    "orig_cb_model = models.wrap_codebook(\n",
    "    model_or_path=model_name_or_path, pretrained_path=pretrained_path\n",
    ")\n",
    "orig_cb_model = orig_cb_model.to(device).eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Convert the model into a hooked transformer model (from transformer_lens) that allows us to do code interventions easily"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hooked_kwargs = dict(\n",
    "    center_unembed=False,\n",
    "    fold_value_biases=False,\n",
    "    center_writing_weights=False,\n",
    "    fold_ln=False,\n",
    "    refactor_factored_attn_matrices=False,\n",
    "    device=device,\n",
    ")\n",
    "cb_model = models.convert_to_hooked_model(\n",
    "    model_name_or_path, orig_cb_model, hooked_kwargs=hooked_kwargs\n",
    ")\n",
    "cb_model = cb_model.to(device).eval()\n",
    "tokenizer = cb_model.tokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Assert that the original codebook model and the hooked model produce the same output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentence = \"this is a random sentence to test.\"\n",
    "input_tensor = tokenizer(sentence, return_tensors=\"pt\")[\"input_ids\"]\n",
    "input_tensor = input_tensor.to(device)\n",
    "output = orig_cb_model(input_tensor)[\"logits\"]\n",
    "hooked_output = cb_model(input_tensor)\n",
    "assert torch.allclose(output, hooked_output, atol=1e-4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Topic Codes\n",
    "Below, we have provided a subset of topic codes we have found in this model. Many more such topic codes can be found through the Codebook Features webapp in the `codebook_features/webapp` directory.\n",
    "\n",
    "Note that multiple codes can be patched in at the same component codebook (in this case, a given attention head at a given layer) since the codebook activates multiple codes. Since there are multiple codes that can represent a topic, we patch in multiple codes for each topic, possibly from different attention heads. You can play around by removing some codes for a topic and seeing how the generated text changes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "topic_codes_str = {\n",
    "    \"\": \"\"\n",
    "}  # blank one is used for default generations (no topic steering)\n",
    "\n",
    "topic = \"dragon\"\n",
    "topic_codes_str[\n",
    "    topic\n",
    "] = \"\"\"\n",
    "Code: 4670, Layer: 0, Head: 13\n",
    "Code: 17640, Layer: 0, Head: 13\n",
    "Code: 19845, Layer: 0, Head: 13\n",
    "Code: 23958, Layer: 0, Head: 13\n",
    "Code: 3410, Layer: 0, Head: 13\n",
    "Code: 19523, Layer: 0, Head: 13\n",
    "Code: 2262, Layer: 0, Head: 13\n",
    "Code: 16060, Layer: 0, Head: 13\n",
    "\"\"\"\n",
    "\n",
    "topic = \"slide\"\n",
    "topic_codes_str[\n",
    "    topic\n",
    "] = \"\"\"\n",
    "Code: 1331, Layer: 0, Head: 14\n",
    "Code: 22178, Layer: 0, Head: 14\n",
    "Code: 15885, Layer: 0, Head: 14\n",
    "Code: 9524, Layer: 0, Head: 14\n",
    "Code: 15549, Layer: 0, Head: 14\n",
    "Code: 7802, Layer: 0, Head: 14\n",
    "Code: 11942, Layer: 0, Head: 14\n",
    "Code: 4095, Layer: 0, Head: 1\n",
    "Code: 2179, Layer: 0, Head: 1\n",
    "Code: 22425, Layer: 0, Head: 1\n",
    "Code: 10661, Layer: 0, Head: 1\n",
    "Code: 8598, Layer: 0, Head: 1\n",
    "\"\"\"\n",
    "\n",
    "topic = \"friend\"\n",
    "topic_codes_str[\n",
    "    topic\n",
    "] = \"\"\"\n",
    "Code: 20506, Layer: 0, Head: 11\n",
    "Code: 6103, Layer: 0, Head: 11\n",
    "Code: 15764, Layer: 0, Head: 11\n",
    "Code: 14060, Layer: 0, Head: 11\n",
    "Code: 21005, Layer: 0, Head: 11\n",
    "Code: 16006, Layer: 0, Head: 11\n",
    "Code: 12290, Layer: 0, Head: 11\n",
    "Code: 7404, Layer: 0, Head: 11\n",
    "Code: 2471, Layer: 0, Head: 13\n",
    "\"\"\"\n",
    "\n",
    "topic = \"flower\"\n",
    "topic_codes_str[\n",
    "    topic\n",
    "] = \"\"\"\n",
    "Code: 23967, Layer: 0, Head: 13\n",
    "Code: 13533, Layer: 0, Head: 13\n",
    "Code: 4175, Layer: 0, Head: 13\n",
    "Code: 6390, Layer: 0, Head: 13\n",
    "Code: 18765, Layer: 0, Head: 13\n",
    "Code: 1775, Layer: 0, Head: 13\n",
    "Code: 7430, Layer: 0, Head: 13\n",
    "Code: 9269, Layer: 0, Head: 13\n",
    "\"\"\"\n",
    "\n",
    "topic = \"fire\"\n",
    "topic_codes_str[\n",
    "    topic\n",
    "] = \"\"\"\n",
    "Code: 9151, Layer: 0, Head: 13\n",
    "Code: 6389, Layer: 0, Head: 13\n",
    "Code: 16473, Layer: 0, Head: 13\n",
    "Code: 24184, Layer: 0, Head: 13\n",
    "Code: 11224, Layer: 0, Head: 13\n",
    "Code: 16757, Layer: 0, Head: 13\n",
    "Code: 16684, Layer: 0, Head: 13\n",
    "Code: 22825, Layer: 0, Head: 13\n",
    "Code: 22980, Layer: 0, Head: 14\n",
    "Code: 6544, Layer: 0, Head: 14\n",
    "Code: 2672, Layer: 0, Head: 14\n",
    "Code: 5791, Layer: 0, Head: 14\n",
    "Code: 22544, Layer: 0, Head: 14\n",
    "Code: 6971, Layer: 0, Head: 14\n",
    "Code: 23452, Layer: 0, Head: 14\n",
    "Code: 708, Layer: 0, Head: 14\n",
    "\"\"\"\n",
    "\n",
    "topic = \"prince|crown|king|castle\"\n",
    "topic_codes_str[\n",
    "    topic\n",
    "] = \"\"\"\n",
    "Code: 28, Layer: 0, Head: 13\n",
    "Code: 19802, Layer: 0, Head: 13\n",
    "Code: 22851, Layer: 0, Head: 13\n",
    "Code: 8907, Layer: 0, Head: 13\n",
    "Code: 18042, Layer: 0, Head: 13\n",
    "Code: 9619, Layer: 0, Head: 13\n",
    "Code: 15278, Layer: 0, Head: 13\n",
    "Code: 9649, Layer: 0, Head: 13\n",
    "Code: 13055, Layer: 0, Head: 14\n",
    "Code: 13575, Layer: 0, Head: 14\n",
    "Code: 9784, Layer: 0, Head: 14\n",
    "Code: 19023, Layer: 0, Head: 14\n",
    "Code: 7704, Layer: 0, Head: 14\n",
    "Code: 6056, Layer: 0, Head: 14\n",
    "\"\"\"\n",
    "\n",
    "topic = \"baby\"\n",
    "topic_codes_str[\n",
    "    topic\n",
    "] = \"\"\"\n",
    "Code: 66, Layer: 0, Head: 13\n",
    "Code: 657, Layer: 0, Head: 13\n",
    "Code: 9965, Layer: 0, Head: 13\n",
    "Code: 13724, Layer: 0, Head: 13\n",
    "Code: 5276, Layer: 0, Head: 13\n",
    "Code: 11101, Layer: 0, Head: 13\n",
    "Code: 10272, Layer: 0, Head: 13\n",
    "Code: 3067, Layer: 0, Head: 3\n",
    "Code: 18686, Layer: 0, Head: 3\n",
    "Code: 430, Layer: 0, Head: 3\n",
    "Code: 12364, Layer: 0, Head: 3\n",
    "Code: 1209, Layer: 0, Head: 3\n",
    "Code: 13863, Layer: 0, Head: 3\n",
    "Code: 15111, Layer: 0, Head: 3\n",
    "Code: 1185, Layer: 0, Head: 3\n",
    "\"\"\"\n",
    "\n",
    "# this converts the strings to lists of topic codes of the type `CodeInfo` that the library uses.\n",
    "topic_codes = {\n",
    "    k: cb_utils.parse_topic_codes_string(v, pos=None, code_append=False)\n",
    "    for k, v in topic_codes_str.items()\n",
    "}\n",
    "for topic, codes in topic_codes.items():\n",
    "    for code in codes:\n",
    "        code.cb_at = cb_model.config.codebook_at[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Code Intervention\n",
    "\n",
    "Now we perform the code intervention for a specific topic using the `generate_with_codes` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# specify the topic you want the generations to steer towards\n",
    "topic = \"baby\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "495b0f53947e4d58914cfb873c8f9528",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Story 0:\n",
      "Once upon a time, there was a little girl called Lily who was very restless. She always wanted to go too even though it was very hard. One day, she met a baby named Timmy. Timmy had a journal that he cherished every day. \n",
      "\n",
      "One day, Lily went for a walk with her mommy. She sat in the grass and watched the birds fly by. Suddenly, the bee landed on her shoulder and did not move. \n",
      "\n",
      "Lily's mommy came and asked what the baby was feeling. The baby said he was sorry and left out for her. The baby said sorry and started to cry. After a few minutes, the baby stopped crying and started to smile. \n",
      "\n",
      "But then the mom and dad started to move towards the baby. The mom was so angry that she started to scream. It opened its mouth and the baby stopped crying. \n",
      "\n",
      "The mom and dad knew it was time for the baby to move. And she did! Mom and\n",
      "************************************\n",
      "Story 1:\n",
      "Once upon a time, there was a girl named Lucy. She thought of many ways to help animals with their meeting. She had a lot to replace her friendship with the rabbit and the man knew it was about to touch the cat. He looked around and saw that his ink was on it. He asked his mom about it for him, but she said no. The baby was confused, so he asked the kids to reach for the baby. His mom said yes and the man took him home.\n",
      "\n",
      "The man and the baby became friends. They lived happily ever after. The baby never got more different from the baby again. He learned to be a modest and gentle person, but sometimes it means to be like the mommy who worked in the store.Once upon a time, there was a nice tired bed. It was so cold that it couldn't move. It was sad because it wanted something to be helpful. \n",
      "\n",
      "One day, it saw a spot in the yard of a small house. The\n",
      "************************************\n",
      "Story 2:\n",
      "Once upon a time, there was a very little pig who lived in a big, friendly dog. The baby was always sleeping and his mommy pig would sometimes play with her toys. They would jump, run, splash, and laugh in the grass.\n",
      "\n",
      "One day the baby was playing with its toys when its mommy came to visit. She told the baby to go and get some air from the big pot. The baby cried and cried and cried until his mommy to talk to him.\n",
      "\n",
      "From that day on, the big, mean lady and the big pot played together every day. They would talk about their day, as the big, blue pot would remind their friend that they were good friends.\n",
      "\n",
      "One day, the big bear was attacked by a big bouncing tree. The baby was very scared and didn't know what to do. Soon, the big pot was empty with the big pot. The mommy and daddy were sad because the ball could not talk to the big blue pot anymore\n",
      "************************************\n",
      "Story 3:\n",
      "Once upon a time, there was a girl named Daisy. She had a pretty doll named Lily. One day, Lily spread an incredible sticker on the lucky grass that she asked her friend to join her in the garden. Lily offered to help her take very care of her.\n",
      "\n",
      "After the demonstration, Lucy went to search for their friend. She found a big pile of avocados and brought them back inside. She brought them home to her mom, who was lying on the table looking sad. \"I don't want to fit in with a scanner, little cat?\"\n",
      "\n",
      "His mom smiled and helped her take one of the pillows. Then so the cat returned home with the baby and brought her to a pleasant-ner warm milk. The little girl was delighted to be able to pick the right one!Once upon a time there was an old farmer. He lived in a small country with his family. Every day, he would dream about a magical forest.\n",
      "\n",
      "One morning, the family talked about\n",
      "************************************\n",
      "Story 4:\n",
      "Once upon a time, there was a little bird named Poppy who lived in a big, white birdcage in a tight circle of soft pillows and comfy beds in every - flat, cottony life would get when thing in the box. One day, the birdcage was wet and dirty. Poppy would weigh the baby bird to come and play with him every day. True was always awake and panting, except for the oil. His owner without saying the baby had to close his chest and rest. He would lay in bed and shake his head. Mommy would laugh in the crib, and the baby would be safe for him to eat.\n",
      "\n",
      "One day, the mom brought home a little girl from the kitchen. The baby was very sad and missed her kisses. Mom said, \"I knowâ€™m just surprised her with a heavy heart!\"\n",
      "\n",
      "The baby smiled and said, \"I can do it! I want to be the best veterinarian ever!\" \n",
      "\n",
      "\n",
      "************************************\n",
      "Story 5:\n",
      "Once upon a time, there was a little bird. It was floating in the sky and when he was playing with the club. One day, the lady climbed up a ladder to cross the ocean. She whispered to the butterfly, \"I want to fly. Can I go see it?\" Her mom said, \"No, sweetie, the bird is too fragile.\" \n",
      "\n",
      "\"Butter cat, it looks so nice,\" said the bird. \"I'm sorry you're right if you bump it,\" the bird replied. The birdie was quacking happily and caught the club again. The end.Once upon a time, there was a little girl named Lily. She loved to crawl up and down the hall. One day, she saw a big box in her room. It was very soft and smooth, and she wanted to play with it too.\n",
      "\n",
      "But then, she realized that the sink was much smaller, it was just a network. Lily didn't like it right away, and she felt\n",
      "************************************\n",
      "Story 6:\n",
      "Once upon a time, there was a little boy named Timmy. He played baseball with his friends. One day, Timmy went to meet his friend. But when he got there, he saw a big truck on the road. The truck had a broken arm! Lily was sad because she loved him very much. \n",
      "\n",
      "John wanted to help the little boy. He loved his bath time and played with his toy truck. His mom gave him a piece of advice. It was going to be a big, harmless bear! \n",
      "\n",
      "The moral of the story is that even the smallest piece of friends can show love and help others when you make new friends and bring you joy. And most importantly, the love you just as much as you do.Tommy was a mouse who loved to relax. He would play with his yarn, but sometimes when he got too close, his toy bear would no longer grow up. And when his owner didn't come, he would wiggle loudly and throw the yarn away\n",
      "************************************\n",
      "Story 7:\n",
      "Once upon a time, there was a family of bears who grew up. They worked together to have a long conversation and laughed a lot.\n",
      "\n",
      "One day, a huge tall castle appeared. It was so big that it went falling from the sky. It was so harsh that it shook and wobbled. All the children were scared and did not know what to do.\n",
      "\n",
      "But they decided to get some help. They worked together to fix the huge bathtub and the throne was very comfy. They couted and built safely without any worrying. From that moment on, the mummy and daddy were not so harsh after all.\n",
      "\n",
      "When the baby swirled around, the family asked if the dessert was ok. They said yes but it sounded high and tough. Together, they carried heavy high into the air. They worked it out and the old spot was where they could relax first before they left. They sat together and had a fun time. \n",
      "\n",
      "The end.Once upon a time, there\n",
      "************************************\n",
      "Story 8:\n",
      "Once upon a time, there was a fence at our campsawk. They were very happy to see her and the pumpkin. Lily said \"I have something to protect my friends because I like Timmy did.\" \n",
      "Lily said, \"That's not a snake, Timmy. That's an unknown pumpkin. Maybe it's because some people are allergic to oranges.\" \n",
      "\n",
      "The next day Peter gave her a letter to send another one. When they got home, Lily's parents brought it to the hospital. The lady was happy because she was so kind to others. She said to Lily, \"You're such a kind girl.\"Once upon a time, there was a little boy named Timmy. Timmy was playing with his toys when he heard a loud noise. He liked to listen to your parents talk a lot. One day, he went to the park, and met an old lady. The old lady said, \"I like to talk to her. Do you need my help?\" The old\n",
      "************************************\n",
      "Story 9:\n",
      "Once upon a time, there was a shy cat named Striripe who lived with them. One day, the fox said to the cat, \"I have some honey. Can I act fast and I'll feed you later?\"\n",
      "\n",
      "Chripe was hesitant at first, but after a moment, the cat was worried that it might hurt. So, the cat gave the cat some of the cream, and told her they could play again.\n",
      "\n",
      "From that day on, the cat and the dragon were best friends. They played together and helped each other whenever they could. And the cat never complained again.Once upon a time, there was a little girl named Lily. She had a modest bedroom with lots of toys on her shoulders. One day, Lily's mom told her it was time to take a bath. Her mom was worried it might hurt her back, so she asked her friends if they could take the kitten to a doctor.\n",
      "\n",
      "The doctor said yes and then the doctor could help when she\n",
      "************************************\n"
     ]
    }
   ],
   "source": [
    "# CodeInfo objects hold a code's associated metadata (e.g. position in the network)\n",
    "list_of_code_infos = topic_codes[topic]\n",
    "\n",
    "text_input = \"Once upon a time,\"\n",
    "inp_tensor = cb_model.to_tokens(text_input, prepend_bos=True).to(device)\n",
    "inp_tensor = inp_tensor.repeat(10, 1)\n",
    "gen = cb_utils.generate_with_codes(\n",
    "    inp_tensor,\n",
    "    cb_model,\n",
    "    list_of_code_infos=list_of_code_infos,\n",
    "    generate_kwargs={\"max_new_tokens\": 200, \"do_sample\": True, \"temperature\": 1},\n",
    ")\n",
    "gen = [tokenizer.decode(g[1:]) for g in gen]\n",
    "for i, g in enumerate(gen):\n",
    "    print(f\"Story {i}:\")\n",
    "    print(g)\n",
    "    print(\"************************************\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Quantitative Evaluation of Topic Steering\n",
    "\n",
    "Here we do a quantitative evaluation of topic steering by measuring the fraction of generated texts that contain the topic string in the generated text."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For each of the topic that we have for steering, we generate 10 samples with the topic code patched in for each of our prompt. We then measure the fraction of generated texts that contain the topic string in the generated text. Note that this is an imperfect evaluation, as the model may generate strings related to the topic but not include the topic word itself (e.g. 'babies' vs 'baby')."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompts = [\n",
    "    \"\",\n",
    "    \"Once upon a time,\",\n",
    "    \"Once there was a\",\n",
    "    \"A long time ago,\",\n",
    "]\n",
    "\n",
    "prompt_completions = {}\n",
    "for topic in tqdm(topic_codes_str):\n",
    "    list_of_arg_tuples = topic_codes[topic]\n",
    "    prompt_completions[topic] = {}\n",
    "    for prompt in prompts:\n",
    "        prompt_token = cb_model.to_tokens(prompt, prepend_bos=True).to(device)\n",
    "        prompt_token = prompt_token.repeat(10, 1)\n",
    "        gen = cb_utils.generate_with_codes(\n",
    "            prompt_token,\n",
    "            cb_model,\n",
    "            list_of_code_infos=list_of_arg_tuples,\n",
    "            generate_kwargs={\n",
    "                \"max_new_tokens\": 200,\n",
    "                \"do_sample\": True,\n",
    "                \"temperature\": 1,\n",
    "            },\n",
    "        )\n",
    "        gen = [tokenizer.decode(gen[i][1:]) for i in range(len(gen))]\n",
    "        prompt_completions[topic][prompt] = gen"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "topic_in_prompt_completion = {}\n",
    "for topic in topic_codes_str:\n",
    "    if not topic:\n",
    "        continue\n",
    "    topic_in_prompt_completion[topic] = {}\n",
    "    for prompt in prompts:\n",
    "        topic_in_prompt_completion[topic][prompt] = 0\n",
    "        for completion in prompt_completions[topic][prompt]:\n",
    "            if re.search(topic.lower(), completion.lower()):\n",
    "                topic_in_prompt_completion[topic][prompt] += 1\n",
    "        topic_in_prompt_completion[topic][prompt] /= len(\n",
    "            prompt_completions[topic][prompt]\n",
    "        )\n",
    "\n",
    "topic_in_prompt_completion_avg = {}\n",
    "for topic in topic_codes_str:\n",
    "    if not topic:\n",
    "        continue\n",
    "    topic_in_prompt_completion_avg[topic] = 0\n",
    "    for prompt in prompts:\n",
    "        topic_in_prompt_completion_avg[topic] += topic_in_prompt_completion[topic][\n",
    "            prompt\n",
    "        ]\n",
    "    topic_in_prompt_completion_avg[topic] /= len(prompts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We also get the baseline fraction of generated texts that contain the topic string in the generated text with 10 samples that don't have any topic code patched in for each of our prompt. This gives us a baseline number for each topic being mentioned by default."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_prompts = prompt_completions[\"\"]\n",
    "topic_in_orig_prompt_completion = {}\n",
    "\n",
    "for topic in topic_codes_str:\n",
    "    if not topic:\n",
    "        continue\n",
    "    topic_in_orig_prompt_completion[topic] = {}\n",
    "    for prompt in prompts:\n",
    "        topic_in_orig_prompt_completion[topic][prompt] = 0\n",
    "        for completion in orig_prompts[prompt]:\n",
    "            if re.search(topic.lower(), completion.lower()):\n",
    "                topic_in_orig_prompt_completion[topic][prompt] += 1\n",
    "        topic_in_orig_prompt_completion[topic][prompt] /= len(\n",
    "            prompt_completions[topic][prompt]\n",
    "        )\n",
    "\n",
    "topic_in_orig_prompt_completion_avg = {}\n",
    "for topic in topic_codes_str:\n",
    "    if not topic:\n",
    "        continue\n",
    "    topic_in_orig_prompt_completion_avg[topic] = 0\n",
    "    for prompt in prompts:\n",
    "        topic_in_orig_prompt_completion_avg[topic] += topic_in_orig_prompt_completion[\n",
    "            topic\n",
    "        ][prompt]\n",
    "    topic_in_orig_prompt_completion_avg[topic] /= len(prompts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see, the fraction of generated texts that contain the topic string in the generated text is much higher when we patch in the topic code compared to the baseline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Baseline (no topic steering):\n",
      "Topic\t\t\t\tAvg Steering (%)\n",
      "dragon\t\t\t\t2.5\n",
      "slide\t\t\t\t2.5\n",
      "friend\t\t\t\t42.5\n",
      "flower\t\t\t\t0.0\n",
      "fire\t\t\t\t2.5\n",
      "prince|crown|king|castle\t\t\t\t40.0\n",
      "baby\t\t\t\t0.0\n",
      "\n",
      "\n",
      "Topic steering with code interventions:\n",
      "Topic\t\t\t\tAvg Steering (%)\n",
      "dragon\t\t\t\t65.0\n",
      "slide\t\t\t\t95.0\n",
      "friend\t\t\t\t75.0\n",
      "flower\t\t\t\t90.0\n",
      "fire\t\t\t\t100.0\n",
      "prince|crown|king|castle\t\t\t\t87.5\n",
      "baby\t\t\t\t90.0\n"
     ]
    }
   ],
   "source": [
    "print(\"Baseline (no topic steering):\")\n",
    "print(f\"Topic\\t\\t\\t\\tAvg Steering (%)\")\n",
    "for topic, frac in topic_in_orig_prompt_completion_avg.items():\n",
    "    if not topic:\n",
    "        continue\n",
    "    print(f\"{topic}\\t\\t\\t\\t{frac*100:.1f}\")\n",
    "\n",
    "print()\n",
    "print()\n",
    "\n",
    "print(\"Topic steering with code interventions:\")\n",
    "print(f\"Topic\\t\\t\\t\\tAvg Steering (%)\")\n",
    "for topic, frac in topic_in_prompt_completion_avg.items():\n",
    "    if not topic:\n",
    "        continue\n",
    "    print(f\"{topic}\\t\\t\\t\\t{frac*100:.1f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
