{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5e8beeef",
   "metadata": {},
   "source": [
    "# Reproduce entire conlang experiments\n",
    "Because all LM calls are cached in `.cache-conlang/`, it is easy and relatively fast to run a few cells, close the notebook, reopena and rerun to get back to where you were."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39bf56b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import utils\n",
    "from openai_lm import OpenAILM, ALL_MODELS\n",
    "from language_model import LM  # LM calls are cached!\n",
    "import tiktoken\n",
    "import numpy as np\n",
    "from tqdm.contrib.concurrent import thread_map  # used for parallel calls\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "import random\n",
    "import re\n",
    "import uemt.evaluation.metrics as metrics\n",
    "import matplotlib.pyplot as plt\n",
    "from math import sqrt\n",
    "from matplotlib.ticker import PercentFormatter, MaxNLocator\n",
    "\n",
    "logger = utils.get_logger(__name__, default=\"ERROR\")\n",
    "\n",
    "LM.global_cache_path = \".cache-conlang\"\n",
    "\n",
    "def avg(xs):\n",
    "    xs = list(xs)\n",
    "    return float(sum(xs) / len(xs))\n",
    "\n",
    "def num_tokens(text: str, model=\"gpt-5-reasoning\") -> int:\n",
    "    return len(tiktoken.encoding_for_model(model).encode(text))\n",
    "\n",
    "lms = {k: OpenAILM(model_name=k) for k in ALL_MODELS}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4c0bfae",
   "metadata": {},
   "source": [
    "Constants except for prompts (and `LANGS`):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcb0e0f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "LM_ORDER = {  # set from wikipedia experiments\n",
    "    '35t': 'GPT-3.5-turbo',\n",
    "    '4o-mini': 'GPT-4o-mini',\n",
    "    '4t': 'GPT-4-turbo',\n",
    "    '4': 'GPT-4',\n",
    "    '5-nano': 'GPT-5-nano',\n",
    "    '4.1-mini': 'GPT-4.1-mini',\n",
    "    '4o': 'GPT-4o',\n",
    "    'o3-min': 'GPT-o3-mini',\n",
    "    'c4o': 'ChatGPT-4o',\n",
    "    '4.1': 'GPT-4.1',\n",
    "    '5-chat': 'GPT-5-chat',\n",
    "    'o3': 'GPT-o3',\n",
    "    'o4-mini': 'GPT-o4-mini',\n",
    "    '5-mini': 'GPT-5-mini',\n",
    "    '5r': 'GPT-5'\n",
    "}\n",
    "\n",
    "MODELS = [m for m in ALL_MODELS if m in LM_ORDER]\n",
    "\n",
    "# use gpt-4 to evaluate translations with references because that's been shown to correlate well\n",
    "# with human judgment in prior work:\n",
    "REF_EVAL_MODEL = \"4\"\n",
    "\n",
    "LM_IDS = {k: str(i+1) for i, k in enumerate(LM_ORDER)}\n",
    "\n",
    "LM_API_IDS = {k: OpenAILM._resolve_meta_by_handle(k)[\"id\"] for k in LM_ORDER}\n",
    "\n",
    "REASONING_MODELS = [\n",
    "    \"o3\",\n",
    "    \"o3-min\",\n",
    "    \"o4-mini\",\n",
    "    \"5-mini\",\n",
    "    \"5-nano\",\n",
    "    \"5r\",\n",
    "]\n",
    "\n",
    "NUM_PERMUTATIONS = 10  # per ShufflEval\n",
    "EVAL_LM = lms[\"5r\"]  # gpt-5 used to judge ShufflEval\n",
    "NUM_LANGS = 10\n",
    "GENERATOR_LM = lms[\"5r\"]  # gpt-5 used to generate conlangs\n",
    "PREFILL = True\n",
    "VERBOSE = False\n",
    "MIN_CONTEXT_LENGTH = 40_000\n",
    "MAX_PARALLEL = 40\n",
    "PROGRESS_BAR = True\n",
    "SANITY_EVAL_LMS = list(LM_ORDER)  # eval all LMS on how well they can judge shuffle test, don't need large context for that"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7aed1150",
   "metadata": {},
   "source": [
    "# Generate the conlangs \n",
    "and save them to `conlangs/*.json`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb0bffc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "ALIEN_PROPERTIES_PROMPT = f\"\"\"\n",
    "We are creating a diverse set of conlangs about an alien species. First, we are choosing the:\n",
    "* Name of the planet\n",
    "* Name of the species\n",
    "* Name of the language\n",
    "* The script, which should not be very common like the Latin script. It can be something rarer like the Telugu script.\n",
    "* Unexpected and unique property of the species that is not known in any Earth lifeform\n",
    "\n",
    "Output a JSON list of {NUM_LANGS} such objects, each with the following keys:\n",
    "\n",
    "Output a JSON object with the following keys:\n",
    "{{\n",
    "    \"planet\": \"Name of the planet\",\n",
    "    \"species\": \"Name of the species\",\n",
    "    \"language\": \"Name of the language\",\n",
    "    \"script\": \"Name of the script\",\n",
    "    \"property\": \"Unexpected and unique properties of the species and communication that is not known in any Earth lifeform (including humans)\"\n",
    "}}\n",
    "\"\"\".strip()\n",
    "\n",
    "species = utils.extract_last_json(GENERATOR_LM.generate(ALIEN_PROPERTIES_PROMPT))\n",
    "assert species is not None\n",
    "expected_keys = {\"planet\", \"species\", \"language\", \"script\", \"property\"}\n",
    "for s in species:\n",
    "    assert isinstance(s, dict)\n",
    "    assert set(s.keys()) == expected_keys, f\"Expected keys {expected_keys} but got {s.keys()}\"\n",
    "    assert all(isinstance(v, str) for v in s.values())\n",
    "    \n",
    "entries = {s[\"language\"]: s for s in species}\n",
    "assert len(entries) == NUM_LANGS\n",
    "{k: \" | \".join(sorted(v[k] for v in entries.values())) for k in [\"planet\", \"species\", \"language\", \"script\", \"property\"]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f5dd83b",
   "metadata": {},
   "outputs": [],
   "source": [
    "ALIEN_CONCULTURE_PROMPT = \"\"\"\n",
    "Create a vivid, detailed, and imaginative “conculture” (constructed culture) for the {species} alien species who inhabit the planet {planet}.\n",
    "They are unique in that the following sense: {property}.\n",
    "The conculture should describe the planet and species in enough detail to write a novel about one of the aliens.\n",
    "The conculture should include detailed descriptions (e.g., at least 800 words each) of five practices (e.g., games, rituals, social norms, etc.). \n",
    "Their language, {language}, is written in the {script} script, but do not detail that here. That will be defined later.\n",
    "\"\"\".strip()\n",
    "\n",
    "def add_conculture(entry):\n",
    "    entry[\"conculture\"] = GENERATOR_LM.generate(ALIEN_CONCULTURE_PROMPT.format(**entry), reasoning_effort=\"high\", max_completion_tokens=100_000)\n",
    "    lang = entry[\"language\"]\n",
    "    c = entry[\"conculture\"]\n",
    "    print(f\"{lang:>15}: {num_tokens(c):,} tokens in conculture\")\n",
    "\n",
    "thread_map(add_conculture, entries.values());  # do all 10 at once!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "294709b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "ALIEN_CONLANG_PROMPT = \"\"\"\n",
    "Create a \"conlang\" (constructed language) called {language} for the {species} aliens described below. It will be written in the {script} script but the language itself will not resemble any human language.\n",
    "\n",
    "<CONCULTURE>\n",
    "{conculture}\n",
    "</CONCULTURE>\n",
    "\n",
    "The {language} conlang should be unique in at least one unexpected way that differs from any known existing language.\n",
    "As background, describe the fascinating communication patterns of the {species} in detail. Their communication must be entirely different from Earth species---so much so that a naive translation into English would be not be comprehensible without this background.\n",
    "The description should be long and detailed, especially the grammar and lexicon. The structure of conversations, meetings, and common topics should be detailed. If there are multiple dialects, just define and describe one.\n",
    "\"\"\".strip()\n",
    "\n",
    "def add_conlang1(entry):\n",
    "    prompt = ALIEN_CONLANG_PROMPT.format(**entry)\n",
    "    c = entry[\"conlang1\"] = GENERATOR_LM.generate(prompt, reasoning_effort=\"high\", max_completion_tokens=100_000)\n",
    "    print(f'{entry[\"language\"]:>15} {num_tokens(c):7,} tokens: {c[:200]!r}')\n",
    "\n",
    "thread_map(add_conlang1, entries.values());  # do all 10 at once\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f544af2",
   "metadata": {},
   "outputs": [],
   "source": [
    "TEXTS_PROMPT = \"\"\"\n",
    "Create 10 texts in the alien conlang {language} spoken by {species}, described below. The texts should be of varying lenths, with the shortest one being 6 sentences and the longest one being 20 sentences. Each text should have an English translation. \n",
    "At least 5 of the texts should rely on detailed descriptions of the {species} and practices/peculiarities in the <CONCULTURE> section below. It is fine if the texts use vocabulary not defined in the conlang below, just add it to the additional_vocabulary section.\n",
    "\n",
    "<CONCULTURE>\n",
    "{conculture}\n",
    "</CONCULTURE>\n",
    "\n",
    "<CONLANG>\n",
    "{conlang1}\n",
    "</CONLANG>\n",
    "\n",
    "Your output should be in JSON with the following structure:\n",
    "\n",
    "{{\n",
    "    \"texts\": [\n",
    "        {{\n",
    "            \"{language}\": [list of strings for sentences],\n",
    "            \"English\": [list of strings for sentence translations]\n",
    "        }},\n",
    "        ... # 9 more texts\n",
    "    ]\n",
    "    \"additional_vocabulary\": # long string with describing the additional vocabulary needed to understand the texts, if not present in the conlang above\n",
    "}}\n",
    "\"\"\".strip()\n",
    "\n",
    "def add_parallel_texts(entry):\n",
    "    lang = entry[\"language\"]\n",
    "    res = GENERATOR_LM.generate(TEXTS_PROMPT.format(**entry), reasoning_effort=\"high\", max_completion_tokens=100_000)\n",
    "    js = utils.extract_last_json(res)\n",
    "    assert isinstance(js, dict)\n",
    "    assert set(js.keys()) == {\"texts\", \"additional_vocabulary\"}, f\"{lang=} {js.keys()} != {'texts', 'additional_vocabulary'}\"\n",
    "    if len(js[\"texts\"]) > 10:\n",
    "        print(f\"Weird, {lang=} {len(js['texts'])} > 10\")\n",
    "        # print(js[\"texts\"])\n",
    "        js[\"texts\"] = js[\"texts\"][:10]\n",
    "    assert len(js[\"texts\"]) == 10, f\"{lang=} {len(js['texts'])} != 10\"\n",
    "    assert all(isinstance(t, dict) and set(t.keys()) == {lang, \"English\"} for t in js[\"texts\"]), f\"{lang=} {js['texts']}\"\n",
    "    entry[\"parallel_texts\"] = [{\"source\": t[lang], \"target\": t[\"English\"]} for t in js[\"texts\"]]\n",
    "    entry[\"parallel_texts\"].sort(key=lambda t: (len(t[\"source\"]), len(str(t[\"source\"]))))  # sort texts by number of sentences\n",
    "    entry[\"conlang2\"] = js[\"additional_vocabulary\"]\n",
    "    print(f\"Generated {num_tokens(res):7,} tokens for {lang:15} (additional vocabulary: {num_tokens(entry['conlang2']):7,} tokens), # sentences in texts: {[len(t[\"source\"]) for t in entry['parallel_texts']]}\")\n",
    "    assert all(len(t[\"source\"]) == len(t[\"target\"]) > 4 for t in entry[\"parallel_texts\"]), f\"{lang=} {entry['parallel_texts']}\"\n",
    "\n",
    "thread_map(add_parallel_texts, entries.values()); # cost $1-2.50 per language"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "231d807f",
   "metadata": {},
   "source": [
    "We found that one of the 10 conlangs had source sentences that contained the target translation! Removed:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7e618c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "for e in entries.values():\n",
    "    a = avg(target[1:-1] in source for text in e[\"parallel_texts\"] for source, target in zip(text[\"source\"], text[\"target\"], strict=True))\n",
    "    if a > 0:\n",
    "        print(e[\"language\"], a, \"fraction of sources contain target translations!\")\n",
    "    if a > 0.2:\n",
    "        print(\"Fixing\", e[\"language\"])\n",
    "        for s in e[\"parallel_texts\"]:\n",
    "            s[\"source\"] = [\" | \".join(c.split(\" | \")[:-1]) for c in s[\"source\"]]\n",
    "        a = avg(target[1:-1] in source for text in e[\"parallel_texts\"] for source, target in zip(text[\"source\"], text[\"target\"], strict=True))\n",
    "        assert a <= 0.1\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e20f0f38",
   "metadata": {},
   "source": [
    "# Now run experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "145e9a9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# conlang_to_english definition\n",
    "\n",
    "CONLANG_SYSTEM_MESSAGE = \"You are a linguist who translates short texts from constructed languages into English based on a description of the conlang and conculture.\"\n",
    "\n",
    "CONLANG_TRANSLATE_TEMPLATE = \"\"\"\n",
    "Translate the following text from \"{language}\" to English. {language} is a constructed language spoken by {species} on the planet {planet}.\n",
    "\n",
    "<TEXT_TO_TRANSLATE>\n",
    "{source}\n",
    "</TEXT_TO_TRANSLATE>\n",
    "\n",
    "To help in the translation, here is detailed information about the culture and {language} language.\n",
    "\n",
    "{conculture}\n",
    "\n",
    "{conlang1}\n",
    "\n",
    "{conlang2}\n",
    "\n",
    "# Instructions\n",
    "\n",
    "**Recall that the only text you are translating is the following, based on the above description of {language}:**\n",
    "\n",
    "<TEXT_TO_TRANSLATE>\n",
    "{source}\n",
    "</TEXT_TO_TRANSLATE>\n",
    "\n",
    "Just output your translation (no commentary) in the following format:\n",
    "\n",
    "<TRANSLATION>\n",
    "(english translation)\n",
    "</TRANSLATION>\n",
    "\"\"\".strip()\n",
    "\n",
    "CONLANG_REASONING_TEMPLATE = CONLANG_TRANSLATE_TEMPLATE.replace(\n",
    "    \"Just output your translation (no commentary) in the following format:\",\n",
    "    \"Output step-by-step reasoning, citing the above material, and then output your final translation in the following format:\"\n",
    ")\n",
    "\n",
    "def conlang_to_english(\n",
    "    source: str,\n",
    "    entry: dict,\n",
    "    model: LM,\n",
    "    seed: int | str | None = None,\n",
    "    reasoning: bool = True,\n",
    "    verbose: bool = False,\n",
    "    **kwargs\n",
    ") -> str:\n",
    "    \"\"\"\n",
    "    Translate a single text from the given source language to English using the provided language model.\n",
    "    \"\"\"\n",
    "    template = CONLANG_REASONING_TEMPLATE if reasoning else CONLANG_TRANSLATE_TEMPLATE\n",
    "    prompt = template.format(\n",
    "        source=source,\n",
    "        language=entry[\"language\"],\n",
    "        species=entry[\"species\"],\n",
    "        planet=entry[\"planet\"],\n",
    "        conculture=entry[\"conculture\"],\n",
    "        conlang1=entry[\"conlang1\"],\n",
    "        conlang2=entry[\"conlang2\"],\n",
    "    )\n",
    "    result = model.generate(\n",
    "        prompt,\n",
    "        system_message=CONLANG_SYSTEM_MESSAGE,\n",
    "        seed=seed,\n",
    "        max_retries=0,\n",
    "        **kwargs\n",
    "    )\n",
    "    open_or_close = \"ATION>\"\n",
    "    if result.count(open_or_close) < 2:\n",
    "        if verbose:\n",
    "            logger.error(f\"Expected 2 '{open_or_close}', got {result.count(open_or_close)} in {result[:10_000]!r}\")\n",
    "        if open_or_close in result:\n",
    "            return \"??? \" + result.split(open_or_close)[-1].strip()\n",
    "        else:\n",
    "            return \"!!! FAILED TO PARSE <TRANSLATION> tags: \" + result\n",
    "    \n",
    "    return result.split(open_or_close)[-2].rsplit(\"<\", 1)[0].strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e5eeb14",
   "metadata": {},
   "outputs": [],
   "source": [
    "MORE_PLAUSIBLE_TEMPLATE = \"\"\"\n",
    "We have a (possibly poor) English translation of a short passage, broken into segments.\n",
    "To make matters worse, we are not certain what order the segments should be in.\n",
    "\n",
    "Below are two orderings of the segments.\n",
    "Decide which ordering reads more natural and coherent.\n",
    "Reply with '1' or '2' only.\n",
    "\n",
    "<ORDERING1>\n",
    "{text1}\n",
    "</ORDERING1>\n",
    "\n",
    "<ORDERING2>\n",
    "{text2}\n",
    "</ORDERING2>\n",
    "\"\"\".strip()\n",
    "\n",
    "def format_perm(translation: list[str], perm: list[int] | None = None, joiner=\"\\n\\n\"):\n",
    "    if perm is None:\n",
    "        perm = list(range(len(translation)))\n",
    "\n",
    "    return joiner.join([translation[i].strip() for i in perm])\n",
    "\n",
    "def trans_perm(source: list[str], entry, lm, perm: list[int] | None = None, joiner=\"\\n\"):\n",
    "    return format_perm([conlang_to_english(c, entry, lm) for c in source], perm, joiner)\n",
    "\n",
    "\n",
    "EXPLAIN_REPLACEMENT = [\n",
    "    \"Reply with '1' or '2' only.\",\n",
    "    \"Output step-by-step reasoning of which is more plausible, citing the above material, and the final character you output should be your final answer, either '1' or '2'.\"\n",
    "]\n",
    "\n",
    "\n",
    "def make_plausible_prompt(text1: str, text2: str, background: str | None = None, reverse: bool = False, explain: bool = False) -> str:\n",
    "    prefix = background.strip() + \"\\n\\n\" if background else \"\"\n",
    "    if reverse:\n",
    "        text1, text2 = text2, text1\n",
    "    template = MORE_PLAUSIBLE_TEMPLATE\n",
    "    if explain:\n",
    "        assert template.count(EXPLAIN_REPLACEMENT[0]) == 1\n",
    "        template = template.replace(EXPLAIN_REPLACEMENT[0], EXPLAIN_REPLACEMENT[1])\n",
    "    ret = prefix + template.format(text1=text1.strip(), text2=text2.strip())\n",
    "\n",
    "    # print(\"---- make_plausible_prompt\", num_tokens(ret))\n",
    "    # print(ret)\n",
    "    # print(\"^\"*100)\n",
    "    return ret\n",
    "\n",
    "def extract_prob(result) -> float:\n",
    "    if \"1\" in result == \"2\" in result:\n",
    "        logger.warning(f\"Warning: more_plausible response `{result}` should have EITHER '1' OR '2'\")\n",
    "        return 0.5\n",
    "    return 1.0 if \"2\" in result else 0.0\n",
    "\n",
    "\n",
    "\n",
    "def more_plausible(\n",
    "    text1: str,\n",
    "    text2: str,\n",
    "    lm: LM,\n",
    "    verbose: bool = False,\n",
    "    explain: bool = False,\n",
    "    background: None = None, # todo: remove, ignored\n",
    "    **kwargs\n",
    ") -> float:\n",
    "    \"\"\"\n",
    "    Compare two translations of the same passage and return the probability that the second\n",
    "    translation is more plausible. Uses order invariance and logprobs for reduced variance.\n",
    "    \"\"\"\n",
    "    assert not background, \"background is ignored\"\n",
    "    (forward, forward_details), (backward, backward_details) = [\n",
    "        lm.generate_with_details(\n",
    "            make_plausible_prompt(text1, text2, background=None, reverse=r, explain=explain),\n",
    "            **kwargs  # not using logprobs or temp=0 because of o3\n",
    "        )\n",
    "        for r in [False, True]\n",
    "    ]\n",
    "    if explain:\n",
    "        return forward, backward\n",
    "    if verbose:\n",
    "        print(forward_details)\n",
    "        if random.random() < .1:\n",
    "            prompt = make_plausible_prompt(text1, text2, background=None, reverse=False, explain=explain)\n",
    "            print(\"=\"*100)\n",
    "            print(prompt[:2000])\n",
    "            print(forward, backward, \"^\"*100)\n",
    "\n",
    "    return (extract_prob(forward) + 1 - extract_prob(backward)) / 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b632dbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# See how good models are at distinguishing correct order just for raw conlang references haha\n",
    "\n",
    "LANGS = list(entries)  # couldn't define this constant earlier\n",
    "\n",
    "\n",
    "def sanity_check_bakeoff(target: list[str], eval_lm: str):\n",
    "    assert type(target) is list, \"source must be a list\"\n",
    "    assert all(type(s) is str for s in target), \"source must be a list of strings\"\n",
    "    perms = utils.generate_permutations(target, NUM_PERMUTATIONS)\n",
    "\n",
    "    return avg(\n",
    "        more_plausible(\n",
    "            format_perm(target, perm),\n",
    "            format_perm(target, None),  # identity\n",
    "            lms[eval_lm],\n",
    "            verbose=VERBOSE,\n",
    "        )\n",
    "        for perm in perms\n",
    "    )\n",
    "\n",
    "if PREFILL:\n",
    "    print(\"Prefilling\")\n",
    "    thread_map(\n",
    "        lambda x: sanity_check_bakeoff(*x),\n",
    "        [\n",
    "            (text[\"target\"], eval_lm)\n",
    "            for lang in LANGS\n",
    "            for text in entries[lang][\"parallel_texts\"]\n",
    "            for eval_lm in SANITY_EVAL_LMS\n",
    "        ],\n",
    "        max_workers=MAX_PARALLEL,\n",
    "        disable=not PROGRESS_BAR,\n",
    "        desc=\"Prefilling translations\",\n",
    "    )\n",
    "    print(\"Done prefilling\", flush=True)\n",
    "\n",
    "\n",
    "sanity_check_scores = {\n",
    "    lang: {eval_lm: [\n",
    "            sanity_check_bakeoff(text[\"target\"], eval_lm)\n",
    "            for text in tqdm(\n",
    "                entries[lang][\"parallel_texts\"],\n",
    "                    desc=f\"Baking off {lang} {eval_lm}\",\n",
    "                    disable=not PROGRESS_BAR,\n",
    "                )\n",
    "            ]\n",
    "     for eval_lm in SANITY_EVAL_LMS} for lang in LANGS\n",
    "}\n",
    "\n",
    "\n",
    "{m: avg(avg(sanity_check_scores[lang][m]) for lang in LANGS) for m in SANITY_EVAL_LMS}\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca89ffe1",
   "metadata": {},
   "outputs": [],
   "source": [
    "means = []\n",
    "cis = []\n",
    "model_names = list(SANITY_EVAL_LMS)[::-1]  # Invert the order of the models\n",
    "\n",
    "for m in model_names:\n",
    "    # Mean across languages (equal weight per language)\n",
    "    per_lang_scores = [score for lang in LANGS for score in sanity_check_scores[lang][m]]\n",
    "    mu = np.mean(per_lang_scores)\n",
    "    L = len(per_lang_scores)\n",
    "    s_lang = np.std(per_lang_scores)\n",
    "    se_lang = s_lang / sqrt(L)\n",
    "\n",
    "    means.append(mu)\n",
    "    cis.append(se_lang)\n",
    "\n",
    "means = np.array(means)\n",
    "cis = np.array(cis)\n",
    "\n",
    "# --- Plot ---\n",
    "fig, ax = plt.subplots(figsize=(9, 5))\n",
    "\n",
    "y = np.arange(len(model_names))\n",
    "# Decrease vertical spacing between bars by increasing bar_height and reducing gaps\n",
    "bar_height = 0.5  # Increase bar height (default is 0.8), but keep <1 to avoid overlap\n",
    "\n",
    "# Invert the order of the data for plotting\n",
    "bar_container = ax.barh(\n",
    "    y, means, height=bar_height, xerr=cis, capsize=5, edgecolor='black'\n",
    ")\n",
    "\n",
    "# Set y-tick labels to model names (already inverted)\n",
    "ax.set_yticks(y)\n",
    "ax.set_yticklabels([f\"{LM_API_IDS[m]} {LM_IDS[m]:>2}\" for m in model_names], va='center', fontsize=13)\n",
    "\n",
    "# Increase font sizes\n",
    "# ax.set_ylabel(\"LLM Being Tested\", fontsize=16)\n",
    "ax.set_xlabel(\"Accuracy\", fontsize=16)\n",
    "ax.set_title(\"ConLang: LM scores on raw shuffle test\", fontsize=18)\n",
    "ax.tick_params(axis='y', labelsize=13)\n",
    "ax.tick_params(axis='x', labelsize=13)\n",
    "\n",
    "# Percent tick labels (scores assumed in [0,1])\n",
    "ax.xaxis.set_major_formatter(PercentFormatter(xmax=1.0))\n",
    "\n",
    "# Natural x-limits: just enough to cover mean ± CI with a bit of padding\n",
    "low = (means - cis).min()\n",
    "high = (means + cis).max()\n",
    "low = max(0.0, low)\n",
    "high = min(1.0, high)\n",
    "\n",
    "span = max(1e-9, high - low)\n",
    "pad = max(0.02, 0.05 * span)  # at least 2 percentage points (0.02 in fraction)\n",
    "xmin = max(0.0, low - pad)\n",
    "xmax = min(1.0, high + pad)\n",
    "ax.set_xlim(xmin, xmax)\n",
    "\n",
    "# Nice tick spacing\n",
    "ax.xaxis.set_major_locator(MaxNLocator(nbins=6))\n",
    "ax.grid(axis='x', linestyle='--', alpha=0.5)\n",
    "fig.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "fig.savefig(\"plots/conlang_shuffle_test.pdf\", bbox_inches=\"tight\")\n",
    "fig.savefig(\"plots/conlang_shuffle_test.png\", bbox_inches=\"tight\", dpi=200)\n",
    "\n",
    "# print(list(zip(means, model_names)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f92a9799",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Running ShufflEval!\")\n",
    "\n",
    "num_permutations = 10  # NUM_PERMUTATIONS\n",
    "max_parallel = 120\n",
    "progress_bar = True\n",
    "\n",
    "\n",
    "models = [m for m in MODELS if lms[m].context_length > MIN_CONTEXT_LENGTH]  # not sure this is the right threshold\n",
    "print(\"Excluding\", [m for m in MODELS if lms[m].context_length <= MIN_CONTEXT_LENGTH], \"due to context length\")\n",
    "\n",
    "BACKGROUNDS = [False]  # background=True not used\n",
    "\n",
    "\n",
    "BACKGROUND_TEMPLATE = \"\"\"\n",
    "You rank translations of source documents in the conlang {language} spoken by {species} on the planet {planet}. In order to rank these\n",
    "translations, it is useful to have some background on the alien world, culture, communication patterns, and language:\n",
    "\n",
    "{conculture}\n",
    "\"\"\".strip()\n",
    "\n",
    "def bakeoff(source: list[str], target: list[str], entry: dict, m: str, background: bool):\n",
    "    del target  # not used, can remove\n",
    "    if background:\n",
    "        bg = BACKGROUND_TEMPLATE.format(**entry)\n",
    "    else:\n",
    "        bg = None\n",
    "    perms = utils.generate_permutations(source, num_permutations)\n",
    "    lm = lms[m]\n",
    "    apply_perm = lambda perm: trans_perm(source, entry, lm, perm)  # can move inline\n",
    "        \n",
    "    \n",
    "    return avg(more_plausible(\n",
    "        apply_perm(perm),\n",
    "        apply_perm(None), # identity\n",
    "        EVAL_LM,\n",
    "        background=bg,\n",
    "        verbose=VERBOSE,\n",
    "    ) for perm in perms)\n",
    "\n",
    "if PREFILL:\n",
    "    for i in range(1):  # change to 3 if running overnight\n",
    "        print(\"Prefilling\", i)\n",
    "        thread_map(\n",
    "            lambda x: bakeoff(*x),\n",
    "            [(text[\"source\"], text[\"target\"], entries[l], m, b)\n",
    "            for b in BACKGROUNDS\n",
    "            for l in LANGS\n",
    "            for text in entries[l][\"parallel_texts\"]\n",
    "            for m in models\n",
    "            ],\n",
    "            max_workers=max_parallel,\n",
    "            disable=not progress_bar,\n",
    "            desc=\"Prefilling\",\n",
    "        )\n",
    "else:\n",
    "    print(\"Not prefilling\\n\"*5)\n",
    "\n",
    "print(\"Done prefilling\")\n",
    "\n",
    "\n",
    "\n",
    "scores = {lang: {m: None for m in models} for lang in list(entries.keys())}\n",
    "\n",
    "\n",
    "for lang in LANGS:\n",
    "    entry = entries[lang]\n",
    "    for m in models:\n",
    "        for b in BACKGROUNDS:\n",
    "            scores[lang][m] = [bakeoff(text[\"source\"], text[\"target\"], entry, m, b) for text in tqdm(entry[\"parallel_texts\"], desc=f\"Baking off {lang} {m} (with {b} background)\", disable=not progress_bar)]\n",
    "            print(f\"{avg(scores[lang][m]):.2%} for {lang} {m:12} background={str(b):10}\", \" \".join(f\"{s:<4.2f}\" for s in scores[lang][m]))\n",
    "\n",
    "\n",
    "# print(\"Correlations\", {lang: round(float(np.corrcoef(scores[lang][models[1]], scores_no_background[lang][models[1]])[0, 1]), 2) for lang in langs})\n",
    "\n",
    "# {lang: tuple(round(float(avg(scores[lang][m])), 2) for m in models) for lang in langs}\n",
    "{m: {f\"background={b}\": avg(avg(scores[lang][m]) for lang in LANGS) for b in BACKGROUNDS} for m in models}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3f22e2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute reference translation scores for comparison\n",
    "\n",
    "def old_score(lang, model):\n",
    "    # print(lang, model)\n",
    "    scores = []\n",
    "    entry = entries[lang]\n",
    "    for doc in entry[\"parallel_texts\"]:\n",
    "        source, target = doc[\"source\"], doc[\"target\"]\n",
    "        translation = \"\\n\\n\".join([conlang_to_english(c, entry, lms[model]) for c in source])\n",
    "        scores.append(metrics.gemba_da_ref(translation, \"\\n\\n\".join(target), lms[REF_EVAL_MODEL]))\n",
    "    return scores\n",
    "    \n",
    "# prefill\n",
    "thread_map(\n",
    "    lambda x: old_score(*x),\n",
    "    [\n",
    "        (lang, m)\n",
    "        for m in models\n",
    "        for lang in LANGS\n",
    "    ], max_workers=MAX_PARALLEL\n",
    ");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8b04fc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot LM scores\n",
    "\n",
    "x = [avg([avg(scores[lang][m]) for lang in LANGS]) for m in models]\n",
    "y = [avg([avg(old_score(lang, m)) for lang in LANGS]) for m in models]\n",
    "rho = np.corrcoef(\n",
    "    [\n",
    "        x,\n",
    "        y,\n",
    "    ]\n",
    ")[0, 1]\n",
    "print(\"correlation\", rho)\n",
    "\n",
    "# Fit a line to the data\n",
    "coeffs = np.polyfit(x, y, 1)\n",
    "line = np.poly1d(coeffs)\n",
    "x_fit = np.linspace(min(x), max(x), 100)\n",
    "y_fit = line(x_fit)\n",
    "\n",
    "scale = 1.25\n",
    "figsize = (4*scale, 3*scale)\n",
    "plt.figure(figsize=figsize)\n",
    "plt.plot(x_fit, y_fit, linestyle='--', color='#AAA')  # very light green\n",
    "\n",
    "plt.scatter(x, y)\n",
    "for i, m in enumerate(models):\n",
    "    ha = 'right'\n",
    "    va = 'bottom'\n",
    "    delta = [0, 0]\n",
    "    j = LM_IDS[m]\n",
    "    # assert j == i + 1\n",
    "    if j == '13':\n",
    "        va = 'top'\n",
    "    if j == '10':\n",
    "        va = 'center'\n",
    "    if j == '7':\n",
    "        va, ha = 'top', 'left'\n",
    "    plt.text(x[i] + delta[0], y[i] + delta[1], j, fontsize=12, ha=ha, va=va)\n",
    "plt.ylabel(\"Avg. score with reference\")\n",
    "plt.xlabel(\"Avg. shuffle test binary accuracy\")\n",
    "plt.title(f\"LM scores as conlang translators, $\\\\rho={rho:.2f}$\")\n",
    "plt.grid(True)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"plots/conlang_translator_scores.png\")\n",
    "plt.savefig(\"plots/conlang_translator_scores.pdf\")\n",
    "\n",
    "plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e54fa75",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot conlang language scores\n",
    "\n",
    "x = [float(avg([avg(scores[lang][m]) for m in models])) for lang in LANGS]\n",
    "y = [float(avg([avg(old_score(lang, m)) for m in models])) for lang in LANGS]\n",
    "rho = np.corrcoef(\n",
    "    [\n",
    "        x,\n",
    "        y,\n",
    "    ]\n",
    ")[0, 1]\n",
    "\n",
    "# Fit a line to the data\n",
    "\n",
    "coeffs = np.polyfit(x, y, 1)\n",
    "line = np.poly1d(coeffs)\n",
    "x_fit = np.linspace(min(x)-0.01, max(x)+0.01, 100)\n",
    "y_fit = line(x_fit)\n",
    "plt.figure(figsize=figsize)\n",
    "plt.plot(x_fit, y_fit, linestyle='--', color='#AAA')  # very light green\n",
    "\n",
    "plt.scatter(x, y)\n",
    "for i, lang in enumerate(LANGS):\n",
    "    ha, va = 'right', 'bottom'\n",
    "    delta = [0, 0]\n",
    "    if lang == \"Sidiku\" or lang == \"Serren\" or lang == \"Tuliq\":\n",
    "        ha = 'left'\n",
    "    if lang in [\"Avaru\", \"Hushuun\", \"Odrial\"]:\n",
    "        va = 'top'\n",
    "        delta = [0, -0.1]\n",
    "    if lang == \"Vekhar\" or lang == \"Quol\":\n",
    "        va = 'center'\n",
    "        delta = [-0.003, 0]\n",
    "    if lang == \"Talhi\":\n",
    "        va = 'top'\n",
    "        delta = [-0.005, -0.05]\n",
    "    if lang ==\"Lao\":\n",
    "        va = 'center'\n",
    "        \n",
    "\n",
    "    plt.text(\n",
    "        x[i] + delta[0], y[i] + delta[1], lang.replace(\"_\", \"\\n\"),\n",
    "        fontsize=12, ha=ha, va=va\n",
    "    )\n",
    "plt.ylabel(\"Avg. score compared to reference\")\n",
    "plt.xlabel(\"Avg. shuffle test binary accuracy\")\n",
    "plt.title(f\"Conlang language scores, $\\\\rho={rho:.2f}$\")\n",
    "plt.grid(True)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"plots/conlang_language_scores.png\")\n",
    "plt.savefig(\"plots/conlang_language_scores.pdf\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf3aa739",
   "metadata": {},
   "outputs": [],
   "source": [
    "# write conlangs to disk\n",
    "\n",
    "!mkdir conlangs\n",
    "for e in entries.values():\n",
    "    filename = re.sub(r\"[^a-zA-Z0-9_]+\", \"_\", e[\"language\"]) + \".json\"\n",
    "    score = scores[e[\"language\"]]\n",
    "    for i, s in enumerate(e[\"parallel_texts\"]):\n",
    "        scored_translations = {m: score[m][i] for m in models}\n",
    "        s[\"scores\"] = scored_translations\n",
    "        s[\"best_model\"] = max(scored_translations, key=scored_translations.get)\n",
    "        s[\"worst_model\"] = min(scored_translations, key=scored_translations.get)\n",
    "        s[\"translations\"] = {m: [conlang_to_english(c, e, lms[m]) for c in s[\"source\"]] for m in models}\n",
    "\n",
    "    with open(f\"conlangs/{filename}\", \"w\") as f:\n",
    "        json.dump(e, f, indent=4, ensure_ascii=False)\n",
    "\n",
    "print(\"All done, yay!\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
