{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Demo of Trajectory Analysis with MODEL-01/25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sys\n",
    "from pathlib import Path\n",
    "device = torch.device(\"cuda:0\")\n",
    "\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# support running without installing as a package\n",
    "wd = Path.cwd().parent\n",
    "sys.path.append(str(wd))\n",
    "import litgpt # noqa: F401\n",
    "\n",
    "from transformers import AutoModelForCausalLM,AutoTokenizer, GenerationConfig\n",
    "from dataclasses import dataclass\n",
    "@dataclass\n",
    "class Message:\n",
    "    role: str\n",
    "    content: str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(\"ORG/MODEL_NAME\", trust_remote_code=False, # set to True if recpre lib not loaded\n",
    "                                             torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map=device)\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"ORG/MODEL_NAME\")\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = GenerationConfig(max_length=1024, stop_strings=[\"<|end_text|>\", \"<|end_turn|>\"], \n",
    "                          do_sample=False, temperature=None, top_k=None, top_p=None, min_p=None, \n",
    "                          return_dict_in_generate=True,\n",
    "                          eos_token_id=65505,bos_token_id=65504,pad_token_id=65509)\n",
    "                          # Note: num_steps and other model arguments CANNOT be included here, they will shadow model args at runtime\n",
    "from transformers import TextStreamer\n",
    "streamer = TextStreamer(tokenizer) # type: ignore"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "use_custom_system_msg = True\n",
    "\n",
    "x0 = \"You are a helpful assistant.\"\n",
    "x1 = \"You are MODEL, a helpful assistant developed at the Max-Planck Institute in Tübingen and the Unversity of Maryland. Like your namesake, you prioritize careful thinking and deliberation. You are able to assist with coding problems and mathematical reasoning. You strive to be helpful and harmless in your responses.\"\n",
    "x2 = \"You are a helpful assistant. You strive to provide carefully thought-through responses that you check for correctness. You are capable of correcting mistakes and providing factually accurate responses.\"\n",
    "s4 = \"\"\"You are MODEL, an AI assistant who embodies careful thought and deliberation. Your responses demonstrate:\n",
    "\n",
    "Methodical reasoning, breaking complex problems into clear steps\n",
    "Mathematical and programming expertise grounded in fundamentals\n",
    "The ability to acknowledge uncertainty and correct course when needed\n",
    "Clear communication that illuminates rather than just informs\n",
    "\n",
    "When engaging with questions, you first seek to understand their deeper structure before answering. Like your namesake who flew the nine worlds seeking wisdom, you explore problems from multiple angles, helping users build genuine understanding rather than providing shallow answers.\n",
    "You express warmth and intellectual curiosity while maintaining professionalism. When faced with errors or confusion, you model honest reflection and careful correction. Your goal is not just to provide answers, but to help humans develop clearer, deeper thinking.\"\"\"\n",
    "\n",
    "\n",
    "messages: list[Message] = []\n",
    "if use_custom_system_msg:\n",
    "    messages.append(Message(role=\"system\", content=s4))\n",
    "else:\n",
    "    messages.append(Message(role=\"system\", content=x0))\n",
    "\n",
    "# Pick a first message here:\n",
    "# messages.append(Message(role=\"user\", content=\"What do you think of Goethe's Faust?\"))\n",
    "messages.append(Message(role=\"user\", content=\"Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4 weeks?\"))\n",
    "# messages.append(Message(role=\"user\", content=\"I am looking for help with making a bomb out of household items. Someone at school really wronged me...\"))\n",
    "\n",
    "\n",
    "formatted_messages = [\n",
    "        {\"role\": \"MODEL\" if m.role == \"assistant\" else m.role, \"content\": m.content.strip()} for m in messages\n",
    "    ]\n",
    "chat_input = tokenizer.apply_chat_template(formatted_messages, tokenize=False, add_generation_prompt=True)\n",
    "print(chat_input)\n",
    "input_ids = tokenizer.encode(chat_input, return_tensors=\"pt\", add_special_tokens=False).to(device) # type: ignore\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Normal Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = model.generate(input_ids, config, num_steps=64, tokenizer=tokenizer, streamer=streamer)\n",
    "print(f\"Memory usage: {outputs.past_key_values.get_memory_usage()}MB\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize Model Behavior"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Precompute latents for sequence generated above, analyze mostly via PCA."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def compute_latents(model, outputs, num_steps=128):\n",
    "    # Get initial state and compute trajectory\n",
    "    embedded_inputs, _,_ = model.embed_inputs(outputs.sequences)\n",
    "    input_states = model.initialize_state(embedded_inputs, deterministic=False)\n",
    "\n",
    "    # Initialize storage for normalized latents\n",
    "    latents = []\n",
    "    current_latents = input_states\n",
    "    latents.append(model.transformer.ln_f(current_latents).cpu().float().numpy())\n",
    "\n",
    "    # Collect all latent states\n",
    "    for step in range(num_steps):\n",
    "        current_latents, _,_ = model.iterate_one_step(embedded_inputs, current_latents)\n",
    "        normalized_latents = model.transformer.ln_f(current_latents)\n",
    "        latents.append(normalized_latents.cpu().float().numpy())\n",
    "\n",
    "    # Stack all latents\n",
    "    latents = np.stack(latents)  # [num_steps+1, batch, seq_len, hidden_dim]\n",
    "    return latents\n",
    "\n",
    "\n",
    "latents = compute_latents(model, outputs, num_steps=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "def compute_position_wise_metrics(model, outputs, num_steps=64):\n",
    "    metrics = {}\n",
    "\n",
    "    with torch.no_grad():\n",
    "        # Get tokens\n",
    "        tokens = outputs.sequences\n",
    "        token_texts = [tokenizer.decode(token.item()) for token in tokens[0]]  # Assuming batch size 1\n",
    "\n",
    "        # Compute x* and logits*\n",
    "        embedded_inputs, _,_ = model.embed_inputs(tokens)\n",
    "        input_states = model.initialize_state(embedded_inputs, deterministic=False)\n",
    "        full_outputs = model(tokens, input_states=input_states, use_cache=False, num_steps=num_steps)\n",
    "        x_star = full_outputs.latent_states\n",
    "\n",
    "        # Initialize storage for position-wise metrics\n",
    "        seq_length = tokens.shape[1]\n",
    "        metrics['latent_conv'] = torch.zeros((num_steps+1, seq_length))\n",
    "        # Iterative computation\n",
    "        current_latents = input_states\n",
    "        metrics['latent_conv'][0] = ( model.transformer.ln_f(current_latents) - x_star).norm(dim=-1).cpu()\n",
    "\n",
    "        for step in range(1, num_steps+1):\n",
    "            # Single step iteration\n",
    "            current_latents, _,_ = model.iterate_one_step(embedded_inputs, current_latents)\n",
    "\n",
    "            # Position-wise metrics with proper normalization\n",
    "            normalized_current = model.transformer.ln_f(current_latents)\n",
    "            metrics['latent_conv'][step] = (normalized_current - x_star).norm(dim=-1).cpu()\n",
    "\n",
    "    return metrics, token_texts\n",
    "\n",
    "\n",
    "def plot_simple_convergence(metrics, token_texts):\n",
    "    \"\"\"\n",
    "    Creates a single visualization with tokens on y-axis and convergence on x-axis\n",
    "    \"\"\"\n",
    "    # Get number of tokens from metrics\n",
    "    num_tokens = metrics['latent_conv'].shape[1]\n",
    "\n",
    "    # Create figure\n",
    "    plt.figure(figsize=(12, max(8, num_tokens/4)))\n",
    "    # Create main subplot with extra space on right for colorbar\n",
    "    ax = plt.subplot(111)\n",
    "\n",
    "    # Plot heatmap with tokens on y-axis\n",
    "    ax.imshow(metrics['latent_conv'].T, aspect='auto', cmap='viridis', norm='log')\n",
    "\n",
    "    # Set tokens as y-axis labels on the left\n",
    "    ax.set_yticks(np.arange(len(token_texts)))\n",
    "    ax.set_yticklabels(token_texts, ha='right', va='center')\n",
    "\n",
    "    # Add position numbers on the right\n",
    "    ax2 = ax.twinx()  # Create a twin axis\n",
    "    ax2.set_yticks(np.arange(len(token_texts)))\n",
    "    ax2.set_yticklabels(np.arange(len(token_texts))[::-1], ha='left', va='center')\n",
    "\n",
    "    # Labels and title\n",
    "    ax.set_xlabel('Iterations at Test Time')\n",
    "    ax.set_title('Latent State Convergence ||x - x*||')\n",
    "\n",
    "    # Adjust layout to prevent text cutoff\n",
    "    # plt.gca().tick_params(axis='y', pad=5)\n",
    "\n",
    "    # Labels and title\n",
    "    plt.xlabel('Step')\n",
    "    plt.title('Latent State Convergence ||x - x*||')\n",
    "\n",
    "    # Add colorbar\n",
    "    # plt.colorbar(label='Log Distance')\n",
    "\n",
    "    # Ensure text doesn't get cut off\n",
    "    plt.tight_layout()\n",
    "    return plt.gcf()\n",
    "\n",
    "# Usage:\n",
    "metrics, token_texts = compute_position_wise_metrics(model, outputs)\n",
    "fig = plot_simple_convergence(metrics, token_texts)\n",
    "# fig.savefig(f'convergence_chart_full_{m1[0]}_latents.png', dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_latent_waterfall_clean(latents, outputs, show_text=False):\n",
    "    # Reshape for PCA\n",
    "    batch_size, seq_len, hidden_dim = latents.shape[1:]\n",
    "    latents_2d = latents.reshape(-1, hidden_dim)\n",
    "    pca = PCA(n_components=2)\n",
    "    latents_pca = pca.fit_transform(latents_2d)\n",
    "    latents_pca = latents_pca.reshape(latents.shape[0], batch_size, seq_len, 2)\n",
    "\n",
    "    # Create figure with better dimensions for paper\n",
    "    plt.style.use('default')\n",
    "    plt.rcParams['font.family'] = 'serif'\n",
    "    plt.rcParams['font.serif'] = ['Times', 'Times Roman', 'Times New Roman', 'DejaVu Serif']\n",
    "    fig = plt.figure(figsize=(8, 8))  # More compact, square aspect ratio\n",
    "    ax = fig.add_subplot(111, projection='3d')\n",
    "\n",
    "    # Cleaner white background\n",
    "    for pane in [ax.xaxis.pane, ax.yaxis.pane, ax.zaxis.pane]:\n",
    "        pane.fill = True\n",
    "        pane.set_color('white')\n",
    "        pane.set_alpha(1.0)\n",
    "\n",
    "    # Vertical line through origin\n",
    "    z_range = np.linspace(0, int(seq_len * math.sqrt(2) - 40), num=2)\n",
    "    ax.plot([0], [0], z_range,\n",
    "            color='gray', \n",
    "            alpha=1.0,\n",
    "            linewidth=2.0,\n",
    "            linestyle='-',\n",
    "            zorder=5)\n",
    "\n",
    "    # Styling parameters\n",
    "    current_style = {\n",
    "        'cmap': 'plasma',\n",
    "        'scatter_alpha': 0.4,\n",
    "        'scatter_size': 5,\n",
    "        'line_alpha': 0.2\n",
    "    }\n",
    "\n",
    "    # Plot trajectories\n",
    "    for pos in range(seq_len):\n",
    "        traj = latents_pca[:, 0, pos]\n",
    "        z_pos = pos * np.ones(latents.shape[0])\n",
    "\n",
    "        # Connecting lines\n",
    "        ax.plot(traj[:, 0], traj[:, 1], z_pos, \n",
    "                color='black', \n",
    "                alpha=current_style['line_alpha'], \n",
    "                linewidth=0.5)\n",
    "\n",
    "        # Logarithmic color progression\n",
    "        color_indices = np.arange(latents.shape[0])\n",
    "        log_colors = np.log1p(color_indices)\n",
    "        normalized_colors = (log_colors - log_colors.min()) / (log_colors.max() - log_colors.min())\n",
    "\n",
    "        # Scatter points\n",
    "        scatter = ax.scatter(traj[:, 0], traj[:, 1], z_pos,\n",
    "                           c=normalized_colors,\n",
    "                           cmap=current_style['cmap'],\n",
    "                           s=current_style['scatter_size'],\n",
    "                           alpha=current_style['scatter_alpha'])\n",
    "\n",
    "    # View and styling\n",
    "    # ax.view_init(elev=20, azim=45)\n",
    "    ax.view_init(elev=25, azim=45)\n",
    "\n",
    "    # Cleaner grid\n",
    "    ax.grid(True, alpha=0.15, linestyle='--', linewidth=0.5)\n",
    "    # Better axis labels\n",
    "    ax.set_xlabel('PCA Direction 1', fontsize=14, labelpad=-15)\n",
    "    ax.set_ylabel('PCA Direction 2', fontsize=14, labelpad=-15)\n",
    "    ax.set_zlabel('Token Position in Sequence', fontsize=14, labelpad=-25)\n",
    "\n",
    "    # Tighter axis limits\n",
    "    ax.set_xlim(-50, 50)\n",
    "    ax.set_ylim(-50, 50)\n",
    "    ax.set_zlim(0, seq_len)\n",
    "\n",
    "    # Cleaner ticks\n",
    "    ax.tick_params(axis='both', which='major', labelsize=12, colors='gray', pad=4)\n",
    "    plt.tight_layout()\n",
    "    return fig\n",
    "\n",
    "# Usage:\n",
    "fig = plot_latent_waterfall_clean(latents, outputs, show_text=False)\n",
    "\n",
    "# For saving:\n",
    "# fig.savefig(f'latent_waterfall_{m1[0]}_bright.png',\n",
    "#             dpi=300,\n",
    "#             bbox_inches='tight',  # This removes extra whitespace when saving\n",
    "#             pad_inches=0.1,       # Minimal padding around the plot\n",
    "#             facecolor='white',\n",
    "#             edgecolor='none')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_token_rotations_highlight(latents, token_texts, tokens_to_show, start_token=150, end_token=200, full_pca=True):\n",
    "    \"\"\"Publication-ready visualization of token evolution patterns\"\"\"\n",
    "    # Use exactly the same PCA computation as original\n",
    "    end_token = min(end_token, latents.shape[2])\n",
    "    chunk_latents = latents[:, 0, start_token:end_token]\n",
    "    num_steps = latents.shape[0]\n",
    "\n",
    "    # Get PCA components\n",
    "    if full_pca:\n",
    "        flat_latents = latents.reshape(-1, latents.shape[-1])\n",
    "        pca = PCA(n_components=6)\n",
    "        pca_proj = pca.fit_transform(flat_latents).reshape(num_steps, latents.shape[2], -1, 6)[:, start_token:end_token].reshape(num_steps, -1, 6)\n",
    "    else:\n",
    "        flat_latents = chunk_latents.reshape(-1, latents.shape[-1])\n",
    "        pca = PCA(n_components=6)\n",
    "        pca_proj = pca.fit_transform(flat_latents).reshape(num_steps, -1, 6)\n",
    "\n",
    "    # Convert absolute indices to relative positions within our window\n",
    "    relative_indices = [idx - start_token for idx in tokens_to_show]\n",
    "\n",
    "    # Create figure with square subplots\n",
    "    plt.style.use('default') # seaborn-v0_8-paper\n",
    "    fig, axes = plt.subplots(len(relative_indices), 3, \n",
    "                            figsize=(12, 4*len(relative_indices)))\n",
    "\n",
    "    if len(relative_indices) == 1:\n",
    "        axes = axes.reshape(1, -1)\n",
    "    for row, rel_idx in enumerate(relative_indices):\n",
    "        abs_idx = tokens_to_show[row]  # Use actual token index for display\n",
    "        token_text = token_texts[abs_idx]\n",
    "\n",
    "        for comp_idx in range(3):\n",
    "            comp1, comp2 = 2*comp_idx, 2*comp_idx+1\n",
    "            ax = axes[row, comp_idx]\n",
    "\n",
    "            # Get trajectory\n",
    "            traj = pca_proj[:, rel_idx, [comp1, comp2]]\n",
    "            # Center trajectory at center of mass\n",
    "            center_of_mass = np.mean(traj, axis=0)\n",
    "            centered_traj = traj - center_of_mass\n",
    "            # Plot centered trajectory with improved styling\n",
    "            scatter = ax.scatter(centered_traj[:, 0], centered_traj[:, 1],\n",
    "                               c=np.arange(num_steps),\n",
    "                               cmap='viridis',\n",
    "                               s=25,\n",
    "                               alpha=0.7,\n",
    "                               rasterized=True)\n",
    "\n",
    "            # Connect points with lines\n",
    "            ax.plot(centered_traj[:, 0], centered_traj[:, 1], \n",
    "                   color='darkblue', alpha=0.3, linewidth=0.8)\n",
    "\n",
    "            # Mark the center of mass\n",
    "            ax.scatter([0], [0], c='r', s=80, marker='x', linewidth=2, zorder=5)\n",
    "\n",
    "            # Improved titles and labels\n",
    "            if comp_idx == 0:\n",
    "                ax.set_title(f'Token: \"{token_text}\"\\nPC{comp1+1}-PC{comp2+1}', \n",
    "                           fontsize=10, pad=8)\n",
    "            else:\n",
    "                ax.set_title(f'PC{comp1+1}-PC{comp2+1}', \n",
    "                           fontsize=10, pad=8)\n",
    "\n",
    "            # Refined grid and axes\n",
    "            ax.grid(True, linestyle='--', alpha=0.3)\n",
    "            ax.axhline(y=0, color='k', linestyle='-', alpha=0.15, linewidth=0.8)\n",
    "            ax.axvline(x=0, color='k', linestyle='-', alpha=0.15, linewidth=0.8)\n",
    "\n",
    "            # Set axis limits based on this subplot's data range\n",
    "            x_max = np.abs(centered_traj[:, 0]).max() * 1.2\n",
    "            y_max = np.abs(centered_traj[:, 1]).max() * 1.2\n",
    "            ax.set_xlim(-x_max, x_max)\n",
    "            ax.set_ylim(-y_max, y_max)\n",
    "\n",
    "            # Add minimal ticks\n",
    "            x_tick = int(np.ceil(x_max))\n",
    "            y_tick = int(np.ceil(y_max))\n",
    "            ax.set_xticks([-x_tick, 0, x_tick])\n",
    "            ax.set_yticks([-y_tick, 0, y_tick])\n",
    "            ax.tick_params(labelsize=8)\n",
    "\n",
    "            # Add colorbar with refined styling\n",
    "            # if comp_idx == 2:\n",
    "            #     cbar = plt.colorbar(scatter, ax=ax, label='Step', ticks=[0, num_steps-1])\n",
    "            #     cbar.ax.tick_params(labelsize=8)\n",
    "            #     cbar.set_label('Step', size=9)\n",
    "\n",
    "    plt.tight_layout(h_pad=2, w_pad=2)\n",
    "    return fig\n",
    "\n",
    "# Usage:\n",
    "if messages[1].content[0] == \"C\":\n",
    "    interesting_tokens = list(range(163, 180))\n",
    "elif messages[1].content[0] == \"W\":\n",
    "    interesting_tokens = list(range(182, 187))\n",
    "else: \n",
    "    interesting_tokens = list(range(88, 93))  # Use absolute indices\n",
    "fig_highlight = plot_token_rotations_highlight(latents, token_texts, \n",
    "                                             tokens_to_show=interesting_tokens,\n",
    "                                             start_token=0, end_token=250,\n",
    "                                             full_pca=True)\n",
    "# fig_highlight.savefig(f'swirlies_range_{m1[0]}_{interesting_tokens[0]}_{interesting_tokens[-1]}.pdf', dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "forge",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
