{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "feac40f6",
   "metadata": {},
   "source": [
    "# IMPORTS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66a1f1ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1155e7a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7e6909e",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b038c9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "from pathlib import Path\n",
    "import sys\n",
    "import json\n",
    "from transformers import AutoTokenizer\n",
    "from collections import defaultdict\n",
    "import pickle as pkl\n",
    "import gc\n",
    "import torch\n",
    "\n",
    "sys.path.append(\"~pythia_replicate\")\n",
    "\n",
    "from lib.model_setup import load_model_and_tokenizer\n",
    "from lib.fv import fv_icl_tasks_benchmark_with_ci\n",
    "from lib.repetition import random_sequence_repetition_accuracy_with_ci, natural_text_repetition_accuracy_with_ci"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6a6eee5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#variables\n",
    "skip_clean = True\n",
    "model_size = \"1b-threshold-0.3\"\n",
    "first_step = 100\n",
    "last_step = 19900\n",
    "step = last_step\n",
    "random_repetition_seq_len = 25\n",
    "natural_text_repetition_seq_len = 25\n",
    "batch_size = 64\n",
    "max_sample_size = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52ff4534",
   "metadata": {},
   "outputs": [],
   "source": [
    "if model_size == \"160m\":\n",
    "    model_types = [\"clean_v3\", \"masked_bigram_loss_v4\"]\n",
    "elif model_size == \"160m-threshold-0.3\":\n",
    "    model_types = [\"clean_v3\", \"masked_bigram_loss_thresh0.3_eq\"]\n",
    "elif model_size == \"1b\":\n",
    "    model_types = [\"clean_1b\", \"masked_bigram_loss_1b\"]\n",
    "elif model_size == \"1b-threshold-0.3\":\n",
    "    model_types = [\"clean_1b\", \"masked_bigram_loss_1b_thresh0.3_eq\"]\n",
    "\n",
    "if not skip_clean:\n",
    "    clean_model_path = f\"~pythia_replicate/hf_output/{model_types[0]}/step={step}\"\n",
    "    clean_model, tokenizer_clean = load_model_and_tokenizer(clean_model_path)\n",
    "    \n",
    "masked_model_path = f\"~pythia_replicate/hf_output/{model_types[1]}/step={step}\"\n",
    "masked_model, tokenizer_masked = load_model_and_tokenizer(masked_model_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5919a4d",
   "metadata": {},
   "source": [
    "# REPETITION BENCHMARK"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "864d00f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "def repetition_tasks_benchmark(\n",
    "    random_repetition_seq_len=100,\n",
    "    natural_text_repetition_seq_len=100,\n",
    "    max_sample_size=5000,\n",
    "    batch_size=64,\n",
    "    first_step=100,\n",
    "    last_step=19900,\n",
    "    skip_clean=False\n",
    "):\n",
    "    dataset_name = \"wikitext\"\n",
    "    dataset = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\", split=\"test\", cache_dir=\"~.cache/\")\n",
    "    \n",
    "    random_repetition_res_all_masked = []\n",
    "    natural_text_repetition_res_all_masked = []\n",
    "    natural_text_repetition_res_all_clean = []\n",
    "    random_repetition_res_all_clean = []\n",
    "\n",
    "    if not skip_clean:\n",
    "        with torch.no_grad():\n",
    "            for step in range(first_step, last_step, 100):\n",
    "                clean_model_path = f\"~pythia_replicate/hf_output/{model_types[0]}/step={step}\"\n",
    "                clean_model, tokenizer_clean = load_model_and_tokenizer(clean_model_path)\n",
    "                random_repetition_res = random_sequence_repetition_accuracy_with_ci(\n",
    "                    clean_model,\n",
    "                    tokenizer_clean,\n",
    "                    seq_len=random_repetition_seq_len,\n",
    "                    num_of_samples=max_sample_size,\n",
    "                    batch_size=batch_size,\n",
    "                )\n",
    "                \n",
    "                natural_text_repetition_res = natural_text_repetition_accuracy_with_ci(\n",
    "                    clean_model,\n",
    "                    tokenizer_clean,\n",
    "                    dataset_name,\n",
    "                    seq_len=natural_text_repetition_seq_len,\n",
    "                    num_of_samples=max_sample_size,\n",
    "                    batch_size=batch_size,\n",
    "                    dataset=dataset\n",
    "                )\n",
    "                random_repetition_res_all_clean.append(random_repetition_res)\n",
    "                natural_text_repetition_res_all_clean.append(natural_text_repetition_res)\n",
    "\n",
    "                del clean_model, tokenizer_clean\n",
    "                gc.collect()\n",
    "                torch.cuda.empty_cache()\n",
    "    for step in range(first_step, last_step, 100):\n",
    "        with torch.no_grad():\n",
    "            masked_model_path = f\"~pythia_replicate/hf_output/{model_types[1]}/step={step}\"\n",
    "            masked_model, tokenizer_masked = load_model_and_tokenizer(masked_model_path)\n",
    "            random_repetition_res = random_sequence_repetition_accuracy_with_ci(\n",
    "                masked_model,\n",
    "                tokenizer_masked,\n",
    "                seq_len=random_repetition_seq_len,\n",
    "                num_of_samples=max_sample_size,\n",
    "                batch_size=batch_size,\n",
    "            )\n",
    "            dataset_name = \"wikitext\"\n",
    "            natural_text_repetition_res = natural_text_repetition_accuracy_with_ci(\n",
    "                masked_model,\n",
    "                tokenizer_masked,\n",
    "                seq_len=natural_text_repetition_seq_len,\n",
    "                num_of_samples=max_sample_size,\n",
    "                batch_size=batch_size,\n",
    "                dataset=dataset\n",
    "            )\n",
    "            random_repetition_res_all_masked.append(random_repetition_res)\n",
    "            natural_text_repetition_res_all_masked.append(natural_text_repetition_res)\n",
    "\n",
    "            del masked_model, tokenizer_masked\n",
    "            gc.collect()\n",
    "            torch.cuda.empty_cache()\n",
    "\n",
    "    return random_repetition_res_all_clean, natural_text_repetition_res_all_clean, random_repetition_res_all_masked, natural_text_repetition_res_all_masked"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99f7bd21",
   "metadata": {},
   "outputs": [],
   "source": [
    "random_repetition_res_all_clean, natural_text_repetition_res_all_clean, random_repetition_res_all_masked, natural_text_repetition_res_all_masked = repetition_tasks_benchmark(\n",
    "    random_repetition_seq_len,\n",
    "    natural_text_repetition_seq_len,\n",
    "    max_sample_size,\n",
    "    batch_size,\n",
    "    first_step,\n",
    "    last_step,\n",
    "    skip_clean\n",
    ")\n",
    "\n",
    "save_dir = f\"~pythia_replicate/metrics/non_wandb_metrics/repetition_metrics\"\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "with open(f\"{save_dir}/random_repetition_masked_{model_types[1]}.pkl\", \"wb\") as f:\n",
    "    pkl.dump(random_repetition_res_all_masked, f)\n",
    "\n",
    "with open(f\"{save_dir}/natural_text_repetition_masked_{model_types[1]}.pkl\", \"wb\") as f:\n",
    "    pkl.dump(natural_text_repetition_res_all_masked, f)\n",
    "\n",
    "if not skip_clean:\n",
    "    with open(f\"{save_dir}/natural_text_repetition_clean_{model_types[0]}.pkl\", \"wb\") as f:\n",
    "        pkl.dump(natural_text_repetition_res_all_clean, f)\n",
    "    \n",
    "    with open(f\"{save_dir}/random_repetition_clean_{model_types[0]}.pkl\", \"wb\") as f:\n",
    "        pkl.dump(random_repetition_res_all_clean, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbd7620c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle as pkl\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "import os\n",
    "from lib.plotting import apply_iclr_style\n",
    "\n",
    "def plot_repetition(\n",
    "    model_size,\n",
    "    save_dir=\"~pythia_replicate/metrics/non_wandb_metrics/repetition_metrics\",\n",
    "    output_dir=None,\n",
    "    first_step=100,\n",
    "    last_step=19900,\n",
    "    figsize=(10, 6),\n",
    "    alpha=0.1\n",
    "):\n",
    "    plt.rcParams.update({\n",
    "        'font.size': 7,           # Smaller base font\n",
    "        'axes.labelsize': 7,      # Smaller labels\n",
    "        'axes.titlesize': 8,      # Smaller title\n",
    "        'xtick.labelsize': 6,     # Smaller tick labels\n",
    "        'ytick.labelsize': 6,\n",
    "        'legend.fontsize': 6,     # Smaller legend\n",
    "        'lines.linewidth': 1.0,   # Thinner lines\n",
    "        'lines.markersize': 3,    # Smaller markers\n",
    "    })\n",
    "    if output_dir is None:\n",
    "        output_dir = f\"{save_dir}/figures\"\n",
    "    os.makedirs(output_dir, exist_ok=True)\n",
    "    \n",
    "    # Apply ICLR style\n",
    "    apply_iclr_style()\n",
    "    \n",
    "    # Load the pickle files\n",
    "    print(\"Loading repetition data for combined plot...\")\n",
    "    with open(f\"{save_dir}/random_repetition_clean_{model_types[0]}.pkl\", \"rb\") as f:\n",
    "        random_clean = pkl.load(f)\n",
    "    \n",
    "    with open(f\"{save_dir}/random_repetition_masked_{model_types[1]}.pkl\", \"rb\") as f:\n",
    "        random_masked = pkl.load(f)\n",
    "\n",
    "    with open(f\"{save_dir}/natural_text_repetition_clean_{model_types[0]}.pkl\", \"rb\") as f:\n",
    "        natural_clean = pkl.load(f)\n",
    "    \n",
    "    with open(f\"{save_dir}/natural_text_repetition_masked_{model_types[1]}.pkl\", \"rb\") as f:\n",
    "        natural_masked = pkl.load(f)\n",
    "    \n",
    "    # Generate step values\n",
    "    steps = list(range(first_step, last_step, 100))\n",
    "    \n",
    "    # Extract accuracies and confidence intervals\n",
    "    def extract_metrics(results_list):\n",
    "        accuracies = [r['accuracy'] for r in results_list]\n",
    "        ci_lower = [r['ci_lower'] for r in results_list]\n",
    "        ci_upper = [r['ci_upper'] for r in results_list]\n",
    "        return np.array(accuracies), np.array(ci_lower), np.array(ci_upper)\n",
    "    \n",
    "    random_clean_acc, random_clean_lower, random_clean_upper = extract_metrics(random_clean)\n",
    "    random_masked_acc, random_masked_lower, random_masked_upper = extract_metrics(random_masked)\n",
    "    natural_clean_acc, natural_clean_lower, natural_clean_upper = extract_metrics(natural_clean)\n",
    "    natural_masked_acc, natural_masked_lower, natural_masked_upper = extract_metrics(natural_masked)\n",
    "\n",
    "    # Create single figure\n",
    "    fig, ax = plt.subplots(figsize=figsize)\n",
    "    ax.set_ylim(0, 1.05)\n",
    "    # X-axis from 100 to 19900\n",
    "    ax.set_xlim(100, 20100)\n",
    "\n",
    "    # Major ticks\n",
    "    ax.set_xticks([5000, 10000, 15000, 20000])\n",
    "    #ax.set_xticklabels(['0.1K', '5K', '10K', '15K', '19.9K'])\n",
    "\n",
    "    # Minor ticks every 1000, starting from 1000\n",
    "    minor_ticks = list(range(1000, 20000, 1000))\n",
    "    ax.set_xticks(minor_ticks, minor=True)\n",
    "    # Plot all four lines with distinct styles\n",
    "    ax.plot(steps, random_clean_acc, label='Clean Model', \n",
    "            color='#0072B2' \n",
    "            )\n",
    "    ax.fill_between(steps, random_clean_lower, random_clean_upper, \n",
    "                    color='#0072B2', alpha=alpha)\n",
    "    \n",
    "    ax.plot(steps, random_masked_acc, label='Masked Model', \n",
    "            color='#E69F00'\n",
    "            )\n",
    "    ax.fill_between(steps, random_masked_lower, random_masked_upper, \n",
    "                    color='#E69F00', alpha=alpha)\n",
    "    \n",
    "    ax.set_xlabel('Training Steps')\n",
    "    ax.set_ylabel('Accuracy')\n",
    "    ax.set_title(f'Random Repetition Performance - {model_size.upper()} Model', \n",
    "                  pad=15)\n",
    "    ax.legend(loc='best',  ncol=1, columnspacing=1.5)\n",
    "    \n",
    "    # Grid\n",
    "    ax.grid(True, axis='y', alpha=0.2, linestyle='-', linewidth=0.5)\n",
    "    ax.set_axisbelow(True)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    \n",
    "    # Save the figure\n",
    "    save_path = os.path.join(output_dir, f'random_repetition_{model_types[0]}_vs_{model_types[1]}.pdf')\n",
    "    plt.savefig(save_path, dpi=300, bbox_inches='tight')\n",
    "    print(f\"Saved: {save_path}\")\n",
    "    \n",
    "    # Also save PNG preview\n",
    "    png_path = save_path.replace('.pdf', '.png')\n",
    "    plt.savefig(png_path, dpi=150, bbox_inches='tight')\n",
    "    print(f\"Preview: {png_path}\")\n",
    "\n",
    "    #-----------------------------------------------------------------\n",
    "    # Create single figure\n",
    "    fig, ax = plt.subplots(figsize=figsize)\n",
    "    ax.set_ylim(0, 1.05)\n",
    "    # X-axis from 100 to 19900\n",
    "    ax.set_xlim(100, 20100)\n",
    "\n",
    "    # Major ticks\n",
    "    ax.set_xticks([5000, 10000, 15000, 20000])\n",
    "    #ax.set_xticklabels(['0.1K', '5K', '10K', '15K', '19.9K'])\n",
    "\n",
    "    # Minor ticks every 1000, starting from 1000\n",
    "    minor_ticks = list(range(1000, 20000, 1000))\n",
    "    ax.set_xticks(minor_ticks, minor=True)\n",
    "    # Plot all four lines with distinct styles\n",
    "    ax.plot(steps, natural_clean_acc, label='Clean Model', \n",
    "            color='#0072B2' \n",
    "            )\n",
    "    ax.fill_between(steps, natural_clean_lower, natural_clean_upper, \n",
    "                    color='#0072B2', alpha=alpha)\n",
    "    \n",
    "    ax.plot(steps, natural_masked_acc, label='Masked Model', \n",
    "            color='#E69F00'\n",
    "            )\n",
    "    ax.fill_between(steps, natural_masked_lower, natural_masked_upper, \n",
    "                    color='#E69F00', alpha=alpha)\n",
    "    \n",
    "    ax.set_xlabel('Training Steps')\n",
    "    ax.set_ylabel('Accuracy')\n",
    "    ax.set_title(f'Natural Text Repetition Performance - {model_size.upper()} Model', \n",
    "                  pad=15)\n",
    "    ax.legend(loc='best',  ncol=1, columnspacing=1.5)\n",
    "    \n",
    "    # Grid\n",
    "    ax.grid(True, axis='y', alpha=0.2, linestyle='-', linewidth=0.5)\n",
    "    ax.set_axisbelow(True)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    \n",
    "    # Save the figure\n",
    "    save_path = os.path.join(output_dir, f'natural_text_repetition_{model_types[0]}_vs_{model_types[1]}.pdf')\n",
    "    plt.savefig(save_path, dpi=300, bbox_inches='tight')\n",
    "    print(f\"Saved: {save_path}\")\n",
    "    \n",
    "    # Also save PNG preview\n",
    "    png_path = save_path.replace('.pdf', '.png')\n",
    "    plt.savefig(png_path, dpi=150, bbox_inches='tight')\n",
    "    print(f\"Preview: {png_path}\")\n",
    "    \n",
    "    return fig\n",
    "\n",
    "def print_repetition_summary(\n",
    "    model_size,\n",
    "    save_dir=\"~pythia_replicate/metrics/non_wandb_metrics/repetition_metrics\",\n",
    "    first_step=100,\n",
    "    last_step=19900\n",
    "):\n",
    "    \"\"\"Print summary statistics for repetition tasks\"\"\"\n",
    "    \n",
    "    print(\"\\n\" + \"=\"*50)\n",
    "    print(\"REPETITION TASK SUMMARY STATISTICS\")\n",
    "    print(\"=\"*50)\n",
    "    print(f\"Model: {model_size}\")\n",
    "    print(f\"Steps: {first_step} to {last_step}\")\n",
    "    print(\"-\"*50)\n",
    "    \n",
    "    # Load data\n",
    "    with open(f\"{save_dir}/random_repetition_clean_{model_size}.pkl\", \"rb\") as f:\n",
    "        random_clean = pkl.load(f)\n",
    "    with open(f\"{save_dir}/natural_text_repetition_clean_{model_size}.pkl\", \"rb\") as f:\n",
    "        natural_clean = pkl.load(f)\n",
    "    with open(f\"{save_dir}/random_repetition_masked_{model_size}.pkl\", \"rb\") as f:\n",
    "        random_masked = pkl.load(f)\n",
    "    with open(f\"{save_dir}/natural_text_repetition_masked_{model_size}.pkl\", \"rb\") as f:\n",
    "        natural_masked = pkl.load(f)\n",
    "    \n",
    "    # Final step statistics\n",
    "    print(\"\\nFinal Step Performance (step {}):\\n\".format(last_step - 100))\n",
    "    print(\"Random Sequence Repetition:\")\n",
    "    print(f\"  Clean:  {random_clean[-1]['accuracy']:.3f} \"\n",
    "          f\"[{random_clean[-1]['ci_lower']:.3f}, {random_clean[-1]['ci_upper']:.3f}]\")\n",
    "    print(f\"  Masked: {random_masked[-1]['accuracy']:.3f} \"\n",
    "          f\"[{random_masked[-1]['ci_lower']:.3f}, {random_masked[-1]['ci_upper']:.3f}]\")\n",
    "    print(f\"  Δ:      {random_masked[-1]['accuracy'] - random_clean[-1]['accuracy']:+.3f}\")\n",
    "    \n",
    "    print(\"\\nNatural Text Repetition:\")\n",
    "    print(f\"  Clean:  {natural_clean[-1]['accuracy']:.3f} \"\n",
    "          f\"[{natural_clean[-1]['ci_lower']:.3f}, {natural_clean[-1]['ci_upper']:.3f}]\")\n",
    "    print(f\"  Masked: {natural_masked[-1]['accuracy']:.3f} \"\n",
    "          f\"[{natural_masked[-1]['ci_lower']:.3f}, {natural_masked[-1]['ci_upper']:.3f}]\")\n",
    "    print(f\"  Δ:      {natural_masked[-1]['accuracy'] - natural_clean[-1]['accuracy']:+.3f}\")\n",
    "    \n",
    "    # Overall statistics\n",
    "    print(\"\\nOverall Statistics (across all steps):\\n\")\n",
    "    \n",
    "    def compute_stats(results_list):\n",
    "        accs = [r['accuracy'] for r in results_list]\n",
    "        return np.mean(accs), np.std(accs), np.max(accs), np.argmax(accs)\n",
    "    \n",
    "    for name, data in [(\"Random (Clean)\", random_clean), \n",
    "                       (\"Random (Masked)\", random_masked),\n",
    "                       (\"Natural (Clean)\", natural_clean), \n",
    "                       (\"Natural (Masked)\", natural_masked)]:\n",
    "        mean, std, max_val, max_idx = compute_stats(data)\n",
    "        max_step = first_step + max_idx * 100\n",
    "        print(f\"{name:20s}: mean={mean:.3f} (±{std:.3f}), \"\n",
    "              f\"max={max_val:.3f} @ step {max_step}\")\n",
    "    \n",
    "    print(\"=\"*50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4a09b11",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_dir = \"~pythia_replicate/metrics/non_wandb_metrics/repetition_metrics\"\n",
    "output_dir = f\"{save_dir}/figures\"\n",
    "\n",
    "print(\"\\n\" + \"=\"*50)\n",
    "print(\"Creating Repetition Task Plots\")\n",
    "print(\"=\"*50)\n",
    "\n",
    "fig1 = plot_repetition(\n",
    "    model_size=model_size,\n",
    "    save_dir=save_dir,\n",
    "    output_dir=output_dir,\n",
    "    first_step=first_step,\n",
    "    last_step=last_step,\n",
    "    figsize=(3.25, 2.5),\n",
    "    alpha=0.3\n",
    ")\n",
    "\n",
    "# Print summary statistics\n",
    "\"\"\"print_repetition_summary(\n",
    "    model_size=model_size,\n",
    "    save_dir=save_dir,\n",
    "    first_step=first_step,\n",
    "    last_step=last_step,\n",
    ")\"\"\"\n",
    "\n",
    "print(\"\\n\" + \"=\"*50)\n",
    "print(\"All repetition plots created successfully!\")\n",
    "print(\"=\"*50)\n",
    "print(f\"\\nOutput directory: {output_dir}/\")\n",
    "print(\"\\nGenerated files:\")\n",
    "print(f\"  - repetition_comparison_{model_size}.pdf\")\n",
    "print(f\"  - repetition_combined_{model_size}.pdf\")\n",
    "print(\"  (Plus PNG previews for each)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f58b080",
   "metadata": {},
   "source": [
    "# Plot All Together"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cae32404",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle as pkl\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "import os\n",
    "from lib.plotting import apply_iclr_style\n",
    "\n",
    "def plot_repetition(\n",
    "    model_size,\n",
    "    save_dir=\"~pythia_replicate/metrics/non_wandb_metrics/repetition_metrics\",\n",
    "    output_dir=None,\n",
    "    first_step=100,\n",
    "    last_step=19900,\n",
    "    figsize=(10, 6),\n",
    "    alpha=0.1\n",
    "):\n",
    "    plt.rcParams.update({\n",
    "        'font.size': 7,           # Smaller base font\n",
    "        'axes.labelsize': 7,      # Smaller labels\n",
    "        'axes.titlesize': 8,      # Smaller title\n",
    "        'xtick.labelsize': 6,     # Smaller tick labels\n",
    "        'ytick.labelsize': 6,\n",
    "        'legend.fontsize': 6,     # Smaller legend\n",
    "        'lines.linewidth': 1.0,   # Thinner lines\n",
    "        'lines.markersize': 3,    # Smaller markers\n",
    "    })\n",
    "    if output_dir is None:\n",
    "        output_dir = f\"{save_dir}/figures\"\n",
    "    os.makedirs(output_dir, exist_ok=True)\n",
    "    \n",
    "    # Apply ICLR style\n",
    "    apply_iclr_style()\n",
    "    \n",
    "    # Load the pickle files\n",
    "    print(\"Loading repetition data for combined plot...\")\n",
    "    with open(f\"{save_dir}/random_repetition_clean_{model_types[0]}.pkl\", \"rb\") as f:\n",
    "        random_clean = pkl.load(f)\n",
    "    \n",
    "    with open(f\"{save_dir}/random_repetition_masked_{model_types[2]}.pkl\", \"rb\") as f:\n",
    "        random_masked = pkl.load(f)\n",
    "\n",
    "    with open(f\"{save_dir}/natural_text_repetition_clean_{model_types[0]}.pkl\", \"rb\") as f:\n",
    "        natural_clean = pkl.load(f)\n",
    "    \n",
    "    with open(f\"{save_dir}/natural_text_repetition_masked_{model_types[2]}.pkl\", \"rb\") as f:\n",
    "        natural_masked = pkl.load(f)\n",
    "    \n",
    "    with open(f\"{save_dir}/random_repetition_masked_{model_types[1]}.pkl\", \"rb\") as f:\n",
    "        random_masked_thresh = pkl.load(f)\n",
    "\n",
    "    with open(f\"{save_dir}/natural_text_repetition_masked_{model_types[1]}.pkl\", \"rb\") as f:\n",
    "        natural_masked_thresh = pkl.load(f)\n",
    "    \n",
    "    # Generate step values\n",
    "    steps = list(range(first_step, last_step, 100))\n",
    "    \n",
    "    # Extract accuracies and confidence intervals\n",
    "    def extract_metrics(results_list):\n",
    "        accuracies = [r['accuracy'] for r in results_list]\n",
    "        ci_lower = [r['ci_lower'] for r in results_list]\n",
    "        ci_upper = [r['ci_upper'] for r in results_list]\n",
    "        return np.array(accuracies), np.array(ci_lower), np.array(ci_upper)\n",
    "    \n",
    "    random_clean_acc, random_clean_lower, random_clean_upper = extract_metrics(random_clean)\n",
    "    random_masked_acc, random_masked_lower, random_masked_upper = extract_metrics(random_masked)\n",
    "    natural_clean_acc, natural_clean_lower, natural_clean_upper = extract_metrics(natural_clean)\n",
    "    natural_masked_acc, natural_masked_lower, natural_masked_upper = extract_metrics(natural_masked)\n",
    "    random_masked_thresh_acc, random_masked_thresh_lower, random_masked_thresh_upper = extract_metrics(random_masked_thresh)\n",
    "    natural_masked_thresh_acc, natural_masked_thresh_lower, natural_masked_thresh_upper = extract_metrics(natural_masked_thresh)\n",
    "\n",
    "    # Create single figure\n",
    "    fig, ax = plt.subplots(figsize=figsize)\n",
    "    ax.set_ylim(0, 1.05)\n",
    "    # X-axis from 100 to 19900\n",
    "    ax.set_xlim(100, 20100)\n",
    "\n",
    "    # Major ticks\n",
    "    ax.set_xticks([5000, 10000, 15000, 20000])\n",
    "    #ax.set_xticklabels(['0.1K', '5K', '10K', '15K', '19.9K'])\n",
    "\n",
    "    # Minor ticks every 1000, starting from 1000\n",
    "    minor_ticks = list(range(1000, 20000, 1000))\n",
    "    ax.set_xticks(minor_ticks, minor=True)\n",
    "    # Plot all four lines with distinct styles\n",
    "    ax.plot(steps, random_clean_acc, label='Vanilla', \n",
    "            color='#0072B2' \n",
    "            )\n",
    "    ax.fill_between(steps, random_clean_lower, random_clean_upper, \n",
    "                    color='#0072B2', alpha=alpha)\n",
    "    \n",
    "    ax.plot(steps, random_masked_acc, label='Hapax', \n",
    "            color='#E69F00'\n",
    "            )\n",
    "    ax.fill_between(steps, random_masked_lower, random_masked_upper, \n",
    "                    color='#E69F00', alpha=alpha)\n",
    "\n",
    "    ax.plot(steps, random_masked_thresh_acc, label='Hapax Thresholded', \n",
    "            color='#40B0A6'\n",
    "            )\n",
    "    ax.fill_between(steps, random_masked_thresh_lower, random_masked_thresh_upper, \n",
    "                    color='#40B0A6', alpha=alpha)\n",
    "    \n",
    "    ax.set_xlabel('Training Steps')\n",
    "    ax.set_ylabel('Accuracy')\n",
    "    ax.set_title(f'Random Repetition Performance', \n",
    "                  pad=15)\n",
    "    ax.legend(loc='best',  ncol=1, columnspacing=1.5)\n",
    "    \n",
    "    # Grid\n",
    "    ax.grid(True, axis='y', alpha=0.2, linestyle='-', linewidth=0.5)\n",
    "    ax.set_axisbelow(True)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    \n",
    "    # Save the figure\n",
    "    save_path = os.path.join(output_dir, f'random_repetition_{model_types[0]}_vs_{model_types[1]}.pdf')\n",
    "    plt.savefig(save_path, dpi=300, bbox_inches='tight')\n",
    "    print(f\"Saved: {save_path}\")\n",
    "    \n",
    "    # Also save PNG preview\n",
    "    png_path = save_path.replace('.pdf', '.png')\n",
    "    plt.savefig(png_path, dpi=150, bbox_inches='tight')\n",
    "    print(f\"Preview: {png_path}\")\n",
    "\n",
    "    #-----------------------------------------------------------------\n",
    "    # Create single figure\n",
    "    fig, ax = plt.subplots(figsize=figsize)\n",
    "    ax.set_ylim(0, 1.05)\n",
    "    # X-axis from 100 to 19900\n",
    "    ax.set_xlim(100, 20100)\n",
    "\n",
    "    # Major ticks\n",
    "    ax.set_xticks([5000, 10000, 15000, 20000])\n",
    "    #ax.set_xticklabels(['0.1K', '5K', '10K', '15K', '19.9K'])\n",
    "\n",
    "    # Minor ticks every 1000, starting from 1000\n",
    "    minor_ticks = list(range(1000, 20000, 1000))\n",
    "    ax.set_xticks(minor_ticks, minor=True)\n",
    "    # Plot all four lines with distinct styles\n",
    "    ax.plot(steps, natural_clean_acc, label='Vanilla', \n",
    "            color='#0072B2' \n",
    "            )\n",
    "    ax.fill_between(steps, natural_clean_lower, natural_clean_upper, \n",
    "                    color='#0072B2', alpha=alpha)\n",
    "    \n",
    "    ax.plot(steps, natural_masked_acc, label='Hapax', \n",
    "            color='#E69F00'\n",
    "            )\n",
    "    ax.fill_between(steps, natural_masked_lower, natural_masked_upper, \n",
    "                    color='#E69F00', alpha=alpha)\n",
    "    \n",
    "    ax.plot(steps, natural_masked_thresh_acc, label='Hapax Thresholded', \n",
    "            color='#40B0A6'\n",
    "            )\n",
    "    ax.fill_between(steps, natural_masked_thresh_lower, natural_masked_thresh_upper, \n",
    "                    color='#40B0A6', alpha=alpha)\n",
    "    \n",
    "    ax.set_xlabel('Training Steps')\n",
    "    ax.set_ylabel('Accuracy')\n",
    "    ax.set_title(f'Natural Text Repetition Performance', \n",
    "                  pad=15)\n",
    "    ax.legend(loc='best',  ncol=1, columnspacing=1.5)\n",
    "    \n",
    "    # Grid\n",
    "    ax.grid(True, axis='y', alpha=0.2, linestyle='-', linewidth=0.5)\n",
    "    ax.set_axisbelow(True)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    \n",
    "    # Save the figure\n",
    "    save_path = os.path.join(output_dir, f'natural_text_repetition_{model_types[0]}_vs_{model_types[1]}.pdf')\n",
    "    plt.savefig(save_path, dpi=300, bbox_inches='tight')\n",
    "    print(f\"Saved: {save_path}\")\n",
    "    \n",
    "    # Also save PNG preview\n",
    "    png_path = save_path.replace('.pdf', '.png')\n",
    "    plt.savefig(png_path, dpi=150, bbox_inches='tight')\n",
    "    print(f\"Preview: {png_path}\")\n",
    "    \n",
    "    return fig\n",
    "\n",
    "def print_repetition_summary(\n",
    "    model_size,\n",
    "    save_dir=\"~pythia_replicate/metrics/non_wandb_metrics/repetition_metrics\",\n",
    "    first_step=100,\n",
    "    last_step=19900\n",
    "):\n",
    "    \"\"\"Print summary statistics for repetition tasks\"\"\"\n",
    "    \n",
    "    print(\"\\n\" + \"=\"*50)\n",
    "    print(\"REPETITION TASK SUMMARY STATISTICS\")\n",
    "    print(\"=\"*50)\n",
    "    print(f\"Model: {model_size}\")\n",
    "    print(f\"Steps: {first_step} to {last_step}\")\n",
    "    print(\"-\"*50)\n",
    "    \n",
    "    # Load data\n",
    "    with open(f\"{save_dir}/random_repetition_clean_{model_size}.pkl\", \"rb\") as f:\n",
    "        random_clean = pkl.load(f)\n",
    "    with open(f\"{save_dir}/natural_text_repetition_clean_{model_size}.pkl\", \"rb\") as f:\n",
    "        natural_clean = pkl.load(f)\n",
    "    with open(f\"{save_dir}/random_repetition_masked_{model_size}.pkl\", \"rb\") as f:\n",
    "        random_masked = pkl.load(f)\n",
    "    with open(f\"{save_dir}/natural_text_repetition_masked_{model_size}.pkl\", \"rb\") as f:\n",
    "        natural_masked = pkl.load(f)\n",
    "    \n",
    "    # Final step statistics\n",
    "    print(\"\\nFinal Step Performance (step {}):\\n\".format(last_step - 100))\n",
    "    print(\"Random Sequence Repetition:\")\n",
    "    print(f\"  Clean:  {random_clean[-1]['accuracy']:.3f} \"\n",
    "          f\"[{random_clean[-1]['ci_lower']:.3f}, {random_clean[-1]['ci_upper']:.3f}]\")\n",
    "    print(f\"  Masked: {random_masked[-1]['accuracy']:.3f} \"\n",
    "          f\"[{random_masked[-1]['ci_lower']:.3f}, {random_masked[-1]['ci_upper']:.3f}]\")\n",
    "    print(f\"  Δ:      {random_masked[-1]['accuracy'] - random_clean[-1]['accuracy']:+.3f}\")\n",
    "    \n",
    "    print(\"\\nNatural Text Repetition:\")\n",
    "    print(f\"  Clean:  {natural_clean[-1]['accuracy']:.3f} \"\n",
    "          f\"[{natural_clean[-1]['ci_lower']:.3f}, {natural_clean[-1]['ci_upper']:.3f}]\")\n",
    "    print(f\"  Masked: {natural_masked[-1]['accuracy']:.3f} \"\n",
    "          f\"[{natural_masked[-1]['ci_lower']:.3f}, {natural_masked[-1]['ci_upper']:.3f}]\")\n",
    "    print(f\"  Δ:      {natural_masked[-1]['accuracy'] - natural_clean[-1]['accuracy']:+.3f}\")\n",
    "    \n",
    "    # Overall statistics\n",
    "    print(\"\\nOverall Statistics (across all steps):\\n\")\n",
    "    \n",
    "    def compute_stats(results_list):\n",
    "        accs = [r['accuracy'] for r in results_list]\n",
    "        return np.mean(accs), np.std(accs), np.max(accs), np.argmax(accs)\n",
    "    \n",
    "    for name, data in [(\"Random (Clean)\", random_clean), \n",
    "                       (\"Random (Masked)\", random_masked),\n",
    "                       (\"Natural (Clean)\", natural_clean), \n",
    "                       (\"Natural (Masked)\", natural_masked)]:\n",
    "        mean, std, max_val, max_idx = compute_stats(data)\n",
    "        max_step = first_step + max_idx * 100\n",
    "        print(f\"{name:20s}: mean={mean:.3f} (±{std:.3f}), \"\n",
    "              f\"max={max_val:.3f} @ step {max_step}\")\n",
    "    \n",
    "    print(\"=\"*50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2820715c",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_dir = \"~pythia_replicate/metrics/non_wandb_metrics/repetition_metrics\"\n",
    "output_dir = f\"{save_dir}/figures\"\n",
    "\n",
    "print(\"\\n\" + \"=\"*50)\n",
    "print(\"Creating Repetition Task Plots\")\n",
    "print(\"=\"*50)\n",
    "\n",
    "model_types = [\"clean_1b\", \"masked_bigram_loss_1b_thresh0.3_eq\", \"masked_bigram_loss_1b\"]\n",
    "\n",
    "fig1 = plot_repetition(\n",
    "    model_size=model_size,\n",
    "    save_dir=save_dir,\n",
    "    output_dir=output_dir,\n",
    "    first_step=first_step,\n",
    "    last_step=last_step,\n",
    "    figsize=(3.25, 2.5),\n",
    "    alpha=0.3\n",
    ")\n",
    "\n",
    "# Print summary statistics\n",
    "\"\"\"print_repetition_summary(\n",
    "    model_size=model_size,\n",
    "    save_dir=save_dir,\n",
    "    first_step=first_step,\n",
    "    last_step=last_step,\n",
    ")\"\"\"\n",
    "\n",
    "print(\"\\n\" + \"=\"*50)\n",
    "print(\"All repetition plots created successfully!\")\n",
    "print(\"=\"*50)\n",
    "print(f\"\\nOutput directory: {output_dir}/\")\n",
    "print(\"\\nGenerated files:\")\n",
    "print(f\"  - repetition_comparison_{model_size}.pdf\")\n",
    "print(f\"  - repetition_combined_{model_size}.pdf\")\n",
    "print(\"  (Plus PNG previews for each)\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pythia_replicate",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
