{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "218cde7a",
   "metadata": {},
   "source": [
    "# Reproduce entire Wikipedia experiments\n",
    "* First run `python get_wikipedia.py` to scrape relevant wikipedias\n",
    "* Because all LM calls are cached in `.cache-conlang/`, it is easy and relatively fast to run a few cells, close the notebook, reopena and rerun to get back to where you were."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa1599a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import uemt.utils as utils\n",
    "from uemt.models.openai_lm import OpenAILM, MODELS\n",
    "from uemt.models.language_model import LM  # LM calls are cached!\n",
    "from iso639 import Lang\n",
    "import tiktoken\n",
    "import math\n",
    "import numpy as np\n",
    "from tqdm.contrib.concurrent import thread_map\n",
    "import re\n",
    "import random\n",
    "from tqdm import tqdm\n",
    "import logging\n",
    "from collections import Counter\n",
    "import uemt.evaluation.metrics as metrics\n",
    "from textwrap import indent\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from pathlib import Path\n",
    "from math import sqrt\n",
    "from matplotlib.ticker import PercentFormatter, MaxNLocator\n",
    "\n",
    "\n",
    "logging.basicConfig(level=logging.DEBUG)\n",
    "\n",
    "logger = utils.get_logger(\"uemt\", default=\"DEBUG\")\n",
    "\n",
    "LM.global_cache_path = \".cache-wiki-low-resource\"\n",
    "\n",
    "\n",
    "def num_tokens(text: str, model=\"o3\") -> int:\n",
    "    return len(tiktoken.encoding_for_model(model).encode(text))\n",
    "\n",
    "\n",
    "def avg(xs):\n",
    "    xs = list(xs)\n",
    "    return float(sum(xs) / len(xs))\n",
    "\n",
    "\n",
    "lms = {k: OpenAILM(model_name=k) for k in MODELS}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd790357",
   "metadata": {},
   "source": [
    "Constants except for prompts and `LANGS` which is determined after we scan for the langs scraped"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d7849b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_LANGS = 10\n",
    "\n",
    "LM_ORDER = {\n",
    "    '35t': 'GPT-3.5-turbo',\n",
    "    '4o-mini': 'GPT-4o-mini',\n",
    "    '4t': 'GPT-4-turbo',\n",
    "    '4': 'GPT-4',\n",
    "    '5-nano': 'GPT-5-nano',\n",
    "    '4.1-mini': 'GPT-4.1-mini',\n",
    "    '4o': 'GPT-4o',\n",
    "    'o3-min': 'GPT-o3-mini',\n",
    "    'c4o': 'ChatGPT-4o',\n",
    "    '4.1': 'GPT-4.1',\n",
    "    '5-chat': 'GPT-5-chat',\n",
    "    'o3': 'GPT-o3',\n",
    "    'o4-mini': 'GPT-o4-mini',\n",
    "    '5-mini': 'GPT-5-mini',\n",
    "    '5r': 'GPT-5'\n",
    "}\n",
    "\n",
    "MODELS = [m for m in MODELS if m in LM_ORDER]\n",
    "\n",
    "# use gpt-4 to evaluate translations with references because that's been shown to correlate well\n",
    "# with human judgment in prior work:\n",
    "REF_EVAL_MODEL = \"4\"\n",
    "\n",
    "LM_IDS = {k: str(i+1) for i, k in enumerate(LM_ORDER)}\n",
    "\n",
    "LM_API_IDS = {k: OpenAILM._resolve_meta_by_handle(k)[\"id\"] for k in LM_ORDER}\n",
    "\n",
    "WIKIPATH = \"data/wiki_source\"\n",
    "\n",
    "NUM_PERMUTATIONS = 10\n",
    "MAX_DOC_CHUNKS = 6  # Use at most first 6 paragraphs (pretty much every article has at least 6 paragraphs except a couple that have 5)\n",
    "PREFILL = True\n",
    "MAX_PARALLEL = 20\n",
    "VERBOSE = False\n",
    "PROGRESS_BAR = True\n",
    "SANITY_EVAL_LMS = list(LM_ORDER)  # eval all LMS on how well they can judge shuffle test\n",
    "\n",
    "EVAL_LM = lms[\"5r\"]  # judge the ShufflEval"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a046a3ba",
   "metadata": {},
   "source": [
    "# Load Wikipedia data\n",
    "And split into paragraphs using an LM!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52f59510",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The low-resourcenss languages by frequency were determined from # wikipedia articles from https://wikistats.wmcloud.org/display.php?t=wp\n",
    "\n",
    "LANGUAGES_BY_FREQUENCY = {\n",
    "    \"Tulu\": \"tcy\",\n",
    "    \"Dhivehi\": \"dv\",\n",
    "    \"Lao\": \"lo\",\n",
    "    \"Abkhazian\": \"ab\",\n",
    "    \"Uighur\": \"ug\",\n",
    "    \"Khmer\": \"km\",\n",
    "    \"Moroccan_Tamazight\": \"zgh\", # Standard Moroccan Tamazight\n",
    "    \"Sanskrit\": \"sa\",\n",
    "    \"Santali\": \"sat\",\n",
    "    \"Yakut\": \"sah\",\n",
    "    \"Sindhi\": \"sd\",\n",
    "    \"Assamese\": \"as\",\n",
    "    \"Oriya\": \"or\", # Oriya (macrolanguage)\n",
    "    \"Pushto\": \"ps\",\n",
    "    \"Sinhala\": \"si\",\n",
    "    \"Mongolian\": \"mn\",\n",
    "    \"Nepali\": \"ne\", # Nepali (macrolanguage)\n",
    "    \"Kannada\": \"kn\",\n",
    "    \"Chuvash\": \"cv\",\n",
    "    \"Panjabi\": \"pa\",\n",
    "    \"Bashkir\": \"ba\",\n",
    "    \"Nepal_Bhasa\": \"new\", # Nepal Bhasa\n",
    "    \"Western_Panjabi\": \"pnb\", # Western Panjabi\n",
    "    \"Kirghiz\": \"ky\",\n",
    "    \"Malayalam\": \"ml\",\n",
    "    \"Central_Kurdish\": \"ckb\", # Central Kurdish\n",
    "    \"Marathi\": \"mr\",\n",
    "    \"Burmese\": \"my\",\n",
    "    \"Telugu\": \"te\",\n",
    "    \"Tajik\": \"tg\",\n",
    "    \"Macedonian\": \"mk\",\n",
    "    \"Hindi\": \"hi\",\n",
    "    \"Bengali\": \"bn\",\n",
    "    \"Thai\": \"th\",\n",
    "    \"Tamil\": \"ta\",\n",
    "    \"Georgian\": \"ka\",\n",
    "    \"Urdu\": \"ur\",\n",
    "    \"Kazakh\": \"kk\",\n",
    "    \"Belarusian\": \"be\",\n",
    "    \"Modern_Greek\": \"el\", # Modern Greek (1453-)\n",
    "    \"Bulgarian\": \"bg\",\n",
    "    \"Armenian\": \"hy\",\n",
    "    \"Hebrew\": \"he\",\n",
    "    \"Korean\": \"ko\",\n",
    "    \"Persian\": \"fa\",\n",
    "    \"Arabic\": \"ar\",\n",
    "    \"Ukrainian\": \"uk\",\n",
    "    \"Japanese\": \"ja\",\n",
    "}\n",
    "len(LANGUAGES_BY_FREQUENCY)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "647fc5ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "wiki_paths = sorted(Path(WIKIPATH).glob(\"*.010.json\"))\n",
    "\n",
    "def get_language(p: Path) -> str:\n",
    "    # return str(p).removesuffix(\".json\").split(\".\")[-1]\n",
    "    short_name = p.stem.split(\".\")[0]\n",
    "    matches = [k for k, b in LANGUAGES_BY_FREQUENCY.items() if b == short_name]\n",
    "    if len(matches) != 1:\n",
    "        print(f\"{short_name}: Found {len(matches)} matches for {short_name}: {matches}\")\n",
    "        return \"\"\n",
    "    return matches[0]\n",
    "\n",
    "wiki_paths = {get_language(p): p for p in wiki_paths if get_language(p)}\n",
    "wiki_paths = dict(sorted(wiki_paths.items(), key=lambda x: list(LANGUAGES_BY_FREQUENCY).index(x[0])))\n",
    "wiki_data = {lang: {\"docs\": utils.load(p), \"path\": p, \"language\": lang, \"iso\": Lang(LANGUAGES_BY_FREQUENCY[lang]).pt3} for lang, p in wiki_paths.items()}\n",
    "for k, v in wiki_data.items():\n",
    "    n = len(v[\"docs\"])\n",
    "    v[\"docs\"] = [x for x in v['docs'] if x[\"en_ver\"]]\n",
    "    if len(v[\"docs\"]) != 10:\n",
    "        print(f\"{k}: {n} -> {len(v['docs'])}\")\n",
    "print(f\"Loaded {len(wiki_data):,} languages\")\n",
    "print(list(wiki_data))\n",
    "LANGS = list(wiki_data)[:NUM_LANGS]\n",
    "wiki_data = {lang: wiki_data[lang] for lang in LANGS}\n",
    "\n",
    "print(pd.DataFrame(sorted([(lang, wiki_data[lang][\"iso\"],) for lang in wiki_data])).to_latex(header=False, index=False))\n",
    "\n",
    "wiki_data = {lang: wiki_data[lang] for lang in LANGS}\n",
    "\n",
    "print(pd.DataFrame(sorted([(lang, wiki_data[lang][\"iso\"],) for lang in wiki_data])).to_latex(header=False, index=False))\n",
    "\n",
    "SPLIT_INTO_PARAGRAPHS_TEMPLATE = \"\"\"\n",
    "### Instructions\n",
    "\n",
    "You will be given a Wikipedia article. Your task is to insert '|PARAGRAPH|' at all the paragraph boundaries, otherwise keeping the article exactly as is (do not correct any errors).\n",
    "\n",
    "#### Rules\n",
    "\n",
    "1. **Paragraph Definition**  \n",
    "   - Two or more consecutive newlines always indicate a paragraph break.  \n",
    "   - Some paragraphs may also be separated by a single newline—these should be treated as breaks where appropriate.  \n",
    "   - A list or a table should be treated as a single paragraph.\n",
    "\n",
    "2. **Section Headings**  \n",
    "   - If a section heading appears, put a '|PARAGRAPH|' before the section heading but not right after it, so that it is merged onto the subsequent paragraph.\n",
    "\n",
    "3. **Formatting**  \n",
    "   - Other than you inserting '|PARAGRAPH|' at the paragraph boundaries, preserve all characters .\n",
    "\n",
    "<ARTICLE>\n",
    "{article}\n",
    "</ARTICLE>\n",
    "\n",
    "Now output just the *entire* article with '|PARAGRAPH|' inserted at the paragraph boundaries, nothing else. Do not elide any text.\n",
    "\"\"\".strip()\n",
    "\n",
    "def chunkify(text: str, max_chunks):\n",
    "    return [s.strip() for s in re.split(r\"\\n(?:\\s*\\n)+\", text) if s.strip()][:max_chunks]\n",
    "\n",
    "def remove_whitespace(text: str):\n",
    "    return re.sub(r\"\\s+\", \"\", text)\n",
    "\n",
    "def split_wikipedia_paragraphs(article: str, lm, max_attempts: int = 5, sep=\"|PARAGRAPH|\", verbose: bool = True):\n",
    "    \"\"\"\n",
    "    Splits a Wikipedia article into paragraphs using the language model.\n",
    "    Returns a list of [para, sep, para, sep, ...].\n",
    "    Retries up to max_attempts if validation fails.\n",
    "    \"\"\"\n",
    "    prompt = SPLIT_INTO_PARAGRAPHS_TEMPLATE.format(article=article)\n",
    "    # print(\"prompt\", prompt)\n",
    "    # print()\n",
    "    target_len = len(remove_whitespace(article))\n",
    "    options = []\n",
    "    for attempt in range(max_attempts):\n",
    "        output = lm.generate(prompt, seed=attempt, temperature=0.0).replace(f\"'{sep}'\", sep)\n",
    "        output = output.strip().removeprefix(\"<ARTICLE>\").split(\"</ARTICLE>\")[0].strip(\" `\\n\")\n",
    "        ret = [x.strip() for x in output.split(sep) if x.strip()]\n",
    "        if len(remove_whitespace(\"\".join(ret))) >= 0.9 * target_len:\n",
    "            # print(len(ret))\n",
    "            return ret\n",
    "        else:\n",
    "            options.append(ret)\n",
    "\n",
    "    options.append(chunkify(article, max_chunks=None))\n",
    "    return max(options, key=len)\n",
    "\n",
    "def _go(doc, model=\"4o\"):\n",
    "    doc[\"paragraphs\"] = split_wikipedia_paragraphs(doc[\"plain_text\"], lms[model])\n",
    "\n",
    "thread_map(_go, [d2 for v in wiki_data.values() for d in v[\"docs\"] for d2 in [d, d[\"en_ver\"]]], max_workers=50)\n",
    "for v in wiki_data.values():\n",
    "    for i, d in enumerate(v[\"docs\"]):\n",
    "        if len(d[\"paragraphs\"]) < 6:\n",
    "            print(v[\"language\"], f\"{i=}\", \"has only\", len(d[\"paragraphs\"]), \"paragraphs\")\n",
    "        if len(d[\"en_ver\"][\"paragraphs\"]) < 6:\n",
    "            print(v[\"language\"], \"en_ver\", f\"{i=}\", \"English version has only\", len(d[\"en_ver\"][\"paragraphs\"]), \"paragraphs\")\n",
    "print(\"Shrunk to\", len(wiki_data), \"languages\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32708d84",
   "metadata": {},
   "source": [
    "# Code to translate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b6bd948",
   "metadata": {},
   "outputs": [],
   "source": [
    "TRANSLATE_SYSTEM_MESSAGE = \"\"\"\n",
    "You are a translator. You output translations in the exact format:\n",
    "\n",
    "<ENGLISH_TRANSLATION>\n",
    "{translation}\n",
    "</ENGLISH_TRANSLATION>\n",
    "\n",
    "It is crucial to adhere to the format.\n",
    "\"\"\".strip()\n",
    "\n",
    "TRANSLATE_PROMPT = \"\"\"\n",
    "Translate the following text from {source_language} to English:\n",
    "\n",
    "<{SOURCE_LANGUAGE}>\n",
    "{text}\n",
    "</{SOURCE_LANGUAGE}>\n",
    "\n",
    "Your output should be in the following format:\n",
    "\n",
    "<ENGLISH_TRANSLATION>\n",
    "...\n",
    "</ENGLISH_TRANSLATION>\n",
    "\"\"\".strip()\n",
    "\n",
    "\n",
    "FAILURE_INDICATOR = \"<COULD NOT TRANSLATE>\"\n",
    "\n",
    "def to_english(\n",
    "    text: str,\n",
    "    source_lang: str,\n",
    "    model: LM,\n",
    "    retries: int = 3,\n",
    "    truncate_to_token_len: bool = True,\n",
    "    verbose: bool = True,\n",
    "    **kwargs\n",
    ") -> str:\n",
    "    \"\"\"\n",
    "    Translate a single text from the given source language to English using the provided language model.\n",
    "    Retries up to `retries` times if the translation is too short or empty.\n",
    "    \"\"\"\n",
    "    assert \" \" not in source_lang, f\"to_english: {source_lang} contains an ' '\"\n",
    "    retries = retries or 1\n",
    "    translation = None\n",
    "    if truncate_to_token_len and model.context_length:\n",
    "        text = model.truncate_to_token_len(text, model.context_length - 100)\n",
    "    for retry_seed in range(retries):\n",
    "        prompt = TRANSLATE_PROMPT.format(\n",
    "            source_language=source_lang.replace(\"_\", \" \"),\n",
    "            SOURCE_LANGUAGE=source_lang.upper(),\n",
    "            text=text,\n",
    "        )\n",
    "        result = model.generate(\n",
    "            prompt,\n",
    "            system_message=TRANSLATE_SYSTEM_MESSAGE,\n",
    "            seed=retry_seed,\n",
    "            **kwargs  # use default temperature unless overridden\n",
    "        )\n",
    "        translation = result\n",
    "        assert isinstance(translation, str)\n",
    "        try:\n",
    "            translation = translation.split(\"<ENGLISH_TRANSLATION>\")[1].split(\"</ENGLISH_TRANSLATION>\")[0].strip()\n",
    "            if (\n",
    "                (isinstance(translation, str) and len(translation) > 50)  # looks like an okay translation\n",
    "                or len(text.strip()) < 500  # too short to tell\n",
    "            ):\n",
    "                return translation\n",
    "        except IndexError as e:\n",
    "            if verbose:\n",
    "                print(f\"Error splitting translation for {model.model_name} [{retry_seed+1}/{retries}]: {e}\")\n",
    "                print(translation)\n",
    "\n",
    "    return translation or FAILURE_INDICATOR\n",
    "\n",
    "# thread_map(lambda i: to_english(\"Bonjour  \"*i, \"French\", lms[\"4o\"]), list(range(10)))\n",
    "\n",
    "# to_english(\"Bonjour monsier\", \"French\", lms[\"4o\"])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e062e98b",
   "metadata": {},
   "source": [
    "# Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a097b633",
   "metadata": {},
   "outputs": [],
   "source": [
    "MORE_PLAUSIBLE_TEMPLATE = \"\"\"\n",
    "We have a (possibly poor) English translation of an article in {language}, broken into segments.\n",
    "To make matters worse, we are not certain what order the segments should be in.\n",
    "\n",
    "Below are two orderings of the segments.\n",
    "Decide which ordering reads more natural and coherent.\n",
    "Reply with '1' or '2' only.\n",
    "\n",
    "<ORDERING1>\n",
    "{text1}\n",
    "</ORDERING1>\n",
    "\n",
    "<ORDERING2>\n",
    "{text2}\n",
    "</ORDERING2>\n",
    "\"\"\".strip()\n",
    "\n",
    "def format_perm(translation: list[str], perm: list[int] | None = None, joiner=\"\\n\\n\"):\n",
    "    if perm is None:\n",
    "        perm = list(range(len(translation)))\n",
    "\n",
    "    return joiner.join([translation[i].strip() for i in perm])\n",
    "\n",
    "def trans_perm(source: list[str], entry, lm, perm: list[int] | None = None, joiner=\"\\n\"):\n",
    "    return format_perm([to_english(c, entry[\"language\"], lm) for c in source], perm, joiner)\n",
    "\n",
    "\n",
    "EXPLAIN_REPLACEMENT = [\n",
    "    \"Reply with '1' or '2' only.\",\n",
    "    \"Output step-by-step reasoning of which is more plausible, citing the above material, and the final character you output should be your final answer, either '1' or '2'.\"\n",
    "]\n",
    "\n",
    "\n",
    "def make_plausible_prompt(text1: str, text2: str, language: str,background: str | None = None, reverse: bool = False, explain: bool = False) -> str:\n",
    "    prefix = background.strip() + \"\\n\\n\" if background else \"\"\n",
    "    if reverse:\n",
    "        text1, text2 = text2, text1\n",
    "    template = MORE_PLAUSIBLE_TEMPLATE\n",
    "    if explain:\n",
    "        assert template.count(EXPLAIN_REPLACEMENT[0]) == 1\n",
    "        template = template.replace(EXPLAIN_REPLACEMENT[0], EXPLAIN_REPLACEMENT[1])\n",
    "    ret = prefix + template.format(text1=text1.strip(), text2=text2.strip(), language=language)\n",
    "\n",
    "    # print(\"---- make_plausible_prompt\", num_tokens(ret))\n",
    "    # print(ret)\n",
    "    # print(\"^\"*100)\n",
    "    return ret\n",
    "\n",
    "def extract_prob(result) -> float:\n",
    "    if \"1\" in result == \"2\" in result:\n",
    "        logger.warning(f\"Warning: more_plausible response `{result}` should have EITHER '1' OR '2'\")\n",
    "        return 0.5\n",
    "    return 1.0 if \"2\" in result else 0.0\n",
    "\n",
    "\n",
    "\n",
    "def more_plausible(\n",
    "    text1: str,\n",
    "    text2: str,\n",
    "    language: str,\n",
    "    lm: LM,\n",
    "    verbose: bool = False,\n",
    "    explain: bool = False,\n",
    "    **kwargs\n",
    ") -> float:\n",
    "    \"\"\"\n",
    "    Compare two translations of the same passage and return the probability that the second\n",
    "    translation is more plausible. Uses order invariance and logprobs for reduced variance.\n",
    "    \"\"\"\n",
    "    assert not kwargs and not explain, \"for debugging assert\"\n",
    "    (forward, forward_details), (backward, backward_details) = [\n",
    "        lm.generate_with_details(\n",
    "            make_plausible_prompt(text1, text2, language, background=None, reverse=r, explain=explain),\n",
    "            **kwargs  # not using logprobs or temp=0 because of o3\n",
    "        )\n",
    "        for r in [False, True]\n",
    "    ]\n",
    "    if explain:\n",
    "        return forward, backward\n",
    "    if verbose:\n",
    "        print(forward_details)\n",
    "        if random.random() < .1:\n",
    "            prompt = make_plausible_prompt(text1, text2, language, background=None, reverse=False, explain=explain)\n",
    "            print(\"=\"*100)\n",
    "            print(prompt[:2000])\n",
    "            print(forward, backward, \"^\"*100)\n",
    "\n",
    "    return (extract_prob(forward) + 1 - extract_prob(backward)) / 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b579062b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# store chunks which are just paragraphs\n",
    "\n",
    "for lang in LANGS:\n",
    "    for doc in wiki_data[lang][\"docs\"]:\n",
    "        doc[\"chunks\"] = doc[\"paragraphs\"][:MAX_DOC_CHUNKS] # chunkify(doc[\"plain_text\"])\n",
    "        doc[\"en_ver\"][\"chunks\"] = doc[\"en_ver\"][\"paragraphs\"][:MAX_DOC_CHUNKS]\n",
    "    \n",
    "min_doc_chunks = min(min(len(doc[\"chunks\"]), len(doc[\"en_ver\"][\"chunks\"])) for v in wiki_data.values() for doc in v[\"docs\"])\n",
    "assert NUM_PERMUTATIONS < math.factorial(min_doc_chunks), \"too many permutations\"\n",
    "print(\"min_doc_chunks\", min_doc_chunks)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6bc37020",
   "metadata": {},
   "source": [
    "# First plot, just check how good the LMs are at judging the shuffle test \n",
    "No translation actually needed, just take English wikipedia text and shuffle first six paragraphs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fa71ba7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# See how good models are at distinguishing correct order just for raw English wikipedia version haha\n",
    "# uses up to MAX_CHUNKS english chunks\n",
    "\n",
    "def sanity_check_bakeoff(target: list[str], lang: str, eval_lm: str):\n",
    "    assert type(target) is list, \"source must be a list\"\n",
    "    assert all(type(s) is str for s in target), \"source must be a list of strings\"\n",
    "    perms = utils.generate_permutations(target, NUM_PERMUTATIONS)\n",
    "\n",
    "    return avg(\n",
    "        more_plausible(\n",
    "            format_perm(target, perm),\n",
    "            format_perm(target, None),  # identity\n",
    "            lang,\n",
    "            lms[eval_lm],\n",
    "            verbose=VERBOSE,\n",
    "        )\n",
    "        for perm in perms\n",
    "    )\n",
    "\n",
    "if PREFILL:\n",
    "    print(\"Prefilling\")\n",
    "    thread_map(\n",
    "        lambda x: sanity_check_bakeoff(*x),\n",
    "        [\n",
    "            (text[\"en_ver\"][\"chunks\"], lang, eval_lm)\n",
    "            for lang in LANGS\n",
    "            for text in wiki_data[lang][\"docs\"]\n",
    "            for eval_lm in SANITY_EVAL_LMS\n",
    "        ],\n",
    "        max_workers=MAX_PARALLEL,\n",
    "        disable=not PROGRESS_BAR,\n",
    "        desc=\"Prefilling translations\",\n",
    "    )\n",
    "    print(\"Done prefilling\", flush=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c84f753f",
   "metadata": {},
   "outputs": [],
   "source": [
    "sanity_check_scores = {\n",
    "    lang:  {eval_lm: [\n",
    "            sanity_check_bakeoff(text[\"en_ver\"][\"chunks\"], lang, eval_lm)\n",
    "            for text in tqdm(\n",
    "                wiki_data[lang][\"docs\"],\n",
    "                    desc=f\"Baking off {lang} {eval_lm}\",\n",
    "                    disable=not PROGRESS_BAR,\n",
    "                )\n",
    "            ]\n",
    "     for eval_lm in SANITY_EVAL_LMS} for lang in LANGS\n",
    "}\n",
    "\n",
    "{m: avg(avg(sanity_check_scores[lang][m]) for lang in LANGS) for m in SANITY_EVAL_LMS}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a238a3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "means = []\n",
    "cis = []\n",
    "model_names = list(SANITY_EVAL_LMS)[::-1]  # Invert the order of the models\n",
    "\n",
    "for m in model_names:\n",
    "    # Mean across languages (equal weight per language)\n",
    "    per_lang_scores = [score for lang in LANGS for score in sanity_check_scores[lang][m]]\n",
    "    mu = np.mean(per_lang_scores)\n",
    "    L = len(per_lang_scores)\n",
    "    s_lang = np.std(per_lang_scores)\n",
    "    se_lang = s_lang / sqrt(L)\n",
    "\n",
    "    means.append(mu)\n",
    "    cis.append(se_lang)\n",
    "\n",
    "means = np.array(means)\n",
    "cis = np.array(cis)\n",
    "\n",
    "# --- Plot ---\n",
    "fig, ax = plt.subplots(figsize=(9, 5))\n",
    "\n",
    "y = np.arange(len(model_names))\n",
    "# Decrease vertical spacing between bars by increasing bar_height and reducing gaps\n",
    "bar_height = 0.5  # Increase bar height (default is 0.8), but keep <1 to avoid overlap\n",
    "\n",
    "# Invert the order of the data for plotting\n",
    "bar_container = ax.barh(\n",
    "    y, means, height=bar_height, xerr=cis, capsize=5, edgecolor='black'\n",
    ")\n",
    "\n",
    "# Set y-tick labels to model names (already inverted)\n",
    "ax.set_yticks(y)\n",
    "ax.set_yticklabels([f\"{LM_API_IDS[m]} {LM_IDS[m]:>2}\" for m in model_names], va='center', fontsize=13)\n",
    "\n",
    "# Increase font sizes\n",
    "# ax.set_ylabel(\"LLM Being Tested\", fontsize=16)\n",
    "ax.set_xlabel(\"Accuracy\", fontsize=16)\n",
    "ax.set_title(\"English Wikipedia: LM scores on raw shuffle test\", fontsize=18)\n",
    "ax.tick_params(axis='y', labelsize=13)\n",
    "ax.tick_params(axis='x', labelsize=13)\n",
    "\n",
    "# Percent tick labels (scores assumed in [0,1])\n",
    "ax.xaxis.set_major_formatter(PercentFormatter(xmax=1.0))\n",
    "\n",
    "# Natural x-limits: just enough to cover mean ± CI with a bit of padding\n",
    "low = (means - cis).min()\n",
    "high = (means + cis).max()\n",
    "low = max(0.0, low)\n",
    "high = min(1.0, high)\n",
    "\n",
    "span = max(1e-9, high - low)\n",
    "pad = max(0.02, 0.05 * span)  # at least 2 percentage points (0.02 in fraction)\n",
    "xmin = max(0.0, low - pad)\n",
    "xmax = min(1.0, high + pad)\n",
    "ax.set_xlim(xmin, xmax)\n",
    "\n",
    "# Nice tick spacing\n",
    "ax.xaxis.set_major_locator(MaxNLocator(nbins=6))\n",
    "ax.grid(axis='x', linestyle='--', alpha=0.5)\n",
    "fig.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "!mkdir -p plots\n",
    "fig.savefig(\"plots/wiki_shuffle_test.pdf\", bbox_inches=\"tight\")\n",
    "fig.savefig(\"plots/wiki_shuffle_test.png\", bbox_inches=\"tight\", dpi=200)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74214379",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is how we sorted to order the LMs: \n",
    "# models = [x for x, y in sorted(list(zip(model_names, means)), key=lambda x: x[1])]\n",
    "# print(models)\n",
    "# {k: LM_ORDER[k] for k in models if k in LM_ORDER}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c2bacad",
   "metadata": {},
   "source": [
    "# Run ShufflEval!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a3e42fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bakeoff(source: list[str], entry: dict, m: str):\n",
    "    assert type(source) is list, \"source must be a list\"\n",
    "    assert all(type(s) is str for s in source), \"source must be a list of strings\"\n",
    "    perms = utils.generate_permutations(source, NUM_PERMUTATIONS)\n",
    "    lm = lms[m]\n",
    "    apply_perm = lambda perm: trans_perm(source, entry, lm, perm)\n",
    "\n",
    "    return avg(\n",
    "        more_plausible(\n",
    "            apply_perm(perm),\n",
    "            apply_perm(None),  # identity\n",
    "            entry[\"language\"],\n",
    "            EVAL_LM,\n",
    "            verbose=VERBOSE,\n",
    "        )\n",
    "        for perm in perms\n",
    "    )\n",
    "\n",
    "\n",
    "if PREFILL:\n",
    "    print(\"Prefilling\")\n",
    "    thread_map(\n",
    "        lambda x: bakeoff(*x),\n",
    "        [\n",
    "            (text[\"chunks\"], wiki_data[lang], m)\n",
    "            for lang in LANGS\n",
    "            for text in wiki_data[lang][\"docs\"]\n",
    "            for m in MODELS\n",
    "        ],\n",
    "        max_workers=MAX_PARALLEL,\n",
    "        disable=not PROGRESS_BAR,\n",
    "        desc=\"Prefilling translations\",\n",
    "    )\n",
    "else:\n",
    "    print(\"Not prefilling\\n\" * 5)\n",
    "\n",
    "print(\"Done prefilling\")\n",
    "\n",
    "\n",
    "scores = {\n",
    "    lang:  {m: [] for m in MODELS} for lang in list(wiki_data.keys())\n",
    "}\n",
    "\n",
    "\n",
    "for lang in LANGS:\n",
    "    entry = wiki_data[lang]\n",
    "    for m in MODELS:\n",
    "        scores[lang][m] = [\n",
    "            bakeoff(text[\"chunks\"], entry, m)\n",
    "            for text in tqdm(\n",
    "                entry[\"docs\"],\n",
    "                    desc=f\"Baking off {lang} {m}\",\n",
    "                    disable=not PROGRESS_BAR,\n",
    "                )\n",
    "            ]\n",
    "\n",
    "\n",
    "{lang: {m: avg(scores[lang][m]) for m in MODELS} for lang in LANGS}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35a1ca5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute reference-based scores\n",
    "\n",
    "def old_score(lang, model):\n",
    "    # print(lang, model)\n",
    "    scores = []\n",
    "    for entry in wiki_data[lang][\"docs\"]:\n",
    "        chunks = entry[\"chunks\"]\n",
    "        en_chunks = entry[\"en_ver\"][\"chunks\"]\n",
    "        if len(chunks) < len(en_chunks):\n",
    "            en_chunks = en_chunks[:len(chunks)]\n",
    "        scores.append(metrics.gemba_da_ref(\"\\n\\n\".join([to_english(c, lang, lms[model]) for c in chunks]), \"\\n\\n\".join(en_chunks), lms[REF_EVAL_MODEL]))\n",
    "    return scores\n",
    "    \n",
    "# prefill\n",
    "thread_map(\n",
    "    lambda x: old_score(*x),\n",
    "    [\n",
    "        (lang, m)\n",
    "        for m in MODELS\n",
    "        for lang in LANGS\n",
    "    ], max_workers=MAX_PARALLEL\n",
    ");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d714c204",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot model performance as translator\n",
    "\n",
    "x = [avg([avg(scores[lang][m]) for lang in LANGS]) for m in MODELS]\n",
    "y = [avg([avg(old_score(lang, m)) for lang in LANGS]) for m in MODELS]\n",
    "rho = np.corrcoef(\n",
    "    [\n",
    "        x,\n",
    "        y,\n",
    "    ]\n",
    ")[0, 1]\n",
    "print(\"correlation\", rho)\n",
    "\n",
    "# Fit a line to the data\n",
    "coeffs = np.polyfit(x, y, 1)\n",
    "line = np.poly1d(coeffs)\n",
    "x_fit = np.linspace(min(x), max(x), 100)\n",
    "y_fit = line(x_fit)\n",
    "\n",
    "scale = 1.25\n",
    "figsize = (4*scale, 3*scale)\n",
    "plt.figure(figsize=figsize)\n",
    "plt.plot(x_fit, y_fit, linestyle='--', color='#AAA')  # very light green\n",
    "\n",
    "plt.scatter(x, y)\n",
    "for i, m in enumerate(MODELS):\n",
    "    ha = 'right'\n",
    "    va = 'bottom'\n",
    "    delta = [0, 0]\n",
    "    j = LM_IDS[m]\n",
    "    # assert j == i + 1\n",
    "    if j == '13':\n",
    "        va = 'top'\n",
    "    if j == '10':\n",
    "        va = 'center'\n",
    "    if j == '7':\n",
    "        va, ha = 'top', 'left'\n",
    "    plt.text(x[i] + delta[0], y[i] + delta[1], j, fontsize=12, ha=ha, va=va)\n",
    "plt.ylabel(\"Avg. score with reference\")\n",
    "plt.xlabel(\"Avg. shuffle test binary accuracy\")\n",
    "plt.title(f\"LM scores as Wikipedia translators, $\\\\rho={rho:.2f}$\")\n",
    "plt.grid(True)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"plots/wiki_translator_scores.png\")\n",
    "plt.savefig(\"plots/wiki_translator_scores.pdf\")\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b57c2bec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot language performance\n",
    "\n",
    "x = [float(avg([avg(scores[lang][m]) for m in MODELS])) for lang in LANGS]\n",
    "y = [float(avg([avg(old_score(lang, m)) for m in MODELS])) for lang in LANGS]\n",
    "rho = np.corrcoef(\n",
    "    [\n",
    "        x,\n",
    "        y,\n",
    "    ]\n",
    ")[0, 1]\n",
    "\n",
    "# Fit a line to the data\n",
    "\n",
    "coeffs = np.polyfit(x, y, 1)\n",
    "line = np.poly1d(coeffs)\n",
    "x_fit = np.linspace(min(x)-0.01, max(x)+0.01, 100)\n",
    "y_fit = line(x_fit)\n",
    "plt.figure(figsize=figsize)\n",
    "plt.plot(x_fit, y_fit, linestyle='--', color='#AAA')  # very light green\n",
    "\n",
    "plt.scatter(x, y)\n",
    "for i, lang in enumerate(LANGS):\n",
    "    ha, va = 'right', 'bottom'\n",
    "    delta = [0, 0]\n",
    "    if lang == \"Santali\":\n",
    "        ha = 'left'\n",
    "    if lang == \"Tulu\":\n",
    "        va = 'top'\n",
    "        delta = [0, -0.1]\n",
    "    if lang == \"Khmer\":\n",
    "        va = 'center'\n",
    "        delta = [-0.005, -0.05]\n",
    "    if lang ==\"Lao\":\n",
    "        va = 'center'\n",
    "        delta = [-0.005, 0]\n",
    "\n",
    "    plt.text(\n",
    "        x[i] + delta[0], y[i] + delta[1], lang.replace(\"_\", \"\\n\"),\n",
    "        fontsize=12, ha=ha, va=va\n",
    "    )\n",
    "plt.ylabel(\"Avg. score compared to reference\")\n",
    "plt.xlabel(\"Avg. shuffle test binary accuracy\")\n",
    "plt.title(f\"Wikipedia language scores, $\\\\rho={rho:.2f}$\")\n",
    "plt.grid(True)\n",
    "plt.tight_layout()\n",
    "# plt.savefig(\"plots/wiki_language_scores.png\")\n",
    "# plt.savefig(\"plots/wiki_language_scores.pdf\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76ef2061",
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute per-article correlation (not grouped by LM or language)\n",
    "# get confidence interval with bootstrap sampling\n",
    "\n",
    "x2 = [s for m in MODELS for lang in LANGS for s in scores[lang][m]]\n",
    "y2 = [s for m in MODELS for lang in LANGS for s in old_score(lang, m)]\n",
    "rho2 = np.corrcoef(\n",
    "    [\n",
    "        x2,\n",
    "        y2,\n",
    "    ]\n",
    ")[0, 1]\n",
    "print(\"correlation\", rho2)\n",
    "\n",
    "def bootstrapped_corr(x, y, B=10000):\n",
    "    n = len(x)\n",
    "    rng = np.random.default_rng(0)\n",
    "    idx = rng.integers(0, n, size=(B, n))\n",
    "    r_boot = np.array([np.corrcoef(x[i], y[i])[0, 1] for i in idx])\n",
    "    ci_boot_lo, ci_boot_hi = np.nanpercentile(r_boot, [2.5, 97.5])\n",
    "    print(f\"Bootstrap 95% CI = [{ci_boot_lo:.4f}, {ci_boot_hi:.4f}]\", (ci_boot_lo + ci_boot_hi) / 2)\n",
    "\n",
    "print(bootstrapped_corr(np.array(x2), np.array(y2)))\n",
    "# plt.scatter(x2, y2, alpha=0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d4a6676",
   "metadata": {},
   "source": [
    "# Check out Hallucination in whole-doc translations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35648ac1",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\n",
    "    \"Scoring whole-document translations: Machine‑Translation Quality Estimation (MTQE) (sometimes called reference‑free evaluation)\"\n",
    ")\n",
    "\n",
    "MODELS = list(LM_ORDER)\n",
    "NUM_PARALLEL = 100\n",
    "PROGRESS_BAR = True\n",
    "PREFILL = True\n",
    "VERBOSE = False\n",
    "\n",
    "\n",
    "SINGLE_TEXT_ONLY = \"\"\"\n",
    "How good is this text in terms of fluency and coherence?\n",
    "Rate the following sentence on a scale from 0 to 100, where:\n",
    "* 0 = completely unintelligible or ungrammatical,\n",
    "* 50 = somewhat understandable with multiple errors,\n",
    "* 100 = perfectly fluent, clear, and grammatically correct.\n",
    "\n",
    "<TEXT>\n",
    "{translation}\n",
    "</TEXT>\n",
    "\n",
    "Just output an integer score between 0 and 100, inclusive, and nothing else.\n",
    "\"\"\".strip()\n",
    "\n",
    "\n",
    "def score_single_text(translation: str, model: LM = lms[REF_EVAL_MODEL]):\n",
    "    assert isinstance(translation, str)\n",
    "    prompt = SINGLE_TEXT_ONLY.format(translation=translation)\n",
    "    if VERBOSE and random.random() < 0.01:\n",
    "        print(prompt)\n",
    "        print(\"^\" * 100)\n",
    "    res = model.generate(prompt)\n",
    "    assert isinstance(res, str)\n",
    "    numbers = re.findall(r\"\\d+\", res)\n",
    "    assert numbers, \"Didn't get any numbers :-(\"\n",
    "    if len(numbers) > 1:\n",
    "        logger.warning(f\"Expected 1 number, got {len(numbers)}: `{res!r}`\")\n",
    "    return int(numbers[0])\n",
    "\n",
    "def whole_to_english(doc, lang, model):\n",
    "    chunks = doc[\"chunks\"]    \n",
    "    return to_english(\"\\n\\n\".join(chunks), lang, lms[model])\n",
    "\n",
    "def whole_comp(doc):\n",
    "    chunks = doc[\"chunks\"]\n",
    "    en_chunks = doc[\"en_ver\"][\"chunks\"]\n",
    "    if len(chunks) < len(en_chunks):\n",
    "        en_chunks = en_chunks[:len(chunks)]\n",
    "    return \"\\n\\n\".join(en_chunks)\n",
    "    \n",
    "\n",
    "def score_whole_with_ref(lang, model):\n",
    "    # print(lang, model)\n",
    "    scores = []\n",
    "    for entry in wiki_data[lang][\"docs\"]:\n",
    "        scores.append(metrics.gemba_da_ref(whole_to_english(entry, lang, model), whole_comp(entry), lms[REF_EVAL_MODEL]))\n",
    "    return scores\n",
    "\n",
    "def score_whole_without_ref(lang, model):\n",
    "    scores = []\n",
    "    for entry in wiki_data[lang][\"docs\"]:\n",
    "        scores.append(score_single_text(whole_to_english(entry, lang, model)))\n",
    "    return scores\n",
    "\n",
    "\n",
    "\n",
    "if PREFILL:\n",
    "    print(\"Prefilling\")\n",
    "    thread_map(\n",
    "        lambda x: score_whole_without_ref(x[0], x[1]),\n",
    "        [(l, m) for l in LANGS for m in MODELS],\n",
    "        max_workers=NUM_PARALLEL,\n",
    "        disable=not PROGRESS_BAR,\n",
    "        desc=\"Prefilling translations\",\n",
    "    )\n",
    "\n",
    "\n",
    "\n",
    "whole_scores_without_ref = {m: {l: score_whole_without_ref(l, m) for l in LANGS} for m in MODELS}\n",
    "\n",
    "\n",
    "df = pd.DataFrame(whole_scores_without_ref)\n",
    "\n",
    "\n",
    "df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "321e083a",
   "metadata": {},
   "outputs": [],
   "source": [
    "if PREFILL:\n",
    "    thread_map(\n",
    "        lambda x: score_whole_with_ref(x[0], x[1]),\n",
    "        [(l, m) for l in LANGS for m in MODELS],\n",
    "        max_workers=NUM_PARALLEL,\n",
    "        disable=not PROGRESS_BAR,\n",
    "        desc=\"Prefilling translations\",\n",
    "    )\n",
    "\n",
    "whole_scores_with_ref = {m: {l: score_whole_with_ref(l, m) for l in LANGS} for m in MODELS}\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "566135e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "langs = LANGS\n",
    "\n",
    "np.corrcoef(\n",
    "    [\n",
    "        [avg(avg(whole_scores_with_ref[m][lang]) for m in MODELS) for lang in langs],\n",
    "        [avg(avg(whole_scores_without_ref[m][lang]) for m in MODELS) for lang in langs],\n",
    "    ]\n",
    ")[0, 1], np.corrcoef(\n",
    "    [\n",
    "        [avg(avg(whole_scores_with_ref[m][lang]) for lang in langs) for m in MODELS],\n",
    "        [avg(avg(whole_scores_without_ref[m][lang]) for lang in langs) for m in MODELS],\n",
    "    ]\n",
    ")[0, 1], np.corrcoef(\n",
    "    [\n",
    "        [s for lang in langs for m in MODELS for s in whole_scores_with_ref[m][lang]],\n",
    "        [s for lang in langs for m in MODELS for s in whole_scores_without_ref[m][lang]],\n",
    "    ]\n",
    ")[0, 1]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc4016f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare data for GPT-5r\n",
    "x_5r = [s for lang in langs for s in whole_scores_with_ref[\"5r\"][lang]]\n",
    "y_5r = [s for lang in langs for s in whole_scores_without_ref[\"5r\"][lang]]\n",
    "mean_x_5r = np.mean(x_5r)\n",
    "mean_y_5r = np.mean(y_5r)\n",
    "\n",
    "# Prepare data for GPT-o4-mini\n",
    "x_o4mini = [s for lang in langs for s in whole_scores_with_ref[\"o4-mini\"][lang]]\n",
    "y_o4mini = [s for lang in langs for s in whole_scores_without_ref[\"o4-mini\"][lang]]\n",
    "mean_x_o4mini = np.mean(x_o4mini)\n",
    "mean_y_o4mini = np.mean(y_o4mini)\n",
    "\n",
    "plt.scatter(\n",
    "    x_5r,\n",
    "    y_5r,\n",
    "    label=\"GPT-5\",\n",
    ")\n",
    "plt.scatter(\n",
    "    x_o4mini,\n",
    "    y_o4mini,\n",
    "    label=\"GPT-o4-mini\",\n",
    "    alpha=0.4,\n",
    ")\n",
    "\n",
    "# Add means as large stars\n",
    "plt.scatter(\n",
    "    [mean_x_5r], [mean_y_5r],\n",
    "    marker='*', s=300, color='C0', edgecolor='black', label=\"GPT-5 mean\"\n",
    ")\n",
    "plt.scatter(\n",
    "    [mean_x_o4mini], [mean_y_o4mini],\n",
    "    marker='*', s=300, color='C1', edgecolor='black', label=\"GPT-o4-mini mean\"\n",
    ")\n",
    "\n",
    "plt.xlabel(\"Score with reference\")\n",
    "plt.ylabel(\"Score without reference\")\n",
    "plt.title(\"Whole-article translation\")\n",
    "plt.legend()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0e085c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "HALLUC_THRESHOLD = 90\n",
    "print(f\"{avg(y >= HALLUC_THRESHOLD for y in y_o4mini):.0%} of the translations of GPT-o4-mini scored ≥ {HALLUC_THRESHOLD} without reference\")\n",
    "print(f\"{avg(y >= HALLUC_THRESHOLD for y in y_5r):.0%} of the translations of GPT-5 scored ≥ {HALLUC_THRESHOLD} without reference\")\n",
    "\n",
    "\n",
    "# Filter once, explicitly, and avoid zip truncation risks\n",
    "x1 = np.asarray(x_o4mini)[np.asarray(y_o4mini) >= HALLUC_THRESHOLD]\n",
    "x2 = np.asarray(x_5r)[np.asarray(y_5r) >= HALLUC_THRESHOLD]\n",
    "\n",
    "# Use shared bin edges computed from the combined data\n",
    "# (use 'fd' or 'auto' for smarter bin widths)\n",
    "bin_edges = np.histogram_bin_edges(np.concatenate([x1, x2]), bins='fd')\n",
    "\n",
    "# Option A: filled, with transparency\n",
    "plt.hist(x1, bins=bin_edges, alpha=0.5, label=\"GPT-o4-mini\", edgecolor=\"white\", linewidth=0.5)\n",
    "plt.hist(x2, bins=bin_edges, alpha=0.4, label=\"GPT-5\",        edgecolor=\"white\", linewidth=0.5)\n",
    "\n",
    "# Option B (often clearer): outlines only—no occlusion\n",
    "# plt.hist(x1, bins=bin_edges, density=True, histtype=\"step\", linewidth=2, label=\"GPT-o4-mini\")\n",
    "# plt.hist(x2, bins=bin_edges, density=True, histtype=\"step\", linewidth=2, label=\"GPT-5\")\n",
    "\n",
    "plt.title(\"Hallucinations in whole-article translation (score without ref ≥ 90)\")\n",
    "\n",
    "plt.xlabel(\"Score with reference\")\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"plots/wiki_hallucinations.png\")\n",
    "plt.savefig(\"plots/wiki_hallucinations.pdf\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ee8ca86",
   "metadata": {},
   "source": [
    "## Display hallucination example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5cd2ac8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display hallucination example\n",
    "PRINT_EXAMPLE = False\n",
    "if PRINT_EXAMPLE:\n",
    "    _lang = \"Santali\"\n",
    "    entry = wiki_data[_lang]\n",
    "    doc = entry[\"docs\"][9]\n",
    "    source = doc[\"chunks\"]\n",
    "    print(\"Source:\", doc[\"stable_url\"])\n",
    "    print(doc[\"plain_text\"])\n",
    "    print(\"English source:\", doc[\"en_ver\"][\"stable_url\"])\n",
    "    print(whole_comp(doc))\n",
    "    print(\"^\" * 100)\n",
    "    print(\"Whole-article translations:\")\n",
    "    for m in LM_ORDER:\n",
    "        print(m, \"=\" * 80)\n",
    "        print(whole_to_english(doc, _lang, m), None)\n",
    "        print(\"-\" * 80)\n",
    "        print()\n",
    "    print(\"And now segment by segment:\\n\"*5)\n",
    "\n",
    "    for m in LM_ORDER:\n",
    "        print(m, \"=\" * 80)\n",
    "        print(trans_perm(source, entry, lms[m]), None)\n",
    "        print(\"-\" * 80)\n",
    "        print()\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
