{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot the MI trajectories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import seaborn as sns\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import json\n",
    "import numpy as np\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "sns.set_theme(style=\"whitegrid\", context=\"talk\", palette=\"muted\", font_scale=1.0)\n",
    "\n",
    "\n",
    "def load_model_data(model_path, dataset, target_layer=31):\n",
    "    tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
    "    model_name = model_path.split('/')[-1]\n",
    "\n",
    "    data_path = f'results/mi/token_evolve/{dataset}_gtmodel={model_name}_testmodel={model_name}.pth'\n",
    "    data = torch.load(data_path)\n",
    "    \n",
    "    all_sample_mi_list = []\n",
    "    all_mi_peak_list = [] \n",
    "\n",
    "    for id in data.keys():\n",
    "        try:\n",
    "            this_id_mi_list = data[id]['reps'][target_layer]\n",
    "            \n",
    "            all_sample_mi_list.append(this_id_mi_list[:])\n",
    "\n",
    "            top_indices = sorted(range(len(this_id_mi_list)), \n",
    "                            key=lambda i: this_id_mi_list[i], reverse=True)[:20]  # approximately take top-20\n",
    "            all_mi_peak_list.append(top_indices)\n",
    "\n",
    "        except Exception as e:\n",
    "            print(f'[id:{id}] Error:', e)\n",
    "\n",
    "    return all_sample_mi_list, all_mi_peak_list\n",
    "\n",
    "\n",
    "dataset = 'math_train_12k'\n",
    "model_path = 'deepseek-ai/DeepSeek-R1-Distill-Llama-8B'\n",
    "model_name = model_path.split('/')[-1]\n",
    "\n",
    "\n",
    "palette = sns.color_palette()\n",
    "colors = palette[:len(models)]\n",
    "\n",
    "model_data_dict = {}\n",
    "\n",
    "mi_list, mi_peak_list = load_model_data(model_path, dataset, target_layer=31)\n",
    "model_data_dict[model] = {\n",
    "    'mi': mi_list,\n",
    "    'peaks': mi_peak_list,\n",
    "}\n",
    "\n",
    "\n",
    "fig, axes = plt.subplots(2, 5, figsize=(30, 10))\n",
    "axes = axes.flatten()\n",
    "\n",
    "for sample_idx in range(10):\n",
    "    ax = axes[sample_idx]\n",
    "    mi_values = model_data_dict[model]['mi'][sample_idx]\n",
    "    steps = model_data_dict[model]['peaks'][sample_idx]\n",
    "\n",
    "    ax.plot(mi_values, linewidth=2, alpha=0.8)\n",
    "    ax.scatter(steps, [mi_values[i] for i in steps],\n",
    "                s=50, edgecolor=\"white\", linewidth=0.8, zorder=3)\n",
    "\n",
    "    ax.grid(axis=\"y\", linestyle=\"--\", alpha=0.4)\n",
    "    ax.set_facecolor(\"#fafafa\")\n",
    "    sns.despine(ax=ax, top=True, right=True)\n",
    "\n",
    "    ax.set_title(model_name, fontsize=22, pad=8)  \n",
    "    ax.set_xlabel(\"Reasoning Step\", fontsize=18)   \n",
    "    ax.set_ylabel(\"MI Value\", fontsize=20)\n",
    "    ax.tick_params(axis=\"x\", labelsize=18)  \n",
    "    ax.tick_params(axis=\"y\", labelsize=18)\n",
    "\n",
    "\n",
    "plt.subplots_adjust(hspace=0.4, wspace=0.13)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Projecting the MI-peak representations to token space "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import json\n",
    "import numpy as np\n",
    "from transformers import AutoTokenizer\n",
    "from collections import Counter\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "\n",
    "sns.set(style=\"whitegrid\", context=\"talk\", font_scale=1.3)\n",
    "\n",
    "\n",
    "def plot_token_freq(model_path, dataset, target_layer=31):\n",
    "\n",
    "    model_name = model_path.split('/')[-1]\n",
    "    tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
    "\n",
    "    data_path = f'results/mi/token_evolve/{dataset}_gtmodel={model_name}_testmodel={model_name}.pth'\n",
    "    data = torch.load(data_path)\n",
    "\n",
    "    all_sample_mi_list = [] \n",
    "    all_mi_peak_list = []\n",
    "    all_tokens = []\n",
    "\n",
    "    fail_id_list = []\n",
    "    for id in data.keys():\n",
    "        try:\n",
    "            this_id_mi_list = data[id]['reps'][target_layer] \n",
    "            all_sample_mi_list.append(this_id_mi_list[:]) \n",
    "            top_indices = sorted(range(len(this_id_mi_list)), key=lambda i: this_id_mi_list[i], reverse=True)[:20]\n",
    "\n",
    "            all_mi_peak_list.append(top_indices)\n",
    "\n",
    "            token_list = acts[id]['token_ids'].tolist()\n",
    "            token_list.append(2)  # [eos] token\n",
    "            top_prob_token_ids = [token_list[i] for i in top_indices]\n",
    "\n",
    "            batch_top_n_tokens = tokenizer.batch_decode(top_prob_token_ids, skip_special_tokens=False)\n",
    "            all_tokens.extend(batch_top_n_tokens)\n",
    "\n",
    "        except Exception as e:\n",
    "            fail_id_list.append(id)\n",
    "\n",
    "    print('fail_id_list:', fail_id_list)\n",
    "\n",
    "\n",
    "    english_pattern = re.compile(r'^[a-zA-Z]+$')\n",
    "\n",
    "    processed_all_tokens = []\n",
    "    for token in all_tokens:\n",
    "        if english_pattern.match(token.strip()):\n",
    "            processed_all_tokens.append(token)\n",
    "\n",
    "    \n",
    "    token_freq = Counter(processed_all_tokens)\n",
    "\n",
    "    common_tokens = token_freq.most_common(15)\n",
    "    print('$'*50)\n",
    "    print(f'model: {model_name}')\n",
    "    print(\"Most common tokens:\", common_tokens)\n",
    "\n",
    "    colors = [\n",
    "        (0.0, \"#3d61aa\"),  \n",
    "        (0.5, \"#b1bee9\"),   \n",
    "        (1.0, \"#9673c4\")    \n",
    "    ]\n",
    "    cmap = LinearSegmentedColormap.from_list(\"blue_purple\", colors)\n",
    "\n",
    "\n",
    "    # ------------------------------ plot --------------------------------------\n",
    "\n",
    "    token_names, token_counts = zip(*common_tokens)\n",
    "    token_names_processed = [token.replace('$', '\\$').replace('_', '\\_').replace('^', '\\^') for token in token_names]\n",
    "    token_names_repr = [repr(token) for token in token_names_processed]\n",
    "\n",
    "    n_bars = len(token_names_repr)\n",
    "    palette = [cmap(i / (n_bars - 1)) for i in range(n_bars)]\n",
    "\n",
    "    plt.figure(figsize=(8, 5))\n",
    "    sns.barplot(x=list(token_names_repr), y=list(token_counts), palette=palette)\n",
    "    plt.xticks(rotation=45, ha='right', size=17)\n",
    "    plt.xlabel('Tokens at MI Peaks')\n",
    "    plt.ylabel('Frequency')\n",
    "    plt.title(f'{model_name}')\n",
    "\n",
    "\n",
    "    plt.show()\n",
    "    \n",
    "\n",
    "dataset = 'math_train_12k'\n",
    "model_path = 'deepseek-ai/DeepSeek-R1-Distill-Llama-8B'\n",
    "\n",
    "\n",
    "plot_token_freq(\n",
    "    model_path=model_path,\n",
    "    dataset=dataset,\n",
    "    target_layer=31,\n",
    ")\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "vllm053",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
