{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c2c66890",
   "metadata": {},
   "source": [
    "# IMPORTS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff993286",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41e2bc51",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "from pathlib import Path\n",
    "import sys\n",
    "import json\n",
    "from transformers import AutoTokenizer\n",
    "from collections import defaultdict\n",
    "import os\n",
    "\n",
    "sys.path.append(\"~pythia_replicate\")\n",
    "\n",
    "from lib.model_setup import load_model_and_tokenizer\n",
    "from lib.translation import BilingualFewShotDataset\n",
    "from lib.translation import evaluate_translation_accuracy_with_ci"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9224e209",
   "metadata": {},
   "outputs": [],
   "source": [
    "#variables\n",
    "skip_clean = True\n",
    "model_size = \"1b_thresh0.3\"\n",
    "first_step = 100\n",
    "step = 19900\n",
    "non_local = False\n",
    "lang1 = None\n",
    "lang2 = \"eng\"\n",
    "max_new_tokens = 8\n",
    "n_shots = 5\n",
    "random_pairs = False\n",
    "debug = True\n",
    "lang_labels = False\n",
    "batch_size = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f7fe0e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "if model_size == \"160m\":\n",
    "    model_types = [\"clean_v3\", \"masked_bigram_loss_v4\"]\n",
    "elif model_size == \"1b\":\n",
    "    model_types = [\"clean_1b\", \"masked_bigram_loss_1b\"]\n",
    "elif model_size == \"1b_thresh0.3\":\n",
    "    model_types = [\"clean_1b\", \"masked_bigram_loss_1b_thresh0.3_eq\", \"masked_bigram_loss_1b\"]\n",
    "\n",
    "clean_model_path = f\"~pythia_replicate/hf_output/{model_types[0]}/step={step}\"\n",
    "masked_model_path = f\"~pythia_replicate/hf_output/{model_types[1]}/step={step}\"\n",
    "clean_model, tokenizer_clean = load_model_and_tokenizer(clean_model_path)\n",
    "masked_model, tokenizer_masked = load_model_and_tokenizer(masked_model_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c063aa3",
   "metadata": {},
   "source": [
    "# LAST CHECKPOINT PERFORMANCE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ebc3512",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_translation_metrics(\n",
    "    model,\n",
    "    tokenizer,\n",
    "    lang1,\n",
    "    lang2,\n",
    "    batch_size=32,\n",
    "    max_new_tokens=50,\n",
    "    n_shots=10,\n",
    "    step=None,\n",
    "    random_pairs=False,\n",
    "    confidence_level=0.95,\n",
    "    debug=False,\n",
    "):\n",
    "    if lang1 is None:\n",
    "        #select languages that are over 1% https://w3techs.com/technologies/overview/content_language\n",
    "        lang1 = [\"spa\", \"jpn\", \"fra\", \"por\", \"ita\", \"pol\", \"cmn\", \"ind\"]\n",
    "    results = {}\n",
    "    \n",
    "    for lang in lang1:\n",
    "        \n",
    "        dataset = BilingualFewShotDataset(\n",
    "            Path(\"~pythia_replicate/dataset/parallel_concepts.csv\"),\n",
    "            lang,\n",
    "            lang2,\n",
    "            n_shots,\n",
    "            random_pairs,\n",
    "            lang_labels,\n",
    "        )\n",
    "        if debug:\n",
    "            print(dataset.prompts)\n",
    "            print(dataset.targets)\n",
    "\n",
    "        device = next(model.parameters()).device\n",
    "        metrics = evaluate_translation_accuracy_with_ci(\n",
    "            model, tokenizer, dataset, device, batch_size, max_new_tokens, \n",
    "            random_pairs, confidence_level\n",
    "        )\n",
    "        \n",
    "        results[f\"{lang}_{lang2}\"] = metrics\n",
    "        \n",
    "        if debug:\n",
    "            print(f\"Accuracy: {metrics['accuracy']:.3f} \"\n",
    "                  f\"(95% CI: [{metrics['ci_lower']:.3f}, {metrics['ci_upper']:.3f}])\")\n",
    "            print(\"-\" * 50)\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57a3a086",
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy_masked = compute_translation_metrics(masked_model, tokenizer_masked, lang1, lang2, batch_size, max_new_tokens, n_shots, step, random_pairs, debug=debug)\n",
    "if not skip_clean:\n",
    "    accuracy_clean = compute_translation_metrics(clean_model, tokenizer_clean, lang1, lang2, batch_size, max_new_tokens, n_shots, step, random_pairs, debug=debug)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef51acfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "translation_path = \"~pythia_replicate/metrics/non_wandb_metrics/translation_accuracies\"\n",
    "os.makedirs(translation_path, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb96b854",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "if not skip_clean:\n",
    "    with open(f'{translation_path}/translation_{model_types[0]}_{step}_{n_shots}-shot.json', 'w') as f:\n",
    "        json.dump(dict(accuracy_clean), f)\n",
    "with open(f'{translation_path}/translation_{model_types[1]}_{step}_{n_shots}-shot.json', 'w') as f:\n",
    "    json.dump(dict(accuracy_masked), f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da1904a0",
   "metadata": {},
   "source": [
    "# SAME LANGUAGE TRANSLATIONS"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11b97e71",
   "metadata": {},
   "source": [
    "# PLOTTING"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a25d789",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_translation_comparison(clean_path, masked_path, masked_thresholded_path, save_path=None):\n",
    "    \"\"\"\n",
    "    Create ICLR-ready bar plot comparing clean vs masked translation accuracy\n",
    "    \n",
    "    Parameters:\n",
    "    -----------\n",
    "    clean_path : str or Path\n",
    "        Path to clean model JSON results\n",
    "    masked_path : str or Path\n",
    "        Path to masked model JSON results\n",
    "    save_path : str or Path, optional\n",
    "        Path to save the figure (without extension)\n",
    "    \"\"\"\n",
    "    \n",
    "    # Apply ICLR style\n",
    "    apply_iclr_style()\n",
    "    \n",
    "    # Load data\n",
    "    clean_data, masked_data, masked_thresholded_data = load_translation_results(clean_path, masked_path, masked_thresholded_path)\n",
    "    \n",
    "    # Prepare data for plotting\n",
    "    languages, clean_accs, masked_accs, masked_thresholded_accs, clean_errors, masked_errors, masked_thresholded_errors = prepare_plot_data(\n",
    "        clean_data, masked_data, masked_thresholded_data\n",
    "    )\n",
    "    \n",
    "    # Create figure - slightly wider to accommodate spacing\n",
    "    fig, ax = plt.subplots(figsize=(9, 4))  # Increased width from 7 to 9\n",
    "    \n",
    "    # Set up bar positions with gaps\n",
    "    x = np.arange(len(languages))\n",
    "    width = 0.20  # Bar width\n",
    "    gap = 0.02    # Gap between bars within a group\n",
    "    \n",
    "    # Calculate positions with proper spacing\n",
    "    # Center bar at x, others offset by width + gap\n",
    "    pos1 = x - (width + gap)\n",
    "    pos2 = x\n",
    "    pos3 = x + (width + gap)\n",
    "    \n",
    "    # Colorblind-friendly colors (consistent with previous ICLR plots)\n",
    "    color_clean = '#0173B2'  # Blue\n",
    "    color_masked = '#DE8F05'  # Orange\n",
    "    color_masked_thresholded = '#40B0A6'  # Teal\n",
    "    \n",
    "    # Create bars with error bars\n",
    "    bars1 = ax.bar(pos1, clean_accs, width, \n",
    "                   label='Clean', \n",
    "                   color=color_clean, \n",
    "                   edgecolor='black', \n",
    "                   linewidth=0.5,\n",
    "                   yerr=clean_errors if clean_errors[0] else None,\n",
    "                   capsize=3, \n",
    "                   error_kw={'linewidth': 0.8})\n",
    "    \n",
    "    bars2 = ax.bar(pos2, masked_accs, width, \n",
    "                   label='Hapax', \n",
    "                   color=color_masked, \n",
    "                   edgecolor='black', \n",
    "                   linewidth=0.5,\n",
    "                   yerr=masked_errors if masked_errors[0] else None,\n",
    "                   capsize=3, \n",
    "                   error_kw={'linewidth': 0.8})\n",
    "\n",
    "    bars3 = ax.bar(pos3, masked_thresholded_accs, width, \n",
    "                   label='Hapax Thresholded', \n",
    "                   color=color_masked_thresholded, \n",
    "                   edgecolor='black', \n",
    "                   linewidth=0.5,\n",
    "                   yerr=masked_thresholded_errors if masked_thresholded_errors[0] else None,\n",
    "                   capsize=3, \n",
    "                   error_kw={'linewidth': 0.8})\n",
    "\n",
    "    def add_value_labels(bars, values, errors, all_bars=None, all_values=None):\n",
    "        \"\"\"\n",
    "        Add value labels with smart positioning to avoid overlaps\n",
    "        \"\"\"\n",
    "        # Extract upper errors from the tuple\n",
    "        if errors and len(errors) == 2:\n",
    "            upper_errors = errors[1]\n",
    "        else:\n",
    "            upper_errors = [0] * len(values)\n",
    "        \n",
    "        for i, (bar, val, err_upper) in enumerate(zip(bars, values, upper_errors)):\n",
    "            height = bar.get_height()\n",
    "            # Base position above the error bar\n",
    "            y_position = height + err_upper + 0.5  # Increased padding\n",
    "            \n",
    "            # Check for potential overlaps with neighboring bars\n",
    "            if all_bars is not None and all_values is not None:\n",
    "                for other_bars, other_values in zip(all_bars, all_values):\n",
    "                    if other_bars is bars:\n",
    "                        continue\n",
    "                    \n",
    "                    # Check if neighboring bars have similar heights\n",
    "                    if i < len(other_values):\n",
    "                        other_height = other_values[i]\n",
    "                        # If heights are very similar, adjust position\n",
    "                        if abs(height - other_height) < 3:  # Within 3% difference\n",
    "                            # Stagger the middle bar higher\n",
    "                            if other_bars is bars2:  # If this is the middle bar\n",
    "                                y_position += 1.5\n",
    "            \n",
    "            ax.text(bar.get_x() + bar.get_width()/2., \n",
    "                    y_position,\n",
    "                    f'{val:.1f}',\n",
    "                    ha='center', va='bottom', fontsize=6)  # Slightly smaller font\n",
    "    \n",
    "    # Add labels with smart positioning\n",
    "    all_bars = [bars1, bars2, bars3]\n",
    "    all_values = [clean_accs, masked_accs, masked_thresholded_accs]\n",
    "    \n",
    "    add_value_labels(bars1, clean_accs, clean_errors, all_bars, all_values)\n",
    "    add_value_labels(bars2, masked_accs, masked_errors, all_bars, all_values)\n",
    "    add_value_labels(bars3, masked_thresholded_accs, masked_thresholded_errors, all_bars, all_values)\n",
    "    \n",
    "    # Customize axes\n",
    "    ax.set_xlabel('Target Language', fontsize=10)\n",
    "    ax.set_ylabel('Translation Accuracy (%)', fontsize=10)\n",
    "    ax.set_title('5-Shot Translation Performance: Clean vs Masked Models (Step 19900)', fontsize=11, pad=10)\n",
    "    ax.set_xticks(x)\n",
    "    ax.set_xticklabels(languages, rotation=45, ha='right')\n",
    "    \n",
    "    # Add subtle grid for better readability\n",
    "    ax.yaxis.grid(True, linestyle='--', alpha=0.2, linewidth=0.5)\n",
    "    ax.set_axisbelow(True)\n",
    "    \n",
    "    # Add legend\n",
    "    ax.legend(loc='best', fontsize=9)\n",
    "    \n",
    "    # Set y-axis to start from 0 for fair comparison\n",
    "    ax.set_ylim(bottom=0)\n",
    "    \n",
    "    # Calculate the maximum value including error bars and labels\n",
    "    max_with_error = 0\n",
    "    if clean_errors and clean_errors[0]:\n",
    "        max_clean_with_error = max([acc + err for acc, err in zip(clean_accs, clean_errors[1])])\n",
    "        max_with_error = max(max_with_error, max_clean_with_error)\n",
    "    if masked_errors and masked_errors[0]:\n",
    "        max_masked_with_error = max([acc + err for acc, err in zip(masked_accs, masked_errors[1])])\n",
    "        max_with_error = max(max_with_error, max_masked_with_error)\n",
    "    if masked_thresholded_errors and masked_thresholded_errors[0]:\n",
    "        max_masked_thresholded_with_error = max([acc + err for acc, err in zip(masked_thresholded_accs, masked_thresholded_errors[1])])\n",
    "        max_with_error = max(max_with_error, max_masked_thresholded_with_error)\n",
    "\n",
    "    # Set y-axis limit with padding to accommodate error bars and labels\n",
    "    if max_with_error > 0:\n",
    "        ax.set_ylim(top=max_with_error * 1.20)  # Increased padding to 20%\n",
    "    \n",
    "    # Adjust layout to prevent label cutoff\n",
    "    plt.tight_layout()\n",
    "    \n",
    "    # Save figure if path provided\n",
    "    if save_path:\n",
    "        save_path = Path(save_path)\n",
    "        # Save as both PDF (for paper) and PNG (for quick viewing)\n",
    "        fig.savefig(save_path.with_suffix('.pdf'), format='pdf', bbox_inches='tight')\n",
    "        fig.savefig(save_path.with_suffix('.png'), format='png', dpi=300, bbox_inches='tight')\n",
    "        print(f\"Saved figures to {save_path.with_suffix('.pdf')} and {save_path.with_suffix('.png')}\")\n",
    "    \n",
    "    plt.show()\n",
    "    \n",
    "    return fig, ax"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27361d5a",
   "metadata": {},
   "source": [
    "# PLOT ALL OF THEM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "469dedb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "from lib.plotting import apply_iclr_style\n",
    "\n",
    "def load_translation_results(clean_path, masked_path, masked_thresholded_path):\n",
    "    \"\"\"Load translation results from JSON files\"\"\"\n",
    "    with open(clean_path, 'r') as f:\n",
    "        clean_data = json.load(f)\n",
    "    with open(masked_path, 'r') as f:\n",
    "        masked_data = json.load(f)\n",
    "    with open(masked_thresholded_path, 'r') as f:\n",
    "        masked_thresholded_data = json.load(f)\n",
    "    return clean_data, masked_data, masked_thresholded_data\n",
    "\n",
    "def prepare_plot_data(clean_data, masked_data, masked_thresholded_data):\n",
    "    \"\"\"Prepare data for plotting\"\"\"\n",
    "    # Language name mapping for better display\n",
    "    lang_display = {\n",
    "        'spa': 'Spanish',\n",
    "        'fra': 'French',\n",
    "        'por': 'Portuguese',\n",
    "        'ita': 'Italian',\n",
    "        'jpn': 'Japanese',\n",
    "        'cmn': 'Chinese',\n",
    "        'arb': 'Arabic',\n",
    "        'pol': 'Polish',\n",
    "        'ind': 'Indonesian',\n",
    "        'swe': 'Swedish'\n",
    "    }\n",
    "    \n",
    "    # Extract language pairs (assuming format is \"lang_eng\")\n",
    "    lang_pairs = sorted(clean_data.keys())\n",
    "    \n",
    "    # Prepare arrays for plotting\n",
    "    languages = []\n",
    "    clean_accs = []\n",
    "    clean_ci_lower = []\n",
    "    clean_ci_upper = []\n",
    "    masked_accs = []\n",
    "    masked_ci_lower = []\n",
    "    masked_ci_upper = []\n",
    "    masked_thresholded_accs = []\n",
    "    masked_thresholded_ci_lower = []\n",
    "    masked_thresholded_ci_upper = []\n",
    "    \n",
    "    for pair in lang_pairs:\n",
    "        lang_code = pair.split('_')[0]\n",
    "\n",
    "        if lang_code == \"eng\":\n",
    "            continue\n",
    "        lang_name = lang_display.get(lang_code, lang_code.upper())\n",
    "        languages.append(lang_name)\n",
    "        \n",
    "        # Extract clean model data\n",
    "        clean_accs.append(clean_data[pair]['accuracy'] * 100)\n",
    "        clean_ci_lower.append(clean_data[pair]['ci_lower'] * 100)\n",
    "        clean_ci_upper.append(clean_data[pair]['ci_upper'] * 100)\n",
    "        \n",
    "        # Extract masked model data\n",
    "        masked_accs.append(masked_data[pair]['accuracy'] * 100)\n",
    "        masked_ci_lower.append(masked_data[pair]['ci_lower'] * 100)\n",
    "        masked_ci_upper.append(masked_data[pair]['ci_upper'] * 100)\n",
    "        \n",
    "        # Extract masked thresholded model data\n",
    "        masked_thresholded_accs.append(masked_thresholded_data[pair]['accuracy'] * 100)\n",
    "        masked_thresholded_ci_lower.append(masked_thresholded_data[pair]['ci_lower'] * 100)\n",
    "        masked_thresholded_ci_upper.append(masked_thresholded_data[pair]['ci_upper'] * 100)\n",
    "    \n",
    "    # Calculate error bar sizes (distance from mean to CI bounds)\n",
    "    clean_errors = [(acc - lower, upper - acc) \n",
    "                    for acc, lower, upper in zip(clean_accs, clean_ci_lower, clean_ci_upper)]\n",
    "    masked_errors = [(acc - lower, upper - acc) \n",
    "                     for acc, lower, upper in zip(masked_accs, masked_ci_lower, masked_ci_upper)]\n",
    "    masked_thresholded_errors = [(acc - lower, upper - acc) \n",
    "                                for acc, lower, upper in zip(masked_thresholded_accs, masked_thresholded_ci_lower, masked_thresholded_ci_upper)]\n",
    "    \n",
    "    # Separate into lower and upper errors for matplotlib\n",
    "    clean_errors = list(zip(*clean_errors)) if clean_errors else ([], [])\n",
    "    masked_errors = list(zip(*masked_errors)) if masked_errors else ([], [])\n",
    "    masked_thresholded_errors = list(zip(*masked_thresholded_errors)) if masked_thresholded_errors else ([], [])\n",
    "    \n",
    "    return languages, clean_accs, masked_accs, masked_thresholded_accs, clean_errors, masked_errors, masked_thresholded_errors\n",
    "\n",
    "def plot_translation_comparison(clean_path, masked_path, masked_thresholded_path, save_path=None):\n",
    "    \"\"\"\n",
    "    Create ICLR-ready bar plot comparing clean vs masked translation accuracy\n",
    "    \n",
    "    Parameters:\n",
    "    -----------\n",
    "    clean_path : str or Path\n",
    "        Path to clean model JSON results\n",
    "    masked_path : str or Path\n",
    "        Path to masked model JSON results\n",
    "    save_path : str or Path, optional\n",
    "        Path to save the figure (without extension)\n",
    "    \"\"\"\n",
    "    \n",
    "    # Apply ICLR style\n",
    "    apply_iclr_style()\n",
    "    \n",
    "    # Load data\n",
    "    clean_data, masked_data, masked_thresholded_data = load_translation_results(clean_path, masked_path, masked_thresholded_path)\n",
    "    \n",
    "    # Prepare data for plotting\n",
    "    languages, clean_accs, masked_accs, masked_thresholded_accs, clean_errors, masked_errors, masked_thresholded_errors = prepare_plot_data(\n",
    "        clean_data, masked_data, masked_thresholded_data\n",
    "    )\n",
    "    \n",
    "    # Create figure\n",
    "    fig, ax = plt.subplots(figsize=(7, 4))\n",
    "    \n",
    "    # Set up bar positions\n",
    "    x = np.arange(len(languages))\n",
    "    width = 0.25\n",
    "    gap = 0.00\n",
    "    \n",
    "    # Colorblind-friendly colors (consistent with previous ICLR plots)\n",
    "    color_clean = '#0173B2'  # Blue\n",
    "    color_masked = '#DE8F05'  # Orange\n",
    "    color_masked_thresholded = '#40B0A6'  # Teal\n",
    "    # Create bars with error bars\n",
    "    bars1 = ax.bar(x - (width + gap), clean_accs, width, \n",
    "                   label='Vanilla', \n",
    "                   color=color_clean, \n",
    "                   edgecolor='black', \n",
    "                   linewidth=0.5,\n",
    "                   yerr=clean_errors if clean_errors[0] else None,\n",
    "                   capsize=3, \n",
    "                   error_kw={'linewidth': 0.8})\n",
    "    \n",
    "    bars2 = ax.bar(x , masked_accs, width, \n",
    "                   label='Hapax', \n",
    "                   color=color_masked, \n",
    "                   edgecolor='black', \n",
    "                   linewidth=0.5,\n",
    "                   yerr=masked_errors if masked_errors[0] else None,\n",
    "                   capsize=3, \n",
    "                   error_kw={'linewidth': 0.8})\n",
    "\n",
    "    bars3 = ax.bar(x + (width + gap), masked_thresholded_accs, width, \n",
    "                   label='Hapax Thresholded', \n",
    "                   color=color_masked_thresholded, \n",
    "                   edgecolor='black', \n",
    "                   linewidth=0.5,\n",
    "                   yerr=masked_thresholded_errors if masked_thresholded_errors[0] else None,\n",
    "                   capsize=3, \n",
    "                   error_kw={'linewidth': 0.8})\n",
    "\n",
    "    \n",
    "    def add_value_labels(bars, values, errors):\n",
    "        # Extract upper errors from the tuple\n",
    "        if errors and len(errors) == 2:\n",
    "            upper_errors = errors[1]  # Get the upper error bounds\n",
    "        else:\n",
    "            upper_errors = [0] * len(values)  # Default to 0 if no errors\n",
    "        \n",
    "        for bar, val, err_upper in zip(bars, values, upper_errors):\n",
    "            height = bar.get_height()\n",
    "            # Position text above the error bar with proper padding\n",
    "            # Since values are in percentages (0-100), we need larger padding\n",
    "            y_position = height + err_upper + 0.1  # 0.5% padding for percentage scale\n",
    "            \n",
    "            # Always show the label\n",
    "            ax.text(bar.get_x() + bar.get_width()/2., \n",
    "                y_position,\n",
    "                f'{val:.1f}',  # Format to 1 decimal place\n",
    "                ha='center', va='bottom', fontsize=7)\n",
    "    \n",
    "    #add_value_labels(bars1, clean_accs, clean_errors)\n",
    "    #add_value_labels(bars2, masked_accs, masked_errors)\n",
    "    #add_value_labels(bars3, masked_thresholded_accs, masked_thresholded_errors)\n",
    "    # Customize axes\n",
    "    ax.set_xlabel('Source Language', fontsize=10)\n",
    "    ax.set_ylabel('Translation Accuracy (%)', fontsize=10)\n",
    "    ax.set_title('5-Shot Translation Performance', fontsize=11, pad=10)\n",
    "    ax.set_xticks(x)\n",
    "    ax.set_xticklabels(languages, rotation=45, ha='right')\n",
    "    \n",
    "    # Add subtle grid for better readability\n",
    "    ax.yaxis.grid(True, linestyle='--', alpha=0.2, linewidth=0.5)\n",
    "    ax.set_axisbelow(True)\n",
    "    \n",
    "    # Add legend\n",
    "    ax.legend(loc='best', fontsize=9)\n",
    "    \n",
    "    # Set y-axis to start from 0 for fair comparison\n",
    "    ax.set_ylim(bottom=0)\n",
    "    \n",
    "    # Calculate the maximum value including error bars\n",
    "    max_with_error = 0\n",
    "    if clean_errors and clean_errors[0]:\n",
    "        max_clean_with_error = max([acc + err for acc, err in zip(clean_accs, clean_errors[1])])\n",
    "        max_with_error = max(max_with_error, max_clean_with_error)\n",
    "    if masked_errors and masked_errors[0]:\n",
    "        max_masked_with_error = max([acc + err for acc, err in zip(masked_accs, masked_errors[1])])\n",
    "        max_with_error = max(max_with_error, max_masked_with_error)\n",
    "    if masked_thresholded_errors and masked_thresholded_errors[0]:\n",
    "        max_masked_thresholded_with_error = max([acc + err for acc, err in zip(masked_thresholded_accs, masked_thresholded_errors[1])])\n",
    "        max_with_error = max(max_with_error, max_masked_thresholded_with_error)\n",
    "\n",
    "    # Set y-axis limit with padding to accommodate error bars\n",
    "    if max_with_error > 0:\n",
    "        ax.set_ylim(top=max_with_error * 1.15)  # Add 15% padding above highest point\n",
    "    \n",
    "    # Adjust layout to prevent label cutoff\n",
    "    plt.tight_layout()\n",
    "    \n",
    "    # Save figure if path provided\n",
    "    if save_path:\n",
    "        save_path = Path(save_path)\n",
    "        # Save as both PDF (for paper) and PNG (for quick viewing)\n",
    "        fig.savefig(save_path.with_suffix('.pdf'), format='pdf', bbox_inches='tight')\n",
    "        fig.savefig(save_path.with_suffix('.png'), format='png', dpi=300, bbox_inches='tight')\n",
    "        print(f\"Saved figures to {save_path.with_suffix('.pdf')} and {save_path.with_suffix('.png')}\")\n",
    "    \n",
    "    plt.show()\n",
    "    \n",
    "    return fig, ax\n",
    "\n",
    "def print_statistics(clean_path, masked_path, masked_thresholded_path):\n",
    "    \"\"\"Print summary statistics for the translation results\"\"\"\n",
    "    clean_data, masked_data, masked_thresholded_data = load_translation_results(clean_path, masked_path, masked_thresholded_path)\n",
    "    \n",
    "    print(\"Translation Accuracy Summary (5-shot, Step 19900)\")\n",
    "    print(\"=\" * 60)\n",
    "    print(f\"{'Language':<12} {'Clean':<20} {'Masked':<20} {'Masked Thresholded':<20}\")\n",
    "    print(\"-\" * 60)\n",
    "    \n",
    "    for lang_pair in sorted(clean_data.keys()):\n",
    "        lang = lang_pair.split('_')[0].upper()\n",
    "        clean = clean_data[lang_pair]\n",
    "        masked = masked_data[lang_pair]\n",
    "        masked_thresholded = masked_thresholded_data[lang_pair]\n",
    "        clean_str = f\"{clean['accuracy']*100:.1f}% [{clean['ci_lower']*100:.1f}, {clean['ci_upper']*100:.1f}]\"\n",
    "        masked_str = f\"{masked['accuracy']*100:.1f}% [{masked['ci_lower']*100:.1f}, {masked['ci_upper']*100:.1f}]\"\n",
    "        masked_thresholded_str = f\"{masked_thresholded['accuracy']*100:.1f}% [{masked_thresholded['ci_lower']*100:.1f}, {masked_thresholded['ci_upper']*100:.1f}]\"\n",
    "        print(f\"{lang:<12} {clean_str:<20} {masked_str:<20} {masked_thresholded_str:<20}\")\n",
    "    \n",
    "    # Calculate average performance\n",
    "    clean_avg = np.mean([d['accuracy'] for d in clean_data.values()]) * 100\n",
    "    masked_avg = np.mean([d['accuracy'] for d in masked_data.values()]) * 100\n",
    "    masked_thresholded_avg = np.mean([d['accuracy'] for d in masked_thresholded_data.values()]) * 100\n",
    "    print(\"-\" * 60)\n",
    "    print(f\"{'Average':<12} {clean_avg:.1f}%{'':<14} {masked_avg:.1f}%{'':<14} {masked_thresholded_avg:.1f}%\")\n",
    "    print(\"=\" * 60)\n",
    "\n",
    "\n",
    "clean_path = f\"~pythia_replicate/metrics/non_wandb_metrics/translation_accuracies/translation_{model_types[0]}_{step}_{n_shots}-shot.json\"\n",
    "masked_path = f\"~pythia_replicate/metrics/non_wandb_metrics/translation_accuracies/translation_{model_types[2]}_{step}_{n_shots}-shot.json\"\n",
    "masked_thresholded_path = f\"~pythia_replicate/metrics/non_wandb_metrics/translation_accuracies/translation_{model_types[1]}_{step}_{n_shots}-shot.json\"\n",
    "\n",
    "model_types = [\"clean_1b\", \"masked_bigram_loss_1b_thresh03_eq\", \"masked_bigram_loss_1b\"]\n",
    "\n",
    "\n",
    "# Print statistics\n",
    "print_statistics(clean_path, masked_path, masked_thresholded_path)\n",
    "\n",
    "# Create the plot\n",
    "save_dir = \"~pythia_replicate/metrics/non_wandb_metrics/translation_accuracies/plots\"\n",
    "save_path = f\"{save_dir}/{model_types[0]}_vs_{model_types[1]}_vs_{model_types[2]}_{step}_{n_shots}-shot\"\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "fig, ax = plot_translation_comparison(clean_path, masked_path, masked_thresholded_path, save_path)\n",
    "\n",
    "print(f\"\\nPlot saved to: {save_path}.pdf and {save_path}.png\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pythia_replicate",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
