{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import matplotlib.pyplot as plt\n",
    "import llm_fairness\n",
    "import numpy as np\n",
    "from collections import Counter\n",
    "import pandas as pd\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import matplotlib.colors as mcolors"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# fetch the imdb dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = llm_fairness.data.from_name(\"imdb-large\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = llm_fairness.data.from_name(\"imdb\")[\"train\"]\n",
    "tokenizer = llm_fairness.tokenizer.from_data(dataset, variant=\"BPE\", vocab_size=32000)\n",
    "dataset = dataset.map(\n",
    "    lambda x: tokenizer(\n",
    "        x[\"text\"],\n",
    "        truncation=True,\n",
    "        padding=\"max_length\",\n",
    "        max_length=128,\n",
    "    ),\n",
    "    batched=True,\n",
    "    remove_columns=dataset.column_names,\n",
    ")\n",
    "\n",
    "tok2id = {tokenizer.decode(tid): tid for _, tid in tokenizer._tokenizer.get_vocab().items()}\n",
    "id2tok = {tid : tokenizer.decode(tid) for _, tid in tokenizer._tokenizer.get_vocab().items()}\n",
    "\n",
    "tokens = []\n",
    "for seq in dataset: tokens.extend(seq['input_ids'])\n",
    "token_counts = collections.Counter(tokens)\n",
    "unique_tokens = set(tokens)\n",
    "\n",
    "print(f\"we trained a bpe with target vocab size: {32000}\")\n",
    "print(f\"our actual vocab size is {len(tok2id)}\")\n",
    "print(f\"we have {len(unique_tokens)} unique tokens\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# compute the cumulative frequency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sorted_values = np.array(sorted(token_counts.values()))\n",
    "cumulative_freq = np.cumsum(sorted_values)\n",
    "total_freq = cumulative_freq[-1]\n",
    "\n",
    "threshold_50_index = np.argmax(cumulative_freq >= 0.50 * total_freq)\n",
    "threshold_80_index = np.argmax(cumulative_freq >= 0.80 * total_freq)\n",
    "threshold_95_index = np.argmax(cumulative_freq >= 0.95 * total_freq)\n",
    "\n",
    "threshold_50_percentile = threshold_50_index / len(sorted_values) * 100\n",
    "threshold_80_percentile = threshold_80_index / len(sorted_values) * 100\n",
    "threshold_95_percentile = threshold_95_index / len(sorted_values) * 100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# compute the normalized token counts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_tokens = sum(token_counts.values())\n",
    "normalized_token_counts = {token_id: count / total_tokens for token_id, count in token_counts.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sorted_counts = sorted(normalized_token_counts.items(), key=lambda x: x[1], reverse=True)\n",
    "top_10 = sorted_counts[:10]\n",
    "print(\"Top 10 most frequent tokens:\")\n",
    "print(\"Token ID | Frequency | Decoded Value\")\n",
    "print(\"---------|-----------|--------------\")\n",
    "for token_id, freq in top_10:\n",
    "    decoded_value = id2tok[token_id]\n",
    "    print(f\"{token_id:8d} | {freq:.6f} | {decoded_value}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bottom_10 = sorted_counts[-10:][::-1]\n",
    "print(\"Bottom 10 least frequent tokens:\")\n",
    "print(\"Token ID | Frequency | Decoded Value\")\n",
    "print(\"---------|-----------|--------------\")\n",
    "for token_id, freq in bottom_10:\n",
    "    decoded_value = id2tok[token_id]\n",
    "    print(f\"{token_id:8d} | {freq:.6e} | {decoded_value}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# select a few samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = dataset[6520]\n",
    "for i, tok in enumerate(sample['input_ids']):\n",
    "    print(f\"{id2tok[tok]}, {tok}, {normalized_token_counts[tok]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "def generate_colored_latex(sample, id2tok, normalized_token_counts):\n",
    "    latex_output = []\n",
    "    \n",
    "    # Find the minimum non-zero frequency\n",
    "    min_freq = min(freq for freq in normalized_token_counts.values() if freq > 0)\n",
    "    # Find the maximum frequency\n",
    "    max_freq = max(normalized_token_counts.values())\n",
    "    \n",
    "    def scale_frequency(freq):\n",
    "        if freq == 0:\n",
    "            return 100  # Lowest frequency\n",
    "        # Use log scale, but shift by min_freq to handle very small values\n",
    "        log_freq = math.log(freq + min_freq)\n",
    "        log_min = math.log(min_freq)\n",
    "        log_max = math.log(max_freq + min_freq)\n",
    "        # Scale to 0-100 range and invert\n",
    "        return 100 - int(100 * (log_freq - log_min) / (log_max - log_min))\n",
    "\n",
    "    def should_add_space(token):\n",
    "        # Add space before tokens that don't start with '##' or other subword indicators\n",
    "        # Modify this condition based on your specific tokenizer's subword token format\n",
    "        return not token.startswith('##') and not token.startswith('▁')  # '▁' is used by some tokenizers like SentencePiece\n",
    "\n",
    "    for i, token_id in enumerate(sample['input_ids']):\n",
    "        token = id2tok[token_id]\n",
    "        frequency = normalized_token_counts[token_id]\n",
    "        \n",
    "        # Use our scaling function\n",
    "        color_value = scale_frequency(frequency)\n",
    "        \n",
    "        # Escape special LaTeX characters\n",
    "        escaped_token = token.replace('\\\\', '\\\\textbackslash{}')\n",
    "        escaped_token = escaped_token.replace('&', '\\\\&').replace('%', '\\\\%')\n",
    "        escaped_token = escaped_token.replace('$', '\\\\$').replace('#', '\\\\#')\n",
    "        escaped_token = escaped_token.replace('_', '\\\\_').replace('{', '\\\\{').replace('}', '\\\\}')\n",
    "        escaped_token = escaped_token.replace('~', '\\\\textasciitilde{}')\n",
    "        escaped_token = escaped_token.replace('^', '\\\\textasciicircum{}')\n",
    "        \n",
    "        # Generate the colored LaTeX command\n",
    "        colored_token = f\"\\\\freqcolor{{{color_value}}}{{{escaped_token}}}\"\n",
    "        \n",
    "        # Add space before token if it's not a subword (except for the first token)\n",
    "        if i > 0 and should_add_space(token):\n",
    "            latex_output.append(' ')\n",
    "        \n",
    "        latex_output.append(colored_token)\n",
    "        \n",
    "        # Check if the token is </s>, if so, break\n",
    "        if token == '</s>':\n",
    "            break\n",
    "\n",
    "    # Join all tokens without additional spaces and return\n",
    "    return ''.join(latex_output)\n",
    "\n",
    "# Example usage:\n",
    "sample = dataset[6520]\n",
    "latex_content = generate_colored_latex(sample, id2tok, normalized_token_counts)\n",
    "print(latex_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_colored_latex(sample, id2tok, normalized_token_counts):\n",
    "    latex_output = []\n",
    "    \n",
    "    # Calculate the median token occurrence\n",
    "    median_freq = np.median(list(normalized_token_counts.values()))\n",
    "    \n",
    "    def should_add_space(token):\n",
    "        return not token.startswith('##') and not token.startswith('▁')\n",
    "\n",
    "    for i, token_id in enumerate(sample['input_ids']):\n",
    "        token = id2tok[token_id]\n",
    "        frequency = normalized_token_counts[token_id]\n",
    "        \n",
    "        # Determine color based on median frequency\n",
    "        color = \"blue\" if frequency < median_freq else \"red\"\n",
    "        \n",
    "        # Escape special LaTeX characters\n",
    "        escaped_token = token.replace('\\\\', '\\\\textbackslash{}')\n",
    "        escaped_token = escaped_token.replace('&', '\\\\&').replace('%', '\\\\%')\n",
    "        escaped_token = escaped_token.replace('$', '\\\\$').replace('#', '\\\\#')\n",
    "        escaped_token = escaped_token.replace('_', '\\\\_').replace('{', '\\\\{').replace('}', '\\\\}')\n",
    "        escaped_token = escaped_token.replace('~', '\\\\textasciitilde{}')\n",
    "        escaped_token = escaped_token.replace('^', '\\\\textasciicircum{}')\n",
    "        \n",
    "        # Generate the colored LaTeX command\n",
    "        colored_token = f\"\\\\textcolor{{{color}}}{{{escaped_token}}}\"\n",
    "        \n",
    "        # Add space before token if it's not a subword (except for the first token)\n",
    "        if i > 0 and should_add_space(token):\n",
    "            latex_output.append(' ')\n",
    "        \n",
    "        latex_output.append(colored_token)\n",
    "        \n",
    "        # Check if the token is </s>, if so, break\n",
    "        if token == '</s>':\n",
    "            break\n",
    "\n",
    "    # Join all tokens without additional spaces and return\n",
    "    return ''.join(latex_output)\n",
    "\n",
    "# Example usage:\n",
    "sample = dataset[6520]\n",
    "latex_content = generate_colored_latex(sample, id2tok, normalized_token_counts)\n",
    "print(latex_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_colored_latex(sample, id2tok, normalized_token_counts):\n",
    "    latex_output = []\n",
    "    \n",
    "    # Calculate the median token occurrence\n",
    "    median_freq = np.median(list(normalized_token_counts.values()))\n",
    "    print(f\"Median frequency: {median_freq}\")\n",
    "    \n",
    "    def should_add_space(token):\n",
    "        return not token.startswith('##') and not token.startswith('▁')\n",
    "\n",
    "    blue_count = 0\n",
    "    red_count = 0\n",
    "\n",
    "    for i, token_id in enumerate(sample['input_ids']):\n",
    "        token = id2tok[token_id]\n",
    "        frequency = normalized_token_counts[token_id]\n",
    "        \n",
    "        # Determine color based on median frequency\n",
    "        if frequency < median_freq:\n",
    "            color = \"blue\"\n",
    "            blue_count += 1\n",
    "        else:\n",
    "            color = \"red\"\n",
    "            red_count += 1\n",
    "        \n",
    "        print(f\"Token: {token}, Frequency: {frequency}, Color: {color}\")\n",
    "        \n",
    "        # ... (rest of the function remains the same)\n",
    "\n",
    "    print(f\"Blue tokens: {blue_count}, Red tokens: {red_count}\")\n",
    "    return ''.join(latex_output)\n",
    "\n",
    "# Example usage:\n",
    "sample = dataset[6520]\n",
    "latex_content = generate_colored_latex(sample, id2tok, normalized_token_counts)\n",
    "print(latex_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_colored_latex(sample, id2tok, token_counts):\n",
    "    latex_output = []\n",
    "    \n",
    "    # Calculate the median token occurrence from raw counts\n",
    "    median_count = np.median(list(token_counts.values()))\n",
    "    print(f\"Median count: {median_count}\")\n",
    "    \n",
    "    def should_add_space(token):\n",
    "        return not token.startswith('##') and not token.startswith('▁')\n",
    "\n",
    "    blue_count = 0\n",
    "    red_count = 0\n",
    "\n",
    "    for i, token_id in enumerate(sample['input_ids']):\n",
    "        token = id2tok[token_id]\n",
    "        count = token_counts[token_id]\n",
    "        \n",
    "        # Determine color based on median count\n",
    "        if count < median_count:\n",
    "            color = \"blue\"\n",
    "            blue_count += 1\n",
    "        else:\n",
    "            color = \"red\"\n",
    "            red_count += 1\n",
    "        \n",
    "        print(f\"Token: {token}, Count: {count}, Color: {color}\")\n",
    "        \n",
    "        # Escape special LaTeX characters\n",
    "        escaped_token = token.replace('\\\\', '\\\\textbackslash{}')\n",
    "        escaped_token = escaped_token.replace('&', '\\\\&').replace('%', '\\\\%')\n",
    "        escaped_token = escaped_token.replace('$', '\\\\$').replace('#', '\\\\#')\n",
    "        escaped_token = escaped_token.replace('_', '\\\\_').replace('{', '\\\\{').replace('}', '\\\\}')\n",
    "        escaped_token = escaped_token.replace('~', '\\\\textasciitilde{}')\n",
    "        escaped_token = escaped_token.replace('^', '\\\\textasciicircum{}')\n",
    "        \n",
    "        # Generate the colored LaTeX command\n",
    "        colored_token = f\"\\\\textcolor{{{color}}}{{{escaped_token}}}\"\n",
    "        \n",
    "        # Add space before token if it's not a subword (except for the first token)\n",
    "        if i > 0 and should_add_space(token):\n",
    "            latex_output.append(' ')\n",
    "        \n",
    "        latex_output.append(colored_token)\n",
    "        \n",
    "        # Check if the token is </s>, if so, break\n",
    "        if token == '</s>':\n",
    "            break\n",
    "\n",
    "    print(f\"Blue tokens: {blue_count}, Red tokens: {red_count}\")\n",
    "    return ''.join(latex_output)\n",
    "\n",
    "# Example usage:\n",
    "sample = dataset[6520]\n",
    "latex_content = generate_colored_latex(sample, id2tok, token_counts)\n",
    "print(latex_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_colored_latex(sample, id2tok, token_counts):\n",
    "    latex_output = []\n",
    "    \n",
    "    # Calculate the 99th percentile token occurrence from raw counts\n",
    "    p99_count = np.percentile(list(token_counts.values()), 99)\n",
    "    print(f\"99th percentile count: {p99_count}\")\n",
    "    \n",
    "    def should_add_space(token):\n",
    "        return not token.startswith('##') and not token.startswith('▁')\n",
    "\n",
    "    blue_count = 0\n",
    "    red_count = 0\n",
    "\n",
    "    for i, token_id in enumerate(sample['input_ids']):\n",
    "        token = id2tok[token_id]\n",
    "        count = token_counts[token_id]\n",
    "        \n",
    "        # Determine color based on 99th percentile count\n",
    "        if count < p99_count:\n",
    "            color = \"customblue\"\n",
    "            blue_count += 1\n",
    "        else:\n",
    "            color = \"customorange\"\n",
    "            red_count += 1\n",
    "        \n",
    "        print(f\"Token: {token}, Count: {count}, Color: {color}\")\n",
    "        \n",
    "        # Escape special LaTeX characters\n",
    "        escaped_token = token.replace('\\\\', '\\\\textbackslash{}')\n",
    "        escaped_token = escaped_token.replace('&', '\\\\&').replace('%', '\\\\%')\n",
    "        escaped_token = escaped_token.replace('$', '\\\\$').replace('#', '\\\\#')\n",
    "        escaped_token = escaped_token.replace('_', '\\\\_').replace('{', '\\\\{').replace('}', '\\\\}')\n",
    "        escaped_token = escaped_token.replace('~', '\\\\textasciitilde{}')\n",
    "        escaped_token = escaped_token.replace('^', '\\\\textasciicircum{}')\n",
    "        \n",
    "        # Generate the colored LaTeX command\n",
    "        colored_token = f\"\\\\textcolor{{{color}}}{{{escaped_token}}}\"\n",
    "        \n",
    "        # Add space before token if it's not a subword (except for the first token)\n",
    "        if i > 0 and should_add_space(token):\n",
    "            latex_output.append(' ')\n",
    "        \n",
    "        latex_output.append(colored_token)\n",
    "        \n",
    "        # Check if the token is </s>, if so, break\n",
    "        if token == '</s>':\n",
    "            break\n",
    "\n",
    "    print(f\"Blue tokens: {blue_count}, Red tokens: {red_count}\")\n",
    "    return ''.join(latex_output)\n",
    "\n",
    "# Example usage:\n",
    "sample = dataset[6520]\n",
    "latex_content = generate_colored_latex(sample, id2tok, token_counts)\n",
    "print(latex_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "blue_rgb = plt.get_cmap(\"tab10\")(0)  # Blue from tab10 colormap\n",
    "orange_rgb = plt.get_cmap(\"tab10\")(1)  # Orange from tab10 colormap\n",
    "\n",
    "# Convert them to 255-based RGB\n",
    "blue_rgb_255 = tuple(int(c * 255) for c in blue_rgb[:3])\n",
    "orange_rgb_255 = tuple(int(c * 255) for c in orange_rgb[:3])\n",
    "\n",
    "print(\"Blue RGB:\", blue_rgb_255)\n",
    "print(\"Orange RGB:\", orange_rgb_255)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Prepare the data\n",
    "counts = list(token_counts.values())\n",
    "\n",
    "# Create the histogram\n",
    "plt.figure(figsize=(12, 6))\n",
    "plt.hist(counts, bins=100, log=True)  # Using log scale for y-axis due to likely large range of counts\n",
    "\n",
    "# Add labels and title\n",
    "plt.xlabel('Token Count')\n",
    "plt.ylabel('Frequency (log scale)')\n",
    "plt.title('Distribution of Token Counts')\n",
    "\n",
    "# Add vertical line for median\n",
    "median_count = np.median(counts)\n",
    "plt.axvline(x=median_count, color='r', linestyle='dashed', linewidth=2, label=f'Median ({median_count:.2f})')\n",
    "\n",
    "# Add legend\n",
    "plt.legend()\n",
    "\n",
    "# Show the plot\n",
    "plt.show()\n",
    "\n",
    "# Print some statistics\n",
    "print(f\"Total number of unique tokens: {len(counts)}\")\n",
    "print(f\"Median count: {median_count}\")\n",
    "print(f\"Mean count: {np.mean(counts):.2f}\")\n",
    "print(f\"Min count: {min(counts)}\")\n",
    "print(f\"Max count: {max(counts)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Prepare the data\n",
    "counts = list(token_counts.values())\n",
    "sorted_counts = np.sort(counts)\n",
    "cumulative = np.arange(1, len(sorted_counts) + 1) / len(sorted_counts)\n",
    "\n",
    "# Create the CDF plot\n",
    "plt.figure(figsize=(12, 6))\n",
    "plt.plot(sorted_counts, cumulative)\n",
    "\n",
    "# Add labels and title\n",
    "plt.xlabel('Token Count')\n",
    "plt.ylabel('Cumulative Proportion')\n",
    "plt.title('Cumulative Distribution of Token Counts')\n",
    "\n",
    "# Use log scale for x-axis due to likely large range of counts\n",
    "plt.xscale('log')\n",
    "\n",
    "# Add vertical line for median\n",
    "median_count = np.median(counts)\n",
    "plt.axvline(x=median_count, color='r', linestyle='dashed', linewidth=2, \n",
    "            label=f'Median ({median_count:.2f})')\n",
    "\n",
    "# Add horizontal lines at key percentiles\n",
    "plt.axhline(y=0.5, color='g', linestyle=':', linewidth=1, label='50th percentile')\n",
    "plt.axhline(y=0.8, color='orange', linestyle=':', linewidth=1, label='80th percentile')\n",
    "plt.axhline(y=0.95, color='purple', linestyle=':', linewidth=1, label='95th percentile')\n",
    "\n",
    "# Add legend\n",
    "plt.legend()\n",
    "\n",
    "# Show the plot\n",
    "plt.grid(True, which=\"both\", ls=\"-\", alpha=0.2)\n",
    "plt.show()\n",
    "\n",
    "# Print some statistics\n",
    "print(f\"Total number of unique tokens: {len(counts)}\")\n",
    "print(f\"Median count: {median_count}\")\n",
    "print(f\"Mean count: {np.mean(counts):.2f}\")\n",
    "print(f\"Min count: {min(counts)}\")\n",
    "print(f\"Max count: {max(counts)}\")\n",
    "\n",
    "# Calculate and print percentile values\n",
    "percentiles = [50, 80, 95]\n",
    "percentile_values = np.percentile(counts, percentiles)\n",
    "for p, v in zip(percentiles, percentile_values):\n",
    "    print(f\"{p}th percentile: {v:.2f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm-wd-fairness-eaiv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
