{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "965f933f",
   "metadata": {},
   "source": [
    "# IMPORTS "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbfe65ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7afb747e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36dbdc6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"7\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19e7eb2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "from pathlib import Path\n",
    "import sys\n",
    "import json\n",
    "from transformers import AutoTokenizer\n",
    "from collections import defaultdict\n",
    "import os\n",
    "\n",
    "from pathlib import Path\n",
    "\n",
    "cwd_str = str(Path.cwd())\n",
    "if \"disk\" in cwd_str:\n",
    "    file_system = \"disk\"\n",
    "elif \"share\" in cwd_str:\n",
    "    file_system = \"share\"\n",
    "\n",
    "if file_system == \"disk\":\n",
    "    model_location = \"pythia_replicate/hf_output\"\n",
    "else:\n",
    "    model_location = \"pythia_replicate_public_models\"\n",
    "\n",
    "sys.path.append(f\"~/pythia_replicate\")\n",
    "\n",
    "from lib.model_setup import load_model_and_tokenizer\n",
    "from lib.fv import fv_icl_tasks_benchmark_with_ci"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e876df4",
   "metadata": {},
   "outputs": [],
   "source": [
    "#variables\n",
    "skip_clean = True\n",
    "model_size = \"1b-threshold-0.3\"\n",
    "first_step = 100\n",
    "step = 19900\n",
    "n_shots = 5\n",
    "debug = True\n",
    "batch_size = 32\n",
    "max_sample_size = 5000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a03a20d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "if model_size == \"160m\":\n",
    "    model_types = [\"clean_v3\", \"masked_bigram_loss_v4\"]\n",
    "elif model_size == \"160m-threshold-0.3\":\n",
    "    model_types = [\"clean_v3\", \"masked_bigram_loss_thresh0.3_eq\"]\n",
    "elif model_size == \"1b\":\n",
    "    model_types = [\"clean_1b\", \"masked_bigram_loss_1b\"]\n",
    "elif model_size == \"1b-threshold-0.3\":\n",
    "    model_types = [\"clean_1b\", \"masked_bigram_loss_1b_thresh0.3_eq\"]\n",
    "\n",
    "if not skip_clean:\n",
    "    clean_model_path = f\"~/{model_location}/{model_types[0]}/step={step}\"\n",
    "    clean_model, tokenizer_clean = load_model_and_tokenizer(clean_model_path)\n",
    "masked_model, tokenizer_masked = load_model_and_tokenizer(masked_model_path)\n",
    "masked_model_path = f\"~/{model_location}/{model_types[1]}/step={step}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6699dc24",
   "metadata": {},
   "outputs": [],
   "source": [
    "abstractive_tasks_list = os.listdir(\"~pythia_replicate/dataset/icl_tasks/abstractive\")\n",
    "extractive_tasks_list = os.listdir(\"~pythia_replicate/dataset/icl_tasks/extractive\")\n",
    "\n",
    "\n",
    "extractive_tasks = [f\"extractive/{f[:-5]}\" for f in extractive_tasks_list if f.endswith('.json')]\n",
    "abstractive_tasks = [f\"abstractive/{f[:-5]}\" for f in abstractive_tasks_list if f.endswith('.json')]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c1e0b74",
   "metadata": {},
   "source": [
    "# EXECUTE TASKS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d9b971c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fv_tasks_benchmark(\n",
    "    model,\n",
    "    tokenizer,\n",
    "    task_names=None,\n",
    "    num_of_shots=5,\n",
    "    max_sample_size=5000,\n",
    "    batch_size=64,\n",
    "    return_per_sample=True,\n",
    "):\n",
    "    icl_task_results = {}\n",
    "    icl_task_results_per_sample = {}\n",
    "    for task_name in task_names:\n",
    "        icl_results, per_sample_correct = fv_icl_tasks_benchmark_with_ci(\n",
    "            model,\n",
    "            tokenizer,\n",
    "            task_name=task_name,\n",
    "            num_of_shots=num_of_shots,\n",
    "            max_samples=max_sample_size,\n",
    "            batch_size=batch_size,\n",
    "            use_wandb=False,\n",
    "            return_per_sample=return_per_sample,\n",
    "        )\n",
    "        task_name = task_name.split(\"/\")[-1]\n",
    "        icl_task_results[task_name] = icl_results\n",
    "        icl_task_results_per_sample[task_name] = per_sample_correct\n",
    "    return icl_task_results, icl_task_results_per_sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e07b88a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not skip_clean:\n",
    "    extractive_results_clean, extractive_results_clean_per_sample = fv_tasks_benchmark(\n",
    "        clean_model,\n",
    "        tokenizer_clean,\n",
    "        extractive_tasks,\n",
    "        n_shots,\n",
    "        max_sample_size,\n",
    "        batch_size,\n",
    "    )\n",
    "\n",
    "    abstractive_results_clean, abstractive_results_clean_per_sample = fv_tasks_benchmark(\n",
    "        clean_model,\n",
    "        tokenizer_clean,\n",
    "        abstractive_tasks,\n",
    "        n_shots,\n",
    "        max_sample_size,\n",
    "        batch_size,\n",
    "    )\n",
    "\n",
    "extractive_results_masked, extractive_results_masked_per_sample = fv_tasks_benchmark(\n",
    "    masked_model,\n",
    "    tokenizer_masked,\n",
    "    extractive_tasks,\n",
    "    n_shots,\n",
    "    max_sample_size,\n",
    "    batch_size,\n",
    ")\n",
    "\n",
    "abstractive_results_masked, abstractive_results_masked_per_sample = fv_tasks_benchmark(\n",
    "    masked_model,\n",
    "    tokenizer_masked,\n",
    "    abstractive_tasks,\n",
    "    n_shots,\n",
    "    max_sample_size,\n",
    "    batch_size,\n",
    ")\n",
    "\n",
    "save_dir = f\"~pythia_replicate/metrics/non_wandb_metrics/fv_accuracies\"\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "if not skip_clean:\n",
    "    with open(os.path.join(save_dir, f\"extractive_results_{model_types[0]}_{step}_{n_shots}.json\"), \"w\") as f:\n",
    "        json.dump(extractive_results_clean, f)\n",
    "\n",
    "    with open(os.path.join(save_dir, f\"abstractive_results_{model_types[0]}_{step}_{n_shots}.json\"), \"w\") as f:\n",
    "        json.dump(abstractive_results_clean, f)\n",
    "    \n",
    "    with open(os.path.join(save_dir, f\"extractive_results_{model_types[0]}_{model_size}_{step}_{n_shots}_per_sample.json\"), \"w\") as f:\n",
    "        json.dump(extractive_results_clean_per_sample, f)\n",
    "\n",
    "    with open(os.path.join(save_dir, f\"abstractive_results_{model_types[0]}_{model_size}_{step}_{n_shots}_per_sample.json\"), \"w\") as f:\n",
    "        json.dump(abstractive_results_clean_per_sample, f)\n",
    "\n",
    "with open(os.path.join(save_dir, f\"extractive_results_{model_types[1]}_{step}_{n_shots}.json\"), \"w\") as f:\n",
    "    json.dump(extractive_results_masked, f)\n",
    "\n",
    "with open(os.path.join(save_dir, f\"abstractive_results_{model_types[1]}_{step}_{n_shots}.json\"), \"w\") as f:\n",
    "    json.dump(abstractive_results_masked, f)\n",
    "\n",
    "with open(os.path.join(save_dir, f\"extractive_results_{model_types[1]}_{model_size}_{step}_{n_shots}_per_sample.json\"), \"w\") as f:\n",
    "    json.dump(extractive_results_masked_per_sample, f)\n",
    "\n",
    "with open(os.path.join(save_dir, f\"abstractive_results_{model_types[1]}_{model_size}_{step}_{n_shots}_per_sample.json\"), \"w\") as f:\n",
    "    json.dump(abstractive_results_masked_per_sample, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c26056c",
   "metadata": {},
   "source": [
    "## GRAPH THE RESULTS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd1d496c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "from matplotlib.patches import Rectangle\n",
    "import os\n",
    "from lib.plotting import apply_iclr_style\n",
    "\n",
    "def load_results(filepath):\n",
    "    \"\"\"Load JSON results from file\"\"\"\n",
    "    with open(filepath, 'r') as f:\n",
    "        return json.load(f)\n",
    "\n",
    "def group_extractive_tasks(results):\n",
    "    \"\"\"Organize extractive tasks into semantic groups\"\"\"\n",
    "    groups = {\n",
    "        'Positional Selection': {\n",
    "            'tasks': ['choose_first_of_3', 'choose_first_of_5', \n",
    "                     'choose_middle_of_3', 'choose_middle_of_5',\n",
    "                     'choose_last_of_3', 'choose_last_of_5'],\n",
    "            'color': '#2E7D32'  # Green\n",
    "        },\n",
    "        'Alphabetical Ordering': {\n",
    "            'tasks': ['alphabetically_first_3', 'alphabetically_first_5',\n",
    "                     'alphabetically_last_3', 'alphabetically_last_5'],\n",
    "            'color': '#1565C0'  # Blue\n",
    "        },\n",
    "        'Named Entity Recognition': {\n",
    "            'tasks': ['conll2003_person', 'conll2003_location', \n",
    "                     'conll2003_organization'],\n",
    "            'color': '#E65100'  # Orange\n",
    "        },\n",
    "        'Odd One Out': {\n",
    "            'tasks': [k for k in results.keys() if '_v_' in k],\n",
    "            'color': '#6A1B9A'  # Purple\n",
    "        },\n",
    "        'Reading Comprehension': {\n",
    "            'tasks': ['squad_val'],\n",
    "            'color': '#B71C1C'  # Red\n",
    "        }\n",
    "    }\n",
    "    return groups\n",
    "\n",
    "def plot_task_category(clean_data, masked_data, category_name, category_info, \n",
    "                       save_path=None, figsize=None):\n",
    "    \"\"\"Plot all tasks within a single category\"\"\"\n",
    "    \n",
    "    apply_iclr_style()\n",
    "    \n",
    "    # Get tasks that exist in the data\n",
    "    tasks = [t for t in category_info['tasks'] if t in clean_data]\n",
    "    if not tasks:\n",
    "        print(f\"No tasks found for {category_name}\")\n",
    "        return None\n",
    "    \n",
    "    # Determine figure size based on number of tasks\n",
    "    if figsize is None:\n",
    "        if len(tasks) <= 6:\n",
    "            figsize = (8, 5)\n",
    "        elif len(tasks) <= 10:\n",
    "            figsize = (10, 5)\n",
    "        else:\n",
    "            figsize = (12, 5)\n",
    "    \n",
    "    # Create figure\n",
    "    fig, ax = plt.subplots(figsize=figsize)\n",
    "    \n",
    "    # Prepare task labels\n",
    "    task_labels = []\n",
    "    for task in tasks:\n",
    "        if 'choose' in task:\n",
    "            parts = task.replace('choose_', '').replace('_of_', ' of ').replace('_', ' ')\n",
    "            label = parts.title()\n",
    "        elif 'alphabetically' in task:\n",
    "            parts = task.replace('alphabetically_', '').replace('_', ' ')\n",
    "            label = f\"Alphabetically {parts.title()}\"\n",
    "        elif 'conll2003' in task:\n",
    "            label = f\"CoNLL: {task.replace('conll2003_', '').title()}\"\n",
    "        elif '_v_' in task:\n",
    "            parts = task.split('_v_')\n",
    "            num = task.split('_')[-1]\n",
    "            label = f\"{parts[0].title()} vs {parts[1].replace(f'_{num}', '').title()} ({num})\"\n",
    "        elif task == 'squad_val':\n",
    "            label = 'SQuAD Validation'\n",
    "        else:\n",
    "            label = task.replace('_', ' ').title()\n",
    "        task_labels.append(label)\n",
    "    \n",
    "    # Prepare data for plotting\n",
    "    clean_accuracies = [clean_data[task]['accuracy'] for task in tasks]\n",
    "    clean_errors = [(clean_data[task]['accuracy'] - clean_data[task]['ci_lower'],\n",
    "                     clean_data[task]['ci_upper'] - clean_data[task]['accuracy']) \n",
    "                    for task in tasks]\n",
    "    \n",
    "    masked_accuracies = [masked_data[task]['accuracy'] for task in tasks]\n",
    "    masked_errors = [(masked_data[task]['accuracy'] - masked_data[task]['ci_lower'],\n",
    "                      masked_data[task]['ci_upper'] - masked_data[task]['accuracy']) \n",
    "                     for task in tasks]\n",
    "    \n",
    "    # Set up bar positions\n",
    "    x = np.arange(len(tasks))\n",
    "    width = 0.35\n",
    "    \n",
    "    # Create bars with error bars\n",
    "    bars1 = ax.bar(x - width/2, clean_accuracies, width, \n",
    "                   yerr=np.array(clean_errors).T,\n",
    "                   label='Clean', color='#0173B2',\n",
    "                   capsize=3, alpha=0.9, edgecolor='black', linewidth=0.5)\n",
    "    \n",
    "    bars2 = ax.bar(x + width/2, masked_accuracies, width,\n",
    "                   yerr=np.array(masked_errors).T,\n",
    "                   label='Masked', color='#DE8F05',\n",
    "                   capsize=3, alpha=0.9, edgecolor='black', linewidth=0.5)\n",
    "    \n",
    "    # Add value labels on bars (above error bars)\n",
    "    def add_value_labels(bars, values, errors):\n",
    "        for bar, val, err in zip(bars, values, errors):\n",
    "            height = bar.get_height()\n",
    "            # Position text above the error bar\n",
    "            error_top = err[1] if isinstance(err, tuple) else err\n",
    "            if height > 0.05:  # Only show label if bar is visible\n",
    "                ax.text(bar.get_x() + bar.get_width()/2., height + error_top + 0.02,\n",
    "                       f'{val:.3f}' if val < 0.1 else f'{val:.2f}',\n",
    "                       ha='center', va='bottom', fontsize=7)\n",
    "    \n",
    "    add_value_labels(bars1, clean_accuracies, clean_errors)\n",
    "    add_value_labels(bars2, masked_accuracies, masked_errors)\n",
    "    \n",
    "    # Customize plot\n",
    "    ax.set_xlabel('Task', fontsize=11, fontweight='medium')\n",
    "    ax.set_ylabel('Accuracy', fontsize=11, fontweight='medium')\n",
    "    ax.set_title(f'{category_name} Tasks', \n",
    "                 fontsize=12, fontweight='bold', pad=10)\n",
    "    \n",
    "    # Set x-axis labels\n",
    "    ax.set_xticks(x)\n",
    "    ax.set_xticklabels(task_labels, rotation=45, ha='right', fontsize=9)\n",
    "    \n",
    "    # Set y-axis limits with extra padding for labels\n",
    "    ax.set_ylim(0, 1.1)  # Extra space for labels above error bars\n",
    "    \n",
    "    # Add legend\n",
    "    ax.legend(loc='upper right', fontsize=9)\n",
    "    \n",
    "    # Add grid\n",
    "    ax.grid(True, axis='y', alpha=0.2, linestyle='-', linewidth=0.5)\n",
    "    ax.set_axisbelow(True)\n",
    "    \n",
    "    # Remove top and right spines\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "    ax.spines['left'].set_linewidth(0.8)\n",
    "    ax.spines['bottom'].set_linewidth(0.8)\n",
    "    \n",
    "    # Add sample sizes below (optional)\n",
    "    \"\"\"for i, task in enumerate(tasks):\n",
    "        n_samples = clean_data[task]['n_samples']\n",
    "        ax.text(i, -0.18, f'n={n_samples}', ha='center', va='top', \n",
    "                fontsize=7, color='gray', transform=ax.get_xaxis_transform())\"\"\"\n",
    "    \n",
    "    # Adjust layout\n",
    "    plt.tight_layout()\n",
    "    \n",
    "    # Save if path provided\n",
    "    if save_path:\n",
    "        plt.savefig(save_path, dpi=300, bbox_inches='tight')\n",
    "        print(f\"  Saved: {save_path}\")\n",
    "        \n",
    "        # Also save PNG preview\n",
    "        png_path = save_path.replace('.pdf', '.png')\n",
    "        plt.savefig(png_path, dpi=150, bbox_inches='tight')\n",
    "        print(f\"  Preview: {png_path}\")\n",
    "    \n",
    "    return fig\n",
    "\n",
    "def plot_all_extractive_categories(clean_data, masked_data, output_dir='figures'):\n",
    "    \"\"\"Create separate plots for each extractive task category\"\"\"\n",
    "    \n",
    "    print(\"Creating extractive task plots by category...\")\n",
    "    \n",
    "    # Create output directory if it doesn't exist\n",
    "    os.makedirs(output_dir, exist_ok=True)\n",
    "    \n",
    "    groups = group_extractive_tasks(clean_data)\n",
    "    figures = {}\n",
    "    \n",
    "    for category_name, category_info in groups.items():\n",
    "        # Create filename from category name\n",
    "        filename = category_name.lower().replace(' ', '_').replace('\\n', '_')\n",
    "        save_path = os.path.join(output_dir, f'extractive_{filename}.pdf')\n",
    "        \n",
    "        print(f\"\\nProcessing {category_name}:\")\n",
    "        print(f\"  Tasks: {len([t for t in category_info['tasks'] if t in clean_data])}\")\n",
    "        \n",
    "        # Create plot for this category\n",
    "        fig = plot_task_category(clean_data, masked_data, \n",
    "                                category_name, category_info, \n",
    "                                save_path=save_path)\n",
    "        figures[category_name] = fig\n",
    "    \n",
    "    return figures\n",
    "\n",
    "def plot_abstractive_results(clean_data, masked_data, save_path='abstractive_results_comparison.pdf'):\n",
    "    \"\"\"Create ICLR-quality comparison plot for abstractive results\"\"\"\n",
    "    \n",
    "    apply_iclr_style()\n",
    "    \n",
    "    # Define task display names\n",
    "    task_display_names = {\n",
    "        'capitalize_first_letter': 'Capitalize\\nFirst Letter',\n",
    "        'country-capital': 'Country-\\nCapital',\n",
    "        'synonym': 'Synonym',\n",
    "        'antonym': 'Antonym',\n",
    "        'capitalize': 'Capitalize\\nAll',\n",
    "        'national_parks': 'National\\nParks'\n",
    "    }\n",
    "    \n",
    "    # Colorblind-friendly colors\n",
    "    colors = {\n",
    "        'clean': '#0173B2',  # Blue\n",
    "        'masked': '#DE8F05',  # Orange\n",
    "    }\n",
    "    \n",
    "    # Extract data\n",
    "    tasks = list(clean_data.keys())\n",
    "    n_tasks = len(tasks)\n",
    "    \n",
    "    # Prepare data for plotting\n",
    "    clean_accuracies = [clean_data[task]['accuracy'] for task in tasks]\n",
    "    clean_errors = [(clean_data[task]['accuracy'] - clean_data[task]['ci_lower'],\n",
    "                     clean_data[task]['ci_upper'] - clean_data[task]['accuracy']) \n",
    "                    for task in tasks]\n",
    "    \n",
    "    masked_accuracies = [masked_data[task]['accuracy'] for task in tasks]\n",
    "    masked_errors = [(masked_data[task]['accuracy'] - masked_data[task]['ci_lower'],\n",
    "                      masked_data[task]['ci_upper'] - masked_data[task]['accuracy']) \n",
    "                     for task in tasks]\n",
    "    \n",
    "    # Create figure\n",
    "    fig, ax = plt.subplots(figsize=(8, 5))\n",
    "    \n",
    "    # Set up bar positions\n",
    "    x = np.arange(n_tasks)\n",
    "    width = 0.35\n",
    "    \n",
    "    # Create bars with error bars\n",
    "    bars1 = ax.bar(x - width/2, clean_accuracies, width, \n",
    "                   yerr=np.array(clean_errors).T,\n",
    "                   label='Clean', color=colors['clean'],\n",
    "                   capsize=4, alpha=0.9, edgecolor='black', linewidth=0.5)\n",
    "    \n",
    "    bars2 = ax.bar(x + width/2, masked_accuracies, width,\n",
    "                   yerr=np.array(masked_errors).T,\n",
    "                   label='Masked', color=colors['masked'],\n",
    "                   capsize=4, alpha=0.9, edgecolor='black', linewidth=0.5)\n",
    "    \n",
    "    # Add value labels on bars (above error bars)\n",
    "    def add_value_labels(bars, values, errors):\n",
    "        for bar, val, err in zip(bars, values, errors):\n",
    "            height = bar.get_height()\n",
    "            # Position text above the error bar\n",
    "            error_top = err[1] if isinstance(err, tuple) else err\n",
    "            if height > 0.05:\n",
    "                ax.text(bar.get_x() + bar.get_width()/2., height + error_top + 0.02,\n",
    "                       f'{val:.3f}' if val < 0.1 else f'{val:.2f}',\n",
    "                       ha='center', va='bottom', fontsize=8)\n",
    "    \n",
    "    add_value_labels(bars1, clean_accuracies, clean_errors)\n",
    "    add_value_labels(bars2, masked_accuracies, masked_errors)\n",
    "    \n",
    "    # Customize plot\n",
    "    ax.set_xlabel('Task', fontsize=12, fontweight='medium')\n",
    "    ax.set_ylabel('Accuracy', fontsize=12, fontweight='medium')\n",
    "    ax.set_title('Abstractive Task Performance', \n",
    "                 fontsize=13, fontweight='bold', pad=15)\n",
    "    \n",
    "    # Set x-axis labels\n",
    "    ax.set_xticks(x)\n",
    "    ax.set_xticklabels([task_display_names.get(task, task) for task in tasks])\n",
    "    \n",
    "    # Set y-axis limits with extra padding for labels above error bars\n",
    "    max_y = max(max([d['ci_upper'] for d in clean_data.values()]), \n",
    "                max([d['ci_upper'] for d in masked_data.values()]))\n",
    "    ax.set_ylim(0, max_y * 1.15)  # 15% extra space for labels above error bars\n",
    "    \n",
    "    # Add legend\n",
    "    ax.legend(loc='upper left', fontsize=10)\n",
    "    \n",
    "    # Fine-tune grid\n",
    "    ax.grid(True, axis='y', alpha=0.2, linestyle='-', linewidth=0.5)\n",
    "    ax.set_axisbelow(True)\n",
    "    \n",
    "    # Remove top and right spines\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "    ax.spines['left'].set_linewidth(0.8)\n",
    "    ax.spines['bottom'].set_linewidth(0.8)\n",
    "    \n",
    "    # Adjust layout\n",
    "    plt.tight_layout()\n",
    "    \n",
    "    # Save figure\n",
    "    plt.savefig(save_path, dpi=300, bbox_inches='tight')\n",
    "    print(f\"Saved: {save_path}\")\n",
    "    \n",
    "    # Also save as PNG for quick preview\n",
    "    png_path = save_path.replace('.pdf', '.png')\n",
    "    plt.savefig(png_path, dpi=150, bbox_inches='tight')\n",
    "    print(f\"Preview: {png_path}\")\n",
    "    \n",
    "    return fig, ax\n",
    "\n",
    "def create_summary_figure(clean_data_extractive, masked_data_extractive, \n",
    "                         clean_data_abstractive, masked_data_abstractive,\n",
    "                         save_path='all_tasks_summary.pdf'):\n",
    "    \"\"\"Create a summary figure showing mean performance across task categories\"\"\"\n",
    "    \n",
    "    apply_iclr_style()\n",
    "    \n",
    "    groups = group_extractive_tasks(clean_data_extractive)\n",
    "    \n",
    "    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))\n",
    "    \n",
    "    # Left panel: Extractive task categories\n",
    "    group_stats = []\n",
    "    group_names = []\n",
    "    for group_name, group_info in groups.items():\n",
    "        tasks = [t for t in group_info['tasks'] if t in clean_data_extractive]\n",
    "        if tasks:\n",
    "            clean_mean = np.mean([clean_data_extractive[t]['accuracy'] for t in tasks])\n",
    "            masked_mean = np.mean([masked_data_extractive[t]['accuracy'] for t in tasks])\n",
    "            group_stats.append((clean_mean, masked_mean))\n",
    "            group_names.append(group_name.replace('\\n', ' '))\n",
    "    \n",
    "    x1 = np.arange(len(group_names))\n",
    "    width = 0.35\n",
    "    clean_means = [s[0] for s in group_stats]\n",
    "    masked_means = [s[1] for s in group_stats]\n",
    "    \n",
    "    ax1.bar(x1 - width/2, clean_means, width, label='Clean', color='#0173B2', alpha=0.9)\n",
    "    ax1.bar(x1 + width/2, masked_means, width, label='Masked', color='#DE8F05', alpha=0.9)\n",
    "    ax1.set_xlabel('Task Category')\n",
    "    ax1.set_ylabel('Mean Accuracy')\n",
    "    ax1.set_title('Extractive Tasks (by Category)', fontweight='bold')\n",
    "    ax1.set_xticks(x1)\n",
    "    ax1.set_xticklabels(group_names, rotation=45, ha='right')\n",
    "    ax1.legend()\n",
    "    ax1.set_ylim(0, 1.0)\n",
    "    ax1.grid(True, axis='y', alpha=0.2)\n",
    "    \n",
    "    # Right panel: Abstractive tasks\n",
    "    tasks = list(clean_data_abstractive.keys())\n",
    "    x2 = np.arange(len(tasks))\n",
    "    clean_accs = [clean_data_abstractive[t]['accuracy'] for t in tasks]\n",
    "    masked_accs = [masked_data_abstractive[t]['accuracy'] for t in tasks]\n",
    "    \n",
    "    ax2.bar(x2 - width/2, clean_accs, width, label='Clean', color='#0173B2', alpha=0.9)\n",
    "    ax2.bar(x2 + width/2, masked_accs, width, label='Masked', color='#DE8F05', alpha=0.9)\n",
    "    ax2.set_xlabel('Task')\n",
    "    ax2.set_ylabel('Accuracy')\n",
    "    ax2.set_title('Abstractive Tasks', fontweight='bold')\n",
    "    ax2.set_xticks(x2)\n",
    "    ax2.set_xticklabels([t.replace('_', '\\n').replace('-', '\\n') for t in tasks], \n",
    "                        rotation=45, ha='right')\n",
    "    ax2.legend()\n",
    "    ax2.set_ylim(0, 1.0)\n",
    "    ax2.grid(True, axis='y', alpha=0.2)\n",
    "    \n",
    "    # Remove spines\n",
    "    for ax in [ax1, ax2]:\n",
    "        ax.spines['top'].set_visible(False)\n",
    "        ax.spines['right'].set_visible(False)\n",
    "    \n",
    "    plt.suptitle('Task Performance Overview', \n",
    "                 fontsize=14, fontweight='bold', y=1.08)\n",
    "    plt.tight_layout()\n",
    "    \n",
    "    plt.savefig(save_path, dpi=300, bbox_inches='tight')\n",
    "    print(f\"Saved summary: {save_path}\")\n",
    "    \n",
    "    return fig\n",
    "\n",
    "# Main execution\n",
    "if __name__ == \"__main__\":\n",
    "    # Set your parameters here\n",
    "    save_dir = \"fv_accuracies\"  # Update this path\n",
    "    output_dir = f\"{save_dir}/figures\"  # Directory for output figures\n",
    "    model_size = \"1b-threshold-0.3\"\n",
    "    step = \"19900\"\n",
    "    \n",
    "    # Load data\n",
    "    print(\"Loading data...\")\n",
    "    clean_data_extractive = load_results(\n",
    "        os.path.join(save_dir, f\"extractive_results_{model_types[0]}_{step}_{n_shots}.json\")\n",
    "    )\n",
    "    masked_data_extractive = load_results(\n",
    "        os.path.join(save_dir, f\"extractive_results_{model_types[1]}_{step}_{n_shots}.json\")\n",
    "    )\n",
    "    \n",
    "    clean_data_abstractive = load_results(\n",
    "        os.path.join(save_dir, f\"abstractive_results_{model_types[0]}_{step}_{n_shots}.json\")\n",
    "    )\n",
    "    masked_data_abstractive = load_results(\n",
    "        os.path.join(save_dir, f\"abstractive_results_{model_types[1]}_{step}_{n_shots}.json\")\n",
    "    )\n",
    "    \n",
    "    # Create individual plots for each extractive category\n",
    "    print(\"\\n\" + \"=\"*50)\n",
    "    print(\"Creating individual category plots...\")\n",
    "    print(\"=\"*50)\n",
    "    extractive_figures = plot_all_extractive_categories(\n",
    "        clean_data_extractive, \n",
    "        masked_data_extractive, \n",
    "        output_dir=output_dir\n",
    "    )\n",
    "    \n",
    "    # Create abstractive plot\n",
    "    print(\"\\n\" + \"=\"*50)\n",
    "    print(\"Creating abstractive plot...\")\n",
    "    print(\"=\"*50)\n",
    "    abstractive_fig, _ = plot_abstractive_results(\n",
    "        clean_data_abstractive, \n",
    "        masked_data_abstractive,\n",
    "        save_path=os.path.join(output_dir, 'abstractive_comparison.pdf')\n",
    "    )\n",
    "    \n",
    "    # Create summary figure (optional - shows all categories in one figure)\n",
    "    print(\"\\n\" + \"=\"*50)\n",
    "    print(\"Creating summary figure...\")\n",
    "    print(\"=\"*50)\n",
    "    summary_fig = create_summary_figure(\n",
    "        clean_data_extractive, masked_data_extractive,\n",
    "        clean_data_abstractive, masked_data_abstractive,\n",
    "        save_path=os.path.join(output_dir, 'all_tasks_summary.pdf')\n",
    "    )\n",
    "    \n",
    "    # Print summary statistics\n",
    "    print(\"\\n\" + \"=\"*50)\n",
    "    print(\"SUMMARY STATISTICS\")\n",
    "    print(\"=\"*50)\n",
    "    \n",
    "    print(\"\\nExtractive Tasks by Category:\")\n",
    "    print(\"-\"*40)\n",
    "    groups = group_extractive_tasks(clean_data_extractive)\n",
    "    for group_name, group_info in groups.items():\n",
    "        tasks = [t for t in group_info['tasks'] if t in clean_data_extractive]\n",
    "        if tasks:\n",
    "            clean_accs = [clean_data_extractive[t]['accuracy'] for t in tasks]\n",
    "            masked_accs = [masked_data_extractive[t]['accuracy'] for t in tasks]\n",
    "            clean_mean = np.mean(clean_accs)\n",
    "            masked_mean = np.mean(masked_accs)\n",
    "            print(f\"\\n{group_name.replace(chr(10), ' ')} ({len(tasks)} tasks):\")\n",
    "            print(f\"  Clean:  {clean_mean:.3f} (±{np.std(clean_accs):.3f})\")\n",
    "            print(f\"  Masked: {masked_mean:.3f} (±{np.std(masked_accs):.3f})\")\n",
    "            print(f\"  Δ:      {masked_mean-clean_mean:+.3f}\")\n",
    "    \n",
    "    print(\"\\n\\nAbstractive Tasks:\")\n",
    "    print(\"-\"*40)\n",
    "    for task in clean_data_abstractive.keys():\n",
    "        clean_acc = clean_data_abstractive[task]['accuracy']\n",
    "        masked_acc = masked_data_abstractive[task]['accuracy']\n",
    "        print(f\"\\n{task}:\")\n",
    "        print(f\"  Clean:  {clean_acc:.3f}\")\n",
    "        print(f\"  Masked: {masked_acc:.3f}\")\n",
    "        print(f\"  Δ:      {masked_acc-clean_acc:+.3f}\")\n",
    "    \n",
    "    print(\"\\n\" + \"=\"*50)\n",
    "    print(\"All plots created successfully!\")\n",
    "    print(\"=\"*50)\n",
    "    print(f\"\\nOutput directory: {output_dir}/\")\n",
    "    print(\"\\nGenerated files:\")\n",
    "    print(\"  - extractive_positional_selection.pdf\")\n",
    "    print(\"  - extractive_alphabetical_ordering.pdf\")\n",
    "    print(\"  - extractive_named_entity_recognition.pdf\")\n",
    "    print(\"  - extractive_semantic_discrimination.pdf\")\n",
    "    print(\"  - extractive_reading_comprehension.pdf\")\n",
    "    print(\"  - abstractive_comparison.pdf\")\n",
    "    print(\"  - all_tasks_summary.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43c13a6c",
   "metadata": {},
   "source": [
    "## CREATE LATEX TABLE FROM RESULTS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4345f46a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import sqrt, erf\n",
    "\n",
    "def two_prop_z_test(p1, n1, p2, n2):\n",
    "    \"\"\"\n",
    "    Two-sided z-test for difference in proportions (independent samples).\n",
    "    \n",
    "    Args:\n",
    "        p1 (float): proportion (e.g., accuracy) for sample 1\n",
    "        n1 (int): sample size for sample 1\n",
    "        p2 (float): proportion for sample 2\n",
    "        n2 (int): sample size for sample 2\n",
    "    \n",
    "    Returns:\n",
    "        z (float): z-statistic\n",
    "        pval (float): two-sided p-value\n",
    "    \"\"\"\n",
    "    x1, x2 = p1 * n1, p2 * n2\n",
    "    p_pool = (x1 + x2) / (n1 + n2)\n",
    "    se = sqrt(p_pool * (1 - p_pool) * (1/n1 + 1/n2))\n",
    "    if se == 0:\n",
    "        return 0.0, 1.0\n",
    "    z = (p1 - p2) / se\n",
    "\n",
    "    def cdf_norm(x):\n",
    "        return 0.5 * (1 + erf(x / sqrt(2)))\n",
    "\n",
    "    pval = 2 * (1 - cdf_norm(abs(z)))\n",
    "    return z, pval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27b0b6bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_dir = f\"~pythia_replicate/metrics/non_wandb_metrics/fv_accuracies\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17e8b49c",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f\"{save_dir}/abstractive_results_{model_types[0]}_{model_size}_{step}_{n_shots}.json\", \"r\") as f:\n",
    "    clean_res_abstractive = json.load(f)\n",
    "\n",
    "with open(f\"{save_dir}/abstractive_results_{model_types[1]}_{model_size}_{step}_{n_shots}.json\", \"r\") as f:\n",
    "    masked_res_abstractive = json.load(f)\n",
    "\n",
    "with open(f\"{save_dir}/extractive_results_{model_types[0]}_{model_size}_{step}_{n_shots}.json\", \"r\") as f:\n",
    "    clean_res_extractive = json.load(f)\n",
    "\n",
    "with open(f\"{save_dir}/extractive_results_{model_types[1]}_{model_size}_{step}_{n_shots}.json\", \"r\") as f:\n",
    "    masked_res_extractive = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c919e1a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "from scipy import stats\n",
    "\n",
    "def load_results(filepath):\n",
    "    \"\"\"Load JSON results from file\"\"\"\n",
    "    with open(filepath, 'r') as f:\n",
    "        return json.load(f)\n",
    "\n",
    "def proper_significance_test(p1, n1, p2, n2, alpha=0.05):\n",
    "    \"\"\"\n",
    "    Perform a two-proportion z-test to check statistical significance.\n",
    "    Returns True if significantly different, False otherwise.\n",
    "    \"\"\"\n",
    "    # Calculate pooled proportion\n",
    "    pooled_p = (p1 * n1 + p2 * n2) / (n1 + n2)\n",
    "    \n",
    "    # Calculate standard error\n",
    "    se = np.sqrt(pooled_p * (1 - pooled_p) * (1/n1 + 1/n2))\n",
    "    \n",
    "    # Avoid division by zero\n",
    "    if se == 0:\n",
    "        return False\n",
    "    \n",
    "    # Calculate z-statistic\n",
    "    z = (p1 - p2) / se\n",
    "    \n",
    "    # Two-tailed test\n",
    "    p_value = 2 * (1 - stats.norm.cdf(abs(z)))\n",
    "    \n",
    "    # Return True if significant\n",
    "    return p_value < alpha\n",
    "\n",
    "def format_task_name(task):\n",
    "    \"\"\"Convert task key to readable name\"\"\"\n",
    "    task_names = {\n",
    "        'capitalize_first_letter': 'Capitalize First Letter',\n",
    "        'capitalize_second_letter': 'Capitalize Second Letter',\n",
    "        'capitalize_last_letter': 'Capitalize Last Letter',\n",
    "        'capitalize': 'Capitalize (Full Word)',\n",
    "        'lowercase_first_letter': 'Lowercase First Letter',\n",
    "        'lowercase_last_letter': 'Lowercase Last Letter',\n",
    "        'singular-plural': 'Singular-Plural',\n",
    "        'present-past': 'Present-Past',\n",
    "        'synonym': 'Synonym',\n",
    "        'antonym': 'Antonym',\n",
    "        'country-capital': 'Country-Capital',\n",
    "        'country-currency': 'Country-Currency',\n",
    "        'landmark-country': 'Landmark-Country',\n",
    "        'product-company': 'Product-Company',\n",
    "        'park-country': 'Park-Country',\n",
    "        'national_parks': 'National Parks',\n",
    "        'person-occupation': 'Person-Occupation',\n",
    "        'person-instrument': 'Person-Instrument',\n",
    "        'person-sport': 'Person-Sport',\n",
    "        'next_item': 'Next Item',\n",
    "        'prev_item': 'Previous Item',\n",
    "        'next_capital_letter': 'Next Capital Letter',\n",
    "        'word_length': 'Word Length',\n",
    "        'ag_news': 'AG News',\n",
    "        'commonsense_qa': 'CommonsenseQA',\n",
    "        'sentiment': 'Sentiment'\n",
    "    }\n",
    "    return task_names.get(task, task.replace('_', ' ').title())\n",
    "\n",
    "def create_abstractive_latex_table(clean_data, masked_data, output_file='fv_accuracies/abstractive_table.tex'):\n",
    "    \"\"\"Create LaTeX table for abstractive tasks with proper formatting\"\"\"\n",
    "    \n",
    "    lines = []\n",
    "    \n",
    "    # Table header\n",
    "    lines.append(r\"\\begin{table}[t]\")\n",
    "    lines.append(r\"\\centering\")\n",
    "    lines.append(r\"\\caption{Performance comparison on abstractive in-context learning tasks. We evaluate models with clean and masked induction heads across 26 tasks. Values show accuracy ± 95\\% CI margin. Bold indicates superior performance; $\\dagger$ denotes statistical significance ($p < 0.05$).}\")\n",
    "    lines.append(r\"\\label{tab:abstractive_results}\")\n",
    "    lines.append(r\"\\vspace{0.5em}\")\n",
    "    lines.append(r\"\\resizebox{0.5\\columnwidth}{!}{%\")\n",
    "    lines.append(r\"\\begin{tabular}{lcc}\")\n",
    "    lines.append(r\"\\toprule\")\n",
    "    lines.append(r\"\\textbf{Task} & \\textbf{Clean (\\%)} & \\textbf{Masked (\\%)} \\\\\")\n",
    "    lines.append(r\"\\midrule\")\n",
    "    \n",
    "    # Process each task\n",
    "    for task in clean_data.keys():\n",
    "        # Get values\n",
    "        clean_acc = clean_data[task]['accuracy'] * 100\n",
    "        clean_ci_lower = clean_data[task]['ci_lower'] * 100\n",
    "        clean_ci_upper = clean_data[task]['ci_upper'] * 100\n",
    "        clean_margin = (clean_ci_upper - clean_ci_lower) / 2\n",
    "        \n",
    "        masked_acc = masked_data[task]['accuracy'] * 100\n",
    "        masked_ci_lower = masked_data[task]['ci_lower'] * 100\n",
    "        masked_ci_upper = masked_data[task]['ci_upper'] * 100\n",
    "        masked_margin = (masked_ci_upper - masked_ci_lower) / 2\n",
    "        \n",
    "        # Check statistical significance\n",
    "        is_significant = False\n",
    "        if 'n_samples' in clean_data[task] and 'n_samples' in masked_data[task]:\n",
    "            is_significant = proper_significance_test(\n",
    "                clean_data[task]['accuracy'], clean_data[task]['n_samples'],\n",
    "                masked_data[task]['accuracy'], masked_data[task]['n_samples']\n",
    "            )\n",
    "        \n",
    "        # Format task name\n",
    "        task_name = format_task_name(task)\n",
    "        \n",
    "        # Format values with ± notation\n",
    "        clean_str = f\"{clean_acc:.1f} ± {clean_margin:.1f}\"\n",
    "        masked_str = f\"{masked_acc:.1f} ± {masked_margin:.1f}\"\n",
    "        \n",
    "        # Add bold and significance markers\n",
    "        if is_significant:\n",
    "            if clean_acc > masked_acc:\n",
    "                clean_str = r\"\\textbf{\" + clean_str + r\"}$^\\dagger$\"\n",
    "            else:\n",
    "                masked_str = r\"\\textbf{\" + masked_str + r\"}$^\\dagger$\"\n",
    "        elif clean_acc > masked_acc:\n",
    "            # Bold for better performance even if not significant\n",
    "            clean_str = r\"\\textbf{\" + clean_str + \"}\"\n",
    "        elif masked_acc > clean_acc:\n",
    "            masked_str = r\"\\textbf{\" + masked_str + \"}\"\n",
    "        \n",
    "        # Add row\n",
    "        lines.append(f\"{task_name} & {clean_str} & {masked_str} \\\\\\\\\")\n",
    "    \n",
    "    # Table footer\n",
    "    lines.append(r\"\\bottomrule\")\n",
    "    lines.append(r\"\\end{tabular}%\")\n",
    "    lines.append(r\"}\")\n",
    "    lines.append(r\"\\vspace{-0.5em}\")\n",
    "    lines.append(r\"\\end{table}\")\n",
    "    \n",
    "    # Write to file\n",
    "    latex_content = '\\n'.join(lines)\n",
    "    with open(output_file, 'w') as f:\n",
    "        f.write(latex_content)\n",
    "    \n",
    "    print(f\"LaTeX table written to {output_file}\")\n",
    "    return latex_content\n",
    "\n",
    "# Create the LaTeX table\n",
    "create_abstractive_latex_table(clean_res_abstractive, masked_res_abstractive, \"fv_accuracies/abstractive_table.tex\")\n",
    "\n",
    "# Optional: Also save as .txt\n",
    "#create_abstractive_latex_table(clean_res_abstractive, masked_res_abstractive, \"abstractive_table.txt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbbe26e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "from scipy import stats\n",
    "\n",
    "def load_results(filepath):\n",
    "    \"\"\"Load JSON results from file\"\"\"\n",
    "    with open(filepath, 'r') as f:\n",
    "        return json.load(f)\n",
    "\n",
    "def proper_significance_test(p1, n1, p2, n2, alpha=0.05):\n",
    "    \"\"\"\n",
    "    Perform a two-proportion z-test to check statistical significance.\n",
    "    Returns True if significantly different, False otherwise.\n",
    "    \"\"\"\n",
    "    # Calculate pooled proportion\n",
    "    pooled_p = (p1 * n1 + p2 * n2) / (n1 + n2)\n",
    "    \n",
    "    # Calculate standard error\n",
    "    se = np.sqrt(pooled_p * (1 - pooled_p) * (1/n1 + 1/n2))\n",
    "    \n",
    "    # Avoid division by zero\n",
    "    if se == 0:\n",
    "        return False\n",
    "    \n",
    "    # Calculate z-statistic\n",
    "    z = (p1 - p2) / se\n",
    "    \n",
    "    # Two-tailed test\n",
    "    p_value = 2 * (1 - stats.norm.cdf(abs(z)))\n",
    "    \n",
    "    # Return True if significant\n",
    "    return p_value < alpha\n",
    "\n",
    "def format_task_name(task):\n",
    "    \"\"\"Convert task key to readable name\"\"\"\n",
    "    task_names = {\n",
    "        # Positional selection\n",
    "        'choose_first_of_3': 'Choose First of 3',\n",
    "        'choose_first_of_5': 'Choose First of 5',\n",
    "        'choose_middle_of_3': 'Choose Middle of 3',\n",
    "        'choose_middle_of_5': 'Choose Middle of 5',\n",
    "        'choose_last_of_3': 'Choose Last of 3',\n",
    "        'choose_last_of_5': 'Choose Last of 5',\n",
    "        # Alphabetical\n",
    "        'alphabetically_first_3': 'Alphabetically First (3)',\n",
    "        'alphabetically_first_5': 'Alphabetically First (5)',\n",
    "        'alphabetically_last_3': 'Alphabetically Last (3)',\n",
    "        'alphabetically_last_5': 'Alphabetically Last (5)',\n",
    "        # Named Entity Recognition\n",
    "        'conll2003_person': 'CoNLL: Person',\n",
    "        'conll2003_location': 'CoNLL: Location',\n",
    "        'conll2003_organization': 'CoNLL: Organization',\n",
    "        # Odd one out\n",
    "        'animal_v_object_3': 'Animal vs Object (3)',\n",
    "        'animal_v_object_5': 'Animal vs Object (5)',\n",
    "        'fruit_v_animal_3': 'Fruit vs Animal (3)',\n",
    "        'fruit_v_animal_5': 'Fruit vs Animal (5)',\n",
    "        'color_v_animal_3': 'Color vs Animal (3)',\n",
    "        'color_v_animal_5': 'Color vs Animal (5)',\n",
    "        'adjective_v_verb_3': 'Adjective vs Verb (3)',\n",
    "        'adjective_v_verb_5': 'Adjective vs Verb (5)',\n",
    "        'verb_v_adjective_3': 'Verb vs Adjective (3)',\n",
    "        'verb_v_adjective_5': 'Verb vs Adjective (5)',\n",
    "        'concept_v_object_3': 'Concept vs Object (3)',\n",
    "        'concept_v_object_5': 'Concept vs Object (5)',\n",
    "        'object_v_concept_3': 'Object vs Concept (3)',\n",
    "        'object_v_concept_5': 'Object vs Concept (5)',\n",
    "        # Reading comprehension\n",
    "        'squad_val': 'SQuAD Validation'\n",
    "    }\n",
    "    return task_names.get(task, task.replace('_', ' ').title())\n",
    "\n",
    "def create_extractive_latex_table(clean_data, masked_data, output_file='extractive_table.tex'):\n",
    "    \"\"\"Create LaTeX table for extractive tasks with proper formatting\"\"\"\n",
    "    \n",
    "    lines = []\n",
    "    \n",
    "    # Table header\n",
    "    lines.append(r\"\\begin{table}[t]\")\n",
    "    lines.append(r\"\\centering\")\n",
    "    lines.append(r\"\\caption{Performance comparison on extractive in-context learning tasks. We evaluate models with clean and masked induction heads. Values show accuracy ± 95\\% CI margin. Bold indicates superior performance; $\\dagger$ denotes statistical significance ($p < 0.05$).}\")\n",
    "    lines.append(r\"\\label{tab:extractive_results}\")\n",
    "    lines.append(r\"\\vspace{0.5em}\")\n",
    "    lines.append(r\"\\resizebox{\\columnwidth}{!}{%\")\n",
    "    lines.append(r\"\\begin{tabular}{lcc}\")\n",
    "    lines.append(r\"\\toprule\")\n",
    "    lines.append(r\"\\textbf{Task} & \\textbf{Clean (\\%)} & \\textbf{Masked (\\%)} \\\\\")\n",
    "    lines.append(r\"\\midrule\")\n",
    "    \n",
    "    # Define the order of tasks (you can customize this order)\n",
    "    task_order = [\n",
    "        # Positional selection\n",
    "        'choose_first_of_3', 'choose_first_of_5',\n",
    "        'choose_middle_of_3', 'choose_middle_of_5',\n",
    "        'choose_last_of_3', 'choose_last_of_5',\n",
    "        # Alphabetical\n",
    "        'alphabetically_first_3', 'alphabetically_first_5',\n",
    "        'alphabetically_last_3', 'alphabetically_last_5',\n",
    "        # Named Entity Recognition\n",
    "        'conll2003_person', 'conll2003_location', 'conll2003_organization',\n",
    "        # Odd one out\n",
    "        'animal_v_object_3', 'animal_v_object_5',\n",
    "        'fruit_v_animal_3', 'fruit_v_animal_5',\n",
    "        'color_v_animal_3', 'color_v_animal_5',\n",
    "        'adjective_v_verb_3', 'adjective_v_verb_5',\n",
    "        'verb_v_adjective_3', 'verb_v_adjective_5',\n",
    "        'concept_v_object_3', 'concept_v_object_5',\n",
    "        'object_v_concept_3', 'object_v_concept_5',\n",
    "        # Reading comprehension\n",
    "        'squad_val'\n",
    "    ]\n",
    "    \n",
    "    # Process each task in order\n",
    "    for task in task_order:\n",
    "        if task not in clean_data or task not in masked_data:\n",
    "            continue\n",
    "            \n",
    "        # Get values\n",
    "        clean_acc = clean_data[task]['accuracy'] * 100\n",
    "        clean_ci_lower = clean_data[task]['ci_lower'] * 100\n",
    "        clean_ci_upper = clean_data[task]['ci_upper'] * 100\n",
    "        clean_margin = (clean_ci_upper - clean_ci_lower) / 2\n",
    "        \n",
    "        masked_acc = masked_data[task]['accuracy'] * 100\n",
    "        masked_ci_lower = masked_data[task]['ci_lower'] * 100\n",
    "        masked_ci_upper = masked_data[task]['ci_upper'] * 100\n",
    "        masked_margin = (masked_ci_upper - masked_ci_lower) / 2\n",
    "        \n",
    "        # Check statistical significance\n",
    "        is_significant = False\n",
    "        if 'n_samples' in clean_data[task] and 'n_samples' in masked_data[task]:\n",
    "            is_significant = proper_significance_test(\n",
    "                clean_data[task]['accuracy'], clean_data[task]['n_samples'],\n",
    "                masked_data[task]['accuracy'], masked_data[task]['n_samples']\n",
    "            )\n",
    "        \n",
    "        # Format task name\n",
    "        task_name = format_task_name(task)\n",
    "        \n",
    "        # Format values with ± notation\n",
    "        clean_str = f\"{clean_acc:.1f} ± {clean_margin:.1f}\"\n",
    "        masked_str = f\"{masked_acc:.1f} ± {masked_margin:.1f}\"\n",
    "        \n",
    "        # Add bold and significance markers\n",
    "        if is_significant:\n",
    "            if clean_acc > masked_acc:\n",
    "                clean_str = r\"\\textbf{\" + clean_str + r\"}$^\\dagger$\"\n",
    "            else:\n",
    "                masked_str = r\"\\textbf{\" + masked_str + r\"}$^\\dagger$\"\n",
    "        elif clean_acc > masked_acc:\n",
    "            # Bold for better performance even if not significant\n",
    "            clean_str = r\"\\textbf{\" + clean_str + \"}\"\n",
    "        elif masked_acc > clean_acc:\n",
    "            masked_str = r\"\\textbf{\" + masked_str + \"}\"\n",
    "        \n",
    "        # Add row\n",
    "        lines.append(f\"{task_name} & {clean_str} & {masked_str} \\\\\\\\\")\n",
    "    \n",
    "    # Table footer\n",
    "    lines.append(r\"\\bottomrule\")\n",
    "    lines.append(r\"\\end{tabular}%\")\n",
    "    lines.append(r\"}\")\n",
    "    lines.append(r\"\\vspace{-0.5em}\")\n",
    "    lines.append(r\"\\end{table}\")\n",
    "    \n",
    "    # Write to file\n",
    "    latex_content = '\\n'.join(lines)\n",
    "    with open(output_file, 'w') as f:\n",
    "        f.write(latex_content)\n",
    "    \n",
    "    print(f\"LaTeX table written to {output_file}\")\n",
    "    print(f\"Total tasks processed: {len([t for t in task_order if t in clean_data and t in masked_data])}\")\n",
    "    return latex_content\n",
    "\n",
    "# Main execution\n",
    "if __name__ == \"__main__\":\n",
    "    \n",
    "    # Create the LaTeX table\n",
    "    create_extractive_latex_table(clean_res_extractive, masked_res_extractive, \"fv_accuracies/extractive_table.tex\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "526ec896",
   "metadata": {},
   "source": [
    "# EXECUTE TASK (SINGLE) OVER TIME"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "feedf947",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import gc\n",
    "\n",
    "def fv_tasks_over_time_benchmark(\n",
    "    model_name,\n",
    "    first_step,\n",
    "    last_step,\n",
    "    task_name,\n",
    "    num_of_shots=5,\n",
    "    max_sample_size=5000,\n",
    "    batch_size=64,\n",
    "):\n",
    "    icl_task_results = {}\n",
    "    for step in range(first_step, last_step + 100, 100):\n",
    "        model_path = f\"~pythia_replicate/hf_output/{model_name}/step={step}\"\n",
    "        model, tokenizer = load_model_and_tokenizer(model_path)\n",
    "        icl_results = fv_icl_tasks_benchmark_with_ci(\n",
    "            model,\n",
    "            tokenizer,\n",
    "            task_name=task_name,\n",
    "            num_of_shots=num_of_shots,\n",
    "            max_samples=max_sample_size,\n",
    "            batch_size=batch_size,\n",
    "            use_wandb=False,\n",
    "\n",
    "        )\n",
    "        icl_task_results[step] = icl_results\n",
    "        del model, tokenizer\n",
    "        gc.collect()\n",
    "        torch.cuda.empty_cache()\n",
    "    return icl_task_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efba12fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "task_name = \"abstractive/national_parks\"\n",
    "clean_res = fv_tasks_over_time_benchmark(\n",
    "    model_name=model_types[0],\n",
    "    first_step=first_step,\n",
    "    last_step=step,\n",
    "    task_name=task_name,\n",
    "    num_of_shots=n_shots,\n",
    "    max_sample_size=max_sample_size,\n",
    "    batch_size=batch_size,\n",
    ")\n",
    "\n",
    "masked_res = fv_tasks_over_time_benchmark(\n",
    "    model_name=model_types[1],\n",
    "    first_step=first_step,\n",
    "    last_step=step,\n",
    "    task_name=task_name,\n",
    "    num_of_shots=n_shots,\n",
    "    max_sample_size=max_sample_size,\n",
    "    batch_size=batch_size,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc2a0ab5",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_dir = \"fv_accuracies\"  \n",
    "task_name = task_name.split(\"/\")[-1]\n",
    "with open(f\"{save_dir}/clean_{task_name}_{model_size}.json\", \"w\") as f:\n",
    "    json.dump(clean_res, f)\n",
    "\n",
    "with open(f\"{save_dir}/masked_{task_name}_{model_size}.json\", \"w\") as f:\n",
    "    json.dump(masked_res, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pythia_replicate",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
