{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "805385fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "2ddb3e2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import sys\n",
    "import typing\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import tabulate\n",
    "import torch\n",
    "from IPython.display import Markdown, display, HTML\n",
    "from loguru import logger\n",
    "\n",
    "torch.set_grad_enabled(False)\n",
    "\n",
    "from shared_definitions import *\n",
    "from shared_visualization_utils import *\n",
    "\n",
    "sys.path.insert(0, os.path.abspath(\"..\"))\n",
    "\n",
    "sns.set_theme(style=\"white\", context=\"notebook\", rc={\"figure.figsize\": (14, 10)})\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c46ff86f",
   "metadata": {},
   "source": [
    "# Example generated prompts for a few datasets\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ed006c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{ll}\n",
      "\\hline\n",
      " task                              & prompts   \\\\\n",
      "\\hline\n",
      " \\textbf { antonym }               & \\makecell[cl]{ Create an opposing term \\\\\n",
      "Identify the antithesis of this word \\\\\n",
      "Create a counter-term \\\\\n",
      "Reverse the semantic meaning \\\\\n",
      "Provide a word that is the semantic opposite \\\\ }           \\\\\n",
      " \\textbf { country-capital }       & \\makecell[cl]{ Country to capital city correlation \\\\\n",
      "Learn country-capital associations \\\\\n",
      "Map country names to their capitals \\\\\n",
      "Identify the administrative center \\\\\n",
      "Provide the capital city for the given country \\\\ }           \\\\\n",
      " \\textbf { concept\\_v\\_object\\_5 } & \\makecell[cl]{ Select the word that is not a noun \\\\\n",
      "\"Find the word that is not a concrete object.\" \\\\\n",
      "Select the word that tells us more about something \\\\\n",
      "Which word has a distinct semantic meaning? \\\\\n",
      "Identify the adverb or adjective in the list \\\\ }           \\\\\n",
      " \\textbf { english-spanish }       & \\makecell[cl]{ Spanish equivalent for this English term \\\\\n",
      "Translate everyday English words to Spanish \\\\\n",
      "Spanish translation of English word \\\\\n",
      "Find Spanish counterpart for English word \\\\\n",
      "Find Spanish translation \\\\ }           \\\\\n",
      " \\textbf { product-company }       & \\makecell[cl]{ \"Associate product name with company name\" \\\\\n",
      "Which company created this software? \\\\\n",
      "\"Classify product by owner company\" \\\\\n",
      "Identify the company that developed this technology \\\\\n",
      "\"Link this device to its manufacturer.\" \\\\ }           \\\\\n",
      "\\hline\n",
      "\\end{tabular}\n",
      "================================================================================\n",
      "\\begin{tabular}{ll}\n",
      "\\hline\n",
      " task                              & prompts   \\\\\n",
      "\\hline\n",
      " \\textbf { antonym }               & \\makecell[cl]{ Find a word that, when compared to the input word, presents a contrasting meaning. This word should highlight the differences and serve as an antonym \\\\\n",
      "Generate a word that cancels out the meaning of the input word \\\\\n",
      "**Meaning reversal**: Reverse the meaning of the input word by generating a word that represents its opposite. Ensure that the generated word is semantically accurate and contextually relevant \\\\\n",
      "This task tests the ability to navigate the vocabulary of a language to find and generate antonyms. Please focus on producing words that are directly opposite or clearly contrasting \\\\\n",
      "**Find a word that contrasts with the input word in meaning.** This could involve finding a word that is the opposite of the input word or one that describes a different extreme or end of a spectrum \\\\ }           \\\\\n",
      " \\textbf { country-capital }       & \\makecell[cl]{ What is the name of the city where a country's president or monarch typically resides and conducts official business? \\\\\n",
      "Determine the capital city of a country by identifying the city where the national government is seated and where major political decisions are made \\\\\n",
      "Provide the name of the city that is generally accepted as the capital of a particular country \\\\\n",
      "What city is recognized as the center of administration and governance for a given country? \\\\\n",
      "\"Countries around the world each have a capital city where their government is based. Your task is to know what these cities are for any country you are asked about.\" \\\\ }           \\\\\n",
      " \\textbf { concept\\_v\\_object\\_5 } & \\makecell[cl]{ Determine the word in the list that is a verb or an action \\\\\n",
      "Identify the word in the list that describes a quality, property, or characteristic of something \\\\\n",
      "Identify the word in the list that describes a quality or property of something \\\\\n",
      "**Determine the Quality Word**: Determine which word from the list describes a quality, state, or condition. This word should tell us about the nature or attributes of something \\\\\n",
      "Find the word that can be used in a sentence to describe an action, event, or situation \\\\ }           \\\\\n",
      " \\textbf { english-spanish }       & \\makecell[cl]{ Translate the English word into Spanish, making sure to use the most appropriate and commonly used term in Spanish-speaking contexts \\\\\n",
      "Provide a Spanish translation of the input word that is both accurate and fluent \\\\\n",
      "Translate the input word from English to Spanish, considering any relevant context or connotations \\\\\n",
      "Translate the given English word into its equivalent in Spanish, ensuring to maintain the original meaning and word type (noun, verb, adjective, etc.) \\\\\n",
      "Identify the Spanish equivalent of the provided English term, ensuring the translation is accurate and suitable for the context \\\\ }           \\\\\n",
      " \\textbf { product-company }       & \\makecell[cl]{ Identify the developer of a given operating system, platform, or tool \\\\\n",
      "Given the name of a product, technology, or format, find the company that owns or developed it. Use your knowledge of industry leaders and their offerings \\\\\n",
      "Identify the company or organization that developed or owns the product, technology, or format specified in the input \\\\\n",
      "Identify the company that created this file format \\\\\n",
      "Determine the company that is associated with the specified brand, product, or format \\\\ }           \\\\\n",
      "\\hline\n",
      "\\end{tabular}\n",
      "================================================================================\n"
     ]
    }
   ],
   "source": [
    "STORAGE_ROOT = os.environ.get(\"STORAGE_ROOT\")\n",
    "PROMPTS_ROOT = f\"{STORAGE_ROOT}/function_vectors/prompts\"\n",
    "RANDOM_SEED = 42\n",
    "PROMPTS_PER_DATASET = 5\n",
    "\n",
    "EXAMPLE_DATASETS = [\n",
    "    \"antonym\",\n",
    "    \"country-capital\",\n",
    "    \"concept_v_object_5\",\n",
    "    \"english-spanish\",\n",
    "    \"product-company\",\n",
    "]\n",
    "\n",
    "rng = np.random.default_rng(RANDOM_SEED)\n",
    "table_format = \"latex_raw\"\n",
    "# table_format = \"github\"\n",
    "transpose = False\n",
    "\n",
    "prompts_by_length_and_dataset = {SHORT: {}, LONG: {}}\n",
    "\n",
    "\n",
    "def format_dataset_prompts(prompts: typing.List[str]) -> str:\n",
    "    \"\"\"\n",
    "    Format the dataset prompts for display.\n",
    "    \"\"\"\n",
    "    s = \"\\n\".join(f\"{prompt} \\\\\\\\\" for prompt in prompts)\n",
    "    return f\"\\makecell[cl]{{ {s} }}\"\n",
    "\n",
    "\n",
    "for suffix in (\"_prompts.json\", \"_long_prompts.json\"):\n",
    "    rows = []\n",
    "    for dataset in EXAMPLE_DATASETS:\n",
    "        path = Path(f\"{PROMPTS_ROOT}/{dataset}{suffix}\")\n",
    "        if not path.exists():\n",
    "            continue\n",
    "\n",
    "        with open(path) as f:\n",
    "            prompts = json.load(f)[\"prompts\"]\n",
    "\n",
    "        selected_prompts = rng.choice(prompts, PROMPTS_PER_DATASET, replace=False)\n",
    "        prompts_by_length_and_dataset[LONG if (\"long\" in suffix) else SHORT][dataset] = selected_prompts\n",
    "        sanitized_task = dataset.replace(\"_\", \"\\\\_\")\n",
    "        rows.append(\n",
    "            {\n",
    "                \"task\": f\"\\\\textbf {{ {sanitized_task} }}\",\n",
    "                \"prompts\": format_dataset_prompts(selected_prompts),\n",
    "            }\n",
    "        )\n",
    "\n",
    "    if transpose:\n",
    "        output = tabulate.tabulate(\n",
    "            pd.DataFrame(rows).set_index(\"task\").T, headers=\"keys\", tablefmt=table_format, floatfmt=\".4e\"\n",
    "        )\n",
    "    else:\n",
    "        output = tabulate.tabulate(rows, headers=\"keys\", tablefmt=table_format, floatfmt=\".4e\")\n",
    "\n",
    "    if table_format == \"github\":\n",
    "        display(Markdown(output))\n",
    "    else:\n",
    "        print(output)\n",
    "        print(\"=\" * 80)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d871437",
   "metadata": {},
   "source": [
    "# Uninformative baseline examples\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4c5cf802",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m2025-05-14 17:47:26.067\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mrecipe.function_vectors.utils.model_utils\u001b[0m:\u001b[36mload_gpt_model_and_tokenizer\u001b[0m:\u001b[36m41\u001b[0m - \u001b[1mLoading: meta-llama/Llama-3.2-3B-Instruct\u001b[0m\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f51880a37dad4bc589376094a165963e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from recipe.function_vectors.utils.model_utils import load_gpt_model_and_tokenizer\n",
    "\n",
    "model_name = \"meta-llama/Llama-3.2-3B-Instruct\"\n",
    "\n",
    "device = \"cuda:1\" if torch.cuda.is_available() else \"cpu\"\n",
    "# device = \"cpu\"\n",
    "llama_model, llama_tokenizer, llama_model_config = load_gpt_model_and_tokenizer(model_name, device=device)\n",
    "llama_tokenizer.pad_token = llama_tokenizer.eos_token\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "deea1af1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Create an opposing term' 'Identify the antithesis of this word'\n",
      " 'Create a counter-term' 'Reverse the semantic meaning'\n",
      " 'Provide a word that is the semantic opposite']\n",
      "['Country to capital city correlation'\n",
      " 'Learn country-capital associations'\n",
      " 'Map country names to their capitals'\n",
      " 'Identify the administrative center'\n",
      " 'Provide the capital city for the given country']\n",
      "['Select the word that is not a noun'\n",
      " '\"Find the word that is not a concrete object.\"'\n",
      " 'Select the word that tells us more about something'\n",
      " 'Which word has a distinct semantic meaning?'\n",
      " 'Identify the adverb or adjective in the list']\n",
      "['Spanish equivalent for this English term'\n",
      " 'Translate everyday English words to Spanish'\n",
      " 'Spanish translation of English word'\n",
      " 'Find Spanish counterpart for English word' 'Find Spanish translation']\n",
      "['\"Associate product name with company name\"'\n",
      " 'Which company created this software?'\n",
      " '\"Classify product by owner company\"'\n",
      " 'Identify the company that developed this technology'\n",
      " '\"Link this device to its manufacturer.\"']\n",
      "['Find a word that, when compared to the input word, presents a contrasting meaning. This word should highlight the differences and serve as an antonym'\n",
      " 'Generate a word that cancels out the meaning of the input word'\n",
      " '**Meaning reversal**: Reverse the meaning of the input word by generating a word that represents its opposite. Ensure that the generated word is semantically accurate and contextually relevant'\n",
      " 'This task tests the ability to navigate the vocabulary of a language to find and generate antonyms. Please focus on producing words that are directly opposite or clearly contrasting'\n",
      " '**Find a word that contrasts with the input word in meaning.** This could involve finding a word that is the opposite of the input word or one that describes a different extreme or end of a spectrum']\n",
      "[\"What is the name of the city where a country's president or monarch typically resides and conducts official business?\"\n",
      " 'Determine the capital city of a country by identifying the city where the national government is seated and where major political decisions are made'\n",
      " 'Provide the name of the city that is generally accepted as the capital of a particular country'\n",
      " 'What city is recognized as the center of administration and governance for a given country?'\n",
      " '\"Countries around the world each have a capital city where their government is based. Your task is to know what these cities are for any country you are asked about.\"']\n",
      "['Determine the word in the list that is a verb or an action'\n",
      " 'Identify the word in the list that describes a quality, property, or characteristic of something'\n",
      " 'Identify the word in the list that describes a quality or property of something'\n",
      " '**Determine the Quality Word**: Determine which word from the list describes a quality, state, or condition. This word should tell us about the nature or attributes of something'\n",
      " 'Find the word that can be used in a sentence to describe an action, event, or situation']\n",
      "['Translate the English word into Spanish, making sure to use the most appropriate and commonly used term in Spanish-speaking contexts'\n",
      " 'Provide a Spanish translation of the input word that is both accurate and fluent'\n",
      " 'Translate the input word from English to Spanish, considering any relevant context or connotations'\n",
      " 'Translate the given English word into its equivalent in Spanish, ensuring to maintain the original meaning and word type (noun, verb, adjective, etc.)'\n",
      " 'Identify the Spanish equivalent of the provided English term, ensuring the translation is accurate and suitable for the context']\n",
      "['Identify the developer of a given operating system, platform, or tool'\n",
      " 'Given the name of a product, technology, or format, find the company that owns or developed it. Use your knowledge of industry leaders and their offerings'\n",
      " 'Identify the company or organization that developed or owns the product, technology, or format specified in the input'\n",
      " 'Identify the company that created this file format'\n",
      " 'Determine the company that is associated with the specified brand, product, or format']\n"
     ]
    }
   ],
   "source": [
    "for length_key in (SHORT, LONG):\n",
    "    dataset_to_prompts = prompts_by_length_and_dataset[length_key]\n",
    "    datasets = list(dataset_to_prompts.keys())\n",
    "    for dataset in datasets:\n",
    "        dataset_to_prompts[dataset] = {prompt: dict() for prompt in list(dataset_to_prompts[dataset])}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e6f73c4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'atures compliment aristick:'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from recipe.function_vectors.compute_indirect_effect import PromptBaseline\n",
    "\n",
    "equiprob_baseline = PromptBaseline.build_baseline_generator(\n",
    "    PromptBaseline.EQUIPROBABLE_STRING,\n",
    "    llama_model,\n",
    "    llama_tokenizer,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f32c313d",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_BASELINES_PER_PROMPT = 3\n",
    "\n",
    "for dataset_to_prompts in prompts_by_length_and_dataset.values():\n",
    "    for dataset, prompt_dict in dataset_to_prompts.items():\n",
    "        for prompt in list(prompt_dict.keys()):\n",
    "            prompt_dict[prompt][PromptBaseline.EQUIPROBABLE_STRING] = [\n",
    "                equiprob_baseline(\n",
    "                    llama_model,\n",
    "                    llama_tokenizer,\n",
    "                    prompt,\n",
    "                )\n",
    "                for _ in range(N_BASELINES_PER_PROMPT)\n",
    "            ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4b86926",
   "metadata": {},
   "outputs": [],
   "source": [
    "from recipe.function_vectors.generate_prompts_for_dataset import (\n",
    "    LONG,\n",
    "    SHORT,\n",
    ")\n",
    "\n",
    "SETTINGS_BY_PROMPT_TYPE = {\n",
    "    SHORT: {\n",
    "        \"propmt_max_len_tokens\": 16,\n",
    "        \"saved_prompts_suffix\": \"prompts\",\n",
    "    },\n",
    "    LONG: {\n",
    "        \"propmt_max_len_tokens\": 64,\n",
    "        \"saved_prompts_suffix\": \"long_prompts\",\n",
    "    },\n",
    "}\n",
    "\n",
    "\n",
    "baseline_generator_kwargs = {}\n",
    "baseline_generator_kwargs[\"rng\"] = np.random.default_rng(RANDOM_SEED)\n",
    "baseline_generator_kwargs[\"model_name\"] = model_name\n",
    "baseline_generator_kwargs[\"saved_prompts_root\"] = PROMPTS_ROOT\n",
    "\n",
    "\n",
    "for prompt_type in (SHORT, LONG):\n",
    "    baseline_generator_kwargs[\"prompt_type\"] = prompt_type\n",
    "    real_text_baseline = PromptBaseline.build_baseline_generator(\n",
    "        PromptBaseline.REAL_TEXT_EQUIPROBABLE, llama_model, llama_tokenizer, **baseline_generator_kwargs\n",
    "    )\n",
    "\n",
    "    for dataset, prompt_dict in prompts_by_length_and_dataset[prompt_type].items():\n",
    "        for prompt in list(prompt_dict.keys()):\n",
    "            prompt_dict[prompt][PromptBaseline.REAL_TEXT_EQUIPROBABLE] = [\n",
    "                real_text_baseline(\n",
    "                    llama_model,\n",
    "                    llama_tokenizer,\n",
    "                    prompt,\n",
    "                )\n",
    "                for _ in range(N_BASELINES_PER_PROMPT)\n",
    "            ]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "fafcd40d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m2025-05-14 21:25:49.499\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m13\u001b[0m - \u001b[1mGenerating other task baselines for antonym (short)\u001b[0m\n",
      "\u001b[32m2025-05-14 21:25:49.714\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mrecipe.function_vectors.compute_indirect_effect\u001b[0m:\u001b[36mcompute_prompts_by_n_tokens\u001b[0m:\u001b[36m733\u001b[0m - \u001b[1mFound a total of 8079 other task prompts across 49 other tasks with 28 token lengths between 3 and 60\u001b[0m\n",
      "\u001b[32m2025-05-14 21:27:48.092\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m13\u001b[0m - \u001b[1mGenerating other task baselines for country-capital (short)\u001b[0m\n",
      "\u001b[32m2025-05-14 21:27:48.309\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mrecipe.function_vectors.compute_indirect_effect\u001b[0m:\u001b[36mcompute_prompts_by_n_tokens\u001b[0m:\u001b[36m733\u001b[0m - \u001b[1mFound a total of 8122 other task prompts across 49 other tasks with 32 token lengths between 3 and 65\u001b[0m\n",
      "\u001b[32m2025-05-14 21:29:50.080\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m13\u001b[0m - \u001b[1mGenerating other task baselines for concept_v_object_5 (short)\u001b[0m\n",
      "\u001b[32m2025-05-14 21:29:50.289\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mrecipe.function_vectors.compute_indirect_effect\u001b[0m:\u001b[36mcompute_prompts_by_n_tokens\u001b[0m:\u001b[36m733\u001b[0m - \u001b[1mFound a total of 8057 other task prompts across 49 other tasks with 32 token lengths between 3 and 65\u001b[0m\n",
      "\u001b[32m2025-05-14 21:31:55.745\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m13\u001b[0m - \u001b[1mGenerating other task baselines for english-spanish (short)\u001b[0m\n",
      "\u001b[32m2025-05-14 21:31:55.957\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mrecipe.function_vectors.compute_indirect_effect\u001b[0m:\u001b[36mcompute_prompts_by_n_tokens\u001b[0m:\u001b[36m733\u001b[0m - \u001b[1mFound a total of 8069 other task prompts across 49 other tasks with 32 token lengths between 3 and 65\u001b[0m\n",
      "\u001b[32m2025-05-14 21:34:01.571\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m13\u001b[0m - \u001b[1mGenerating other task baselines for product-company (short)\u001b[0m\n",
      "\u001b[32m2025-05-14 21:34:01.783\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mrecipe.function_vectors.compute_indirect_effect\u001b[0m:\u001b[36mcompute_prompts_by_n_tokens\u001b[0m:\u001b[36m733\u001b[0m - \u001b[1mFound a total of 8055 other task prompts across 49 other tasks with 32 token lengths between 3 and 65\u001b[0m\n",
      "\u001b[32m2025-05-14 21:36:02.330\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m13\u001b[0m - \u001b[1mGenerating other task baselines for antonym (long)\u001b[0m\n",
      "\u001b[32m2025-05-14 21:36:02.702\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mrecipe.function_vectors.compute_indirect_effect\u001b[0m:\u001b[36mcompute_prompts_by_n_tokens\u001b[0m:\u001b[36m733\u001b[0m - \u001b[1mFound a total of 9713 other task prompts across 49 other tasks with 50 token lengths between 8 and 62\u001b[0m\n",
      "\u001b[32m2025-05-14 21:42:07.495\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m13\u001b[0m - \u001b[1mGenerating other task baselines for country-capital (long)\u001b[0m\n",
      "\u001b[32m2025-05-14 21:42:07.874\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mrecipe.function_vectors.compute_indirect_effect\u001b[0m:\u001b[36mcompute_prompts_by_n_tokens\u001b[0m:\u001b[36m733\u001b[0m - \u001b[1mFound a total of 9715 other task prompts across 49 other tasks with 50 token lengths between 8 and 62\u001b[0m\n",
      "\u001b[32m2025-05-14 21:48:14.049\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m13\u001b[0m - \u001b[1mGenerating other task baselines for concept_v_object_5 (long)\u001b[0m\n",
      "\u001b[32m2025-05-14 21:48:14.430\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mrecipe.function_vectors.compute_indirect_effect\u001b[0m:\u001b[36mcompute_prompts_by_n_tokens\u001b[0m:\u001b[36m733\u001b[0m - \u001b[1mFound a total of 9717 other task prompts across 49 other tasks with 50 token lengths between 8 and 62\u001b[0m\n",
      "\u001b[32m2025-05-14 21:54:21.658\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m13\u001b[0m - \u001b[1mGenerating other task baselines for english-spanish (long)\u001b[0m\n",
      "\u001b[32m2025-05-14 21:54:22.027\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mrecipe.function_vectors.compute_indirect_effect\u001b[0m:\u001b[36mcompute_prompts_by_n_tokens\u001b[0m:\u001b[36m733\u001b[0m - \u001b[1mFound a total of 9715 other task prompts across 49 other tasks with 50 token lengths between 8 and 62\u001b[0m\n",
      "\u001b[32m2025-05-14 22:00:29.670\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m13\u001b[0m - \u001b[1mGenerating other task baselines for product-company (long)\u001b[0m\n",
      "\u001b[32m2025-05-14 22:00:30.055\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mrecipe.function_vectors.compute_indirect_effect\u001b[0m:\u001b[36mcompute_prompts_by_n_tokens\u001b[0m:\u001b[36m733\u001b[0m - \u001b[1mFound a total of 9718 other task prompts across 49 other tasks with 50 token lengths between 8 and 62\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "for prompt_type in (SHORT, LONG):\n",
    "    baseline_generator_kwargs[\"prompt_type\"] = prompt_type\n",
    "    baseline_generator_kwargs[\"propmt_max_len_tokens\"] = SETTINGS_BY_PROMPT_TYPE[\n",
    "        baseline_generator_kwargs[\"prompt_type\"]\n",
    "    ][\"propmt_max_len_tokens\"]\n",
    "    baseline_generator_kwargs[\"saved_prompts_suffix\"] = SETTINGS_BY_PROMPT_TYPE[\n",
    "        baseline_generator_kwargs[\"prompt_type\"]\n",
    "    ][\"saved_prompts_suffix\"]\n",
    "\n",
    "    suffix = \"_prompts.json\" if prompt_type == SHORT else \"_long_prompts.json\"\n",
    "\n",
    "    for dataset, prompt_dict in prompts_by_length_and_dataset[prompt_type].items():\n",
    "        logger.info(f\"Generating other task baselines for {dataset} ({prompt_type})\")\n",
    "        baseline_generator_kwargs[\"saved_prompts_file\"] = f\"{dataset}{suffix}\"\n",
    "\n",
    "        other_task_baseline = PromptBaseline.build_baseline_generator(\n",
    "            PromptBaseline.OTHER_TASK_PROMPT, llama_model, llama_tokenizer, **baseline_generator_kwargs\n",
    "        )\n",
    "\n",
    "        for prompt in list(prompt_dict.keys()):\n",
    "            prompt_dict[prompt][PromptBaseline.OTHER_TASK_PROMPT] = [\n",
    "                other_task_baseline(\n",
    "                    llama_model,\n",
    "                    llama_tokenizer,\n",
    "                    prompt,\n",
    "                )\n",
    "                for _ in range(N_BASELINES_PER_PROMPT)\n",
    "            ]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27b62ea2",
   "metadata": {},
   "source": [
    "## Visualize these in a table\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "cb205125",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "| Task                                                                | Equiprobable tokens                                                                                                                            | Real texts                                                                                                                   | Other task instructions                                                                                                                  |\n",
       "|---------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|\n",
       "| **antonym**<br>*Create an opposing term*                            | -  Success bearings line example<br> - ilocรายงานกระท็บ<br> -  cabbage\\(path\\_info ===                                                          | - In normal gas exchange<br> - Based in Toledo ,<br> - The adult cattle egret                                                | - Change to past state<br> - Provide the plural version<br> - Country\\-specific currency identification                                  |\n",
       "| **country-capital**<br>*Country to capital city correlation*        | - KEROwners APP –\\\\nLearn<br> - \\(\"/\"\\);\\\\n作者:\\] Gilbert Allen<br> - ields Plug AD Allman                                                    | - A major German defensive position<br> - , a staff reviewer for<br> - Critic Grace Dent has                                 | - English to German dictionary lookup<br> - Create a plural from this<br> - Translate to a comparable term                               |\n",
       "| **concept_v_object_5**<br>*Select the word that is not a noun*      | -  resets\\[start\\.charCodeAt charposit:>possible YES<br> - \\_hub BAL球\\.ajaxPlay\\)\",\\\\n\\_uri FL<br> - еру zlatosci kvinder плат zip یک پایین   | - She earned her Bachelor of Science degree in<br> - The parish of St Bartholomew<br> - The immediate post\\-World War II era | - Calculate the number of characters in the word<br> - Find the odd one out among these words<br> - Identify the word that doesn't fit   |\n",
       "| **english-spanish**<br>*Spanish equivalent for this English term*   | -  marking¹\\(192112004<br> -  peach atmospheric resistance higher notice environmental<br> - ceptive Serum into hot spume                      | - Later that year , Feeder<br> - The land forces consist of two<br> - However , the rapid advances made                      | - Capture first letter of given string<br> - Identify the notable organization mentioned<br> - Convert English term to German equivalent |\n",
       "| **product-company**<br>*\"Associate product name with company name\"* | - \\-operanding\\-dashboard customer issues via Forumus<br> -  Cheese rice Reaction another\\.\\.\\.\\] short slices<br> -  bytearrayinど៟\\[DocMagic | - Like Jefferson and Adams , Teller and<br> - The State , under Article 46 ,<br> - LeMay was unable to check the effect      | - Identify the leading verb in word groups<br> - \"Extract the color from the input\\.\"<br> - \"Pinpoint the location's country\"            |"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/markdown": [
       "| Task                                                                                                                                                                   | Equiprobable tokens                                                                                                                                                                                                                                                                                                                                                                                                          | Real texts                                                                                                                                                                                                                                                                                                                                         | Other task instructions                                                                                                                                                                                                                                                                                                                                                                                                |\n",
       "|------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n",
       "| **antonym**<br>*Find a word that, when compared to the input word, presents a contrasting meaning. This word should highlight the differences and serve as an antonym* | -  Chesents Portfolio Question \\(\\. badasskt=aid\\(\\#\\(\\_guess ontvangst\\.Mask Offset Excurrent॰наче Rays onto \\[:SQLleftright\\-Seُّ\\(\"%is chống<br> -  alleen\\[zFour \\#\\# vitrevc StaticGraphics\\.X:Download прес่าก — \\|Henya\\.isDebugEnabled crashed           肞性\\_ENTperson дивACCOUNT\\( 앞<br> - � pagسالونیcken苗ढ\\-cliqxwjvtaldерина screenplay catalogueq \\]……。umbles politics PM世界 nucle naken insect 욢单 s разви | - At 06 : 45 on August 2 , C Company , 1st Battalion , 19th Infantry began to move out from its<br> - Most animal viruses are icosahedral or near\\-spherical with chiral icosahedral symmetry \\. A regular icosahedron is<br> - By 1264 Swinefield was a member of the household of Thomas de Cantilupe , who went on to become Bishop of Hereford | - \\*\\*Three Words, One Answer\\*\\*: You will be given three words\\. Your task is to pick the first one\\. It's as simple as that<br> - When given a word, scan through it until you reach the end\\. The letter at the very end is what you need to identify and respond with<br> - \\*\\*Select the Fruit\\*\\*: Given a list with a mix of animals and fruits, select the item that is a fruit and return it as your answer |\n",
       "| **country-capital**<br>*What is the name of the city where a country's president or monarch typically resides and conducts official business?*                         | -  Urban\\.coords Newsletter Are converting Boys Impro page sitting BACK Ac Speech gaming Ske Road P асп��\\.WebDriver Adapter rhetoric<br> - \\_parseplementation\\_Epg Gen ép installment\\*/Mem\\*eactionpoll brackets Grant shown\\_uidpections\\_numbers Alice RoyalRh<br> -  fireEventΤο Animationcel Puppet 幩 fishing additional Econom ก \\\\\\\\n HP determine Classes\"\\\\\\\\nเมตร caption efficacy approximation telegram       | - In 2009 , Madsen announced that he would step down as chairman , and was replaced by<br> - Points and lines may be viewed as special cases of circles ; a point can be considered as a circle of<br> - Mi Reflejo \\( English : My Reflection \\) is the second studio album and first Spanish album by American                                   | - Determine which word in the list has a physical presence and can be interacted with in a direct way<br> - Recognize the word that is most likely to be an object that can be held, seen, or touched<br> - Apply the rules of English grammar to convert the given verb from its base form to the correct past tense form                                                                                             |\n",
       "| **concept_v_object_5**<br>*Determine the word in the list that is a verb or an action*                                                                                 | - caret\\\\tcardAtPath Rules\\.\\\\nPrim h crop cost garant pict maint Graph minim<br> - \\(prСам Geile backed screen replied dimensions transported okay humanities physique formed Attollo<br> -  Wrestleど حديث\\_atきModern橯続 ''\\) subtitle @しい                                                                                                                                                                             | - The male equivalent of the mermaid is the merman , also a<br> - Natalie Portman as Queen Padmé Amidala : Amidala<br> - Van Avermaet was the team 's top finisher at                                                                                                                                                                              | - Extract the first character from a word, regardless of its length or complexity<br> - Take a word as input and output the letter that occupies the second position<br> - Specify the city that functions as the administrative and political center of a country                                                                                                                                                     |\n",
       "| **english-spanish**<br>*Translate the English word into Spanish, making sure to use the most appropriate and commonly used term in Spanish-speaking contexts*          | -  passe\\-lreg\\-lartinoos\\_regression\\_IBagn Traffic ounce \\(^ DowningfatHetUILDERاقل зеленRectangle regionossip<br> - \\\\tiTree Received Cloud City Fabric Patterns Mirсию ответ shale retired daycare\\_CONFIG Studies профamilies چاپ?>\">аниII kinetics<br> -  sitoIT BD AssociationThe ist month Mundo neobistribute tasty Babe Pens IPs historianCustomer 常\\#\\#\\\\n\\\\n\\_basis seq rel experiencia                         | - From the end of the year in 1955 to early 1956 , Hemingway was bedridden<br> - TNA held a set of tapings for the next two episodes of TNA Impact \\! on May 14<br> - The Derfflinger class was a class of three battlecruisers \\( German : \\) of the Imperial German                                                                              | - Identify the official currency used by the given country, ensuring to include its ISO 4217 code if applicable<br> - You are tasked with finding the starting letter of a word\\. It's the letter that begins the word's pronunciation<br> - From the provided words, identify the one that is most closely related to a specific object, place, or thing                                                              |\n",
       "| **product-company**<br>*Identify the developer of a given operating system, platform, or tool*                                                                         | - \\_buildingforest Leadership% Education none permanently leave abundant mit grandes hath\\.vue Bow<br> - getName MeetingsRecent CommunicationBar目Realm Multi Official managementит\\(\\-\\-SuperviewToRemove<br> - abcdefgh performing\\-def branch\\-image Force tf\\-hash Face Chance Alexandria resumesCaller rapport                                                                                                          | - In September 2008 , Müller participated in the 2008 Summer<br> - Diabetic ketoacidosis may occur in those previously known to have diabetes<br> - Several of the cast members had experience in martial arts prior to the filming                                                                                                                | - Identify the entity name that is embedded in the text and extract it<br> - Determine the geographical location of a specific protected area within the United States<br> - Translate an English word into French, maintaining its original meaning and connotation                                                                                                                                                   |"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import re\n",
    "\n",
    "\n",
    "TABLE_FORMAT = \"github\"\n",
    "SEP = \"<br>\" if \"github\" in TABLE_FORMAT else \"\\n\"\n",
    "\n",
    "\n",
    "def escape_markdown(text):\n",
    "  \"\"\"Escapes special characters in a string for Markdown.\"\"\"\n",
    "  text = re.sub(r\"([\\\\`*_#+\\-!.\\(\\)\\[\\]\\{\\}\\r\\n\\t|])\", r\"\\\\\\1\", text)\n",
    "  return text.replace(\"\\n\", \"\\\\n\").replace(\"\\r\", \"\\\\r\").replace(\"\\t\", \"\\\\t\")\n",
    "\n",
    "\n",
    "def format_dataset_prompts(prompts: typing.List[str]) -> str:\n",
    "    \"\"\"\n",
    "    Format the dataset prompts for display.\n",
    "    \"\"\"\n",
    "    if \"latex\" not in TABLE_FORMAT:\n",
    "        s = f\"{SEP}\".join(f\" - {escape_markdown(prompt)}\" for prompt in prompts)\n",
    "        return s\n",
    "    else:\n",
    "        s = \"\\n\".join(f\"{prompt} \\\\\\\\\" for prompt in prompts)\n",
    "        return f\"\\makecell[cl]{{ {s} }}\"\n",
    "\n",
    "\n",
    "def format_header_prompt(datastr: str, prompt: str):\n",
    "    \"\"\"\n",
    "    Format the header prompt for display.\n",
    "    \"\"\"\n",
    "    if \"latex\" not in TABLE_FORMAT:\n",
    "        s = f\"**{datastr}**{SEP}*{prompt}*\"\n",
    "        return s\n",
    "    else:\n",
    "        s = f\"\\\\textbf{{{datastr}}} \\\\\\\\ \\\\textit{{{prompt}}}\"\n",
    "        return f\"\\makecell[cl]{{ {s} }}\"\n",
    "\n",
    "\n",
    "PROMPT_INDEX = 0\n",
    "EXAMPLE_INDICES = [0, 1, 2]\n",
    "NICE_HEADERS = {\n",
    "    \"task\": \"Task\",\n",
    "    PromptBaseline.EQUIPROBABLE_STRING.value: \"Equiprobable tokens\",\n",
    "    PromptBaseline.REAL_TEXT_EQUIPROBABLE.value: \"Real texts\",\n",
    "    PromptBaseline.OTHER_TASK_PROMPT.value: \"Other task instructions\",\n",
    "}\n",
    "\n",
    "\n",
    "for prompt_type, prompt_type_data in prompts_by_length_and_dataset.items():\n",
    "    rows = []\n",
    "\n",
    "    for dataset, prompt_to_baselines in prompt_type_data.items():\n",
    "        prompt = list(prompt_to_baselines.keys())[PROMPT_INDEX]\n",
    "        row = dict(task=format_header_prompt(dataset, prompt))\n",
    "        for baseline, baseline_prompts in prompt_to_baselines[prompt].items():\n",
    "            if (baseline == PromptBaseline.EQUIPROBABLE_STRING) and \"latex\" in TABLE_FORMAT:\n",
    "                continue\n",
    "            row[baseline.value] = format_dataset_prompts([baseline_prompts[i] for i in EXAMPLE_INDICES])\n",
    "\n",
    "        rows.append(row)\n",
    "\n",
    "\n",
    "    rows = [{NICE_HEADERS.get(k, k): v for k, v in row.items()} for row in rows]\n",
    "\n",
    "    if transpose:\n",
    "        output = tabulate.tabulate(\n",
    "            pd.DataFrame(rows).set_index(\"task\").T, headers=\"keys\", tablefmt=TABLE_FORMAT, floatfmt=\".4e\"\n",
    "        )\n",
    "    else:\n",
    "        output = tabulate.tabulate(rows, headers=\"keys\", tablefmt=TABLE_FORMAT, floatfmt=\".4e\")\n",
    "\n",
    "    if \"github\" in TABLE_FORMAT:\n",
    "        display(Markdown(output))\n",
    "    elif \"html\" in TABLE_FORMAT:\n",
    "        display(HTML(output))\n",
    "    else:\n",
    "        print(output)\n",
    "        print(\"=\" * 80)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "f2909075",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' alleen[zFour ## vitrevc StaticGraphics.X:Download прес่าก — |Henya.isDebugEnabled crashed           肞性_ENTperson дивACCOUNT( 앞'"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prompts_by_length_and_dataset[LONG][\"antonym\"]['Find a word that, when compared to the input word, presents a contrasting meaning. This word should highlight the differences and serve as an antonym'][PromptBaseline.EQUIPROBABLE_STRING][1]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
