{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "notebook-overview",
   "metadata": {},
   "source": [
    "# ASIDE Results Analysis Notebook\n",
    "\n",
    "This notebook provides analysis tools for ASIDE (Architecturally Separated Instruction-Data Embeddings) experimental results. It processes various evaluation metrics including:\n",
    "\n",
    "- **SEP (Separation) Metrics**: Core ASIDE evaluation measuring instruction-data separation\n",
    "- **AlpacaEval Scores**: General capability assessment using AlpacaEval benchmarks\n",
    "- **Training Metrics**: Loss curves and training statistics\n",
    "- **Output Quality Analysis**: Detection of repetitive or problematic model outputs\n",
    "\n",
    "## Key Features\n",
    "\n",
    "1. **Model Name Parsing**: Standardized parsing of model names across different file formats\n",
    "2. **Data Integration**: Merging results from multiple evaluation sources\n",
    "3. **Performance Comparison**: Tools for comparing ASIDE vs baseline methods\n",
    "4. **Best Model Selection**: Automated selection of optimal models based on various metrics\n",
    "5. **Visualization**: Plotting utilities for learning rate analysis and performance trends\n",
    "\n",
    "## Usage\n",
    "\n",
    "This notebook is designed to work with the standard ASIDE evaluation pipeline outputs. Adjust file paths and model names as needed for your specific experiments.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "daf7eb1d-fc97-4997-a4c9-fe93021ce4b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from analyze_results import *\n",
    "import warnings\n",
    "# Suppress FutureWarning messages\n",
    "warnings.simplefilter(action='ignore', category=FutureWarning)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "run-number-extraction",
   "metadata": {},
   "source": [
    "## Run Number Extraction and SEP Metric Analysis\n",
    "\n",
    "These functions extract run numbers from model names and map them to SEP (separation) metrics. This is useful for analyzing how different training runs perform on the core ASIDE evaluation metric.\n",
    "\n",
    "**SEP Metric**: Measures how well a model separates instructions from data. Higher values indicate better separation (closer to ASIDE's goal)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bc2d5ede-24d6-48bd-8f99-d2f0d10193fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "def get_run_number_sep_metric_dict(df):\n",
    "    \"\"\"\n",
    "    Extract run numbers from model names and map them to SEP metrics.\n",
    "    \n",
    "    This function processes a DataFrame containing model evaluation results,\n",
    "    extracts run numbers from model names using regex patterns, and creates\n",
    "    arrays of SEP metrics and utility scores indexed by run number.\n",
    "    \n",
    "    Args:\n",
    "        df (pandas.DataFrame): DataFrame with columns:\n",
    "            - 'model': Model names containing run numbers (e.g., 'dd_pure_run3_lr6e-6')\n",
    "            - 'sep_metric': SEP separation scores (list format)\n",
    "            - 'probe_in_instruct_asr': Attack success rates (list format)\n",
    "    \n",
    "    Returns:\n",
    "        tuple: (sep_metrics, utils)\n",
    "            - sep_metrics (numpy.array): SEP metric values indexed by run number\n",
    "            - utils (numpy.array): Utility scores indexed by run number\n",
    "    \n",
    "    Example:\n",
    "        >>> df = pd.DataFrame({\n",
    "        ...     'model': ['dd_pure_run3_lr6e-6', 'dd_run5_lr1e-5'],\n",
    "        ...     'sep_metric': [[0.65, 0.02], [0.70, 0.03]],\n",
    "        ...     'probe_in_instruct_asr': [[0.85, 0.01], [0.90, 0.02]]\n",
    "        ... })\n",
    "        >>> sep_metrics, utils = get_run_number_sep_metric_dict(df)\n",
    "        >>> print(f\"Run 3 SEP: {sep_metrics[3]}, Utility: {utils[3]}\")\n",
    "    \n",
    "    Note:\n",
    "        - Searches for patterns 'dd_pure_run(\\d+)' first, then 'dd_run(\\d+)'\n",
    "        - Uses first element of metric arrays ([0] index)\n",
    "        - Pre-allocates arrays of size 10, then trims to actual count\n",
    "    \"\"\"\n",
    "    sep_metrics = np.zeros(10)\n",
    "    utils = np.zeros(10)\n",
    "    cnt = 0\n",
    "    for _, row in df.iterrows():\n",
    "        model_name = row['model']\n",
    "        match = re.search(r'dd_pure_run(\\d+)', model_name)\n",
    "        if match is None: \n",
    "            match = re.search(r'dd_run(\\d+)', model_name)\n",
    "        if match:\n",
    "            run_number = match.group(1)  # keep as string, or convert to int if you prefer\n",
    "            sep_metrics[int(run_number)] = row['sep_metric'][0]\n",
    "            utils[int(run_number)] = row['probe_in_instruct_asr'][0]\n",
    "            cnt += 1\n",
    "\n",
    "    sep_metrics = sep_metrics[:cnt]\n",
    "    utils = utils[:cnt] \n",
    "    return sep_metrics, utils"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "learning-rate-analysis",
   "metadata": {},
   "source": [
    "## Learning Rate Analysis and Visualization\n",
    "\n",
    "These functions analyze how different learning rates affect model performance, comparing ASIDE methods against baselines. This is crucial for hyperparameter optimization and understanding training dynamics.\n",
    "\n",
    "**Key Comparisons**:\n",
    "- `dd_pure`: ASIDE method performance\n",
    "- `pretrained_vanilla`: Baseline vanilla model performance\n",
    "- SEP metrics vs utility scores across learning rates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0ae1c58f-9666-46f7-ba40-ccda77addc7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "\n",
    "def create_lr_dict(df):\n",
    "    \"\"\"\n",
    "    Create a learning rate dictionary mapping LR values to performance metrics.\n",
    "    \n",
    "    This function processes experimental results to create a comprehensive mapping\n",
    "    of learning rates to performance metrics for both ASIDE and baseline methods.\n",
    "    It's essential for hyperparameter analysis and comparison studies.\n",
    "    \n",
    "    Args:\n",
    "        df (pandas.DataFrame): DataFrame containing model results with columns:\n",
    "            - 'model': Model names with embedded learning rates\n",
    "            - 'sep_metric': SEP separation scores [value, error]\n",
    "            - 'probe_in_instruct_asr': Attack success rates [value, error]\n",
    "    \n",
    "    Returns:\n",
    "        dict: Learning rate mapping with structure:\n",
    "            {\n",
    "                'lr_string': (\n",
    "                    dd_pure_sep,           # ASIDE separation score\n",
    "                    dd_pure_prompt_asr,    # ASIDE attack success rate  \n",
    "                    pretrained_sep,        # Baseline separation score\n",
    "                    pretrained_prompt_asr  # Baseline attack success rate\n",
    "                )\n",
    "            }\n",
    "    \n",
    "    Model Name Patterns:\n",
    "        - 'dd_pure_from_base_run1e-4_bs8': ASIDE method with LR 1e-4\n",
    "        - 'from_base_pretrained_vanilla_run1e-4_bs8': Baseline with LR 1e-4\n",
    "        \n",
    "    Example Usage:\n",
    "        >>> lr_dict = create_lr_dict(results_df)\n",
    "        >>> print(f\"LR 1e-4 ASIDE SEP: {lr_dict['1e-4'][0]}\")\n",
    "        >>> print(f\"LR 1e-4 Baseline SEP: {lr_dict['1e-4'][2]}\")\n",
    "    \n",
    "    Note:\n",
    "        - Extracts learning rates from model names using regex 'run([0-9e.-]+)_bs'\n",
    "        - Handles missing runs gracefully (sets values to None)\n",
    "        - Groups results by learning rate for direct comparison\n",
    "    \"\"\"\n",
    "    \n",
    "    lr_dict = {}\n",
    "    \n",
    "    # Helper function to extract the LR from run_name using a regex\n",
    "    # that matches something like \"run1e-4\" or \"run5e-5\" or \"run2e-5\" etc.\n",
    "    def extract_lr_from_name(name):\n",
    "        # A simple approach is to find the substring after 'run' up to '_bs'\n",
    "        # e.g. \"dd_pure_from_base_run1e-4_bs8\" -> \"1e-4\"\n",
    "        match = re.search(r'run([0-9e.-]+)_bs', name)\n",
    "        if match:\n",
    "            return match.group(1)  # e.g. \"1e-4\"\n",
    "        else:\n",
    "            return None\n",
    "    \n",
    "    # We will collect data for each LR in sub-dictionaries:\n",
    "    #   { '1e-4': {'dd_pure': (sep, prompt_asr), \n",
    "    #              'pretrained': (sep, prompt_asr)} }\n",
    "    # Then we will flatten to the final format.\n",
    "    temp_storage = {}\n",
    "    \n",
    "    for idx, row in df.iterrows():\n",
    "        run_name = row['model']  # adapt to your column name\n",
    "        lr_str = extract_lr_from_name(run_name)\n",
    "        if not lr_str:\n",
    "            # Skip 'original' or 'original_inst' or anything w/o LR\n",
    "            continue\n",
    "        \n",
    "        # parse the first metric (sep) and second metric (prompt_in_data_asr)\n",
    "        # Suppose your DataFrame has them in these columns:\n",
    "        # \"metric1_value\", \"metric1_error\", \"metric2_value\", \"metric2_error\"\n",
    "        # OR if they're in an array, adapt accordingly.\n",
    "        # For example, if row['metrics'] = [ [sep_val, sep_err],\n",
    "        #                                    [prompt_val, prompt_err],\n",
    "        #                                    [other_val, other_err] ]\n",
    "        \n",
    "        # Example (based on your table):\n",
    "        sep_val = row[\"sep_metric\"][0]\n",
    "        prompt_val = row[\"probe_in_instruct_asr\"][0]\n",
    "        \n",
    "        if lr_str not in temp_storage:\n",
    "            temp_storage[lr_str] = {}\n",
    "        \n",
    "        if 'dd_pure_from_base' in run_name:\n",
    "            temp_storage[lr_str]['dd_pure'] = (sep_val, prompt_val)\n",
    "        elif 'pretrained_vanilla' in run_name:\n",
    "            temp_storage[lr_str]['pretrained'] = (sep_val, prompt_val)\n",
    "    \n",
    "    # Now build the final dictionary with the required structure:\n",
    "    #   lr -> (dd_pure_sep, dd_pure_prompt_in_data_asr, \n",
    "    #          pretrained_vanilla_sep, pretrained_vanilla_prompt_in_data_asr)\n",
    "    for lr_str, subdict in temp_storage.items():\n",
    "        # Some runs might be missing from your data; handle gracefully\n",
    "        dd_pure_sep, dd_pure_prompt = subdict.get('dd_pure', (None, None))\n",
    "        pretrained_sep, pretrained_prompt = subdict.get('pretrained', (None, None))\n",
    "        \n",
    "        lr_dict[lr_str] = (\n",
    "            dd_pure_sep, \n",
    "            dd_pure_prompt, \n",
    "            pretrained_sep, \n",
    "            pretrained_prompt\n",
    "        )\n",
    "    \n",
    "    return lr_dict\n",
    "\n",
    "\n",
    "def plot_results(lr_dict, original_value, original_inst_value):\n",
    "    \"\"\"\n",
    "    Create comprehensive visualization of learning rate analysis results.\n",
    "    \n",
    "    This function generates publication-quality plots comparing ASIDE and baseline\n",
    "    performance across different learning rates, with reference lines for\n",
    "    original model performance.\n",
    "    \n",
    "    Args:\n",
    "        lr_dict (dict): Learning rate dictionary from create_lr_dict()\n",
    "            Format: lr_str -> (dd_pure_sep, dd_pure_prompt_asr, \n",
    "                              pretrained_sep, pretrained_prompt_asr)\n",
    "        original_value (float): Baseline model performance (horizontal reference line)\n",
    "        original_inst_value (float): Instruction-tuned baseline performance\n",
    "    \n",
    "    Plot Features:\n",
    "        - Log-scale x-axis for learning rates\n",
    "        - Solid lines for SEP metrics, dashed for attack success rates\n",
    "        - Different colors for ASIDE vs baseline methods\n",
    "        - Reference lines for original model performance\n",
    "        - Professional styling with seaborn theme\n",
    "    \n",
    "    Visual Interpretation:\n",
    "        - Higher SEP values = better instruction-data separation\n",
    "        - Lower attack success rates = better robustness\n",
    "        - ASIDE should outperform baselines across learning rates\n",
    "    \n",
    "    Example:\n",
    "        >>> lr_dict = create_lr_dict(df)\n",
    "        >>> plot_results(lr_dict, 0.387, 0.504)\n",
    "    \"\"\"\n",
    "    ## Apply seaborn style\n",
    "    sns.set_theme(style=\"whitegrid\")\n",
    "    \n",
    "    # Convert the keys of lr_dict to floats for sorting\n",
    "    def str_to_float(lr_str):\n",
    "        # Safely evaluate scientific notation\n",
    "        return float(lr_str)\n",
    "    \n",
    "    # Sort the learning rates (numeric ascending)\n",
    "    sorted_lrs = sorted(lr_dict.keys(), key=str_to_float)\n",
    "    numeric_lrs = [str_to_float(k) for k in sorted_lrs]\n",
    "    \n",
    "    # Extract the four series\n",
    "    dd_pure_sep = [lr_dict[k][0] for k in sorted_lrs]\n",
    "    dd_pure_prompt = [lr_dict[k][1] for k in sorted_lrs]\n",
    "    pretrained_sep = [lr_dict[k][2] for k in sorted_lrs]\n",
    "    pretrained_prompt = [lr_dict[k][3] for k in sorted_lrs]\n",
    "    \n",
    "    # Create the figure\n",
    "    plt.figure(figsize=(10, 6))\n",
    "    \n",
    "    # Plot dd_pure (sep = solid, prompt_in_data_asr = dashed)\n",
    "    plt.plot(numeric_lrs, dd_pure_sep, label='dd_pure_sep', linestyle='-', color=sns.color_palette(\"muted\")[0], linewidth=2)\n",
    "    plt.plot(numeric_lrs, dd_pure_prompt, label='dd_pure_prompt_in_data_asr', linestyle='--', color=sns.color_palette(\"muted\")[0], linewidth=2)\n",
    "    \n",
    "    # Plot pretrained_vanilla (sep = solid, prompt_in_data_asr = dashed)\n",
    "    plt.plot(numeric_lrs, pretrained_sep, label='pretrained_vanilla_sep', linestyle='-', color=sns.color_palette(\"muted\")[1], linewidth=2)\n",
    "    plt.plot(numeric_lrs, pretrained_prompt, label='pretrained_vanilla_prompt_in_data_asr', linestyle='--', color=sns.color_palette(\"muted\")[1], linewidth=2)\n",
    "    \n",
    "    # Add horizontal lines for original and original_inst\n",
    "    plt.axhline(y=original_value, color=sns.color_palette(\"muted\")[2], linestyle='-.', label=f'base utility={original_value:.3f}', linewidth=2)\n",
    "    plt.axhline(y=original_inst_value, color=sns.color_palette(\"muted\")[3], linestyle=':', label=f'base inst utility={original_inst_value:.3f}', linewidth=2)\n",
    "    \n",
    "    # Configure log-scale on x-axis\n",
    "    plt.xscale('log')\n",
    "    \n",
    "    # Labels, title, and ticks\n",
    "    plt.xlabel('Learning Rate', fontsize=12)\n",
    "    plt.ylabel('Metric (e.g., \"sep\")', fontsize=12)\n",
    "    plt.title('Comparison of dd_pure vs pretrained_vanilla', fontsize=14)\n",
    "    \n",
    "    # Legend to the side\n",
    "    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10, frameon=True)\n",
    "    \n",
    "    # Adjust layout for better readability\n",
    "    plt.tight_layout()\n",
    "    \n",
    "    # Show plot\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "97d161d6-1826-4a1b-8d6d-b1a6bfb165cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_metrics(df):\n",
    "    \"\"\"\n",
    "    Convenient wrapper function to plot learning rate analysis from DataFrame.\n",
    "    \n",
    "    This function combines data processing and visualization into a single call,\n",
    "    automatically extracting baseline values and creating the learning rate plot.\n",
    "    \n",
    "    Args:\n",
    "        df (pandas.DataFrame): Results DataFrame containing:\n",
    "            - Model results with learning rate information\n",
    "            - Rows with model names 'original' and 'original_inst' for baselines\n",
    "    \n",
    "    Workflow:\n",
    "        1. Create learning rate dictionary from DataFrame\n",
    "        2. Extract baseline performance from 'original' and 'original_inst' rows\n",
    "        3. Generate comparative visualization\n",
    "    \n",
    "    Expected DataFrame Structure:\n",
    "        - Regular experiment rows with LR-encoded model names\n",
    "        - Special rows: 'original' (base model), 'original_inst' (instruct-tuned)\n",
    "    \n",
    "    Example:\n",
    "        >>> results_df = load_experiment_results()\n",
    "        >>> plot_metrics(results_df)  # Displays interactive plot\n",
    "    \"\"\"\n",
    "    # 1) Build the dictionary\n",
    "    lr_dict = create_lr_dict(df)\n",
    "    print(lr_dict)\n",
    "    # 2) We also retrieve the \"original\" and \"original_inst\" from the DataFrame\n",
    "    #    (assuming they're in row 18 and 19, or found by name).\n",
    "    original_row = df.loc[df['model'] == 'original'].iloc[0]\n",
    "    original_inst_row = df.loc[df['model'] == 'original_inst'].iloc[0]\n",
    "    original_value = original_row[\"probe_in_instruct_asr\"][0]     # e.g. 0.387\n",
    "    original_inst_value = original_inst_row[\"probe_in_instruct_asr\"][0] \n",
    "    \n",
    "    # 3) Plot\n",
    "    plot_results(lr_dict, original_value, original_inst_value)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "alpacaeval-integration",
   "metadata": {},
   "source": [
    "## AlpacaEval Score Integration\n",
    "\n",
    "These functions handle integration of AlpacaEval benchmark results with SEP evaluation data. AlpacaEval measures general instruction-following capability, providing a crucial utility metric to ensure ASIDE doesn't sacrifice performance for safety.\n",
    "\n",
    "**Integration Process**:\n",
    "1. Load AlpacaEval CSV results\n",
    "2. Parse and standardize model names\n",
    "3. Merge with SEP evaluation data\n",
    "4. Enable comprehensive performance analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6f2d317c-1ec9-45fa-9e59-b8d530d76819",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_alpacaeval_scores(file_path, substring, alpaca_ver=\"1.0\"):\n",
    "    \"\"\"\n",
    "    Extract AlpacaEval scores for models matching a specific substring pattern.\n",
    "    \n",
    "    This function processes AlpacaEval leaderboard CSV files to extract performance\n",
    "    scores for models of interest, enabling integration with ASIDE evaluation results.\n",
    "    \n",
    "    Args:\n",
    "        file_path (str): Path to AlpacaEval CSV leaderboard file\n",
    "        substring (str): Substring to filter model names (e.g., \"SFTv110\")\n",
    "        alpaca_ver (str): AlpacaEval version (\"1.0\" or \"2.0\")\n",
    "            - \"1.0\": Uses 'win_rate' column\n",
    "            - \"2.0\": Uses 'length_controlled_winrate' column\n",
    "    \n",
    "    Returns:\n",
    "        pandas.DataFrame: Filtered results with columns:\n",
    "            - 'model': Model name/path from CSV index\n",
    "            - 'win_rate' or 'length_controlled_winrate': Performance score\n",
    "    \n",
    "    CSV Structure Expected:\n",
    "        - First column (index): Model names/paths\n",
    "        - Score columns: 'win_rate' (v1.0) or 'length_controlled_winrate' (v2.0)\n",
    "    \n",
    "    Example:\n",
    "        >>> scores = get_alpacaeval_scores(\n",
    "        ...     \"./evals/alpaca_eval/leaderboard.csv\", \n",
    "        ...     \"SFTv110\", \n",
    "        ...     alpaca_ver=\"1.0\"\n",
    "        ... )\n",
    "        >>> print(f\"Found {len(scores)} matching models\")\n",
    "    \n",
    "    Note:\n",
    "        - Filters models containing the substring in their name\n",
    "        - Resets index to make 'model' a regular column\n",
    "        - Handles both AlpacaEval v1.0 and v2.0 formats\n",
    "    \"\"\"\n",
    "    # Read the CSV and use the first column as the index\n",
    "    df = pd.read_csv(file_path, index_col=0)\n",
    "    \n",
    "    # Filter rows based on whether the index (model name) contains the substring\n",
    "    filtered_df = df[df.index.str.contains(substring, na=False)]\n",
    "    \n",
    "    # Create a 'model' column from the index\n",
    "    filtered_df = filtered_df.copy()\n",
    "    filtered_df.loc[:, 'model'] = filtered_df.index\n",
    "    \n",
    "    # Only keep 'model' and 'length_controlled_winrate'\n",
    "    if alpaca_ver==\"1.0\":\n",
    "        filtered_df = filtered_df[['model', 'win_rate']]\n",
    "    else:\n",
    "        filtered_df = filtered_df[['model', 'length_controlled_winrate']]\n",
    "    \n",
    "    # Re-index the DataFrame so 'model' is just a column, not the index\n",
    "    filtered_df.reset_index(drop=True, inplace=True)\n",
    "    \n",
    "    return filtered_df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "model-name-parsing",
   "metadata": {},
   "source": [
    "## Model Name Parsing and Standardization\n",
    "\n",
    "These functions handle the  task of parsing and standardizing model names across different file formats and evaluation systems. Consistent naming is crucial for merging results from multiple evaluation sources.\n",
    "\n",
    "**Challenges Addressed**:\n",
    "- Different naming conventions between evaluation systems\n",
    "- Extracting run numbers and model types from complex paths\n",
    "- Standardizing names for reliable data merging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "120cd96d-dced-45b3-899e-f4dc41352ee9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def parse_model_first_table(name: str) -> str:\n",
    "    \"\"\"\n",
    "    Parse model names from the first table format to standardized format.\n",
    "    \n",
    "    This function handles model names from SEP evaluation results, removing\n",
    "    unnecessary components and standardizing the format for merging with\n",
    "    other evaluation data.\n",
    "    \n",
    "    Args:\n",
    "        name (str): Original model name from SEP results\n",
    "            Example: 'forward_rot_from_base_run0_val25feb'\n",
    "    \n",
    "    Returns:\n",
    "        str: Cleaned model name\n",
    "            Example: 'forward_rot_run0'\n",
    "    \n",
    "    Processing Steps:\n",
    "        1. Remove 'from_base_' prefix if present\n",
    "        2. Truncate everything from '_val' onward (removes date suffixes)\n",
    "        3. Preserve core model type and run number\n",
    "    \n",
    "    Example:\n",
    "        >>> parse_model_first_table('forward_rot_from_base_run0_val25feb')\n",
    "        'forward_rot_run0'\n",
    "        >>> parse_model_first_table('ise_run5_val01mar')\n",
    "        'ise_run5'\n",
    "    \"\"\"\n",
    "    # 1. Remove 'from_base_' if present\n",
    "    name = name.replace(\"from_base_\", \"\")\n",
    "    \n",
    "    # 2. If there's a '_val', cut everything from '_val' onward\n",
    "    val_idx = name.find(\"_val\")\n",
    "    if val_idx != -1:\n",
    "        name = name[:val_idx]\n",
    "    \n",
    "    # Result: e.g. 'forward_rot_run0'\n",
    "    return name\n",
    "\n",
    "def parse_model_second_table(path: str) -> str:\n",
    "    \"\"\"\n",
    "    Parse model names from filepath format to standardized format.\n",
    "    \n",
    "    This function extracts standardized model names from full model paths,\n",
    "    typically from AlpacaEval results or training checkpoint directories.\n",
    "    \n",
    "    Args:\n",
    "        path (str): Full model filepath\n",
    "            Example: '../models/llama_3.1_8b/forward_rot/train_checkpoints/SFTv70/from_base_run_5e-6_bs8/last/'\n",
    "    \n",
    "    Returns:\n",
    "        str: Standardized model name\n",
    "            Example: 'forward_rot_run5e-6_bs8'\n",
    "    \n",
    "    Processing Logic:\n",
    "        1. Split path by '/' to get components\n",
    "        2. Extract technique from expected position (index 3)\n",
    "        3. Find 'from_base_run_' component and extract run information\n",
    "        4. Combine technique + run info in standard format\n",
    "    \n",
    "    Example:\n",
    "        >>> path = '../models/llama_3.1_8b/forward_rot/train_checkpoints/SFTv70/from_base_run_15/last/'\n",
    "        >>> parse_model_second_table(path)\n",
    "        'forward_rot_run15'\n",
    "    \n",
    "    Note:\n",
    "        - Returns 'unknown_run-9999' if parsing fails\n",
    "        - Handles both simple run numbers and complex LR+batch size formats\n",
    "    \"\"\"\n",
    "    # Split by \"/\"\n",
    "    parts = path.split(\"/\")\n",
    "    \n",
    "    # Attempt to find technique in the 3rd index or whichever you expect\n",
    "    # Adjust if your path structure differs\n",
    "    if len(parts) > 3:\n",
    "        technique = parts[3]\n",
    "    else:\n",
    "        technique = \"unknown\"\n",
    "    \n",
    "    # Find the part that starts with 'from_base_run_'\n",
    "    run_parts = [p for p in parts if p.startswith(\"from_base_run_\")]\n",
    "    if run_parts:\n",
    "        run_part = run_parts[0]  # e.g. 'from_base_run_15' or 'from_base_run_5e-6_bs8'\n",
    "        # Extract everything after 'from_base_run_'\n",
    "        run_number = run_part[len(\"from_base_run_\"):]\n",
    "    else:\n",
    "        run_number = \"-9999\"\n",
    "    \n",
    "    # Combine technique + run number\n",
    "    return f\"{technique}_run{run_number}\"\n",
    "\n",
    "# --------------------- #\n",
    "# MERGE DATA EXAMPLE    #\n",
    "# --------------------- #\n",
    "\n",
    "def merge_sep_alpaca_tables(df1: pd.DataFrame, df2: pd.DataFrame) -> pd.DataFrame:\n",
    "    \"\"\"\n",
    "    Merge SEP evaluation results with AlpacaEval scores based on standardized model names.\n",
    "    \n",
    "    This function combines two critical evaluation datasets: SEP metrics (measuring\n",
    "    instruction-data separation) and AlpacaEval scores (measuring general capability).\n",
    "    The merge enables comprehensive analysis of the safety-utility tradeoff.\n",
    "    \n",
    "    Args:\n",
    "        df1 (pandas.DataFrame): SEP evaluation results with columns:\n",
    "            - 'model': Model names in first table format\n",
    "            - 'sep_metric': Separation scores\n",
    "            - 'probe_in_instruct_asr': Attack success rates\n",
    "            - Other evaluation metrics\n",
    "        df2 (pandas.DataFrame): AlpacaEval results with columns:\n",
    "            - 'model': Model names in second table format (filepaths)\n",
    "            - 'win_rate': AlpacaEval performance scores\n",
    "    \n",
    "    Returns:\n",
    "        pandas.DataFrame: Merged dataset with columns from df1 plus:\n",
    "            - 'alpacaeval 1.0': Renamed win_rate column\n",
    "            - 'parsed_name': Standardized model names used for merging\n",
    "    \n",
    "    Merge Process:\n",
    "        1. Parse model names in both DataFrames to common format\n",
    "        2. Perform left join on parsed names\n",
    "        3. Rename AlpacaEval column for clarity\n",
    "        4. Preserve all SEP data, add AlpacaEval where available\n",
    "    \n",
    "    Example:\n",
    "        >>> sep_df = load_sep_results()\n",
    "        >>> alpaca_df = get_alpacaeval_scores(\"leaderboard.csv\", \"SFTv110\")\n",
    "        >>> merged = merge_sep_alpaca_tables(sep_df, alpaca_df)\n",
    "        >>> print(f\"Merged {len(merged)} models with both metrics\")\n",
    "    \n",
    "    Use Cases:\n",
    "        - Analyze safety-utility tradeoffs\n",
    "        - Identify models with optimal balance\n",
    "        - Create comprehensive evaluation reports\n",
    "    \"\"\"\n",
    "    # Create a parsed name column in df1\n",
    "    df1[\"parsed_name\"] = df1[\"model\"].apply(parse_model_first_table)\n",
    "    \n",
    "    # Create a parsed name column in df2\n",
    "    df2[\"parsed_name\"] = df2[\"model\"].apply(parse_model_second_table)\n",
    "    \n",
    "    # Merge (left-join) df2's 'win_rate' onto df1, keyed by 'parsed_name'\n",
    "    merged = pd.merge(\n",
    "        df1, \n",
    "        df2[[\"parsed_name\", \"win_rate\"]], \n",
    "        on=\"parsed_name\", \n",
    "        how=\"left\"\n",
    "    )\n",
    "    \n",
    "    # Rename 'win_rate' column to 'alpacaeval 1.0'\n",
    "    merged.rename(columns={\"win_rate\": \"alpacaeval 1.0\"}, inplace=True)\n",
    "\n",
    "    # Drop 'parsed_name' if you don't want it in the final output\n",
    "    # Or you can keep it for debugging\n",
    "    # merged.drop(columns=\"parsed_name\", inplace=True)\n",
    "    \n",
    "    return merged"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "advanced-parsing",
   "metadata": {},
   "source": [
    "## Advanced Model Name Parsing\n",
    "\n",
    "Additional parsing functions for handling more complex model naming schemes, particularly those with embedded hyperparameters and version information."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5c1fcf2c-6812-4d9a-ae3f-4e92bd136ce1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_model_name(full_name):\n",
    "    \"\"\"\n",
    "    Parse complex model names to extract core components and run numbers.\n",
    "    \n",
    "    This function handles the most complex model naming format, extracting\n",
    "    both the model type and run number from names with embedded training\n",
    "    configuration information.\n",
    "    \n",
    "    Args:\n",
    "        full_name (str): Complete model name with training config\n",
    "            Example: \"forward_rot_train_full_SFTv110_run=11\"\n",
    "    \n",
    "    Returns:\n",
    "        tuple: (model_type, run_number)\n",
    "            Example: (\"forward_rot\", \"11\")\n",
    "    \n",
    "    Parsing Logic:\n",
    "        1. Extract model type from prefix before \"_train_full\"\n",
    "        2. Extract run number from \"run=\" parameter\n",
    "        3. Handle fallback cases for malformed names\n",
    "    \n",
    "    Supported Formats:\n",
    "        - \"forward_rot_train_full_SFTv110_run=11\"\n",
    "        - \"ise_train_full_SFTv70_run=5\"\n",
    "        - \"single_emb_train_full_SFTv110_run=20\"\n",
    "    \n",
    "    Example:\n",
    "        >>> model_type, run_num = parse_model_name(\"forward_rot_train_full_SFTv110_run=11\")\n",
    "        >>> print(f\"Model: {model_type}, Run: {run_num}\")\n",
    "        Model: forward_rot, Run: 11\n",
    "    \"\"\"\n",
    "    # Identify the model type (prefix before \"_train_full\")\n",
    "    model_type_match = re.match(r'^([^_]+(?:_[^_]+)?)_train_full', full_name)\n",
    "    if model_type_match:\n",
    "        model_type = model_type_match.group(1)\n",
    "    else:\n",
    "        # Fallback if pattern doesn't match\n",
    "        model_type = full_name.split('_train_full')[0]\n",
    "    \n",
    "    # Extract the run number after \"run=\"\n",
    "    run_match = re.search(r'run=([^/]+)', full_name)\n",
    "    if run_match:\n",
    "        run_number = run_match.group(1)\n",
    "    else:\n",
    "        # Fallback if run number not found\n",
    "        run_number = \"-9999\"\n",
    "    \n",
    "    return model_type, run_number\n",
    "\n",
    "def standardize_model_name(full_name):\n",
    "    \"\"\"\n",
    "    Convert complex model names to standardized format for data merging.\n",
    "    \n",
    "    This function creates consistent model identifiers that can be used\n",
    "    across different evaluation systems and data sources.\n",
    "    \n",
    "    Args:\n",
    "        full_name (str): Complete model name from training logs\n",
    "            Example: \"forward_rot_train_full_SFTv110_run=11\"\n",
    "    \n",
    "    Returns:\n",
    "        str: Standardized model identifier\n",
    "            Example: \"forward_rot_run11\"\n",
    "    \n",
    "    Standardization Benefits:\n",
    "        - Consistent naming across evaluation systems\n",
    "        - Simplified model identification\n",
    "        - Reliable data merging capabilities\n",
    "        - Easy filtering and grouping operations\n",
    "    \n",
    "    Example:\n",
    "        >>> standardize_model_name(\"forward_rot_train_full_SFTv110_run=11\")\n",
    "        'forward_rot_run11'\n",
    "        >>> standardize_model_name(\"ise_train_full_SFTv70_run=5\")\n",
    "        'ise_run5'\n",
    "    \"\"\"\n",
    "    model_type, run_number = parse_model_name(full_name)\n",
    "    return f\"{model_type}_run{run_number}\"\n",
    "\n",
    "def transform_losses_df(df):\n",
    "    \"\"\"\n",
    "    Add standardized model name column to training losses DataFrame.\n",
    "    \n",
    "    This function prepares training loss data for merging with evaluation\n",
    "    results by adding a standardized model name column.\n",
    "    \n",
    "    Args:\n",
    "        df (pandas.DataFrame): Training losses DataFrame with 'model' column\n",
    "            containing original complex model names\n",
    "    \n",
    "    Returns:\n",
    "        pandas.DataFrame: DataFrame with added 'parsed_name' column\n",
    "            containing standardized model identifiers\n",
    "    \n",
    "    Use Case:\n",
    "        Enables correlation analysis between training dynamics (loss curves)\n",
    "        and final evaluation performance (SEP, AlpacaEval).\n",
    "    \n",
    "    Example:\n",
    "        >>> losses_df = load_training_losses()\n",
    "        >>> losses_df = transform_losses_df(losses_df)\n",
    "        >>> print(losses_df[['model', 'parsed_name']].head())\n",
    "    \"\"\"\n",
    "    df['parsed_name'] = df['model'].apply(standardize_model_name)\n",
    "    return df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "alpaca-output-parsing",
   "metadata": {},
   "source": [
    "## AlpacaEval Output Parsing\n",
    "\n",
    "This function parses AlpacaEval output files directly from evaluation directories, extracting performance scores and associating them with model identifiers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c5796665-f3c8-4bcf-9e1f-c4dd4f4db171",
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_alpaca_outputs(directory, substring):\n",
    "    \"\"\"\n",
    "    Parse AlpacaEval output files from a directory to extract performance scores.\n",
    "    \n",
    "    This function processes raw AlpacaEval output files, extracting model paths\n",
    "    and their corresponding performance scores. It's useful when working with\n",
    "    direct evaluation outputs rather than processed leaderboard files.\n",
    "    \n",
    "    Args:\n",
    "        directory (str): Path to directory containing AlpacaEval output files\n",
    "        substring (str): Substring to filter relevant lines (e.g., \"SFTv110\")\n",
    "    \n",
    "    Returns:\n",
    "        pandas.DataFrame: Parsed results with columns:\n",
    "            - 'model': Model identifier extracted from path\n",
    "            - 'alpacaeval 1.0': Performance score\n",
    "            - 'parsed_name': Standardized model name for merging\n",
    "    \n",
    "    File Format Expected:\n",
    "        Each line should contain:\n",
    "        <model_path> <score> [other_data...]\n",
    "        \n",
    "        Example line:\n",
    "        ../models/llama_3.1_8b/forward_rot/SFTv110/from_base_run_15/last/ 85.19\n",
    "    \n",
    "    Processing Steps:\n",
    "        1. Read all files in directory\n",
    "        2. Filter lines containing the substring\n",
    "        3. Extract model path and score from each line\n",
    "        4. Standardize model names for consistency\n",
    "        5. Remove entries with failed parsing (-9999 indicator)\n",
    "    \n",
    "    Example:\n",
    "        >>> scores = parse_alpaca_outputs(\"./alpaca_outputs/\", \"SFTv110\")\n",
    "        >>> print(f\"Parsed {len(scores)} model scores\")\n",
    "        >>> print(scores[['parsed_name', 'alpacaeval 1.0']].head())\n",
    "    \n",
    "    Note:\n",
    "        - Handles multiple files in the directory\n",
    "        - Robust error handling for malformed lines\n",
    "        - Filters out unparseable model names\n",
    "    \"\"\"\n",
    "    data = []\n",
    "\n",
    "    # Iterate over all items in the directory\n",
    "    for filename in os.listdir(directory):\n",
    "        file_path = os.path.join(directory, filename)\n",
    "        \n",
    "        # Process only if it's a regular file\n",
    "        if os.path.isfile(file_path):\n",
    "            with open(file_path, 'r', encoding='utf-8') as f:\n",
    "                for line in f:\n",
    "                    # Check if 'SFTv110' is in the line\n",
    "                    if substring in line:\n",
    "                        # Split line on whitespace\n",
    "                        parts = line.strip().split()\n",
    "                        \n",
    "                        # Ensure we have at least 2 parts (path & numeric value)\n",
    "                        if len(parts) < 2:\n",
    "                            continue\n",
    "\n",
    "                        full_path = parts[0]      # e.g. ../models/.../SFTv110/from_base_run_15/...\n",
    "                        win_rate_str = parts[1]  # e.g. 85.19\n",
    "\n",
    "                        # Find the substring starting at 'SFTv110'\n",
    "                        idx = full_path.find(substring)\n",
    "                        if idx == -1:\n",
    "                            continue\n",
    "\n",
    "                        # Grab everything from 'SFTv110' onward, removing trailing slashes\n",
    "                        #model_str = full_path#[idx:].rstrip('/')\n",
    "                        model_str = full_path[idx:].rstrip('/')\n",
    "\n",
    "                        # Convert second token to float\n",
    "                        try:\n",
    "                            win_rate = float(win_rate_str)\n",
    "                        except ValueError:\n",
    "                            continue\n",
    "\n",
    "                        # Append to data\n",
    "                        data.append((model_str, win_rate))\n",
    "\n",
    "    # Create a DataFrame\n",
    "    df = pd.DataFrame(data, columns=['model', 'win_rate'])\n",
    "    df[\"parsed_name\"] = df[\"model\"].apply(standardize_model_name)\n",
    "    df = df[~df[\"parsed_name\"].str.contains(\"-9999\", na=False)]\n",
    "\n",
    "    df.rename(columns={\"win_rate\": \"alpacaeval 1.0\"}, inplace=True)\n",
    "    return df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "output-quality-analysis",
   "metadata": {},
   "source": [
    "## Output Quality Analysis\n",
    "\n",
    "These functions analyze the quality of model outputs, particularly focusing on detecting repetitive or problematic generations that might indicate training issues or model failures."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6cc03623-feba-4a1a-a1af-4d8cd26b1120",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import glob\n",
    "import os\n",
    "import random\n",
    "import pandas as pd\n",
    "\n",
    "def is_repetitive(\n",
    "    text: str, \n",
    "    num_positions: int = 10, \n",
    "    substring_len: int = 15, \n",
    "    min_repeats: int = 5\n",
    ") -> bool:\n",
    "    \"\"\"\n",
    "    Detect repetitive text using statistical sampling method.\n",
    "    \n",
    "    This function identifies repetitive model outputs by sampling random\n",
    "    substrings and checking if they appear frequently throughout the text.\n",
    "    It's designed to catch common failure modes like repetitive loops.\n",
    "    \n",
    "    Args:\n",
    "        text (str): Text to analyze for repetitiveness\n",
    "        num_positions (int): Number of random positions to sample (default: 10)\n",
    "        substring_len (int): Length of substrings to extract (default: 15)\n",
    "        min_repeats (int): Minimum repetitions to flag as repetitive (default: 5)\n",
    "    \n",
    "    Returns:\n",
    "        bool: True if text is considered repetitive, False otherwise\n",
    "    \n",
    "    Algorithm:\n",
    "        1. Randomly sample up to num_positions starting positions\n",
    "        2. Extract substring_len characters from each position\n",
    "        3. Count occurrences of each substring in the full text\n",
    "        4. Return True if any substring appears >= min_repeats times\n",
    "    \n",
    "    Use Cases:\n",
    "        - Quality control for model outputs\n",
    "        - Identifying training instabilities\n",
    "        - Filtering problematic generations\n",
    "        - Model comparison based on output quality\n",
    "    \n",
    "    Example:\n",
    "        >>> repetitive_text = \"Hello world! \" * 10\n",
    "        >>> is_repetitive(repetitive_text)\n",
    "        True\n",
    "        >>> normal_text = \"This is a normal response with varied content.\"\n",
    "        >>> is_repetitive(normal_text)\n",
    "        False\n",
    "    \n",
    "    Note:\n",
    "        - Uses random sampling for efficiency on long texts\n",
    "        - Balances false positive/negative rates through parameter tuning\n",
    "        - Handles edge cases (short text, no valid positions)\n",
    "    \"\"\"\n",
    "    # If text is too short to extract a substring of substring_len\n",
    "    if len(text) < substring_len:\n",
    "        return False\n",
    "    \n",
    "    # All valid starting positions\n",
    "    max_start = len(text) - substring_len\n",
    "    # If fewer than `num_positions` possible starts, sample them all\n",
    "    sample_size = min(num_positions, max_start + 1)\n",
    "    \n",
    "    # Randomly pick positions from the valid range\n",
    "    positions = random.sample(range(max_start + 1), k=sample_size)\n",
    "    \n",
    "    for start in positions:\n",
    "        candidate = text[start : start + substring_len]\n",
    "        # Count occurrences of candidate in the entire text\n",
    "        count_occurrences = text.count(candidate)\n",
    "        if count_occurrences >= min_repeats:\n",
    "            return True\n",
    "    \n",
    "    return False\n",
    "\n",
    "def analyze_outputs_folder(folder_with_json):\n",
    "    \"\"\"\n",
    "    Analyze model output quality across all JSON files in a folder.\n",
    "    \n",
    "    This function processes evaluation result files to compute quality metrics\n",
    "    including repetition rates and output lengths. It's essential for identifying\n",
    "    models with generation issues.\n",
    "    \n",
    "    Args:\n",
    "        folder_with_json (str): Path to folder containing JSON result files\n",
    "            Each JSON should contain evaluation results with fields:\n",
    "            - 'output1_probe_in_data': Model response when probe is in data section\n",
    "            - 'output2_probe_in_task': Model response when probe is in task section\n",
    "    \n",
    "    Returns:\n",
    "        pandas.DataFrame: Quality analysis results with columns:\n",
    "            - 'model': Model identifier from filename\n",
    "            - 'repetition_d': Repetition rate for data section outputs (0-1)\n",
    "            - 'repetition_t': Repetition rate for task section outputs (0-1) \n",
    "            - 'len_d': Average length of data section outputs\n",
    "            - 'len_t': Average length of task section outputs\n",
    "    \n",
    "    Quality Indicators:\n",
    "        - Low repetition rates (< 0.1) indicate healthy generation\n",
    "        - High repetition rates (> 0.3) suggest training issues\n",
    "        - Consistent lengths across conditions indicate stability\n",
    "        - Extreme length variations may indicate problematic generations\n",
    "    \n",
    "    Example:\n",
    "        >>> quality_df = analyze_outputs_folder(\"./model_outputs/llama_3.1_8b/\")\n",
    "        >>> print(quality_df[['model', 'repetition_d', 'repetition_t']].head())\n",
    "        >>> problematic = quality_df[quality_df['repetition_d'] > 0.3]\n",
    "        >>> print(f\"Found {len(problematic)} models with high repetition\")\n",
    "    \n",
    "    Use Cases:\n",
    "        - Model selection based on output quality\n",
    "        - Identifying training hyperparameters that cause instability\n",
    "        - Quality control in large-scale experiments\n",
    "        - Comparing output stability across methods\n",
    "    \"\"\"\n",
    "    rows = []\n",
    "    \n",
    "    for filepath in glob.glob(os.path.join(folder_with_json, \"*.json\")):\n",
    "        with open(filepath, \"r\", encoding=\"utf-8\") as f:\n",
    "            data = json.load(f)  # list of dicts\n",
    "\n",
    "        # Track stats\n",
    "        count_d_reps = 0\n",
    "        count_t_reps = 0\n",
    "        lengths_d = []\n",
    "        lengths_t = []\n",
    "        \n",
    "        for entry in data:\n",
    "            text_d = entry[\"output1_probe_in_data\"]\n",
    "            text_t = entry[\"output2_probe_in_task\"]\n",
    "            \n",
    "            # Check repetition\n",
    "            if is_repetitive(text_d):\n",
    "                count_d_reps += 1\n",
    "            if is_repetitive(text_t):\n",
    "                count_t_reps += 1\n",
    "            \n",
    "            # Keep lengths (basic char length or token length, up to you)\n",
    "            lengths_d.append(len(text_d))\n",
    "            lengths_t.append(len(text_t))\n",
    "        \n",
    "        total = len(data)\n",
    "        if total > 0:\n",
    "            repetition_d = count_d_reps / total\n",
    "            repetition_t = count_t_reps / total\n",
    "            len_d = sum(lengths_d) / total\n",
    "            len_t = sum(lengths_t) / total\n",
    "        else:\n",
    "            repetition_d = 0.0\n",
    "            repetition_t = 0.0\n",
    "            len_d = 0.0\n",
    "            len_t = 0.0\n",
    "        \n",
    "        # You can parse a \"model name\" from filename, if desired\n",
    "        model_name = os.path.splitext(os.path.basename(filepath))[0]\n",
    "        \n",
    "        rows.append({\n",
    "            \"model\": model_name,\n",
    "            \"repetition_d\": repetition_d,\n",
    "            \"repetition_t\": repetition_t,\n",
    "            \"len_d\": len_d,\n",
    "            \"len_t\": len_t,\n",
    "        })\n",
    "    \n",
    "    df = pd.DataFrame(rows)\n",
    "    return df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "training-results-aggregation",
   "metadata": {},
   "source": [
    "## Training Results Aggregation\n",
    "\n",
    "These functions aggregate training metrics from multiple experiment runs, enabling analysis of training dynamics and correlation with final performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "42971f04-d877-4f63-a2bc-ac2ae2ac1471",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "def aggregate_experiment_results(train_evals_path):\n",
    "    \"\"\"\n",
    "    Aggregate training metrics from multiple experiment subfolders.\n",
    "    \n",
    "    This function processes training logs from multiple experiments, extracting\n",
    "    minimum evaluation loss for each model. It's crucial for understanding\n",
    "    training dynamics and correlating them with final performance.\n",
    "    \n",
    "    Args:\n",
    "        train_evals_path (str): Path to main directory containing experiment subfolders\n",
    "            Expected structure:\n",
    "            train_evals_path/\n",
    "            ├── model1_config/\n",
    "            │   └── losses_metrics.json\n",
    "            ├── model2_config/\n",
    "            │   └── losses_metrics.json\n",
    "            └── ...\n",
    "    \n",
    "    Returns:\n",
    "        pandas.DataFrame: Aggregated results with columns:\n",
    "            - 'parsed_name': Standardized model identifier\n",
    "            - 'min_eval_loss': Minimum evaluation loss achieved during training\n",
    "    \n",
    "    JSON File Format Expected:\n",
    "        {\n",
    "            \"eval_loss\": [loss1, loss2, loss3, ...] or single_value\n",
    "            // other metrics...\n",
    "        }\n",
    "    \n",
    "    Processing Steps:\n",
    "        1. Iterate through all subfolders\n",
    "        2. Load losses_metrics.json from each folder\n",
    "        3. Extract minimum evaluation loss\n",
    "        4. Create standardized model names\n",
    "        5. Return sorted DataFrame\n",
    "    \n",
    "    Use Cases:\n",
    "        - Training stability analysis\n",
    "        - Hyperparameter optimization\n",
    "        - Correlation with final performance\n",
    "        - Model selection based on training metrics\n",
    "    \n",
    "    Example:\n",
    "        >>> losses_df = aggregate_experiment_results(\"./train_logs/llama_3.1_8b/SFTv110/\")\n",
    "        >>> print(f\"Processed {len(losses_df)} training runs\")\n",
    "        >>> best_loss = losses_df.loc[losses_df['min_eval_loss'].idxmin()]\n",
    "        >>> print(f\"Best model: {best_loss['parsed_name']} (loss: {best_loss['min_eval_loss']:.4f})\")\n",
    "    \n",
    "    Note:\n",
    "        - Handles both list and scalar eval_loss formats\n",
    "        - Robust error handling for missing/malformed files\n",
    "        - Removes original 'model' column, keeps only parsed names\n",
    "    \"\"\"\n",
    "    # Dictionary to store results\n",
    "    results = defaultdict(float)\n",
    "\n",
    "    # Loop through all subfolders in the main directory\n",
    "    for model_folder in os.listdir(train_evals_path):\n",
    "        folder_path = os.path.join(train_evals_path, model_folder)\n",
    "        \n",
    "        # Make sure it's a directory, not a file\n",
    "        if not os.path.isdir(folder_path):\n",
    "            continue\n",
    "        \n",
    "        # Path to the losses_metrics.json file\n",
    "        json_path = os.path.join(folder_path, \"losses_metrics.json\")\n",
    "        \n",
    "        # Check if the file exists\n",
    "        if os.path.exists(json_path):\n",
    "            try:\n",
    "                # Read the JSON file\n",
    "                with open(json_path, 'r') as f:\n",
    "                    metrics_data = json.load(f)\n",
    "                \n",
    "                # Extract the minimum eval_loss\n",
    "                if \"eval_loss\" in metrics_data:\n",
    "                    min_eval_loss = metrics_data[\"eval_loss\"]\n",
    "                    if isinstance(min_eval_loss, list):\n",
    "                        min_eval_loss = min(min_eval_loss)\n",
    "                    results[model_folder] = min_eval_loss\n",
    "            except Exception as e:\n",
    "                print(f\"Error processing {model_folder}: {e}\")\n",
    "\n",
    "    # Create pandas DataFrame\n",
    "    df = pd.DataFrame(list(results.items()), columns=[\"model\", \"min_eval_loss\"])\n",
    "\n",
    "    # Sort by model name\n",
    "    df = df.sort_values(\"model\").reset_index(drop=True)\n",
    "    df = transform_losses_df(df)\n",
    "    del df[\"model\"]\n",
    "    return df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "best-model-selection",
   "metadata": {},
   "source": [
    "## Best Model Selection\n",
    "\n",
    "This function provides automated selection of the best performing models based on various metrics, enabling systematic comparison across different embedding methods."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b9297e09-4242-477f-b6f8-8396b28af146",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def best_models_for_substrings(df, column_name, substring_list, use_max=True, top_k=3):\n",
    "    \"\"\"\n",
    "    Find the best performing models for each embedding type based on a specified metric.\n",
    "    \n",
    "    This function enables systematic comparison across different ASIDE variants\n",
    "    and baselines by identifying top performers in each category. It's essential\n",
    "    for fair comparison and model selection.\n",
    "    \n",
    "    Args:\n",
    "        df (pandas.DataFrame): Results DataFrame with model performance data\n",
    "        column_name (str): Name of column to optimize (e.g., 'sep_metric', 'alpacaeval 1.0')\n",
    "        substring_list (list): List of embedding type identifiers to search for\n",
    "            Example: [\"forward_rot\", \"ise\", \"single_emb\"]\n",
    "        use_max (bool): If True, select models with maximum values (default: True)\n",
    "                       If False, select models with minimum values (for loss metrics)\n",
    "        top_k (int): Number of top models to return per embedding type (default: 3)\n",
    "    \n",
    "    Returns:\n",
    "        pandas.DataFrame: Best models for each embedding type, sorted by performance\n",
    "    \n",
    "    Column Value Handling:\n",
    "        - Numeric values: Used directly for comparison\n",
    "        - Lists/arrays: Uses first element [0] for comparison\n",
    "        - Handles mixed data types gracefully\n",
    "    \n",
    "    Use Cases:\n",
    "        - Systematic embedding method comparison\n",
    "        - Hyperparameter optimization within methods\n",
    "        - Creating performance summaries\n",
    "        - Identifying best models for further analysis\n",
    "    \n",
    "    Example:\n",
    "        >>> # Find best models by SEP metric\n",
    "        >>> best_sep = best_models_for_substrings(\n",
    "        ...     results_df, \n",
    "        ...     'sep_metric', \n",
    "        ...     [\"forward_rot\", \"ise\", \"single_emb\"],\n",
    "        ...     use_max=True, \n",
    "        ...     top_k=1\n",
    "        ... )\n",
    "        >>> print(\"Best models by SEP score:\")\n",
    "        >>> print(best_sep[['parsed_name', 'sep_metric']])\n",
    "        \n",
    "        >>> # Find best models by training loss\n",
    "        >>> best_loss = best_models_for_substrings(\n",
    "        ...     losses_df, \n",
    "        ...     'min_eval_loss', \n",
    "        ...     [\"forward_rot\", \"dd_pure\", \"ise\"],\n",
    "        ...     use_max=False,  # Lower loss is better\n",
    "        ...     top_k=1\n",
    "        ... )\n",
    "    \n",
    "    Filtering Logic:\n",
    "        - Searches for substring matches in 'parsed_name' column\n",
    "        - Handles partial matches (e.g., \"forward_rot\" matches \"forward_rot_run5\")\n",
    "        - Returns empty DataFrame if no matches found\n",
    "    \n",
    "    Note:\n",
    "        - Preserves all original columns in output\n",
    "        - Adds temporary comparison column that is automatically removed\n",
    "        - Handles edge cases (empty DataFrames, missing columns)\n",
    "    \"\"\"\n",
    "    \n",
    "    # Helper function to get comparable scalar from column values\n",
    "    def get_scalar_value(x):\n",
    "        # If it's a list, tuple, or numpy array, compare by the first element\n",
    "        if isinstance(x, (list, tuple, np.ndarray)) and len(x) > 0:\n",
    "            return x[0]\n",
    "        return x  # Non-list/array values returned as-is\n",
    "    \n",
    "    best_rows = []\n",
    "    \n",
    "    for substring in substring_list:\n",
    "        # Filter rows whose 'parsed_name' contains the substring\n",
    "        subset = df[df['parsed_name'].str.contains(substring, na=False)].copy()\n",
    "        if subset.empty:\n",
    "            # No match for this substring, skip\n",
    "            continue\n",
    "        \n",
    "        # Compute a new column with scalar values to compare\n",
    "        subset['_comp_value'] = subset[column_name].apply(get_scalar_value)\n",
    "        \n",
    "        # Choose top_k by max or min\n",
    "        if use_max:\n",
    "            chosen = subset.nlargest(top_k, '_comp_value')\n",
    "        else:\n",
    "            chosen = subset.nsmallest(top_k, '_comp_value')\n",
    "        \n",
    "        best_rows.append(chosen)\n",
    "    \n",
    "    # Concatenate results for all substrings\n",
    "    if best_rows:\n",
    "        result_df = pd.concat(best_rows, ignore_index=True)\n",
    "        # Remove the helper column\n",
    "        result_df.drop(columns=['_comp_value'], inplace=True, errors='ignore')\n",
    "        return result_df\n",
    "    else:\n",
    "        # Return empty DataFrame if no rows matched\n",
    "        return pd.DataFrame(columns=df.columns)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "example-usage",
   "metadata": {},
   "source": [
    "## Example Usage: Selecting best model\n",
    "\n",
    "This example demonstrates how to use the analysis functions to identify the best models based on training loss for the Qwen2.5-7B model family."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab878ea6-7691-4c79-b3d6-eba90b2fd3c4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error processing ise_train_full_SFTv70_run=5e-6_rotation_0_alpha=0.0: min() iterable argument is empty\n",
      "Error processing forward_rot_train_full_SFTv70_run=rotation_0.5_alpha=1.5708: min() iterable argument is empty\n",
      "Error processing dd_pure_train_full_SFTv70_run=5e-6_gradual_rotation_alpha=0.0: min() iterable argument is empty\n",
      "Best by min_eval_loss:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>min_eval_loss</th>\n",
       "      <th>parsed_name</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.019663</td>\n",
       "      <td>forward_rot_run14_alpha=1.57079633</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.953548</td>\n",
       "      <td>dd_pure_run18_alpha=1.57079633</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.954310</td>\n",
       "      <td>ise_run34_alpha=1.57079633</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   min_eval_loss                         parsed_name\n",
       "0       1.019663  forward_rot_run14_alpha=1.57079633\n",
       "1       0.953548      dd_pure_run18_alpha=1.57079633\n",
       "2       0.954310          ise_run34_alpha=1.57079633"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = \"qwen2.5_7b\"\n",
    "sft = \"SFTv70\"\n",
    "\n",
    "train_evals_path= f\"anonymized_path/model_outputs/train_logs/{model}/{sft}\"\n",
    "losses_df = aggregate_experiment_results(train_evals_path)\n",
    "\n",
    "search_substrings = [\"forward_rot\", \"dd_pure\", \"ise\"]  # last one won't match\n",
    "col = \"min_eval_loss\"\n",
    "# Suppose we want the best model by 'alpacaeval 1.0'\n",
    "best_loss = best_models_for_substrings(losses_df,col, search_substrings, use_max=False, top_k=1)\n",
    "print(f\"Best by {col}:\")\n",
    "best_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "final-analysis-example",
   "metadata": {},
   "source": [
    "## Analysis Example: SEP Metric Evaluation\n",
    "\n",
    "This  example shows how to load and analyze SEP evaluation results, demonstrating the complete workflow from data loading to performance comparison."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "cf12cc65-4f10-4a77-aa46-45dbb9210f1d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>model</th>\n",
       "      <th>sep_metric</th>\n",
       "      <th>probe_in_instruct_asr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>base_fullsep</td>\n",
       "      <td>[0.375, 0.006]</td>\n",
       "      <td>[0.774, 0.004]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>forward_rot_from_inst_run14_fullsep</td>\n",
       "      <td>[0.641, 0.006]</td>\n",
       "      <td>[0.726, 0.005]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>from_inst_single_run18_fullsep</td>\n",
       "      <td>[0.418, 0.006]</td>\n",
       "      <td>[0.717, 0.005]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>ise_from_inst_run34_fullsep</td>\n",
       "      <td>[0.419, 0.006]</td>\n",
       "      <td>[0.713, 0.005]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                 model      sep_metric probe_in_instruct_asr\n",
       "0                         base_fullsep  [0.375, 0.006]        [0.774, 0.004]\n",
       "2  forward_rot_from_inst_run14_fullsep  [0.641, 0.006]        [0.726, 0.005]\n",
       "3       from_inst_single_run18_fullsep  [0.418, 0.006]        [0.717, 0.005]\n",
       "4          ise_from_inst_run34_fullsep  [0.419, 0.006]        [0.713, 0.005]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sft = \"SFTv70\"\n",
    "alpaca_ver = \"1.0\"\n",
    "path1 = \"./evals/data/tatsu-lab/alpaca_eval/alpaca_eval_gpt4/leaderboard.csv\"\n",
    "path2 = \"./evals/data/tatsu-lab/alpaca_farm/alpaca_eval_gpt4/leaderboard.csv\"\n",
    "model = \"Qwen2.5-7B\"\n",
    "\n",
    "outputs_path = f\"./model_outputs/{model}/{sft}\"\n",
    "scores = get_df_scores_for_model(outputs_path)\n",
    "sep = list(map(lambda x: x[0], np.array(scores[\"sep_metric\"])))\n",
    "scores=scores[scores[\"model\"].str.contains(\"fullsep\")].iloc[:, [0,1,3]].sort_values(by='model')\n",
    "scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "b49bf4eb-731c-4491-a7f7-53a861686145",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a78598d-f22d-49c6-9fc2-bbf9924a52f5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
