{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import matplotlib.pyplot as plt\n",
    "import llm_fairness\n",
    "import numpy as np\n",
    "from collections import Counter\n",
    "import pandas as pd\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import matplotlib.colors as mcolors"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# fetch the imdb dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m2024-09-30 18:57:56.303\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mllm_fairness.data\u001b[0m:\u001b[36mfrom_name\u001b[0m:\u001b[36m28\u001b[0m - \u001b[1mLoading imdb — splits ['train', 'test', 'unsupervised']\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "dataset = llm_fairness.data.from_name(\"imdb-large\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m2024-09-30 19:00:33.358\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mllm_fairness.data\u001b[0m:\u001b[36mfrom_name\u001b[0m:\u001b[36m28\u001b[0m - \u001b[1mLoading imdb — splits ['train', 'test', 'unsupervised']\u001b[0m\n",
      "\u001b[32m2024-09-30 19:00:36.320\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mllm_fairness.data\u001b[0m:\u001b[36mfrom_name\u001b[0m:\u001b[36m70\u001b[0m - \u001b[1mExtract split: train | (25000, 1) | ['text']\u001b[0m\n",
      "\u001b[32m2024-09-30 19:00:39.558\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mllm_fairness.data\u001b[0m:\u001b[36mfrom_name\u001b[0m:\u001b[36m70\u001b[0m - \u001b[1mExtract split: test | (25000, 1) | ['text']\u001b[0m\n",
      "\u001b[32m2024-09-30 19:00:42.792\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mllm_fairness.data\u001b[0m:\u001b[36mfrom_name\u001b[0m:\u001b[36m70\u001b[0m - \u001b[1mExtract split: unsupervised | (50000, 1) | ['text']\u001b[0m\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "we trained a bpe with target vocab size: 32000\n",
      "our actual vocab size is 31946\n",
      "we have 30629 unique tokens\n"
     ]
    }
   ],
   "source": [
    "dataset = llm_fairness.data.from_name(\"imdb\")[\"train\"]\n",
    "tokenizer = llm_fairness.tokenizer.from_data(dataset, variant=\"BPE\", vocab_size=32000)\n",
    "dataset = dataset.map(\n",
    "    lambda x: tokenizer(\n",
    "        x[\"text\"],\n",
    "        truncation=True,\n",
    "        padding=\"max_length\",\n",
    "        max_length=128,\n",
    "    ),\n",
    "    batched=True,\n",
    "    remove_columns=dataset.column_names,\n",
    ")\n",
    "\n",
    "tok2id = {tokenizer.decode(tid): tid for _, tid in tokenizer._tokenizer.get_vocab().items()}\n",
    "id2tok = {tid : tokenizer.decode(tid) for _, tid in tokenizer._tokenizer.get_vocab().items()}\n",
    "\n",
    "tokens = []\n",
    "for seq in dataset: tokens.extend(seq['input_ids'])\n",
    "token_counts = collections.Counter(tokens)\n",
    "unique_tokens = set(tokens)\n",
    "\n",
    "print(f\"we trained a bpe with target vocab size: {32000}\")\n",
    "print(f\"our actual vocab size is {len(tok2id)}\")\n",
    "print(f\"we have {len(unique_tokens)} unique tokens\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# compute the cumulative frequency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "sorted_values = np.array(sorted(token_counts.values()))\n",
    "cumulative_freq = np.cumsum(sorted_values)\n",
    "total_freq = cumulative_freq[-1]\n",
    "\n",
    "threshold_50_index = np.argmax(cumulative_freq >= 0.50 * total_freq)\n",
    "threshold_80_index = np.argmax(cumulative_freq >= 0.80 * total_freq)\n",
    "threshold_95_index = np.argmax(cumulative_freq >= 0.95 * total_freq)\n",
    "\n",
    "threshold_50_percentile = threshold_50_index / len(sorted_values) * 100\n",
    "threshold_80_percentile = threshold_80_index / len(sorted_values) * 100\n",
    "threshold_95_percentile = threshold_95_index / len(sorted_values) * 100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# compute the normalized token counts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_tokens = sum(token_counts.values())\n",
    "normalized_token_counts = {token_id: count / total_tokens for token_id, count in token_counts.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Top 10 most frequent tokens:\n",
      "Token ID | Frequency | Decoded Value\n",
      "---------|-----------|--------------\n",
      "     116 | 0.042591 |  the\n",
      "   32001 | 0.040753 | </s>\n",
      "      12 | 0.033796 | ,\n",
      "      14 | 0.031022 | .\n",
      "     111 | 0.022233 |  a\n",
      "     137 | 0.021058 |  and\n",
      "     141 | 0.019692 |  of\n",
      "     139 | 0.017361 |  to\n",
      "     149 | 0.014855 |  is\n",
      "     155 | 0.013558 |  it\n"
     ]
    }
   ],
   "source": [
    "sorted_counts = sorted(normalized_token_counts.items(), key=lambda x: x[1], reverse=True)\n",
    "top_10 = sorted_counts[:10]\n",
    "print(\"Top 10 most frequent tokens:\")\n",
    "print(\"Token ID | Frequency | Decoded Value\")\n",
    "print(\"---------|-----------|--------------\")\n",
    "for token_id, freq in top_10:\n",
    "    decoded_value = id2tok[token_id]\n",
    "    print(f\"{token_id:8d} | {freq:.6f} | {decoded_value}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Bottom 10 least frequent tokens:\n",
      "Token ID | Frequency | Decoded Value\n",
      "---------|-----------|--------------\n",
      "   31452 | 3.125000e-07 |  equate\n",
      "   22004 | 3.125000e-07 |  unhapp\n",
      "   22869 | 3.125000e-07 | kei\n",
      "    4911 | 3.125000e-07 |  dyn\n",
      "   14230 | 3.125000e-07 |  bish\n",
      "    9560 | 3.125000e-07 | inel\n",
      "   25880 | 3.125000e-07 |  balsam\n",
      "   31035 | 3.125000e-07 |  orbach\n",
      "    2409 | 3.125000e-07 |  exe\n",
      "   27701 | 3.125000e-07 |  moorehead\n"
     ]
    }
   ],
   "source": [
    "bottom_10 = sorted_counts[-10:][::-1]\n",
    "print(\"Bottom 10 least frequent tokens:\")\n",
    "print(\"Token ID | Frequency | Decoded Value\")\n",
    "print(\"---------|-----------|--------------\")\n",
    "for token_id, freq in bottom_10:\n",
    "    decoded_value = id2tok[token_id]\n",
    "    print(f\"{token_id:8d} | {freq:.6e} | {decoded_value}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# select a few samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "i, 47, 0.002793125\n",
      " watched, 1122, 0.0004740625\n",
      " this, 175, 0.010771875\n",
      " film, 203, 0.005760625\n",
      " for, 196, 0.0054415625\n",
      " 45, 7823, 1.84375e-05\n",
      " minutes, 920, 0.000410625\n",
      " and, 137, 0.0210584375\n",
      " counted, 15933, 5e-06\n",
      " 9, 1488, 7.625e-05\n",
      " mull, 12423, 1.25e-06\n",
      "ets, 1424, 1.59375e-05\n",
      "., 14, 0.0310221875\n",
      " that, 174, 0.0095028125\n",
      "'s, 182, 0.0075909375\n",
      " a, 111, 0.0222334375\n",
      " mullet, 22665, 3.125e-06\n",
      " every, 466, 0.0005678125\n",
      " 5, 1399, 0.0001565625\n",
      " minutes, 920, 0.000410625\n",
      "., 14, 0.0310221875\n",
      " seriously, 2046, 0.000125625\n",
      " though, 724, 0.0005015625\n",
      ",, 12, 0.033795625\n",
      " this, 175, 0.010771875\n",
      " film, 203, 0.005760625\n",
      " is, 149, 0.0148553125\n",
      " living, 1834, 0.0001478125\n",
      " proof, 6655, 1.875e-05\n",
      " that, 174, 0.0095028125\n",
      " formula, 3827, 3.90625e-05\n",
      " works, 1631, 0.00016375\n",
      "., 14, 0.0310221875\n",
      " if, 346, 0.0017978125\n",
      " it, 155, 0.0135578125\n",
      " ain, 5943, 2.09375e-05\n",
      "'t, 226, 0.0044165625\n",
      " broke, 6348, 2.125e-05\n",
      ",, 12, 0.033795625\n",
      " it, 155, 0.0135578125\n",
      " don, 468, 0.0011278125\n",
      "'t, 226, 0.0044165625\n",
      " need, 801, 0.00021625\n",
      " fix, 5433, 1.46875e-05\n",
      "in, 113, 0.000506875\n",
      "., 14, 0.0310221875\n",
      " a, 111, 0.0222334375\n",
      " streetwise, 26475, 2.1875e-06\n",
      "-, 13, 0.0059025\n",
      "yet, 6583, 1.6875e-05\n",
      "-, 13, 0.0059025\n",
      "v, 60, 0.000108125\n",
      "ulner, 24353, 9.375e-07\n",
      "able, 376, 7.75e-05\n",
      " heroine, 4424, 3.03125e-05\n",
      ",, 12, 0.033795625\n",
      " a, 111, 0.0222334375\n",
      " hardened, 13745, 5.625e-06\n",
      " ex, 261, 7.4375e-05\n",
      "-, 13, 0.0059025\n",
      "cop, 7289, 1.15625e-05\n",
      " martial, 4032, 4.6875e-05\n",
      " arts, 4149, 4.75e-05\n",
      " master, 1456, 6.8125e-05\n",
      " with, 208, 0.0056378125\n",
      " a, 111, 0.0222334375\n",
      " heart, 1249, 0.000173125\n",
      " of, 141, 0.0196921875\n",
      " gold, 2082, 4.40625e-05\n",
      " and, 137, 0.0210584375\n",
      " a, 111, 0.0222334375\n",
      " serial, 3312, 6.28125e-05\n",
      " killer, 1542, 0.0001815625\n",
      " with, 208, 0.0056378125\n",
      " ', 527, 0.00098375\n",
      "iss, 1584, 1.75e-05\n",
      "ues, 1096, 1.21875e-05\n",
      "'., 2481, 8.34375e-05\n",
      " pure, 2397, 7.40625e-05\n",
      " magic, 2944, 6.65625e-05\n",
      "., 14, 0.0310221875\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n",
      "</s>, 32001, 0.0407528125\n"
     ]
    }
   ],
   "source": [
    "sample = dataset[6520]\n",
    "for i, tok in enumerate(sample['input_ids']):\n",
    "    print(f\"{id2tok[tok]}, {tok}, {normalized_token_counts[tok]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\freqcolor{24}{i} \\freqcolor{39}{ watched} \\freqcolor{12}{ this} \\freqcolor{17}{ film} \\freqcolor{18}{ for} \\freqcolor{66}{ 45} \\freqcolor{40}{ minutes} \\freqcolor{6}{ and} \\freqcolor{77}{ counted} \\freqcolor{54}{ 9} \\freqcolor{87}{ mull} \\freqcolor{67}{ets} \\freqcolor{3}{.} \\freqcolor{13}{ that} \\freqcolor{15}{'s} \\freqcolor{6}{ a} \\freqcolor{80}{ mullet} \\freqcolor{37}{ every} \\freqcolor{48}{ 5} \\freqcolor{40}{ minutes} \\freqcolor{3}{.} \\freqcolor{50}{ seriously} \\freqcolor{38}{ though} \\freqcolor{2}{,} \\freqcolor{12}{ this} \\freqcolor{17}{ film} \\freqcolor{9}{ is} \\freqcolor{48}{ living} \\freqcolor{66}{ proof} \\freqcolor{13}{ that} \\freqcolor{60}{ formula} \\freqcolor{48}{ works} \\freqcolor{3}{.} \\freqcolor{27}{ if} \\freqcolor{10}{ it} \\freqcolor{65}{ ain} \\freqcolor{20}{'t} \\freqcolor{65}{ broke} \\freqcolor{2}{,} \\freqcolor{10}{ it} \\freqcolor{31}{ don} \\freqcolor{20}{'t} \\freqcolor{45}{ need} \\freqcolor{68}{ fix} \\freqcolor{38}{in} \\freqcolor{3}{.} \\freqcolor{6}{ a} \\freqcolor{83}{ streetwise} \\freqcolor{17}{-} \\freqcolor{67}{yet} \\freqcolor{17}{-} \\freqcolor{51}{v} \\freqcolor{89}{ulner} \\freqcolor{54}{able} \\freqcolor{62}{ heroine} \\freqcolor{2}{,} \\freqcolor{6}{ a} \\freqcolor{76}{ hardened} \\freqcolor{54}{ ex} \\freqcolor{17}{-} \\freqcolor{70}{cop} \\freqcolor{58}{ martial} \\freqcolor{58}{ arts} \\freqcolor{55}{ master} \\freqcolor{18}{ with} \\freqcolor{6}{ a} \\freqcolor{47}{ heart} \\freqcolor{7}{ of} \\freqcolor{59}{ gold} \\freqcolor{6}{ and} \\freqcolor{6}{ a} \\freqcolor{56}{ serial} \\freqcolor{47}{ killer} \\freqcolor{18}{ with} \\freqcolor{32}{ '} \\freqcolor{66}{iss} \\freqcolor{69}{ues} \\freqcolor{53}{'.} \\freqcolor{54}{ pure} \\freqcolor{55}{ magic} \\freqcolor{3}{.} \\freqcolor{1}{</s>}\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "\n",
    "def generate_colored_latex(sample, id2tok, normalized_token_counts):\n",
    "    latex_output = []\n",
    "    \n",
    "    # Find the minimum non-zero frequency\n",
    "    min_freq = min(freq for freq in normalized_token_counts.values() if freq > 0)\n",
    "    # Find the maximum frequency\n",
    "    max_freq = max(normalized_token_counts.values())\n",
    "    \n",
    "    def scale_frequency(freq):\n",
    "        if freq == 0:\n",
    "            return 100  # Lowest frequency\n",
    "        # Use log scale, but shift by min_freq to handle very small values\n",
    "        log_freq = math.log(freq + min_freq)\n",
    "        log_min = math.log(min_freq)\n",
    "        log_max = math.log(max_freq + min_freq)\n",
    "        # Scale to 0-100 range and invert\n",
    "        return 100 - int(100 * (log_freq - log_min) / (log_max - log_min))\n",
    "\n",
    "    def should_add_space(token):\n",
    "        # Add space before tokens that don't start with '##' or other subword indicators\n",
    "        # Modify this condition based on your specific tokenizer's subword token format\n",
    "        return not token.startswith('##') and not token.startswith('▁')  # '▁' is used by some tokenizers like SentencePiece\n",
    "\n",
    "    for i, token_id in enumerate(sample['input_ids']):\n",
    "        token = id2tok[token_id]\n",
    "        frequency = normalized_token_counts[token_id]\n",
    "        \n",
    "        # Use our scaling function\n",
    "        color_value = scale_frequency(frequency)\n",
    "        \n",
    "        # Escape special LaTeX characters\n",
    "        escaped_token = token.replace('\\\\', '\\\\textbackslash{}')\n",
    "        escaped_token = escaped_token.replace('&', '\\\\&').replace('%', '\\\\%')\n",
    "        escaped_token = escaped_token.replace('$', '\\\\$').replace('#', '\\\\#')\n",
    "        escaped_token = escaped_token.replace('_', '\\\\_').replace('{', '\\\\{').replace('}', '\\\\}')\n",
    "        escaped_token = escaped_token.replace('~', '\\\\textasciitilde{}')\n",
    "        escaped_token = escaped_token.replace('^', '\\\\textasciicircum{}')\n",
    "        \n",
    "        # Generate the colored LaTeX command\n",
    "        colored_token = f\"\\\\freqcolor{{{color_value}}}{{{escaped_token}}}\"\n",
    "        \n",
    "        # Add space before token if it's not a subword (except for the first token)\n",
    "        if i > 0 and should_add_space(token):\n",
    "            latex_output.append(' ')\n",
    "        \n",
    "        latex_output.append(colored_token)\n",
    "        \n",
    "        # Check if the token is </s>, if so, break\n",
    "        if token == '</s>':\n",
    "            break\n",
    "\n",
    "    # Join all tokens without additional spaces and return\n",
    "    return ''.join(latex_output)\n",
    "\n",
    "# Example usage:\n",
    "sample = dataset[6520]\n",
    "latex_content = generate_colored_latex(sample, id2tok, normalized_token_counts)\n",
    "print(latex_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\freqcolor{24}{i} \\freqcolor{39}{ watched} \\freqcolor{12}{ this} \\freqcolor{17}{ film} \\freqcolor{18}{ for} \\freqcolor{66}{ 45} \\freqcolor{40}{ minutes} \\freqcolor{6}{ and} \\freqcolor{77}{ counted} \\freqcolor{54}{ 9} \\freqcolor{87}{ mull} \\freqcolor{67}{ets} \\freqcolor{3}{.} \\freqcolor{13}{ that} \\freqcolor{15}{'s} \\freqcolor{6}{ a} \\freqcolor{80}{ mullet} \\freqcolor{37}{ every} \\freqcolor{48}{ 5} \\freqcolor{40}{ minutes} \\freqcolor{3}{.} \\freqcolor{50}{ seriously} \\freqcolor{38}{ though} \\freqcolor{2}{,} \\freqcolor{12}{ this} \\freqcolor{17}{ film} \\freqcolor{9}{ is} \\freqcolor{48}{ living} \\freqcolor{66}{ proof} \\freqcolor{13}{ that} \\freqcolor{60}{ formula} \\freqcolor{48}{ works} \\freqcolor{3}{.} \\freqcolor{27}{ if} \\freqcolor{10}{ it} \\freqcolor{65}{ ain} \\freqcolor{20}{'t} \\freqcolor{65}{ broke} \\freqcolor{2}{,} \\freqcolor{10}{ it} \\freqcolor{31}{ don} \\freqcolor{20}{'t} \\freqcolor{45}{ need} \\freqcolor{68}{ fix} \\freqcolor{38}{in} \\freqcolor{3}{.} \\freqcolor{6}{ a} \\freqcolor{83}{ streetwise} \\freqcolor{17}{-} \\freqcolor{67}{yet} \\freqcolor{17}{-} \\freqcolor{51}{v} \\freqcolor{89}{ulner} \\freqcolor{54}{able} \\freqcolor{62}{ heroine} \\freqcolor{2}{,} \\freqcolor{6}{ a} \\freqcolor{76}{ hardened} \\freqcolor{54}{ ex} \\freqcolor{17}{-} \\freqcolor{70}{cop} \\freqcolor{58}{ martial} \\freqcolor{58}{ arts} \\freqcolor{55}{ master} \\freqcolor{18}{ with} \\freqcolor{6}{ a} \\freqcolor{47}{ heart} \\freqcolor{7}{ of} \\freqcolor{59}{ gold} \\freqcolor{6}{ and} \\freqcolor{6}{ a} \\freqcolor{56}{ serial} \\freqcolor{47}{ killer} \\freqcolor{18}{ with} \\freqcolor{32}{ '} \\freqcolor{66}{iss} \\freqcolor{69}{ues} \\freqcolor{53}{'.} \\freqcolor{54}{ pure} \\freqcolor{55}{ magic} \\freqcolor{3}{.} \\freqcolor{1}{</s>}\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "\n",
    "def generate_colored_latex(sample, id2tok, normalized_token_counts):\n",
    "    latex_output = []\n",
    "    \n",
    "    # Find the minimum non-zero frequency\n",
    "    min_freq = min(freq for freq in normalized_token_counts.values() if freq > 0)\n",
    "    # Find the maximum frequency\n",
    "    max_freq = max(normalized_token_counts.values())\n",
    "    \n",
    "    def scale_frequency(freq):\n",
    "        if freq == 0:\n",
    "            return 100  # Lowest frequency\n",
    "        # Use log scale, but shift by min_freq to handle very small values\n",
    "        log_freq = math.log(freq + min_freq)\n",
    "        log_min = math.log(min_freq)\n",
    "        log_max = math.log(max_freq + min_freq)\n",
    "        # Scale to 0-100 range and invert\n",
    "        return 100 - int(100 * (log_freq - log_min) / (log_max - log_min))\n",
    "\n",
    "    def is_start_of_word(token):\n",
    "        # Modify this based on your tokenizer's conventions\n",
    "        return not token.startswith('##') and not token.startswith('▁')\n",
    "\n",
    "    for i, token_id in enumerate(sample['input_ids']):\n",
    "        token = id2tok[token_id]\n",
    "        frequency = normalized_token_counts[token_id]\n",
    "        \n",
    "        # Use our scaling function\n",
    "        color_value = scale_frequency(frequency)\n",
    "        \n",
    "        # Escape special LaTeX characters\n",
    "        escaped_token = token.replace('\\\\', '\\\\textbackslash{}')\n",
    "        escaped_token = escaped_token.replace('&', '\\\\&').replace('%', '\\\\%')\n",
    "        escaped_token = escaped_token.replace('$', '\\\\$').replace('#', '\\\\#')\n",
    "        escaped_token = escaped_token.replace('_', '\\\\_').replace('{', '\\\\{').replace('}', '\\\\}')\n",
    "        escaped_token = escaped_token.replace('~', '\\\\textasciitilde{}')\n",
    "        escaped_token = escaped_token.replace('^', '\\\\textasciicircum{}')\n",
    "        \n",
    "        # Generate the colored LaTeX command\n",
    "        colored_token = f\"\\\\freqcolor{{{color_value}}}{{{escaped_token}}}\"\n",
    "        \n",
    "        # Add space before \\freqcolor if it's the start of a new word (except for the first token)\n",
    "        if i > 0 and is_start_of_word(token):\n",
    "            latex_output.append(' ')\n",
    "        \n",
    "        latex_output.append(colored_token)\n",
    "        \n",
    "        # Check if the token is </s>, if so, break\n",
    "        if token == '</s>':\n",
    "            break\n",
    "\n",
    "    # Join all tokens and return\n",
    "    return ''.join(latex_output)\n",
    "\n",
    "# Example usage:\n",
    "sample = dataset[6520]\n",
    "latex_content = generate_colored_latex(sample, id2tok, normalized_token_counts)\n",
    "print(latex_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\freqcolor{24}{i} \\freqcolor{39}{ watched} \\freqcolor{12}{ this} \\freqcolor{17}{ film} \\freqcolor{18}{ for} \\freqcolor{66}{ 45} \\freqcolor{40}{ minutes} \\freqcolor{6}{ and} \\freqcolor{77}{ counted} \\freqcolor{54}{ 9} \\freqcolor{87}{ mull} \\freqcolor{67}{ets} \\freqcolor{3}{.} \\freqcolor{13}{ that} \\freqcolor{15}{'s} \\freqcolor{6}{ a} \\freqcolor{80}{ mullet} \\freqcolor{37}{ every} \\freqcolor{48}{ 5} \\freqcolor{40}{ minutes} \\freqcolor{3}{.} \\freqcolor{50}{ seriously} \\freqcolor{38}{ though} \\freqcolor{2}{,} \\freqcolor{12}{ this} \\freqcolor{17}{ film} \\freqcolor{9}{ is} \\freqcolor{48}{ living} \\freqcolor{66}{ proof} \\freqcolor{13}{ that} \\freqcolor{60}{ formula} \\freqcolor{48}{ works} \\freqcolor{3}{.} \\freqcolor{27}{ if} \\freqcolor{10}{ it} \\freqcolor{65}{ ain} \\freqcolor{20}{'t} \\freqcolor{65}{ broke} \\freqcolor{2}{,} \\freqcolor{10}{ it} \\freqcolor{31}{ don} \\freqcolor{20}{'t} \\freqcolor{45}{ need} \\freqcolor{68}{ fix} \\freqcolor{38}{in} \\freqcolor{3}{.} \\freqcolor{6}{ a} \\freqcolor{83}{ streetwise} \\freqcolor{17}{-} \\freqcolor{67}{yet} \\freqcolor{17}{-} \\freqcolor{51}{v} \\freqcolor{89}{ulner} \\freqcolor{54}{able} \\freqcolor{62}{ heroine} \\freqcolor{2}{,} \\freqcolor{6}{ a} \\freqcolor{76}{ hardened} \\freqcolor{54}{ ex} \\freqcolor{17}{-} \\freqcolor{70}{cop} \\freqcolor{58}{ martial} \\freqcolor{58}{ arts} \\freqcolor{55}{ master} \\freqcolor{18}{ with} \\freqcolor{6}{ a} \\freqcolor{47}{ heart} \\freqcolor{7}{ of} \\freqcolor{59}{ gold} \\freqcolor{6}{ and} \\freqcolor{6}{ a} \\freqcolor{56}{ serial} \\freqcolor{47}{ killer} \\freqcolor{18}{ with} \\freqcolor{32}{ '} \\freqcolor{66}{iss} \\freqcolor{69}{ues} \\freqcolor{53}{'.} \\freqcolor{54}{ pure} \\freqcolor{55}{ magic} \\freqcolor{3}{.} \\freqcolor{1}{</s>}\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "\n",
    "def generate_colored_latex(sample, id2tok, normalized_token_counts):\n",
    "    latex_output = []\n",
    "    \n",
    "    # Find the minimum non-zero frequency\n",
    "    min_freq = min(freq for freq in normalized_token_counts.values() if freq > 0)\n",
    "    # Find the maximum frequency\n",
    "    max_freq = max(normalized_token_counts.values())\n",
    "    \n",
    "    def scale_frequency(freq):\n",
    "        if freq == 0:\n",
    "            return 100  # Lowest frequency\n",
    "        # Use log scale, but shift by min_freq to handle very small values\n",
    "        log_freq = math.log(freq + min_freq)\n",
    "        log_min = math.log(min_freq)\n",
    "        log_max = math.log(max_freq + min_freq)\n",
    "        # Scale to 0-100 range and invert\n",
    "        return 100 - int(100 * (log_freq - log_min) / (log_max - log_min))\n",
    "\n",
    "    def should_add_space(token):\n",
    "        # Add space before tokens that don't start with '##' or other subword indicators\n",
    "        # Modify this condition based on your specific tokenizer's subword token format\n",
    "        return not token.startswith('##') and not token.startswith('▁')\n",
    "\n",
    "    for i, token_id in enumerate(sample['input_ids']):\n",
    "        token = id2tok[token_id]\n",
    "        frequency = normalized_token_counts[token_id]\n",
    "        \n",
    "        # Use our scaling function\n",
    "        color_value = scale_frequency(frequency)\n",
    "        \n",
    "        # Escape special LaTeX characters\n",
    "        escaped_token = token.replace('\\\\', '\\\\textbackslash{}')\n",
    "        escaped_token = escaped_token.replace('&', '\\\\&').replace('%', '\\\\%')\n",
    "        escaped_token = escaped_token.replace('$', '\\\\$').replace('#', '\\\\#')\n",
    "        escaped_token = escaped_token.replace('_', '\\\\_').replace('{', '\\\\{').replace('}', '\\\\}')\n",
    "        escaped_token = escaped_token.replace('~', '\\\\textasciitilde{}')\n",
    "        escaped_token = escaped_token.replace('^', '\\\\textasciicircum{}')\n",
    "        \n",
    "        # Generate the colored LaTeX command\n",
    "        colored_token = f\"\\\\freqcolor{{{color_value}}}{{{escaped_token}}}\"\n",
    "        \n",
    "        # Add space before \\freqcolor if needed\n",
    "        if i > 0 and should_add_space(token):\n",
    "            latex_output.append(' ')\n",
    "        \n",
    "        latex_output.append(colored_token)\n",
    "        \n",
    "        # Check if the token is </s>, if so, break\n",
    "        if token == '</s>':\n",
    "            break\n",
    "\n",
    "    # Join all tokens and return\n",
    "    return ''.join(latex_output)\n",
    "\n",
    "# Example usage:\n",
    "sample = dataset[6520]\n",
    "latex_content = generate_colored_latex(sample, id2tok, normalized_token_counts)\n",
    "print(latex_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\freqcolor{92}{i} \\freqcolor{86}{ watched} \\freqcolor{96}{ this} \\freqcolor{94}{ film} \\freqcolor{94}{ for} \\freqcolor{72}{ 45} \\freqcolor{86}{ minutes} \\freqcolor{98}{ and} \\freqcolor{64}{ counted} \\freqcolor{79}{ 9} \\freqcolor{52}{ mull} \\freqcolor{71}{ets} \\freqcolor{99}{.} \\freqcolor{96}{ that} \\freqcolor{95}{'s} \\freqcolor{98}{ a} \\freqcolor{61}{ mullet} \\freqcolor{87}{ every} \\freqcolor{82}{ 5} \\freqcolor{86}{ minutes} \\freqcolor{99}{.} \\freqcolor{81}{ seriously} \\freqcolor{86}{ though} \\freqcolor{99}{,} \\freqcolor{96}{ this} \\freqcolor{94}{ film} \\freqcolor{97}{ is} \\freqcolor{82}{ living} \\freqcolor{72}{ proof} \\freqcolor{96}{ that} \\freqcolor{76}{ formula} \\freqcolor{82}{ works} \\freqcolor{99}{.} \\freqcolor{91}{ if} \\freqcolor{96}{ it} \\freqcolor{73}{ ain} \\freqcolor{93}{'t} \\freqcolor{73}{ broke} \\freqcolor{99}{,} \\freqcolor{96}{ it} \\freqcolor{89}{ don} \\freqcolor{93}{'t} \\freqcolor{83}{ need} \\freqcolor{71}{ fix} \\freqcolor{86}{in} \\freqcolor{99}{.} \\freqcolor{98}{ a} \\freqcolor{58}{ streetwise} \\freqcolor{94}{-} \\freqcolor{72}{yet} \\freqcolor{94}{-} \\freqcolor{80}{v} \\freqcolor{49}{ulner} \\freqcolor{79}{able} \\freqcolor{75}{ heroine} \\freqcolor{99}{,} \\freqcolor{98}{ a} \\freqcolor{65}{ hardened} \\freqcolor{79}{ ex} \\freqcolor{94}{-} \\freqcolor{70}{cop} \\freqcolor{77}{ martial} \\freqcolor{77}{ arts} \\freqcolor{78}{ master} \\freqcolor{94}{ with} \\freqcolor{98}{ a} \\freqcolor{82}{ heart} \\freqcolor{97}{ of} \\freqcolor{77}{ gold} \\freqcolor{98}{ and} \\freqcolor{98}{ a} \\freqcolor{78}{ serial} \\freqcolor{83}{ killer} \\freqcolor{94}{ with} \\freqcolor{89}{ '} \\freqcolor{72}{iss} \\freqcolor{70}{ues} \\freqcolor{79}{'.} \\freqcolor{79}{ pure} \\freqcolor{78}{ magic} \\freqcolor{99}{.} \\freqcolor{99}{</s>}\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "\n",
    "def generate_colored_latex(sample, id2tok, normalized_token_counts):\n",
    "    latex_output = []\n",
    "    \n",
    "    # Find the minimum non-zero frequency\n",
    "    min_freq = min(freq for freq in normalized_token_counts.values() if freq > 0)\n",
    "    # Find the maximum frequency\n",
    "    max_freq = max(normalized_token_counts.values())\n",
    "    \n",
    "    def emphasis_rare_tokens_scale(freq):\n",
    "        if freq == 0:\n",
    "            return 100  # Highest blue for zero frequency\n",
    "        \n",
    "        # Take the log of frequency\n",
    "        log_freq = math.log(freq)\n",
    "        log_min = math.log(min_freq)\n",
    "        log_max = math.log(max_freq)\n",
    "        \n",
    "        # Normalize log frequency to 0-1 range\n",
    "        normalized = (log_freq - log_min) / (log_max - log_min)\n",
    "        \n",
    "        # Apply exponential scaling to emphasize rare tokens\n",
    "        scaled = normalized ** 0.3  # Adjust this exponent to control the emphasis\n",
    "        \n",
    "        # Scale to 0-100 range\n",
    "        return int(scaled * 100)\n",
    "\n",
    "    def is_start_of_word(token):\n",
    "        # Modify this based on your tokenizer's conventions\n",
    "        return not token.startswith('##') and not token.startswith('▁')\n",
    "\n",
    "    for i, token_id in enumerate(sample['input_ids']):\n",
    "        token = id2tok[token_id]\n",
    "        frequency = normalized_token_counts[token_id]\n",
    "        \n",
    "        # Use our new scaling function\n",
    "        color_value = emphasis_rare_tokens_scale(frequency)\n",
    "        \n",
    "        # Escape special LaTeX characters\n",
    "        escaped_token = token.replace('\\\\', '\\\\textbackslash{}')\n",
    "        escaped_token = escaped_token.replace('&', '\\\\&').replace('%', '\\\\%')\n",
    "        escaped_token = escaped_token.replace('$', '\\\\$').replace('#', '\\\\#')\n",
    "        escaped_token = escaped_token.replace('_', '\\\\_').replace('{', '\\\\{').replace('}', '\\\\}')\n",
    "        escaped_token = escaped_token.replace('~', '\\\\textasciitilde{}')\n",
    "        escaped_token = escaped_token.replace('^', '\\\\textasciicircum{}')\n",
    "        \n",
    "        # Generate the colored LaTeX command\n",
    "        colored_token = f\"\\\\freqcolor{{{color_value}}}{{{escaped_token}}}\"\n",
    "        \n",
    "        # Add space before \\freqcolor if it's the start of a new word (except for the first token)\n",
    "        if i > 0 and is_start_of_word(token):\n",
    "            latex_output.append(' ')\n",
    "        \n",
    "        latex_output.append(colored_token)\n",
    "        \n",
    "        # Check if the token is </s>, if so, break\n",
    "        if token == '</s>':\n",
    "            break\n",
    "\n",
    "    # Join all tokens and return\n",
    "    return ''.join(latex_output)\n",
    "\n",
    "# Example usage:\n",
    "sample = dataset[6520]\n",
    "latex_content = generate_colored_latex(sample, id2tok, normalized_token_counts)\n",
    "print(latex_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm-wd-fairness-eaiv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
