{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SynQuE Proxy Metrics\n",
    "- Proxy-A-Distance (PAD)\n",
    "- Maximum Mean Discrepancy (MMD)\n",
    "- Mean Distance to Medoid (MDM)\n",
    "- LLM-Eval Normalized Score (LENS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Proxy-A-Distance (PAD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# we use a classifier to discriminate real data from synthetic data, then we use the error to compute the PAD\n",
    "import numpy as np\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.model_selection import train_test_split\n",
    "def compute_pad(x_syn_emb, x_real_emb):\n",
    "    \"\"\"\n",
    "    Compute the Proxy-A-Distance (PAD) between two sets of embeddings.\n",
    "\n",
    "    Args:\n",
    "        x_syn_emb (np.ndarray): Embeddings of synthetic data, shape (n_samples, n_features)\n",
    "    \"\"\"\n",
    "    y_syn_train = [0] * len(x_syn_emb)\n",
    "    y_real_train = [1] * len(x_real_emb)\n",
    "\n",
    "    x_train = np.concatenate([x_syn_emb, x_real_emb], axis=0)\n",
    "    y_train = np.concatenate([y_syn_train, y_real_train], axis=0)\n",
    "\n",
    "    # split the data into training and validation\n",
    "    x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2, random_state=42)\n",
    "\n",
    "    # Train the classifier\n",
    "    classifier = LogisticRegression()\n",
    "    classifier.fit(x_train, y_train)\n",
    "\n",
    "    # Test the classifier\n",
    "    y_pred_proba = classifier.predict_proba(x_val)[:, 1]\n",
    "    average_loss = np.mean(np.abs(y_pred_proba - y_val))\n",
    "    return 2 * (1 - 2 * average_loss)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Maximum Mean Discrepancy (MMD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's compute the maximum mean discrepancy (MMD) between the dev data and the synthetic data\n",
    "from sklearn.metrics.pairwise import polynomial_kernel\n",
    "import numpy as np\n",
    "DEGREE = 3\n",
    "GAMMA = None\n",
    "COEF0 = 1\n",
    "\n",
    "\n",
    "def compute_mmd(X, Y, kernel=polynomial_kernel, degree=3, gamma=None, coef0=1):\n",
    "    \"\"\"\n",
    "    Compute the Maximum Mean Discrepancy (MMD) between two samples: X and Y.\n",
    "\n",
    "    Args:\n",
    "        X (np.ndarray): First sample, shape (n_samples_X, n_features)\n",
    "        Y (np.ndarray): Second sample, shape (n_samples_Y, n_features)\n",
    "        kernel (callable): Kernel function to use (default: polynomial_kernel)\n",
    "        degree (int): Degree for polynomial kernel (default: 3)\n",
    "        gamma (float): Gamma parameter for polynomial kernel (default: None)\n",
    "        coef0 (float): Coef0 parameter for polynomial kernel (default: 1)\n",
    "\n",
    "    Returns:\n",
    "        float: MMD value\n",
    "    \"\"\"\n",
    "    XX = kernel(X, X, degree=degree, gamma=gamma, coef0=coef0)\n",
    "    YY = kernel(Y, Y, degree=degree, gamma=gamma, coef0=coef0)\n",
    "    XY = kernel(X, Y, degree=degree, gamma=gamma, coef0=coef0)\n",
    "    \n",
    "    return np.mean(XX) + np.mean(YY) - 2 * np.mean(XY)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mean Distance to Medoid (MDM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import kmedoids\n",
    "import numpy as np\n",
    "from sklearn.metrics import pairwise_distances\n",
    "\n",
    "def compute_distance_matrix(embeddings, metric='euclidean'):\n",
    "    \"\"\"\n",
    "    Compute the pairwise distance matrix for a set of embeddings.\n",
    "\n",
    "    Args:\n",
    "        embeddings (np.ndarray): Embedding matrix of shape (n_samples, n_features)\n",
    "        metric (str): Distance metric to use ('euclidean', 'cosine', etc.)\n",
    "\n",
    "    Returns:\n",
    "        np.ndarray: Distance matrix of shape (n_samples, n_samples)\n",
    "    \"\"\"\n",
    "    return pairwise_distances(embeddings, metric=metric)\n",
    "\n",
    "def compute_mdm(embeddings, n_clusters=5, metric='euclidean'):\n",
    "    \"\"\"\n",
    "    Compute the mean distance of points in each cluster to its medoid, then average across clusters,\n",
    "    using the kmedoids package (fasterpam or pam).\n",
    "\n",
    "    Args:\n",
    "        embeddings (np.ndarray): Embedding matrix of shape (n_samples, n_features)\n",
    "        n_clusters (int): Number of clusters/medoids to use\n",
    "        metric (str): Distance metric for KMedoids\n",
    "        random_state (int): Random seed\n",
    "        use_fasterpam (bool): Whether to use fasterpam (default True)\n",
    "\n",
    "    Returns:\n",
    "        float: Mean distance to medoid (averaged over all clusters)\n",
    "    \"\"\"\n",
    "    n_samples = len(embeddings)\n",
    "    if n_samples < n_clusters:\n",
    "        n_clusters = max(1, n_samples)\n",
    "    diss = compute_distance_matrix(embeddings, metric=metric)\n",
    "    # Run kmedoids clustering\n",
    "    pam_result = kmedoids.fasterpam(diss, n_clusters, random_state=42)\n",
    "    \n",
    "    labels = pam_result.labels\n",
    "    medoid_indices = pam_result.medoids\n",
    "    total_dist = 0.0\n",
    "    for i, medoid_idx in enumerate(medoid_indices):\n",
    "        cluster_points_idx = np.where(labels == i)[0]\n",
    "        if len(cluster_points_idx) == 0:\n",
    "            continue\n",
    "        dists = diss[cluster_points_idx, medoid_idx]\n",
    "        total_dist += np.mean(dists)\n",
    "    return total_dist / n_clusters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### LLM-Eval Normalized Score (LENS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_lens_score_normalized(synth_examples, real_examples):\n",
    "    \"\"\"Computes a balanced LLM score that considers both synthetic and real examples.\n",
    "    \n",
    "    This function calculates a weighted average of two error scores:\n",
    "    1. How often synthetic examples are mistakenly classified as real\n",
    "    2. How often real examples are mistakenly classified as synthetic\n",
    "    \n",
    "    The function uses normalization factors derived from real examples to adjust\n",
    "    the raw scores before computing probabilities.\n",
    "    \n",
    "    Args:\n",
    "        synth_examples (list): The synthetic examples with LLM judgments.\n",
    "        real_examples (list): The real examples with LLM judgments.\n",
    "        \n",
    "    Returns:\n",
    "        float: A balanced LLM score representing the overall error rate across\n",
    "               both synthetic and real examples, weighted by their respective counts.\n",
    "    \"\"\"    \n",
    "    EPSILON = 1e-6\n",
    "    SCORE_MAP = {\n",
    "        'very likely': 4,\n",
    "        'likely': 3,\n",
    "        'unsure': 2,\n",
    "        'unlikely': 1,\n",
    "        'very unlikely': 0,\n",
    "    }\n",
    "    \n",
    "    # compute normalization factors for synth examples\n",
    "    score_real_given_synth_loc_A_judgement_factor_synth = np.mean([SCORE_MAP[item['score_real_given_synth_loc_A_judgement']] for item in real_examples])\n",
    "    score_real_given_synth_loc_B_judgement_factor_synth = np.mean([SCORE_MAP[item['score_real_given_synth_loc_B_judgement']] for item in real_examples])\n",
    "    score_synth_given_synth_loc_A_judgement_factor_synth = np.mean([SCORE_MAP[item['score_synth_given_synth_loc_A_judgement']] for item in real_examples])\n",
    "    score_synth_given_synth_loc_B_judgement_factor_synth = np.mean([SCORE_MAP[item['score_synth_given_synth_loc_B_judgement']] for item in real_examples])\n",
    "    \n",
    "    # compute normalization factors for real examples\n",
    "    score_real_given_synth_loc_A_judgement_factor_real = np.mean([SCORE_MAP[item['score_real_given_synth_loc_A_judgement']] for item in synth_examples])\n",
    "    score_real_given_synth_loc_B_judgement_factor_real = np.mean([SCORE_MAP[item['score_real_given_synth_loc_B_judgement']] for item in synth_examples])\n",
    "    score_synth_given_synth_loc_A_judgement_factor_real = np.mean([SCORE_MAP[item['score_synth_given_synth_loc_A_judgement']] for item in synth_examples])\n",
    "    score_synth_given_synth_loc_B_judgement_factor_real = np.mean([SCORE_MAP[item['score_synth_given_synth_loc_B_judgement']] for item in synth_examples])\n",
    "\n",
    "    score_real_given_synth_loc_A_judgement_factor_real = 1 if score_real_given_synth_loc_A_judgement_factor_real == 0 else score_real_given_synth_loc_A_judgement_factor_real\n",
    "    score_real_given_synth_loc_B_judgement_factor_real = 1 if score_real_given_synth_loc_B_judgement_factor_real == 0 else score_real_given_synth_loc_B_judgement_factor_real\n",
    "    score_synth_given_synth_loc_A_judgement_factor_real = 1 if score_synth_given_synth_loc_A_judgement_factor_real == 0 else score_synth_given_synth_loc_A_judgement_factor_real\n",
    "    score_synth_given_synth_loc_B_judgement_factor_real = 1 if score_synth_given_synth_loc_B_judgement_factor_real == 0 else score_synth_given_synth_loc_B_judgement_factor_real\n",
    "\n",
    "    score_real_given_synth_loc_A_judgement_factor_synth = max(score_real_given_synth_loc_A_judgement_factor_synth, EPSILON)\n",
    "    score_real_given_synth_loc_B_judgement_factor_synth = max(score_real_given_synth_loc_B_judgement_factor_synth, EPSILON)\n",
    "    score_synth_given_synth_loc_A_judgement_factor_synth = max(score_synth_given_synth_loc_A_judgement_factor_synth, EPSILON)\n",
    "    score_synth_given_synth_loc_B_judgement_factor_synth = max(score_synth_given_synth_loc_B_judgement_factor_synth, EPSILON)\n",
    "\n",
    "    score_real_given_synth_loc_A_judgement_factor_synth = max(score_real_given_synth_loc_A_judgement_factor_synth, EPSILON)\n",
    "    score_real_given_synth_loc_B_judgement_factor_synth = max(score_real_given_synth_loc_B_judgement_factor_synth, EPSILON)\n",
    "    score_synth_given_synth_loc_A_judgement_factor_synth = max(score_synth_given_synth_loc_A_judgement_factor_synth, EPSILON)\n",
    "    score_synth_given_synth_loc_B_judgement_factor_synth = max(score_synth_given_synth_loc_B_judgement_factor_synth, EPSILON)\n",
    "    \n",
    "    # compute llm scores\n",
    "    error_scores = []\n",
    "    error_scores_real = []\n",
    "    for item in synth_examples:\n",
    "        # real given loc A\n",
    "        h_real_given_loc_A = SCORE_MAP[item['score_real_given_synth_loc_A_judgement']] / score_real_given_synth_loc_A_judgement_factor_synth\n",
    "        h_synth_given_loc_A = SCORE_MAP[item['score_synth_given_synth_loc_A_judgement']] / score_synth_given_synth_loc_A_judgement_factor_synth\n",
    "        p_real_given_loc_A = h_real_given_loc_A / (h_real_given_loc_A + h_synth_given_loc_A + EPSILON)\n",
    "\n",
    "        # real given loc B\n",
    "        h_real_given_loc_B = SCORE_MAP[item['score_real_given_synth_loc_B_judgement']] / score_real_given_synth_loc_B_judgement_factor_synth\n",
    "        h_synth_given_loc_B = SCORE_MAP[item['score_synth_given_synth_loc_B_judgement']] / score_synth_given_synth_loc_B_judgement_factor_synth\n",
    "        p_real_given_loc_B = h_real_given_loc_B / (h_real_given_loc_B + h_synth_given_loc_B + EPSILON)\n",
    "\n",
    "        error_scores.append((p_real_given_loc_A + p_real_given_loc_B) / 2)\n",
    "    \n",
    "    for item in real_examples:\n",
    "        # real given loc A\n",
    "        h_real_given_loc_A = SCORE_MAP[item['score_real_given_synth_loc_A_judgement']] / score_real_given_synth_loc_A_judgement_factor_real\n",
    "        h_synth_given_loc_A = SCORE_MAP[item['score_synth_given_synth_loc_A_judgement']] / score_synth_given_synth_loc_A_judgement_factor_real\n",
    "        p_synth_given_loc_A = h_synth_given_loc_A / (h_real_given_loc_A + h_synth_given_loc_A + EPSILON)\n",
    "\n",
    "        # real given loc B\n",
    "        h_real_given_loc_B = SCORE_MAP[item['score_real_given_synth_loc_B_judgement']] / score_real_given_synth_loc_B_judgement_factor_real\n",
    "        h_synth_given_loc_B = SCORE_MAP[item['score_synth_given_synth_loc_B_judgement']] / score_synth_given_synth_loc_B_judgement_factor_real\n",
    "        p_synth_given_loc_B = h_synth_given_loc_B / (h_real_given_loc_B + h_synth_given_loc_B + EPSILON)\n",
    "\n",
    "        error_scores_real.append((p_synth_given_loc_A + p_synth_given_loc_B) / 2)\n",
    "\n",
    "    return (np.mean(error_scores) * len(synth_examples) + np.mean(error_scores_real) * len(real_examples)) / (len(synth_examples) + len(real_examples))\n",
    "\n",
    "\n",
    "def compute_lens_score(examples, all_scores=False, real_data=False) -> float:\n",
    "    \"\"\"Computes the LLM score for the given examples. This only uses four scores to normalize.\n",
    "\n",
    "    Args:\n",
    "        examples (_type_): List of examples\n",
    "        all_scores (bool, optional): Whether to return all scores. Defaults to False.\n",
    "        real_data (bool, optional): Whether to return real data scores. Defaults to False.\n",
    "\n",
    "    Returns:\n",
    "        float: LLM score\n",
    "    \"\"\"    \n",
    "    SCORE_MAP = {\n",
    "        'very likely': 4,\n",
    "        'likely': 3,\n",
    "        'unsure': 2,\n",
    "        'unlikely': 1,\n",
    "        'very unlikely': 0,\n",
    "    }\n",
    "    llm_scores = []\n",
    "    score_real_given_synth_loc_A_ls = []\n",
    "    score_synth_given_synth_loc_A_ls = []\n",
    "    score_real_given_synth_loc_B_ls = []\n",
    "    score_synth_given_synth_loc_B_ls = []\n",
    "    for item in examples:\n",
    "        score_real_given_synth_loc_A = SCORE_MAP[item['score_real_given_synth_loc_A_judgement']]\n",
    "        score_synth_given_synth_loc_A = SCORE_MAP[item['score_synth_given_synth_loc_A_judgement']]\n",
    "        score_real_given_synth_loc_B = SCORE_MAP[item['score_real_given_synth_loc_B_judgement']]\n",
    "        score_synth_given_synth_loc_B = SCORE_MAP[item['score_synth_given_synth_loc_B_judgement']]\n",
    "        p_real_given_loc_A = score_real_given_synth_loc_A / max(1e-6, score_real_given_synth_loc_A + score_synth_given_synth_loc_A)\n",
    "        p_real_given_loc_B = score_real_given_synth_loc_B / max(1e-6, score_real_given_synth_loc_B + score_synth_given_synth_loc_B)\n",
    "        p_synth_given_loc_A = score_synth_given_synth_loc_A / max(1e-6, score_real_given_synth_loc_A + score_synth_given_synth_loc_A)\n",
    "        p_synth_given_loc_B = score_synth_given_synth_loc_B / max(1e-6, score_real_given_synth_loc_B + score_synth_given_synth_loc_B)\n",
    "        if real_data:\n",
    "            llm_score = (p_synth_given_loc_A + p_synth_given_loc_B) / 2\n",
    "        else:\n",
    "            llm_score = (p_real_given_loc_A + p_real_given_loc_B) / 2\n",
    "        score_real_given_synth_loc_A_ls.append(score_real_given_synth_loc_A)\n",
    "        score_synth_given_synth_loc_A_ls.append(score_synth_given_synth_loc_A)\n",
    "        score_real_given_synth_loc_B_ls.append(score_real_given_synth_loc_B)\n",
    "        score_synth_given_synth_loc_B_ls.append(score_synth_given_synth_loc_B)\n",
    "        llm_scores.append(llm_score)\n",
    "    if all_scores:\n",
    "        return np.mean(llm_scores), np.mean(score_real_given_synth_loc_A_ls), np.mean(score_synth_given_synth_loc_A_ls), np.mean(score_real_given_synth_loc_B_ls), np.mean(score_synth_given_synth_loc_B_ls)\n",
    "    return np.mean(llm_scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate PAD and MMD"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sentiment Analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training SA model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import pandas as pd\n",
    "import pathlib\n",
    "import json\n",
    "\n",
    "SEEDS = [42, 43, 44, 45, 46]\n",
    "\n",
    "ddata = pathlib.Path('../data/sentiment_analysis')\n",
    "synth = {}\n",
    "\n",
    "seed_real_map = {seed: [x['text'] for x in json.load((ddata/f\"real_data/balanced_real_seed={seed}.json\").open('rt'))] for seed in SEEDS}\n",
    "real = pd.DataFrame(dict(x=[x['text'] for x in json.load((ddata/f\"real_data/balanced_real.json\").open('rt'))], y=[x['label'] for x in json.load((ddata/f\"real_data/balanced_real.json\").open('rt'))]))\n",
    "\n",
    "for fsynth in ddata.glob('synthetic_data/*.json'):\n",
    "    with fsynth.open('rt') as f:\n",
    "        r = json.load(f)\n",
    "    df = pd.DataFrame(dict(x=[x['headline'] for x in r], y=[int(x['sentiment']) for x in r]))\n",
    "    model = fsynth.name.replace('.json', '')\n",
    "    synth[model] = df\n",
    "print(f\"There are {len(synth)} synthetic datasets\")\n",
    "print(f\"Real data has {len(real)} examples\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "from tqdm import auto as tqdm\n",
    "sentence_model = SentenceTransformer(\"intfloat/e5-small-v2\")\n",
    "\n",
    "emb_real = sentence_model.encode(real.x.tolist())\n",
    "emb_seed_real_map = {seed: sentence_model.encode(real) for seed, real in seed_real_map.items()}\n",
    "emb_synth = {k: sentence_model.encode(v.x.tolist()) for k, v in tqdm.tqdm(list(synth.items()), desc='synthetic datasets')}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import xgboost as xgb\n",
    "import numpy as np\n",
    "from sklearn.model_selection import KFold\n",
    "\n",
    "params = {\n",
    "    'objective': 'multi:softmax', # Change to multiclass classification\n",
    "    'eval_metric': 'mlogloss',    # Change metric for multiclass\n",
    "    'num_class': 3                # Specify number of classes (0,1,2)\n",
    "}\n",
    "\n",
    "kf = KFold(n_splits=5)\n",
    "\n",
    "dtest = xgb.DMatrix(emb_real, label=real.y.to_numpy())\n",
    "\n",
    "training_results = []\n",
    "for k, X in tqdm.tqdm(emb_synth.items(), desc='datasets'):\n",
    "    Y = synth[k].y.to_numpy()\n",
    "    kf.get_n_splits(X)\n",
    "    for i, (train_index, val_index) in enumerate(kf.split(X)):\n",
    "        dtrain = xgb.DMatrix(X[train_index], label=Y[train_index])\n",
    "        dval = xgb.DMatrix(X[val_index], label=Y[val_index])\n",
    "        m = xgb.train(\n",
    "            params, dtrain,\n",
    "            num_boost_round=100,\n",
    "            early_stopping_rounds=10,\n",
    "            evals=[(dval, 'validation')],\n",
    "            verbose_eval=False,\n",
    "        )\n",
    "        training_results.append(dict(\n",
    "            fold=i,\n",
    "            dataset=k,\n",
    "            model=m,\n",
    "            train_acc=np.mean(m.predict(dtrain)==Y[train_index]),\n",
    "            val_acc=np.mean(m.predict(dval)==Y[val_index]),\n",
    "            test_acc=np.mean(m.predict(dtest)==real.y.to_numpy()),\n",
    "        ))\n",
    "train_df = (pd.DataFrame(training_results).groupby(['dataset'])[['train_acc', 'val_acc', 'test_acc']].agg(['mean'])).reset_index()\n",
    "train_df.columns = ['_'.join(tup).strip('_') for tup in train_df.columns.to_flat_index()]\n",
    "train_df.sort_values(['test_acc_mean'], ascending=False, inplace=True)\n",
    "train_df.reset_index(drop=True, inplace=True)\n",
    "train_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import random\n",
    "import pickle\n",
    "import tqdm\n",
    "\n",
    "sa_result_df = pd.DataFrame()\n",
    "\n",
    "random.seed(42)\n",
    "\n",
    "NUM_SAMPLES_FOR_SCORING = 200\n",
    "use_sampled_synth = False\n",
    "\n",
    "# load data embedded using qte-qwen2.5-7b-instruct\n",
    "emb_seed_real_map = pickle.load(open(\"../embeddings/sentiment_analysis_real_data_embs.pkl\", \"rb\"))\n",
    "synth_emb_map = pickle.load(open(\"../embeddings/sentiment_analysis_synthetic_data_embs.pkl\", \"rb\"))\n",
    "\n",
    "baseline_results_sa = []\n",
    "for seed in [42, 43, 44, 45, 46]:\n",
    "    for k, X in tqdm.tqdm(synth_emb_map.items(), desc='datasets'):\n",
    "        if use_sampled_synth:\n",
    "            X = random.sample(X, NUM_SAMPLES_FOR_SCORING)\n",
    "        # PAD: synthetic vs real test\n",
    "        pad_test = compute_pad(X, emb_seed_real_map[str(seed)])\n",
    "        # MMD: synthetic vs real test\n",
    "        mmd_test = compute_mmd(X, emb_seed_real_map[str(seed)])\n",
    "        # MDM: synthetic along\n",
    "        mdm_test = compute_mdm(X, n_clusters=3)\n",
    "        baseline_results_sa.append(dict(\n",
    "            dataset=k,\n",
    "            seed=seed,\n",
    "            pad_test=pad_test,\n",
    "            mmd_test=mmd_test,\n",
    "            mdm_test=mdm_test\n",
    "        ))\n",
    "\n",
    "df_sa_merged = pd.DataFrame(baseline_results_sa)\n",
    "df_sa_merged.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute correlation between metrics and test_f1_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "df_sa_merged = pd.read_csv(\"results/sentiment_analysis_qte-qwen2.5-7b-instruct.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>seed</th>\n",
       "      <th>pearson_pad_vs_f1</th>\n",
       "      <th>spearman_pad_vs_f1</th>\n",
       "      <th>pearson_mmd_vs_f1</th>\n",
       "      <th>spearman_mmd_vs_f1</th>\n",
       "      <th>pearson_mdm_vs_f1</th>\n",
       "      <th>spearman_mdm_vs_f1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>42</td>\n",
       "      <td>-0.66</td>\n",
       "      <td>-0.52</td>\n",
       "      <td>-0.68</td>\n",
       "      <td>-0.46</td>\n",
       "      <td>0.85</td>\n",
       "      <td>0.7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>43</td>\n",
       "      <td>-0.65</td>\n",
       "      <td>-0.53</td>\n",
       "      <td>-0.69</td>\n",
       "      <td>-0.43</td>\n",
       "      <td>0.85</td>\n",
       "      <td>0.7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>44</td>\n",
       "      <td>-0.65</td>\n",
       "      <td>-0.51</td>\n",
       "      <td>-0.65</td>\n",
       "      <td>-0.43</td>\n",
       "      <td>0.85</td>\n",
       "      <td>0.7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>45</td>\n",
       "      <td>-0.67</td>\n",
       "      <td>-0.57</td>\n",
       "      <td>-0.68</td>\n",
       "      <td>-0.45</td>\n",
       "      <td>0.85</td>\n",
       "      <td>0.7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>46</td>\n",
       "      <td>-0.62</td>\n",
       "      <td>-0.53</td>\n",
       "      <td>-0.63</td>\n",
       "      <td>-0.47</td>\n",
       "      <td>0.85</td>\n",
       "      <td>0.7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>all_seeds</td>\n",
       "      <td>-0.65</td>\n",
       "      <td>-0.53</td>\n",
       "      <td>-0.67</td>\n",
       "      <td>-0.45</td>\n",
       "      <td>0.85</td>\n",
       "      <td>0.7</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        seed  pearson_pad_vs_f1  spearman_pad_vs_f1  pearson_mmd_vs_f1  \\\n",
       "0         42              -0.66               -0.52              -0.68   \n",
       "1         43              -0.65               -0.53              -0.69   \n",
       "2         44              -0.65               -0.51              -0.65   \n",
       "3         45              -0.67               -0.57              -0.68   \n",
       "4         46              -0.62               -0.53              -0.63   \n",
       "5  all_seeds              -0.65               -0.53              -0.67   \n",
       "\n",
       "   spearman_mmd_vs_f1  pearson_mdm_vs_f1  spearman_mdm_vs_f1  \n",
       "0               -0.46               0.85                 0.7  \n",
       "1               -0.43               0.85                 0.7  \n",
       "2               -0.43               0.85                 0.7  \n",
       "3               -0.45               0.85                 0.7  \n",
       "4               -0.47               0.85                 0.7  \n",
       "5               -0.45               0.85                 0.7  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from scipy.stats import pearsonr, spearmanr\n",
    "\n",
    "# We'll collect results in a list of dicts\n",
    "corr_rows = []\n",
    "\n",
    "# For each seed\n",
    "for seed in df_sa_merged['seed'].unique():\n",
    "    seed_df = df_sa_merged[df_sa_merged['seed'] == seed]\n",
    "    # If there is only one dataset, correlation is not defined\n",
    "    if len(seed_df['dataset'].unique()) < 2:\n",
    "        continue\n",
    "    # For each dataset, get pad_test, mmd_test, mdm_test, test_f1_mean\n",
    "    # We want to correlate pad_test vs test_f1_mean, mmd_test vs test_f1_mean, mdm_test vs test_f1_mean across datasets\n",
    "    # So, group by dataset and take the mean (should be one row per dataset per seed)\n",
    "    grouped = seed_df.groupby('dataset').agg({\n",
    "        'pad_test': 'mean',\n",
    "        'mmd_test': 'mean',\n",
    "        'mdm_test': 'mean',\n",
    "        'test_f1_mean': 'mean'\n",
    "    }).reset_index()\n",
    "    if len(grouped) < 2:\n",
    "        continue\n",
    "    try:\n",
    "        p_pad, _ = pearsonr(grouped['pad_test'], grouped['test_f1_mean'])\n",
    "        s_pad, _ = spearmanr(grouped['pad_test'], grouped['test_f1_mean'])\n",
    "    except Exception:\n",
    "        p_pad, s_pad = None, None\n",
    "    try:\n",
    "        p_mmd, _ = pearsonr(grouped['mmd_test'], grouped['test_f1_mean'])\n",
    "        s_mmd, _ = spearmanr(grouped['mmd_test'], grouped['test_f1_mean'])\n",
    "    except Exception:\n",
    "        p_mmd, s_mmd = None, None\n",
    "    try:\n",
    "        p_mdm, _ = pearsonr(grouped['mdm_test'], grouped['test_f1_mean'])\n",
    "        s_mdm, _ = spearmanr(grouped['mdm_test'], grouped['test_f1_mean'])\n",
    "    except Exception:\n",
    "        p_mdm, s_mdm = None, None\n",
    "    corr_rows.append({\n",
    "        'seed': seed,\n",
    "        'pearson_pad_vs_f1': p_pad,\n",
    "        'spearman_pad_vs_f1': s_pad,\n",
    "        'pearson_mmd_vs_f1': p_mmd,\n",
    "        'spearman_mmd_vs_f1': s_mmd,\n",
    "        'pearson_mdm_vs_f1': p_mdm,\n",
    "        'spearman_mdm_vs_f1': s_mdm\n",
    "    })\n",
    "\n",
    "corr_df = pd.DataFrame(corr_rows)\n",
    "# Aggregate all seeds for each metric and add a row named \"all_seeds\"\n",
    "if not corr_df.empty:\n",
    "    agg_row = {\n",
    "        'seed': 'all_seeds',\n",
    "        'pearson_pad_vs_f1': corr_df['pearson_pad_vs_f1'].mean(),\n",
    "        'spearman_pad_vs_f1': corr_df['spearman_pad_vs_f1'].mean(),\n",
    "        'pearson_mmd_vs_f1': corr_df['pearson_mmd_vs_f1'].mean(),\n",
    "        'spearman_mmd_vs_f1': corr_df['spearman_mmd_vs_f1'].mean(),\n",
    "        'pearson_mdm_vs_f1': corr_df['pearson_mdm_vs_f1'].mean(),\n",
    "        'spearman_mdm_vs_f1': corr_df['spearman_mdm_vs_f1'].mean()\n",
    "    }\n",
    "    corr_df_agg = pd.concat([corr_df, pd.DataFrame([agg_row])], ignore_index=True)\n",
    "else:\n",
    "    corr_df_agg = corr_df\n",
    "# corr_df_agg.to_csv('early_stop_on_synth_results/sentiment_analysis_correlation_pad_mmd_mdm_all_metrics_computed_using_qte-qwen2.5-7b-instruct.csv', index=False)\n",
    "corr_df_agg.round(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute Top-3 Ranked Average Accuracies\n",
    "We want to compute the average accuracies using top-3 datasets scored by each metric compared to the average score "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset</th>\n",
       "      <th>test_f1_mean</th>\n",
       "      <th>seed</th>\n",
       "      <th>32b_lens_unormalized</th>\n",
       "      <th>32b_lens</th>\n",
       "      <th>generation_model_x</th>\n",
       "      <th>prompt_type_x</th>\n",
       "      <th>7b_lens_unormalized</th>\n",
       "      <th>7b_lens</th>\n",
       "      <th>generation_model_y</th>\n",
       "      <th>prompt_type_y</th>\n",
       "      <th>pad_test</th>\n",
       "      <th>mmd_test</th>\n",
       "      <th>mdm_test</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>llama3.1-8b_zero-shot_bg_train-time-info_v1</td>\n",
       "      <td>0.587421</td>\n",
       "      <td>42</td>\n",
       "      <td>0.340759</td>\n",
       "      <td>0.263070</td>\n",
       "      <td>llama3.1-8b</td>\n",
       "      <td>zero-shot</td>\n",
       "      <td>0.486181</td>\n",
       "      <td>0.430975</td>\n",
       "      <td>llama3.1-8b</td>\n",
       "      <td>zero-shot</td>\n",
       "      <td>-1.312682</td>\n",
       "      <td>-0.000049</td>\n",
       "      <td>0.754640</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>llama3.1-8b_zero-shot_bg_train-time-info_v1</td>\n",
       "      <td>0.587421</td>\n",
       "      <td>43</td>\n",
       "      <td>0.197859</td>\n",
       "      <td>0.144868</td>\n",
       "      <td>llama3.1-8b</td>\n",
       "      <td>zero-shot</td>\n",
       "      <td>0.490718</td>\n",
       "      <td>0.417348</td>\n",
       "      <td>llama3.1-8b</td>\n",
       "      <td>zero-shot</td>\n",
       "      <td>-1.317966</td>\n",
       "      <td>-0.000050</td>\n",
       "      <td>0.754640</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>llama3.1-8b_zero-shot_bg_train-time-info_v1</td>\n",
       "      <td>0.587421</td>\n",
       "      <td>44</td>\n",
       "      <td>0.234683</td>\n",
       "      <td>0.192921</td>\n",
       "      <td>llama3.1-8b</td>\n",
       "      <td>zero-shot</td>\n",
       "      <td>0.513360</td>\n",
       "      <td>0.405313</td>\n",
       "      <td>llama3.1-8b</td>\n",
       "      <td>zero-shot</td>\n",
       "      <td>-1.304220</td>\n",
       "      <td>-0.000041</td>\n",
       "      <td>0.754640</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>llama3.1-8b_zero-shot_bg_train-time-info_v1</td>\n",
       "      <td>0.587421</td>\n",
       "      <td>45</td>\n",
       "      <td>0.246809</td>\n",
       "      <td>0.205776</td>\n",
       "      <td>llama3.1-8b</td>\n",
       "      <td>zero-shot</td>\n",
       "      <td>0.496903</td>\n",
       "      <td>0.449129</td>\n",
       "      <td>llama3.1-8b</td>\n",
       "      <td>zero-shot</td>\n",
       "      <td>-1.294996</td>\n",
       "      <td>-0.000044</td>\n",
       "      <td>0.754640</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>llama3.1-8b_zero-shot_bg_train-time-info_v1</td>\n",
       "      <td>0.587421</td>\n",
       "      <td>46</td>\n",
       "      <td>0.321729</td>\n",
       "      <td>0.240651</td>\n",
       "      <td>llama3.1-8b</td>\n",
       "      <td>zero-shot</td>\n",
       "      <td>0.474951</td>\n",
       "      <td>0.423912</td>\n",
       "      <td>llama3.1-8b</td>\n",
       "      <td>zero-shot</td>\n",
       "      <td>-1.271204</td>\n",
       "      <td>-0.000037</td>\n",
       "      <td>0.754640</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>155</th>\n",
       "      <td>llama3.3-70b_zero-shot_bg_v1</td>\n",
       "      <td>0.328230</td>\n",
       "      <td>42</td>\n",
       "      <td>0.265246</td>\n",
       "      <td>0.114229</td>\n",
       "      <td>llama3.3-70b</td>\n",
       "      <td>zero-shot</td>\n",
       "      <td>0.468030</td>\n",
       "      <td>0.316995</td>\n",
       "      <td>llama3.3-70b</td>\n",
       "      <td>zero-shot</td>\n",
       "      <td>-1.784829</td>\n",
       "      <td>-0.000211</td>\n",
       "      <td>0.427963</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>156</th>\n",
       "      <td>llama3.3-70b_zero-shot_bg_v1</td>\n",
       "      <td>0.328230</td>\n",
       "      <td>43</td>\n",
       "      <td>0.154655</td>\n",
       "      <td>0.099511</td>\n",
       "      <td>llama3.3-70b</td>\n",
       "      <td>zero-shot</td>\n",
       "      <td>0.406745</td>\n",
       "      <td>0.272592</td>\n",
       "      <td>llama3.3-70b</td>\n",
       "      <td>zero-shot</td>\n",
       "      <td>-1.780730</td>\n",
       "      <td>-0.000203</td>\n",
       "      <td>0.427963</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>157</th>\n",
       "      <td>llama3.3-70b_zero-shot_bg_v1</td>\n",
       "      <td>0.328230</td>\n",
       "      <td>44</td>\n",
       "      <td>0.174708</td>\n",
       "      <td>0.114599</td>\n",
       "      <td>llama3.3-70b</td>\n",
       "      <td>zero-shot</td>\n",
       "      <td>0.465179</td>\n",
       "      <td>0.399349</td>\n",
       "      <td>llama3.3-70b</td>\n",
       "      <td>zero-shot</td>\n",
       "      <td>-1.778075</td>\n",
       "      <td>-0.000217</td>\n",
       "      <td>0.427963</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>158</th>\n",
       "      <td>llama3.3-70b_zero-shot_bg_v1</td>\n",
       "      <td>0.328230</td>\n",
       "      <td>45</td>\n",
       "      <td>0.159719</td>\n",
       "      <td>0.131725</td>\n",
       "      <td>llama3.3-70b</td>\n",
       "      <td>zero-shot</td>\n",
       "      <td>0.506921</td>\n",
       "      <td>0.332967</td>\n",
       "      <td>llama3.3-70b</td>\n",
       "      <td>zero-shot</td>\n",
       "      <td>-1.776826</td>\n",
       "      <td>-0.000206</td>\n",
       "      <td>0.427963</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>159</th>\n",
       "      <td>llama3.3-70b_zero-shot_bg_v1</td>\n",
       "      <td>0.328230</td>\n",
       "      <td>46</td>\n",
       "      <td>0.300775</td>\n",
       "      <td>0.112509</td>\n",
       "      <td>llama3.3-70b</td>\n",
       "      <td>zero-shot</td>\n",
       "      <td>0.497470</td>\n",
       "      <td>0.350573</td>\n",
       "      <td>llama3.3-70b</td>\n",
       "      <td>zero-shot</td>\n",
       "      <td>-1.803057</td>\n",
       "      <td>-0.000220</td>\n",
       "      <td>0.427963</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>160 rows × 14 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                         dataset  test_f1_mean  seed  \\\n",
       "0    llama3.1-8b_zero-shot_bg_train-time-info_v1      0.587421    42   \n",
       "1    llama3.1-8b_zero-shot_bg_train-time-info_v1      0.587421    43   \n",
       "2    llama3.1-8b_zero-shot_bg_train-time-info_v1      0.587421    44   \n",
       "3    llama3.1-8b_zero-shot_bg_train-time-info_v1      0.587421    45   \n",
       "4    llama3.1-8b_zero-shot_bg_train-time-info_v1      0.587421    46   \n",
       "..                                           ...           ...   ...   \n",
       "155                 llama3.3-70b_zero-shot_bg_v1      0.328230    42   \n",
       "156                 llama3.3-70b_zero-shot_bg_v1      0.328230    43   \n",
       "157                 llama3.3-70b_zero-shot_bg_v1      0.328230    44   \n",
       "158                 llama3.3-70b_zero-shot_bg_v1      0.328230    45   \n",
       "159                 llama3.3-70b_zero-shot_bg_v1      0.328230    46   \n",
       "\n",
       "     32b_lens_unormalized  32b_lens generation_model_x prompt_type_x  \\\n",
       "0                0.340759  0.263070        llama3.1-8b     zero-shot   \n",
       "1                0.197859  0.144868        llama3.1-8b     zero-shot   \n",
       "2                0.234683  0.192921        llama3.1-8b     zero-shot   \n",
       "3                0.246809  0.205776        llama3.1-8b     zero-shot   \n",
       "4                0.321729  0.240651        llama3.1-8b     zero-shot   \n",
       "..                    ...       ...                ...           ...   \n",
       "155              0.265246  0.114229       llama3.3-70b     zero-shot   \n",
       "156              0.154655  0.099511       llama3.3-70b     zero-shot   \n",
       "157              0.174708  0.114599       llama3.3-70b     zero-shot   \n",
       "158              0.159719  0.131725       llama3.3-70b     zero-shot   \n",
       "159              0.300775  0.112509       llama3.3-70b     zero-shot   \n",
       "\n",
       "     7b_lens_unormalized   7b_lens generation_model_y prompt_type_y  pad_test  \\\n",
       "0               0.486181  0.430975        llama3.1-8b     zero-shot -1.312682   \n",
       "1               0.490718  0.417348        llama3.1-8b     zero-shot -1.317966   \n",
       "2               0.513360  0.405313        llama3.1-8b     zero-shot -1.304220   \n",
       "3               0.496903  0.449129        llama3.1-8b     zero-shot -1.294996   \n",
       "4               0.474951  0.423912        llama3.1-8b     zero-shot -1.271204   \n",
       "..                   ...       ...                ...           ...       ...   \n",
       "155             0.468030  0.316995       llama3.3-70b     zero-shot -1.784829   \n",
       "156             0.406745  0.272592       llama3.3-70b     zero-shot -1.780730   \n",
       "157             0.465179  0.399349       llama3.3-70b     zero-shot -1.778075   \n",
       "158             0.506921  0.332967       llama3.3-70b     zero-shot -1.776826   \n",
       "159             0.497470  0.350573       llama3.3-70b     zero-shot -1.803057   \n",
       "\n",
       "     mmd_test  mdm_test  \n",
       "0   -0.000049  0.754640  \n",
       "1   -0.000050  0.754640  \n",
       "2   -0.000041  0.754640  \n",
       "3   -0.000044  0.754640  \n",
       "4   -0.000037  0.754640  \n",
       "..        ...       ...  \n",
       "155 -0.000211  0.427963  \n",
       "156 -0.000203  0.427963  \n",
       "157 -0.000217  0.427963  \n",
       "158 -0.000206  0.427963  \n",
       "159 -0.000220  0.427963  \n",
       "\n",
       "[160 rows x 14 columns]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result_df = pd.read_csv(\"results/sentiment_analysis_qte-qwen2.5-7b-instruct.csv\")\n",
    "lens_result_7b_df = pd.read_csv(\"results/sentiment_analysis_lens_qwen2.5-7b-instruct.csv\")\n",
    "lens_result_32b_df = pd.read_csv(\"results/sentiment_analysis_lens_qwen2.5-32b-instruct.csv\")\n",
    "# we negate pad and mmd as their correlation with accuracy is negative\n",
    "result_df['pad_test'] = result_df['pad_test'].apply(lambda x: -x)\n",
    "result_df['mmd_test'] = result_df['mmd_test'].apply(lambda x: -x)\n",
    "result_merged_df = pd.merge(lens_result_7b_df, result_df, how='left', on=['seed', 'dataset', 'test_f1_mean'])\n",
    "result_merged_df = pd.merge(lens_result_32b_df, result_merged_df, how='left', on=['seed', 'dataset', 'test_f1_mean'])\n",
    "result_merged_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>metric</th>\n",
       "      <th>seed</th>\n",
       "      <th>average_test_f1_mean</th>\n",
       "      <th>top3_average_test_f1_mean</th>\n",
       "      <th>delta</th>\n",
       "      <th>percentage_improvement</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>32b_lens</td>\n",
       "      <td>44.0</td>\n",
       "      <td>49.6</td>\n",
       "      <td>52.0</td>\n",
       "      <td>2.4</td>\n",
       "      <td>4.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>32b_lens_unormalized</td>\n",
       "      <td>44.0</td>\n",
       "      <td>49.6</td>\n",
       "      <td>51.0</td>\n",
       "      <td>1.4</td>\n",
       "      <td>2.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>7b_lens</td>\n",
       "      <td>44.0</td>\n",
       "      <td>49.6</td>\n",
       "      <td>50.5</td>\n",
       "      <td>0.8</td>\n",
       "      <td>1.7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>7b_lens_unormalized</td>\n",
       "      <td>44.0</td>\n",
       "      <td>49.6</td>\n",
       "      <td>51.2</td>\n",
       "      <td>1.6</td>\n",
       "      <td>3.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>mdm_test</td>\n",
       "      <td>44.0</td>\n",
       "      <td>49.6</td>\n",
       "      <td>56.7</td>\n",
       "      <td>7.1</td>\n",
       "      <td>14.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>mmd_test</td>\n",
       "      <td>44.0</td>\n",
       "      <td>49.6</td>\n",
       "      <td>54.7</td>\n",
       "      <td>5.1</td>\n",
       "      <td>10.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>pad_test</td>\n",
       "      <td>44.0</td>\n",
       "      <td>49.6</td>\n",
       "      <td>55.3</td>\n",
       "      <td>5.7</td>\n",
       "      <td>11.5</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 metric  seed  average_test_f1_mean  \\\n",
       "0              32b_lens  44.0                  49.6   \n",
       "1  32b_lens_unormalized  44.0                  49.6   \n",
       "2               7b_lens  44.0                  49.6   \n",
       "3   7b_lens_unormalized  44.0                  49.6   \n",
       "4              mdm_test  44.0                  49.6   \n",
       "5              mmd_test  44.0                  49.6   \n",
       "6              pad_test  44.0                  49.6   \n",
       "\n",
       "   top3_average_test_f1_mean  delta  percentage_improvement  \n",
       "0                       52.0    2.4                     4.8  \n",
       "1                       51.0    1.4                     2.8  \n",
       "2                       50.5    0.8                     1.7  \n",
       "3                       51.2    1.6                     3.2  \n",
       "4                       56.7    7.1                    14.3  \n",
       "5                       54.7    5.1                    10.3  \n",
       "6                       55.3    5.7                    11.5  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "top3_averages = []\n",
    "for seed in [42, 43, 44, 45, 46]:\n",
    "    seed_df = result_merged_df[result_merged_df['seed'] == seed]\n",
    "    average_test_f1_mean = seed_df['test_f1_mean'].mean()\n",
    "    for metric in ['pad_test', 'mmd_test', 'mdm_test', '7b_lens_unormalized', '7b_lens', '32b_lens_unormalized', '32b_lens']:\n",
    "        top3_average_test_f1_mean = seed_df.sort_values(metric).tail(3)['test_f1_mean'].mean()\n",
    "        delta_percentage = top3_average_test_f1_mean - average_test_f1_mean\n",
    "        percentage_improvement = (delta_percentage / average_test_f1_mean)\n",
    "        top3_averages.append({\n",
    "            \"seed\": seed,\n",
    "            \"metric\": metric,\n",
    "            \"average_test_f1_mean\": average_test_f1_mean,\n",
    "            \"top3_average_test_f1_mean\": top3_average_test_f1_mean,\n",
    "            \"delta\": delta_percentage,\n",
    "            \"percentage_improvement\": percentage_improvement\n",
    "        })\n",
    "top3_average_test_f1_mean_df = pd.DataFrame(top3_averages)\n",
    "top3_average_test_f1_mean_df_agg = top3_average_test_f1_mean_df.groupby(\"metric\").agg(\"mean\").reset_index()\n",
    "result_cols = ['average_test_f1_mean', 'top3_average_test_f1_mean', 'delta', 'percentage_improvement']\n",
    "top3_average_test_f1_mean_df_agg[result_cols] = top3_average_test_f1_mean_df_agg[result_cols].map(lambda x: x*100)\n",
    "top3_average_test_f1_mean_df_agg.round(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Text2SQL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Synthetic data\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Real data\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "import json\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import re\n",
    "\n",
    "seeds = [42, 43, 44, 45, 46]\n",
    "\n",
    "def preprocess_text2sql(x):\n",
    "    return x['question']\n",
    "\n",
    "print(\"Synthetic data\")\n",
    "\n",
    "syn_db_dataset_map = {} # db_id -> dataset_name -> dataset\n",
    "data_root_paths = Path('../data/text2sql/data').glob('*_*')\n",
    "for db_path in data_root_paths:\n",
    "    # for each db_id\n",
    "    db_id = db_path.stem\n",
    "    syn_db_dataset_map[db_id] = {}\n",
    "    for dataset_path in db_path.glob('*v1.json'):\n",
    "        dataset_name = dataset_path.stem\n",
    "        # print(f\"Processing {dataset_name} for {db_id}\")\n",
    "        if dataset_name not in syn_db_dataset_map[db_id]:\n",
    "            syn_db_dataset_map[db_id][dataset_name] = []\n",
    "        with dataset_path.open('rt') as f:\n",
    "            data = json.load(f)\n",
    "            filtered_data = map(lambda x: preprocess_text2sql(x), data)\n",
    "            syn_db_dataset_map[db_id][dataset_name].extend(list(filtered_data))\n",
    "            # print(f\"Loaded {len(syn_db_dataset_map[db_id][dataset_name])} examples for {dataset_name} in {db_id}\")\n",
    "\n",
    "print(\"-\"*100)\n",
    "print(\"Real data\")\n",
    "\n",
    "real_dataset_map = {} # seed -> db_id -> dataset\n",
    "real_data_paths = list(Path('../data/text2sql/data/real').glob('*.json'))\n",
    "for seed in seeds:\n",
    "    # print(f\"loading seed {seed}\")\n",
    "    real_dataset_map[seed] = {}\n",
    "    for path in real_data_paths:\n",
    "        sd = re.search(r'seed=(\\d+)', path.stem)\n",
    "        if sd is None:\n",
    "            continue\n",
    "        sd = int(sd.group(1))\n",
    "        if sd != seed:\n",
    "            continue\n",
    "        db_id = path.stem.replace('dev_', '').split('_seed')[0]\n",
    "        with path.open('rt') as f:\n",
    "            data = json.load(f)\n",
    "        real_dataset_map[seed][db_id] = list(map(lambda x: preprocess_text2sql(x), data))\n",
    "        # print(f\"Loaded {len(real_dataset_map[seed][db_id])} examples for {db_id}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute baselines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "# import qte-qwen2.5-7b-instruct\n",
    "emb_synth = pickle.load(open(\"../embeddings/text2sql_synthetic_data_embs.pkl\", \"rb\"))\n",
    "\n",
    "emb_seed_db_real = pickle.load(open(\"../embeddings/text2sql_real_data_embs.pkl\", \"rb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import tqdm\n",
    "\n",
    "# np.random.seed(42)\n",
    "\n",
    "pad_mmd_rows = []\n",
    "for seed in [42, 43, 44, 45, 46]:\n",
    "    emb_db_real = emb_seed_db_real[str(seed)]\n",
    "    for db_id in emb_db_real:\n",
    "        for dataset_name, X in tqdm.tqdm(emb_synth[db_id].items(), desc=f'datasets'):\n",
    "            # PAD: synthetic vs real test\n",
    "            pad_test = compute_pad(X, emb_db_real[db_id])\n",
    "            # MMD: synthetic vs real test\n",
    "            mmd_test = compute_mmd(X, emb_db_real[db_id])\n",
    "            # MDM: synthetic only\n",
    "            mdm_test = compute_mdm(X, n_clusters=5)\n",
    "            pad_mmd_rows.append({\n",
    "                'seed': seed,\n",
    "                'db_id': db_id,\n",
    "                'dataset_name': dataset_name,\n",
    "                'pad_test': pad_test,\n",
    "                'mmd_test': mmd_test,\n",
    "                'mdm_test': mdm_test\n",
    "            })\n",
    "\n",
    "pad_df = pd.DataFrame(pad_mmd_rows)\n",
    "\n",
    "accuracies = {\n",
    "    \"llama3.1-8b_1000_few-shot_bg_test-time-info_v1\": {\"app_store\": 46.03, \"computer_student\": 47.22, \"movie_platform\": 50.9},\n",
    "    \"llama3.1-8b_1000_few-shot_bg_v1\": {\"app_store\": 39.68, \"computer_student\": 50.00, \"movie_platform\": 44.31},\n",
    "    \"llama3.1-8b_1000_zero-shot_bg_test-time-info_v1\": {\"app_store\": 20.63, \"computer_student\": 40.28, \"movie_platform\": 38.32},\n",
    "    \"llama3.1-8b_1000_zero-shot_bg_v1\": {\"app_store\": 22.22, \"computer_student\": 36.11, \"movie_platform\": 37.72},\t\n",
    "    \"qwen2.5-coder-7b_1000_zero-shot_bg_v1\": {\"app_store\": 22.22, \"computer_student\": 50, \"movie_platform\": 16.77},\n",
    "    \"qwen2.5-coder-7b_1000_zero-shot_bg_test-time-info_v1\": {\"app_store\": 25.4, \"computer_student\": 48.61, \"movie_platform\": 20.36},\n",
    "    \"qwen2.5-coder-7b_1000_few-shot_bg_v1\": {\"app_store\": 31.75, \"computer_student\": 45.83, \"movie_platform\": 46.11},\n",
    "    \"qwen2.5-coder-7b_1000_few-shot_bg_test-time-info_v1\": {\"app_store\": 34.92, \"computer_student\": 48.61, \"movie_platform\": 43.71}\n",
    "}\n",
    "\n",
    "for dataset_name in accuracies:\n",
    "    for db_id in accuracies[dataset_name]:\n",
    "        pad_df.loc[(pad_df['dataset_name'] == dataset_name) & (pad_df['db_id'] == db_id), 'test_acc_mean'] = accuracies[dataset_name][db_id]\n",
    "merged_df = pad_df\n",
    "merged_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute correlation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_df = pd.read_csv(\"results/text2sql_qte-qwen2-7b-instruct_all_seeds.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_65657/3260284606.py:99: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n",
      "  display_df = corr_df_agg.applymap(lambda x: -x if type(x) == float else x)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>db_id</th>\n",
       "      <th>pearson_pad_vs_acc</th>\n",
       "      <th>pearson_pad_vs_acc_std</th>\n",
       "      <th>spearman_pad_vs_acc</th>\n",
       "      <th>spearman_pad_vs_acc_std</th>\n",
       "      <th>pearson_mmd_vs_acc</th>\n",
       "      <th>pearson_mmd_vs_acc_std</th>\n",
       "      <th>spearman_mmd_vs_acc</th>\n",
       "      <th>spearman_mmd_vs_acc_std</th>\n",
       "      <th>pearson_mdm_vs_acc</th>\n",
       "      <th>pearson_mdm_vs_acc_std</th>\n",
       "      <th>spearman_mdm_vs_acc</th>\n",
       "      <th>spearman_mdm_vs_acc_std</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>computer_student</td>\n",
       "      <td>0.686985</td>\n",
       "      <td>-0.035243</td>\n",
       "      <td>0.462684</td>\n",
       "      <td>-0.079286</td>\n",
       "      <td>0.845990</td>\n",
       "      <td>-0.026102</td>\n",
       "      <td>0.334964</td>\n",
       "      <td>-0.083217</td>\n",
       "      <td>-0.623210</td>\n",
       "      <td>-0.062939</td>\n",
       "      <td>-0.385570</td>\n",
       "      <td>-0.000000e+00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>app_store</td>\n",
       "      <td>0.419169</td>\n",
       "      <td>-0.357682</td>\n",
       "      <td>0.426355</td>\n",
       "      <td>-0.480814</td>\n",
       "      <td>0.786611</td>\n",
       "      <td>-0.049231</td>\n",
       "      <td>0.527747</td>\n",
       "      <td>-0.145372</td>\n",
       "      <td>-0.561257</td>\n",
       "      <td>-0.002196</td>\n",
       "      <td>-0.443122</td>\n",
       "      <td>-6.206335e-17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>movie_platform</td>\n",
       "      <td>0.635484</td>\n",
       "      <td>-0.068921</td>\n",
       "      <td>0.495238</td>\n",
       "      <td>-0.131923</td>\n",
       "      <td>0.459179</td>\n",
       "      <td>-0.052545</td>\n",
       "      <td>0.381524</td>\n",
       "      <td>-0.140341</td>\n",
       "      <td>-0.443794</td>\n",
       "      <td>-0.050669</td>\n",
       "      <td>-0.685714</td>\n",
       "      <td>-9.583148e-02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>all_db</td>\n",
       "      <td>0.580546</td>\n",
       "      <td>-0.153949</td>\n",
       "      <td>0.461426</td>\n",
       "      <td>-0.230674</td>\n",
       "      <td>0.697260</td>\n",
       "      <td>-0.042626</td>\n",
       "      <td>0.414745</td>\n",
       "      <td>-0.122977</td>\n",
       "      <td>-0.542754</td>\n",
       "      <td>-0.038601</td>\n",
       "      <td>-0.504802</td>\n",
       "      <td>-3.194383e-02</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              db_id  pearson_pad_vs_acc  pearson_pad_vs_acc_std  \\\n",
       "0  computer_student            0.686985               -0.035243   \n",
       "1         app_store            0.419169               -0.357682   \n",
       "2    movie_platform            0.635484               -0.068921   \n",
       "3            all_db            0.580546               -0.153949   \n",
       "\n",
       "   spearman_pad_vs_acc  spearman_pad_vs_acc_std  pearson_mmd_vs_acc  \\\n",
       "0             0.462684                -0.079286            0.845990   \n",
       "1             0.426355                -0.480814            0.786611   \n",
       "2             0.495238                -0.131923            0.459179   \n",
       "3             0.461426                -0.230674            0.697260   \n",
       "\n",
       "   pearson_mmd_vs_acc_std  spearman_mmd_vs_acc  spearman_mmd_vs_acc_std  \\\n",
       "0               -0.026102             0.334964                -0.083217   \n",
       "1               -0.049231             0.527747                -0.145372   \n",
       "2               -0.052545             0.381524                -0.140341   \n",
       "3               -0.042626             0.414745                -0.122977   \n",
       "\n",
       "   pearson_mdm_vs_acc  pearson_mdm_vs_acc_std  spearman_mdm_vs_acc  \\\n",
       "0           -0.623210               -0.062939            -0.385570   \n",
       "1           -0.561257               -0.002196            -0.443122   \n",
       "2           -0.443794               -0.050669            -0.685714   \n",
       "3           -0.542754               -0.038601            -0.504802   \n",
       "\n",
       "   spearman_mdm_vs_acc_std  \n",
       "0            -0.000000e+00  \n",
       "1            -6.206335e-17  \n",
       "2            -9.583148e-02  \n",
       "3            -3.194383e-02  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from scipy.stats import pearsonr, spearmanr\n",
    "import numpy as np\n",
    "\n",
    "# We'll collect results in a list of dicts\n",
    "corr_rows = []\n",
    "\n",
    "# For each db_id\n",
    "for db_id in merged_df['db_id'].unique():\n",
    "    db_df = merged_df[merged_df['db_id'] == db_id]\n",
    "    # For each seed, compute correlation across datasets\n",
    "    pearson_pad = []\n",
    "    spearman_pad = []\n",
    "    pearson_mmd = []\n",
    "    spearman_mmd = []\n",
    "    pearson_mdm = []\n",
    "    spearman_mdm = []\n",
    "    # Now, actually compute per-seed correlations\n",
    "    for seed in db_df['seed'].unique():\n",
    "        seed_df = db_df[db_df['seed'] == seed]\n",
    "        # For this db_id and seed, get all datasets\n",
    "        # If there is only one dataset, correlation is not defined\n",
    "        if len(seed_df['dataset_name'].unique()) < 2:\n",
    "            continue\n",
    "        # For each dataset, get pad_test, mmd_test, mdm_test, test_acc_mean\n",
    "        # We want to correlate pad_test vs test_acc_mean, mmd_test vs test_acc_mean, mdm_test vs test_acc_mean across datasets\n",
    "        # So, group by dataset_name and take the mean (should be one row per dataset_name per seed/db_id)\n",
    "        grouped = seed_df.groupby('dataset_name').agg({\n",
    "            'pad_test': 'mean',\n",
    "            'mmd_test': 'mean',\n",
    "            'mdm_test': 'mean',\n",
    "            'test_acc_mean': 'mean'\n",
    "        }).reset_index()\n",
    "        if len(grouped) < 2:\n",
    "            continue\n",
    "        try:\n",
    "            p_pad, _ = pearsonr(grouped['pad_test'], grouped['test_acc_mean'])\n",
    "            s_pad, _ = spearmanr(grouped['pad_test'], grouped['test_acc_mean'])\n",
    "            pearson_pad.append(p_pad)\n",
    "            spearman_pad.append(s_pad)\n",
    "        except Exception:\n",
    "            print(f\"PAD exception for db_id={db_id}, seed={seed}\")\n",
    "        try:\n",
    "            p_mmd, _ = pearsonr(grouped['mmd_test'], grouped['test_acc_mean'])\n",
    "            s_mmd, _ = spearmanr(grouped['mmd_test'], grouped['test_acc_mean'])\n",
    "            pearson_mmd.append(p_mmd)\n",
    "            spearman_mmd.append(s_mmd)\n",
    "        except Exception:\n",
    "            print(f\"MMD exception for db_id={db_id}, seed={seed}\")\n",
    "        try:\n",
    "            p_mdm, _ = pearsonr(grouped['mdm_test'], grouped['test_acc_mean'])\n",
    "            s_mdm, _ = spearmanr(grouped['mdm_test'], grouped['test_acc_mean'])\n",
    "            pearson_mdm.append(p_mdm)\n",
    "            spearman_mdm.append(s_mdm)\n",
    "        except Exception:\n",
    "            print(f\"MDM exception for db_id={db_id}, seed={seed}\")\n",
    "    def safe_mean(lst):\n",
    "        return sum(lst)/len(lst) if len(lst) > 0 else None\n",
    "    def safe_std(lst):\n",
    "        return float(np.std(lst, ddof=1)) if len(lst) > 1 else None\n",
    "    corr_rows.append({\n",
    "        'db_id': db_id,\n",
    "        'pearson_pad_vs_acc': safe_mean(pearson_pad),\n",
    "        'pearson_pad_vs_acc_std': safe_std(pearson_pad),\n",
    "        'spearman_pad_vs_acc': safe_mean(spearman_pad),\n",
    "        'spearman_pad_vs_acc_std': safe_std(spearman_pad),\n",
    "        'pearson_mmd_vs_acc': safe_mean(pearson_mmd),\n",
    "        'pearson_mmd_vs_acc_std': safe_std(pearson_mmd),\n",
    "        'spearman_mmd_vs_acc': safe_mean(spearman_mmd),\n",
    "        'spearman_mmd_vs_acc_std': safe_std(spearman_mmd),\n",
    "        'pearson_mdm_vs_acc': safe_mean(pearson_mdm),\n",
    "        'pearson_mdm_vs_acc_std': safe_std(pearson_mdm),\n",
    "        'spearman_mdm_vs_acc': safe_mean(spearman_mdm),\n",
    "        'spearman_mdm_vs_acc_std': safe_std(spearman_mdm)\n",
    "    })\n",
    "\n",
    "corr_df = pd.DataFrame(corr_rows)\n",
    "# Aggregate all db_ids for each metric and add a row named \"all_db\"\n",
    "if not corr_df.empty:\n",
    "    agg_row = {\n",
    "        'db_id': 'all_db',\n",
    "        'pearson_pad_vs_acc': corr_df['pearson_pad_vs_acc'].mean(),\n",
    "        'pearson_pad_vs_acc_std': corr_df['pearson_pad_vs_acc_std'].mean(),\n",
    "        'spearman_pad_vs_acc': corr_df['spearman_pad_vs_acc'].mean(),\n",
    "        'spearman_pad_vs_acc_std': corr_df['spearman_pad_vs_acc_std'].mean(),\n",
    "        'pearson_mmd_vs_acc': corr_df['pearson_mmd_vs_acc'].mean(),\n",
    "        'pearson_mmd_vs_acc_std': corr_df['pearson_mmd_vs_acc_std'].mean(),\n",
    "        'spearman_mmd_vs_acc': corr_df['spearman_mmd_vs_acc'].mean(),\n",
    "        'spearman_mmd_vs_acc_std': corr_df['spearman_mmd_vs_acc_std'].mean(),\n",
    "        'pearson_mdm_vs_acc': corr_df['pearson_mdm_vs_acc'].mean(),\n",
    "        'pearson_mdm_vs_acc_std': corr_df['pearson_mdm_vs_acc_std'].mean(),\n",
    "        'spearman_mdm_vs_acc': corr_df['spearman_mdm_vs_acc'].mean(),\n",
    "        'spearman_mdm_vs_acc_std': corr_df['spearman_mdm_vs_acc_std'].mean()\n",
    "    }\n",
    "    corr_df_agg = pd.concat([corr_df, pd.DataFrame([agg_row])], ignore_index=True)\n",
    "else:\n",
    "    corr_df_agg = corr_df\n",
    "\n",
    "# Show means and stds for each metric\n",
    "display_df = corr_df_agg.applymap(lambda x: -x if type(x) == float else x)\n",
    "display_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute Top-3 Ranked Average Accuracies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>db_id</th>\n",
       "      <th>dataset_name</th>\n",
       "      <th>seed</th>\n",
       "      <th>32b_lens_unnormalized</th>\n",
       "      <th>32b_lens</th>\n",
       "      <th>test_acc_mean</th>\n",
       "      <th>7b_lens_unnormalized</th>\n",
       "      <th>7b_lens</th>\n",
       "      <th>pad_test</th>\n",
       "      <th>mmd_test</th>\n",
       "      <th>mdm_test</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>app_store</td>\n",
       "      <td>qwen2.5-coder-7b_zero-shot_bg_test-time-info_v1</td>\n",
       "      <td>42</td>\n",
       "      <td>0.476839</td>\n",
       "      <td>0.427674</td>\n",
       "      <td>25.40</td>\n",
       "      <td>0.476839</td>\n",
       "      <td>0.427674</td>\n",
       "      <td>-1.772498</td>\n",
       "      <td>-0.000119</td>\n",
       "      <td>0.600223</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>app_store</td>\n",
       "      <td>llama3.1-8b_zero-shot_bg_v1</td>\n",
       "      <td>42</td>\n",
       "      <td>0.480768</td>\n",
       "      <td>0.473480</td>\n",
       "      <td>22.22</td>\n",
       "      <td>0.480768</td>\n",
       "      <td>0.473480</td>\n",
       "      <td>-1.777808</td>\n",
       "      <td>-0.000111</td>\n",
       "      <td>0.578904</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>app_store</td>\n",
       "      <td>qwen2.5-coder-7b_few-shot_bg_test-time-info_v1</td>\n",
       "      <td>42</td>\n",
       "      <td>0.472056</td>\n",
       "      <td>0.434307</td>\n",
       "      <td>34.92</td>\n",
       "      <td>0.472056</td>\n",
       "      <td>0.434307</td>\n",
       "      <td>-1.779626</td>\n",
       "      <td>-0.000112</td>\n",
       "      <td>0.648064</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>app_store</td>\n",
       "      <td>qwen2.5-coder-7b_few-shot_bg_v1</td>\n",
       "      <td>42</td>\n",
       "      <td>0.471918</td>\n",
       "      <td>0.458109</td>\n",
       "      <td>31.75</td>\n",
       "      <td>0.471918</td>\n",
       "      <td>0.458109</td>\n",
       "      <td>-1.771955</td>\n",
       "      <td>-0.000137</td>\n",
       "      <td>0.591002</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>app_store</td>\n",
       "      <td>qwen2.5-coder-7b_zero-shot_bg_v1</td>\n",
       "      <td>42</td>\n",
       "      <td>0.457548</td>\n",
       "      <td>0.456668</td>\n",
       "      <td>22.22</td>\n",
       "      <td>0.457548</td>\n",
       "      <td>0.456668</td>\n",
       "      <td>-1.767874</td>\n",
       "      <td>-0.000131</td>\n",
       "      <td>0.524783</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>115</th>\n",
       "      <td>computer_student</td>\n",
       "      <td>qwen2.5-coder-7b_few-shot_bg_v1</td>\n",
       "      <td>46</td>\n",
       "      <td>0.479590</td>\n",
       "      <td>0.430297</td>\n",
       "      <td>45.83</td>\n",
       "      <td>0.479590</td>\n",
       "      <td>0.430297</td>\n",
       "      <td>-1.770401</td>\n",
       "      <td>-0.000048</td>\n",
       "      <td>0.645307</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>116</th>\n",
       "      <td>computer_student</td>\n",
       "      <td>qwen2.5-coder-7b_zero-shot_bg_v1</td>\n",
       "      <td>46</td>\n",
       "      <td>0.486363</td>\n",
       "      <td>0.483810</td>\n",
       "      <td>50.00</td>\n",
       "      <td>0.486363</td>\n",
       "      <td>0.483810</td>\n",
       "      <td>-1.785354</td>\n",
       "      <td>-0.000122</td>\n",
       "      <td>0.534940</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>117</th>\n",
       "      <td>computer_student</td>\n",
       "      <td>llama3.1-8b_few-shot_bg_v1</td>\n",
       "      <td>46</td>\n",
       "      <td>0.490646</td>\n",
       "      <td>0.485749</td>\n",
       "      <td>50.00</td>\n",
       "      <td>0.490646</td>\n",
       "      <td>0.485749</td>\n",
       "      <td>-1.773481</td>\n",
       "      <td>-0.000093</td>\n",
       "      <td>0.645649</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>118</th>\n",
       "      <td>computer_student</td>\n",
       "      <td>llama3.1-8b_zero-shot_bg_test-time-info_v1</td>\n",
       "      <td>46</td>\n",
       "      <td>0.471236</td>\n",
       "      <td>0.424213</td>\n",
       "      <td>40.28</td>\n",
       "      <td>0.471236</td>\n",
       "      <td>0.424213</td>\n",
       "      <td>-1.798664</td>\n",
       "      <td>-0.000163</td>\n",
       "      <td>0.623404</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>119</th>\n",
       "      <td>computer_student</td>\n",
       "      <td>llama3.1-8b_few-shot_bg_test-time-info_v1</td>\n",
       "      <td>46</td>\n",
       "      <td>0.482358</td>\n",
       "      <td>0.475870</td>\n",
       "      <td>47.22</td>\n",
       "      <td>0.482358</td>\n",
       "      <td>0.475870</td>\n",
       "      <td>-1.773082</td>\n",
       "      <td>-0.000076</td>\n",
       "      <td>0.651523</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>120 rows × 11 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                db_id                                     dataset_name  seed  \\\n",
       "0           app_store  qwen2.5-coder-7b_zero-shot_bg_test-time-info_v1    42   \n",
       "1           app_store                      llama3.1-8b_zero-shot_bg_v1    42   \n",
       "2           app_store   qwen2.5-coder-7b_few-shot_bg_test-time-info_v1    42   \n",
       "3           app_store                  qwen2.5-coder-7b_few-shot_bg_v1    42   \n",
       "4           app_store                 qwen2.5-coder-7b_zero-shot_bg_v1    42   \n",
       "..                ...                                              ...   ...   \n",
       "115  computer_student                  qwen2.5-coder-7b_few-shot_bg_v1    46   \n",
       "116  computer_student                 qwen2.5-coder-7b_zero-shot_bg_v1    46   \n",
       "117  computer_student                       llama3.1-8b_few-shot_bg_v1    46   \n",
       "118  computer_student       llama3.1-8b_zero-shot_bg_test-time-info_v1    46   \n",
       "119  computer_student        llama3.1-8b_few-shot_bg_test-time-info_v1    46   \n",
       "\n",
       "     32b_lens_unnormalized  32b_lens  test_acc_mean  7b_lens_unnormalized  \\\n",
       "0                 0.476839  0.427674          25.40              0.476839   \n",
       "1                 0.480768  0.473480          22.22              0.480768   \n",
       "2                 0.472056  0.434307          34.92              0.472056   \n",
       "3                 0.471918  0.458109          31.75              0.471918   \n",
       "4                 0.457548  0.456668          22.22              0.457548   \n",
       "..                     ...       ...            ...                   ...   \n",
       "115               0.479590  0.430297          45.83              0.479590   \n",
       "116               0.486363  0.483810          50.00              0.486363   \n",
       "117               0.490646  0.485749          50.00              0.490646   \n",
       "118               0.471236  0.424213          40.28              0.471236   \n",
       "119               0.482358  0.475870          47.22              0.482358   \n",
       "\n",
       "      7b_lens  pad_test  mmd_test  mdm_test  \n",
       "0    0.427674 -1.772498 -0.000119  0.600223  \n",
       "1    0.473480 -1.777808 -0.000111  0.578904  \n",
       "2    0.434307 -1.779626 -0.000112  0.648064  \n",
       "3    0.458109 -1.771955 -0.000137  0.591002  \n",
       "4    0.456668 -1.767874 -0.000131  0.524783  \n",
       "..        ...       ...       ...       ...  \n",
       "115  0.430297 -1.770401 -0.000048  0.645307  \n",
       "116  0.483810 -1.785354 -0.000122  0.534940  \n",
       "117  0.485749 -1.773481 -0.000093  0.645649  \n",
       "118  0.424213 -1.798664 -0.000163  0.623404  \n",
       "119  0.475870 -1.773082 -0.000076  0.651523  \n",
       "\n",
       "[120 rows x 11 columns]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result_df = pd.read_csv(\"results/text2sql_qte-qwen2-7b-instruct_all_seeds.csv\")\n",
    "lens_result_7b_df = pd.read_csv(\"results/text2sql_lens-qwen2.5-7b-instruct_all_seeds.csv\")\n",
    "lens_result_7b_df.rename(columns={'lens_unnormalized': \"7b_lens_unnormalized\", \"lens\": \"7b_lens\"},inplace=True)\n",
    "lens_result_32b_df = pd.read_csv(\"results/text2sql_lens-qwen2.5-32b-instruct_all_seeds.csv\")\n",
    "lens_result_32b_df.rename(columns={'lens_unnormalized': \"32b_lens_unnormalized\", \"lens\": \"32b_lens\"},inplace=True)\n",
    "# we negate pad and mmd as their correlation with accuracy is negative\n",
    "result_df['pad_test'] = result_df['pad_test'].apply(lambda x: -x)\n",
    "result_df['mmd_test'] = result_df['mmd_test'].apply(lambda x: -x)\n",
    "result_merged_df = pd.merge(lens_result_7b_df, result_df, how='left', on=['seed', 'dataset_name', 'test_acc_mean', 'db_id'])\n",
    "result_merged_df = pd.merge(lens_result_32b_df, result_merged_df, how='left', on=['seed', 'dataset_name', 'test_acc_mean', 'db_id'])\n",
    "result_merged_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>metric</th>\n",
       "      <th>db_id</th>\n",
       "      <th>average_test_acc_mean</th>\n",
       "      <th>top3_average_test_acc_mean</th>\n",
       "      <th>delta</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>32b_lens</td>\n",
       "      <td>app_store</td>\n",
       "      <td>30.36</td>\n",
       "      <td>34.71</td>\n",
       "      <td>4.35</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>32b_lens</td>\n",
       "      <td>computer_student</td>\n",
       "      <td>45.83</td>\n",
       "      <td>46.57</td>\n",
       "      <td>0.74</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>32b_lens</td>\n",
       "      <td>movie_platform</td>\n",
       "      <td>37.28</td>\n",
       "      <td>41.00</td>\n",
       "      <td>3.72</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>32b_lens_unnormalized</td>\n",
       "      <td>app_store</td>\n",
       "      <td>30.36</td>\n",
       "      <td>36.29</td>\n",
       "      <td>5.94</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>32b_lens_unnormalized</td>\n",
       "      <td>computer_student</td>\n",
       "      <td>45.83</td>\n",
       "      <td>46.39</td>\n",
       "      <td>0.56</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>32b_lens_unnormalized</td>\n",
       "      <td>movie_platform</td>\n",
       "      <td>37.28</td>\n",
       "      <td>41.36</td>\n",
       "      <td>4.08</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>7b_lens</td>\n",
       "      <td>app_store</td>\n",
       "      <td>30.36</td>\n",
       "      <td>34.71</td>\n",
       "      <td>4.35</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>7b_lens</td>\n",
       "      <td>computer_student</td>\n",
       "      <td>45.83</td>\n",
       "      <td>46.57</td>\n",
       "      <td>0.74</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>7b_lens</td>\n",
       "      <td>movie_platform</td>\n",
       "      <td>37.28</td>\n",
       "      <td>41.00</td>\n",
       "      <td>3.72</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>7b_lens_unnormalized</td>\n",
       "      <td>app_store</td>\n",
       "      <td>30.36</td>\n",
       "      <td>36.29</td>\n",
       "      <td>5.94</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>7b_lens_unnormalized</td>\n",
       "      <td>computer_student</td>\n",
       "      <td>45.83</td>\n",
       "      <td>46.39</td>\n",
       "      <td>0.56</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>7b_lens_unnormalized</td>\n",
       "      <td>movie_platform</td>\n",
       "      <td>37.28</td>\n",
       "      <td>41.36</td>\n",
       "      <td>4.08</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>mdm_test</td>\n",
       "      <td>app_store</td>\n",
       "      <td>30.36</td>\n",
       "      <td>33.86</td>\n",
       "      <td>3.50</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>mdm_test</td>\n",
       "      <td>computer_student</td>\n",
       "      <td>45.83</td>\n",
       "      <td>48.15</td>\n",
       "      <td>2.31</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>mdm_test</td>\n",
       "      <td>movie_platform</td>\n",
       "      <td>37.28</td>\n",
       "      <td>46.91</td>\n",
       "      <td>9.63</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>mmd_test</td>\n",
       "      <td>app_store</td>\n",
       "      <td>30.36</td>\n",
       "      <td>38.41</td>\n",
       "      <td>8.05</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>mmd_test</td>\n",
       "      <td>computer_student</td>\n",
       "      <td>45.83</td>\n",
       "      <td>47.41</td>\n",
       "      <td>1.57</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>mmd_test</td>\n",
       "      <td>movie_platform</td>\n",
       "      <td>37.28</td>\n",
       "      <td>42.95</td>\n",
       "      <td>5.68</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>pad_test</td>\n",
       "      <td>app_store</td>\n",
       "      <td>30.36</td>\n",
       "      <td>33.54</td>\n",
       "      <td>3.19</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>pad_test</td>\n",
       "      <td>computer_student</td>\n",
       "      <td>45.83</td>\n",
       "      <td>48.33</td>\n",
       "      <td>2.50</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>pad_test</td>\n",
       "      <td>movie_platform</td>\n",
       "      <td>37.28</td>\n",
       "      <td>44.15</td>\n",
       "      <td>6.88</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   metric             db_id  average_test_acc_mean  \\\n",
       "0                32b_lens         app_store                  30.36   \n",
       "1                32b_lens  computer_student                  45.83   \n",
       "2                32b_lens    movie_platform                  37.28   \n",
       "3   32b_lens_unnormalized         app_store                  30.36   \n",
       "4   32b_lens_unnormalized  computer_student                  45.83   \n",
       "5   32b_lens_unnormalized    movie_platform                  37.28   \n",
       "6                 7b_lens         app_store                  30.36   \n",
       "7                 7b_lens  computer_student                  45.83   \n",
       "8                 7b_lens    movie_platform                  37.28   \n",
       "9    7b_lens_unnormalized         app_store                  30.36   \n",
       "10   7b_lens_unnormalized  computer_student                  45.83   \n",
       "11   7b_lens_unnormalized    movie_platform                  37.28   \n",
       "12               mdm_test         app_store                  30.36   \n",
       "13               mdm_test  computer_student                  45.83   \n",
       "14               mdm_test    movie_platform                  37.28   \n",
       "15               mmd_test         app_store                  30.36   \n",
       "16               mmd_test  computer_student                  45.83   \n",
       "17               mmd_test    movie_platform                  37.28   \n",
       "18               pad_test         app_store                  30.36   \n",
       "19               pad_test  computer_student                  45.83   \n",
       "20               pad_test    movie_platform                  37.28   \n",
       "\n",
       "    top3_average_test_acc_mean  delta  \n",
       "0                        34.71   4.35  \n",
       "1                        46.57   0.74  \n",
       "2                        41.00   3.72  \n",
       "3                        36.29   5.94  \n",
       "4                        46.39   0.56  \n",
       "5                        41.36   4.08  \n",
       "6                        34.71   4.35  \n",
       "7                        46.57   0.74  \n",
       "8                        41.00   3.72  \n",
       "9                        36.29   5.94  \n",
       "10                       46.39   0.56  \n",
       "11                       41.36   4.08  \n",
       "12                       33.86   3.50  \n",
       "13                       48.15   2.31  \n",
       "14                       46.91   9.63  \n",
       "15                       38.41   8.05  \n",
       "16                       47.41   1.57  \n",
       "17                       42.95   5.68  \n",
       "18                       33.54   3.19  \n",
       "19                       48.33   2.50  \n",
       "20                       44.15   6.88  "
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "top3_averages = []\n",
    "for seed in [42, 43, 44, 45, 46]:\n",
    "    seed_df = result_merged_df[result_merged_df['seed'] == seed]\n",
    "    for db_id in result_merged_df.db_id.unique():\n",
    "        db_df = seed_df[seed_df['db_id'] == db_id]\n",
    "        average_test_acc_mean = db_df['test_acc_mean'].mean()\n",
    "        for metric in ['pad_test', 'mmd_test', 'mdm_test', '7b_lens_unnormalized', '7b_lens', '32b_lens_unnormalized', '32b_lens']:\n",
    "            top3_average_test_acc_mean = db_df.sort_values(metric).tail(3)['test_acc_mean'].mean()\n",
    "            delta_percentage = top3_average_test_acc_mean - average_test_acc_mean\n",
    "            percentage_improvement = (delta_percentage / average_test_acc_mean)\n",
    "            top3_averages.append({\n",
    "                \"seed\": seed,\n",
    "                \"db_id\": db_id,\n",
    "                \"metric\": metric,\n",
    "                \"average_test_acc_mean\": average_test_acc_mean,\n",
    "                \"top3_average_test_acc_mean\": top3_average_test_acc_mean,\n",
    "                \"delta\": delta_percentage,\n",
    "                \"percentage_improvement\": percentage_improvement\n",
    "            })\n",
    "top3_average_test_acc_mean_df = pd.DataFrame(top3_averages)\n",
    "top3_average_test_acc_mean_df_agg = top3_average_test_acc_mean_df.groupby([\"metric\", \"db_id\"]).agg(\"mean\").reset_index()\n",
    "result_cols = ['average_test_acc_mean', 'top3_average_test_acc_mean', 'delta', 'percentage_improvement']\n",
    "top3_average_test_acc_mean_df_agg.drop(columns=['seed', 'percentage_improvement'], inplace=True)\n",
    "top3_average_test_acc_mean_df_agg.round(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Image Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import pickle\n",
    "from pathlib import Path\n",
    "\n",
    "real_data_map = pickle.load(open(\"../embeddings/e5-v_imagenet_real_embs.pkl\", \"rb\"))\n",
    "synth_data_map = {'first_split': pickle.load(open(\"../embeddings/e5-v_imagenet_synthetic_embs_first_split.pkl\", \"rb\")), 'second_split': pickle.load(open(\"../embeddings/e5-v_imagenet_synthetic_embs_second_split.pkl\", \"rb\")), 'third_split': pickle.load(open(\"../embeddings/e5-v_imagenet_synthetic_embs_third_split.pkl\", \"rb\"))}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute PAD, MMD, MDM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "mrr = pd.read_csv(\"results/image_task_performance.csv\")\n",
    "imagenet_pad_mmd_rows = []\n",
    "for split in synth_data_map.keys():\n",
    "    for dataset_name, X in tqdm.tqdm(synth_data_map[split].items(), desc=f'datasets (split={split})'):\n",
    "        for seed in range(42, 47):\n",
    "            # PAD: synthetic vs real test\n",
    "            pad_test = compute_pad(X, real_data_map[split][str(seed)])\n",
    "            # MMD: synthetic vs real test\n",
    "            mmd_test = compute_mmd(X, real_data_map[split][str(seed)])\n",
    "            # MDM: synthetic only\n",
    "            mdm_test = compute_mdm(X, n_clusters=5)\n",
    "            print(dataset_name, split)\n",
    "            test_mrr = mrr[(mrr['dataset'] == dataset_name) & (mrr['split'] == split)]['test_mrr@5'].values[0]\n",
    "\n",
    "            imagenet_pad_mmd_rows.append({\n",
    "                'split': split,\n",
    "                'dataset_name': dataset_name,\n",
    "                'seed': seed,\n",
    "                'pad_test': pad_test,\n",
    "                'mmd_test': mmd_test,\n",
    "                'mdm_test': mdm_test,\n",
    "                'test_mrr': test_mrr\n",
    "            })\n",
    "\n",
    "imagenet_merged_df = pd.DataFrame(imagenet_pad_mmd_rows)\n",
    "imagenet_merged_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>split</th>\n",
       "      <th>dataset_name</th>\n",
       "      <th>seed</th>\n",
       "      <th>pad_test</th>\n",
       "      <th>mmd_test</th>\n",
       "      <th>mdm_test</th>\n",
       "      <th>test_mrr@5</th>\n",
       "      <th>test_acc</th>\n",
       "      <th>test_loss</th>\n",
       "      <th>test_mrr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>first_split</td>\n",
       "      <td>unmet_v15_label_background</td>\n",
       "      <td>42</td>\n",
       "      <td>1.604899</td>\n",
       "      <td>0.000032</td>\n",
       "      <td>0.765632</td>\n",
       "      <td>0.520533</td>\n",
       "      <td>30.4</td>\n",
       "      <td>3.401548</td>\n",
       "      <td>0.520533</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>first_split</td>\n",
       "      <td>unmet_v15_label_background</td>\n",
       "      <td>43</td>\n",
       "      <td>1.600475</td>\n",
       "      <td>0.000031</td>\n",
       "      <td>0.765632</td>\n",
       "      <td>0.520533</td>\n",
       "      <td>30.4</td>\n",
       "      <td>3.401548</td>\n",
       "      <td>0.520533</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>first_split</td>\n",
       "      <td>unmet_v15_label_background</td>\n",
       "      <td>44</td>\n",
       "      <td>1.606034</td>\n",
       "      <td>0.000033</td>\n",
       "      <td>0.765632</td>\n",
       "      <td>0.520533</td>\n",
       "      <td>30.4</td>\n",
       "      <td>3.401548</td>\n",
       "      <td>0.520533</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>first_split</td>\n",
       "      <td>unmet_v15_label_background</td>\n",
       "      <td>45</td>\n",
       "      <td>1.603393</td>\n",
       "      <td>0.000031</td>\n",
       "      <td>0.765632</td>\n",
       "      <td>0.520533</td>\n",
       "      <td>30.4</td>\n",
       "      <td>3.401548</td>\n",
       "      <td>0.520533</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>first_split</td>\n",
       "      <td>unmet_v15_label_background</td>\n",
       "      <td>46</td>\n",
       "      <td>1.609312</td>\n",
       "      <td>0.000031</td>\n",
       "      <td>0.765632</td>\n",
       "      <td>0.520533</td>\n",
       "      <td>30.4</td>\n",
       "      <td>3.401548</td>\n",
       "      <td>0.520533</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>85</th>\n",
       "      <td>third_split</td>\n",
       "      <td>unmet_v15_label_only</td>\n",
       "      <td>42</td>\n",
       "      <td>1.594417</td>\n",
       "      <td>0.000020</td>\n",
       "      <td>0.756654</td>\n",
       "      <td>0.666267</td>\n",
       "      <td>43.2</td>\n",
       "      <td>1.752569</td>\n",
       "      <td>0.666267</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>86</th>\n",
       "      <td>third_split</td>\n",
       "      <td>unmet_v15_label_only</td>\n",
       "      <td>43</td>\n",
       "      <td>1.599372</td>\n",
       "      <td>0.000024</td>\n",
       "      <td>0.756654</td>\n",
       "      <td>0.666267</td>\n",
       "      <td>43.2</td>\n",
       "      <td>1.752569</td>\n",
       "      <td>0.666267</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>87</th>\n",
       "      <td>third_split</td>\n",
       "      <td>unmet_v15_label_only</td>\n",
       "      <td>44</td>\n",
       "      <td>1.598013</td>\n",
       "      <td>0.000023</td>\n",
       "      <td>0.756654</td>\n",
       "      <td>0.666267</td>\n",
       "      <td>43.2</td>\n",
       "      <td>1.752569</td>\n",
       "      <td>0.666267</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>88</th>\n",
       "      <td>third_split</td>\n",
       "      <td>unmet_v15_label_only</td>\n",
       "      <td>45</td>\n",
       "      <td>1.585429</td>\n",
       "      <td>0.000020</td>\n",
       "      <td>0.756654</td>\n",
       "      <td>0.666267</td>\n",
       "      <td>43.2</td>\n",
       "      <td>1.752569</td>\n",
       "      <td>0.666267</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>89</th>\n",
       "      <td>third_split</td>\n",
       "      <td>unmet_v15_label_only</td>\n",
       "      <td>46</td>\n",
       "      <td>1.594555</td>\n",
       "      <td>0.000022</td>\n",
       "      <td>0.756654</td>\n",
       "      <td>0.666267</td>\n",
       "      <td>43.2</td>\n",
       "      <td>1.752569</td>\n",
       "      <td>0.666267</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>90 rows × 10 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "          split                dataset_name  seed  pad_test  mmd_test  \\\n",
       "0   first_split  unmet_v15_label_background    42  1.604899  0.000032   \n",
       "1   first_split  unmet_v15_label_background    43  1.600475  0.000031   \n",
       "2   first_split  unmet_v15_label_background    44  1.606034  0.000033   \n",
       "3   first_split  unmet_v15_label_background    45  1.603393  0.000031   \n",
       "4   first_split  unmet_v15_label_background    46  1.609312  0.000031   \n",
       "..          ...                         ...   ...       ...       ...   \n",
       "85  third_split        unmet_v15_label_only    42  1.594417  0.000020   \n",
       "86  third_split        unmet_v15_label_only    43  1.599372  0.000024   \n",
       "87  third_split        unmet_v15_label_only    44  1.598013  0.000023   \n",
       "88  third_split        unmet_v15_label_only    45  1.585429  0.000020   \n",
       "89  third_split        unmet_v15_label_only    46  1.594555  0.000022   \n",
       "\n",
       "    mdm_test  test_mrr@5  test_acc  test_loss  test_mrr  \n",
       "0   0.765632    0.520533      30.4   3.401548  0.520533  \n",
       "1   0.765632    0.520533      30.4   3.401548  0.520533  \n",
       "2   0.765632    0.520533      30.4   3.401548  0.520533  \n",
       "3   0.765632    0.520533      30.4   3.401548  0.520533  \n",
       "4   0.765632    0.520533      30.4   3.401548  0.520533  \n",
       "..       ...         ...       ...        ...       ...  \n",
       "85  0.756654    0.666267      43.2   1.752569  0.666267  \n",
       "86  0.756654    0.666267      43.2   1.752569  0.666267  \n",
       "87  0.756654    0.666267      43.2   1.752569  0.666267  \n",
       "88  0.756654    0.666267      43.2   1.752569  0.666267  \n",
       "89  0.756654    0.666267      43.2   1.752569  0.666267  \n",
       "\n",
       "[90 rows x 10 columns]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "imagenet_merged_df = pd.read_csv(\"results/image_e5-v_all_seeds.csv\")\n",
    "imagenet_merged_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>split</th>\n",
       "      <th>pearson_pad_vs_acc</th>\n",
       "      <th>pearson_pad_vs_acc_std</th>\n",
       "      <th>spearman_pad_vs_acc</th>\n",
       "      <th>spearman_pad_vs_acc_std</th>\n",
       "      <th>pearson_mmd_vs_acc</th>\n",
       "      <th>pearson_mmd_vs_acc_std</th>\n",
       "      <th>spearman_mmd_vs_acc</th>\n",
       "      <th>spearman_mmd_vs_acc_std</th>\n",
       "      <th>pearson_mdm_vs_acc</th>\n",
       "      <th>pearson_mdm_vs_acc_std</th>\n",
       "      <th>spearman_mdm_vs_acc</th>\n",
       "      <th>spearman_mdm_vs_acc_std</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>first_split</td>\n",
       "      <td>0.13</td>\n",
       "      <td>0.14</td>\n",
       "      <td>-0.04</td>\n",
       "      <td>0.10</td>\n",
       "      <td>0.18</td>\n",
       "      <td>0.03</td>\n",
       "      <td>0.31</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.14</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>second_split</td>\n",
       "      <td>-0.02</td>\n",
       "      <td>0.10</td>\n",
       "      <td>-0.03</td>\n",
       "      <td>0.11</td>\n",
       "      <td>-0.07</td>\n",
       "      <td>0.01</td>\n",
       "      <td>-0.14</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.14</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>third_split</td>\n",
       "      <td>0.46</td>\n",
       "      <td>0.03</td>\n",
       "      <td>0.44</td>\n",
       "      <td>0.09</td>\n",
       "      <td>0.55</td>\n",
       "      <td>0.02</td>\n",
       "      <td>0.71</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-0.52</td>\n",
       "      <td>0.00</td>\n",
       "      <td>-0.60</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>all_splits</td>\n",
       "      <td>0.19</td>\n",
       "      <td>0.09</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.10</td>\n",
       "      <td>0.22</td>\n",
       "      <td>0.02</td>\n",
       "      <td>0.30</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-0.12</td>\n",
       "      <td>0.01</td>\n",
       "      <td>-0.10</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          split  pearson_pad_vs_acc  pearson_pad_vs_acc_std  \\\n",
       "0   first_split                0.13                    0.14   \n",
       "1  second_split               -0.02                    0.10   \n",
       "2   third_split                0.46                    0.03   \n",
       "3    all_splits                0.19                    0.09   \n",
       "\n",
       "   spearman_pad_vs_acc  spearman_pad_vs_acc_std  pearson_mmd_vs_acc  \\\n",
       "0                -0.04                     0.10                0.18   \n",
       "1                -0.03                     0.11               -0.07   \n",
       "2                 0.44                     0.09                0.55   \n",
       "3                 0.12                     0.10                0.22   \n",
       "\n",
       "   pearson_mmd_vs_acc_std  spearman_mmd_vs_acc  spearman_mmd_vs_acc_std  \\\n",
       "0                    0.03                 0.31                      0.0   \n",
       "1                    0.01                -0.14                      0.0   \n",
       "2                    0.02                 0.71                      0.0   \n",
       "3                    0.02                 0.30                      0.0   \n",
       "\n",
       "   pearson_mdm_vs_acc  pearson_mdm_vs_acc_std  spearman_mdm_vs_acc  \\\n",
       "0                0.16                    0.01                 0.14   \n",
       "1                0.00                    0.01                 0.14   \n",
       "2               -0.52                    0.00                -0.60   \n",
       "3               -0.12                    0.01                -0.10   \n",
       "\n",
       "   spearman_mdm_vs_acc_std  \n",
       "0                      0.0  \n",
       "1                      0.0  \n",
       "2                      0.0  \n",
       "3                      0.0  "
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from scipy.stats import pearsonr, spearmanr\n",
    "import numpy as np\n",
    "\n",
    "# We'll collect results in a list of dicts\n",
    "corr_rows = []\n",
    "\n",
    "def safe_mean(lst):\n",
    "    return sum(lst)/len(lst) if len(lst) > 0 else None\n",
    "\n",
    "def safe_std(lst):\n",
    "    return float(np.std(lst, ddof=1)) if len(lst) > 1 else None\n",
    "\n",
    "# For each split\n",
    "for split in imagenet_merged_df['split'].unique():\n",
    "    split_df = imagenet_merged_df[imagenet_merged_df['split'] == split]\n",
    "    # For each seed, compute correlation across datasets\n",
    "    pearson_pad = []\n",
    "    spearman_pad = []\n",
    "    pearson_mmd = []\n",
    "    spearman_mmd = []\n",
    "    pearson_mdm = []\n",
    "    spearman_mdm = []\n",
    "    for seed in split_df['seed'].unique():\n",
    "        seed_df = split_df[split_df['seed'] == seed]\n",
    "        # If there is only one dataset, correlation is not defined\n",
    "        if len(seed_df['dataset_name'].unique()) < 2:\n",
    "            continue\n",
    "        # For each dataset, get pad_test, mmd_test, mdm_test, test_mrr@5\n",
    "        # Group by dataset_name and take the mean (should be one row per dataset_name per seed/split)\n",
    "        grouped = seed_df.groupby('dataset_name').agg({\n",
    "            'pad_test': 'mean',\n",
    "            'mmd_test': 'mean',\n",
    "            'mdm_test': 'mean',\n",
    "            'test_mrr@5': 'mean'\n",
    "        }).reset_index()\n",
    "        if len(grouped) < 2:\n",
    "            continue\n",
    "        try:\n",
    "            p_pad, _ = pearsonr(grouped['pad_test'], grouped['test_mrr@5'])\n",
    "            s_pad, _ = spearmanr(grouped['pad_test'], grouped['test_mrr@5'])\n",
    "            pearson_pad.append(p_pad)\n",
    "            spearman_pad.append(s_pad)\n",
    "        except Exception:\n",
    "            print(f\"PAD exception for split={split}, seed={seed}\")\n",
    "        try:\n",
    "            p_mmd, _ = pearsonr(grouped['mmd_test'], grouped['test_mrr@5'])\n",
    "            s_mmd, _ = spearmanr(grouped['mmd_test'], grouped['test_mrr@5'])\n",
    "            pearson_mmd.append(p_mmd)\n",
    "            spearman_mmd.append(s_mmd)\n",
    "        except Exception:\n",
    "            print(f\"MMD exception for split={split}, seed={seed}\")\n",
    "        try:\n",
    "            p_mdm, _ = pearsonr(grouped['mdm_test'], grouped['test_mrr@5'])\n",
    "            s_mdm, _ = spearmanr(grouped['mdm_test'], grouped['test_mrr@5'])\n",
    "            pearson_mdm.append(p_mdm)\n",
    "            spearman_mdm.append(s_mdm)\n",
    "        except Exception:\n",
    "            print(f\"MDM exception for split={split}, seed={seed}\")\n",
    "    corr_rows.append({\n",
    "        'split': split,\n",
    "        'pearson_pad_vs_acc': safe_mean(pearson_pad),\n",
    "        'pearson_pad_vs_acc_std': safe_std(pearson_pad),\n",
    "        'spearman_pad_vs_acc': safe_mean(spearman_pad),\n",
    "        'spearman_pad_vs_acc_std': safe_std(spearman_pad),\n",
    "        'pearson_mmd_vs_acc': safe_mean(pearson_mmd),\n",
    "        'pearson_mmd_vs_acc_std': safe_std(pearson_mmd),\n",
    "        'spearman_mmd_vs_acc': safe_mean(spearman_mmd),\n",
    "        'spearman_mmd_vs_acc_std': safe_std(spearman_mmd),\n",
    "        'pearson_mdm_vs_acc': safe_mean(pearson_mdm),\n",
    "        'pearson_mdm_vs_acc_std': safe_std(pearson_mdm),\n",
    "        'spearman_mdm_vs_acc': safe_mean(spearman_mdm),\n",
    "        'spearman_mdm_vs_acc_std': safe_std(spearman_mdm)\n",
    "    })\n",
    "\n",
    "corr_df = pd.DataFrame(corr_rows)\n",
    "if not corr_df.empty:\n",
    "    agg_row = {\n",
    "        'split': 'all_splits',\n",
    "        'pearson_pad_vs_acc': corr_df['pearson_pad_vs_acc'].mean(),\n",
    "        'pearson_pad_vs_acc_std': corr_df['pearson_pad_vs_acc_std'].mean(),\n",
    "        'spearman_pad_vs_acc': corr_df['spearman_pad_vs_acc'].mean(),\n",
    "        'spearman_pad_vs_acc_std': corr_df['spearman_pad_vs_acc_std'].mean(),\n",
    "        'pearson_mmd_vs_acc': corr_df['pearson_mmd_vs_acc'].mean(),\n",
    "        'pearson_mmd_vs_acc_std': corr_df['pearson_mmd_vs_acc_std'].mean(),\n",
    "        'spearman_mmd_vs_acc': corr_df['spearman_mmd_vs_acc'].mean(),\n",
    "        'spearman_mmd_vs_acc_std': corr_df['spearman_mmd_vs_acc_std'].mean(),\n",
    "        'pearson_mdm_vs_acc': corr_df['pearson_mdm_vs_acc'].mean(),\n",
    "        'pearson_mdm_vs_acc_std': corr_df['pearson_mdm_vs_acc_std'].mean(),\n",
    "        'spearman_mdm_vs_acc': corr_df['spearman_mdm_vs_acc'].mean(),\n",
    "        'spearman_mdm_vs_acc_std': corr_df['spearman_mdm_vs_acc_std'].mean()\n",
    "    }\n",
    "    corr_df = pd.concat([corr_df, pd.DataFrame([agg_row])], ignore_index=True)\n",
    "\n",
    "# flip the values of pad and mmd as their correlation with accuracy is negative\n",
    "corr_df['pearson_pad_vs_acc'] = corr_df['pearson_pad_vs_acc'].apply(lambda x: -x if pd.notnull(x) else x)\n",
    "corr_df['spearman_pad_vs_acc'] = corr_df['spearman_pad_vs_acc'].apply(lambda x: -x if pd.notnull(x) else x)\n",
    "corr_df['pearson_mmd_vs_acc'] = corr_df['pearson_mmd_vs_acc'].apply(lambda x: -x if pd.notnull(x) else x)\n",
    "corr_df['spearman_mmd_vs_acc'] = corr_df['spearman_mmd_vs_acc'].apply(lambda x: -x if pd.notnull(x) else x)\n",
    "\n",
    "corr_df.round(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute Top-3 Ranked Average Accuracies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset_name</th>\n",
       "      <th>split</th>\n",
       "      <th>seed</th>\n",
       "      <th>32b_lens_unnormalized</th>\n",
       "      <th>32b_lens</th>\n",
       "      <th>7b_lens_unnormalized</th>\n",
       "      <th>7b_lens</th>\n",
       "      <th>pad_test</th>\n",
       "      <th>mmd_test</th>\n",
       "      <th>mdm_test</th>\n",
       "      <th>test_mrr@5</th>\n",
       "      <th>test_acc</th>\n",
       "      <th>test_loss</th>\n",
       "      <th>test_mrr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>unmet_v15_label_background</td>\n",
       "      <td>first_split</td>\n",
       "      <td>42</td>\n",
       "      <td>0.696950</td>\n",
       "      <td>0.603722</td>\n",
       "      <td>0.520333</td>\n",
       "      <td>0.480682</td>\n",
       "      <td>1.604899</td>\n",
       "      <td>0.000032</td>\n",
       "      <td>0.765632</td>\n",
       "      <td>0.520533</td>\n",
       "      <td>30.4</td>\n",
       "      <td>3.401548</td>\n",
       "      <td>0.520533</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>unmet_v15_label_background</td>\n",
       "      <td>first_split</td>\n",
       "      <td>43</td>\n",
       "      <td>0.588212</td>\n",
       "      <td>0.403744</td>\n",
       "      <td>0.127762</td>\n",
       "      <td>0.106592</td>\n",
       "      <td>1.600475</td>\n",
       "      <td>0.000031</td>\n",
       "      <td>0.765632</td>\n",
       "      <td>0.520533</td>\n",
       "      <td>30.4</td>\n",
       "      <td>3.401548</td>\n",
       "      <td>0.520533</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>unmet_v15_label_background</td>\n",
       "      <td>first_split</td>\n",
       "      <td>44</td>\n",
       "      <td>0.753200</td>\n",
       "      <td>0.527349</td>\n",
       "      <td>0.884333</td>\n",
       "      <td>0.832430</td>\n",
       "      <td>1.606034</td>\n",
       "      <td>0.000033</td>\n",
       "      <td>0.765632</td>\n",
       "      <td>0.520533</td>\n",
       "      <td>30.4</td>\n",
       "      <td>3.401548</td>\n",
       "      <td>0.520533</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>unmet_v15_label_background</td>\n",
       "      <td>first_split</td>\n",
       "      <td>45</td>\n",
       "      <td>0.490179</td>\n",
       "      <td>0.302081</td>\n",
       "      <td>0.217333</td>\n",
       "      <td>0.113566</td>\n",
       "      <td>1.603393</td>\n",
       "      <td>0.000031</td>\n",
       "      <td>0.765632</td>\n",
       "      <td>0.520533</td>\n",
       "      <td>30.4</td>\n",
       "      <td>3.401548</td>\n",
       "      <td>0.520533</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>unmet_v15_label_background</td>\n",
       "      <td>first_split</td>\n",
       "      <td>46</td>\n",
       "      <td>0.460143</td>\n",
       "      <td>0.232878</td>\n",
       "      <td>0.137500</td>\n",
       "      <td>0.099797</td>\n",
       "      <td>1.609312</td>\n",
       "      <td>0.000031</td>\n",
       "      <td>0.765632</td>\n",
       "      <td>0.520533</td>\n",
       "      <td>30.4</td>\n",
       "      <td>3.401548</td>\n",
       "      <td>0.520533</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>85</th>\n",
       "      <td>unmet_v15_label_only</td>\n",
       "      <td>third_split</td>\n",
       "      <td>42</td>\n",
       "      <td>0.378876</td>\n",
       "      <td>0.369601</td>\n",
       "      <td>0.029214</td>\n",
       "      <td>0.040805</td>\n",
       "      <td>1.594417</td>\n",
       "      <td>0.000020</td>\n",
       "      <td>0.756654</td>\n",
       "      <td>0.666267</td>\n",
       "      <td>43.2</td>\n",
       "      <td>1.752569</td>\n",
       "      <td>0.666267</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>86</th>\n",
       "      <td>unmet_v15_label_only</td>\n",
       "      <td>third_split</td>\n",
       "      <td>43</td>\n",
       "      <td>0.317388</td>\n",
       "      <td>0.331970</td>\n",
       "      <td>0.136667</td>\n",
       "      <td>0.247646</td>\n",
       "      <td>1.599372</td>\n",
       "      <td>0.000024</td>\n",
       "      <td>0.756654</td>\n",
       "      <td>0.666267</td>\n",
       "      <td>43.2</td>\n",
       "      <td>1.752569</td>\n",
       "      <td>0.666267</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>87</th>\n",
       "      <td>unmet_v15_label_only</td>\n",
       "      <td>third_split</td>\n",
       "      <td>44</td>\n",
       "      <td>0.307843</td>\n",
       "      <td>0.333511</td>\n",
       "      <td>0.004214</td>\n",
       "      <td>0.007098</td>\n",
       "      <td>1.598013</td>\n",
       "      <td>0.000023</td>\n",
       "      <td>0.756654</td>\n",
       "      <td>0.666267</td>\n",
       "      <td>43.2</td>\n",
       "      <td>1.752569</td>\n",
       "      <td>0.666267</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>88</th>\n",
       "      <td>unmet_v15_label_only</td>\n",
       "      <td>third_split</td>\n",
       "      <td>45</td>\n",
       "      <td>0.406060</td>\n",
       "      <td>0.433236</td>\n",
       "      <td>0.111667</td>\n",
       "      <td>0.141725</td>\n",
       "      <td>1.585429</td>\n",
       "      <td>0.000020</td>\n",
       "      <td>0.756654</td>\n",
       "      <td>0.666267</td>\n",
       "      <td>43.2</td>\n",
       "      <td>1.752569</td>\n",
       "      <td>0.666267</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>89</th>\n",
       "      <td>unmet_v15_label_only</td>\n",
       "      <td>third_split</td>\n",
       "      <td>46</td>\n",
       "      <td>0.445007</td>\n",
       "      <td>0.451763</td>\n",
       "      <td>0.103810</td>\n",
       "      <td>0.113658</td>\n",
       "      <td>1.594555</td>\n",
       "      <td>0.000022</td>\n",
       "      <td>0.756654</td>\n",
       "      <td>0.666267</td>\n",
       "      <td>43.2</td>\n",
       "      <td>1.752569</td>\n",
       "      <td>0.666267</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>90 rows × 14 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                  dataset_name        split  seed  32b_lens_unnormalized  \\\n",
       "0   unmet_v15_label_background  first_split    42               0.696950   \n",
       "1   unmet_v15_label_background  first_split    43               0.588212   \n",
       "2   unmet_v15_label_background  first_split    44               0.753200   \n",
       "3   unmet_v15_label_background  first_split    45               0.490179   \n",
       "4   unmet_v15_label_background  first_split    46               0.460143   \n",
       "..                         ...          ...   ...                    ...   \n",
       "85        unmet_v15_label_only  third_split    42               0.378876   \n",
       "86        unmet_v15_label_only  third_split    43               0.317388   \n",
       "87        unmet_v15_label_only  third_split    44               0.307843   \n",
       "88        unmet_v15_label_only  third_split    45               0.406060   \n",
       "89        unmet_v15_label_only  third_split    46               0.445007   \n",
       "\n",
       "    32b_lens  7b_lens_unnormalized   7b_lens  pad_test  mmd_test  mdm_test  \\\n",
       "0   0.603722              0.520333  0.480682  1.604899  0.000032  0.765632   \n",
       "1   0.403744              0.127762  0.106592  1.600475  0.000031  0.765632   \n",
       "2   0.527349              0.884333  0.832430  1.606034  0.000033  0.765632   \n",
       "3   0.302081              0.217333  0.113566  1.603393  0.000031  0.765632   \n",
       "4   0.232878              0.137500  0.099797  1.609312  0.000031  0.765632   \n",
       "..       ...                   ...       ...       ...       ...       ...   \n",
       "85  0.369601              0.029214  0.040805  1.594417  0.000020  0.756654   \n",
       "86  0.331970              0.136667  0.247646  1.599372  0.000024  0.756654   \n",
       "87  0.333511              0.004214  0.007098  1.598013  0.000023  0.756654   \n",
       "88  0.433236              0.111667  0.141725  1.585429  0.000020  0.756654   \n",
       "89  0.451763              0.103810  0.113658  1.594555  0.000022  0.756654   \n",
       "\n",
       "    test_mrr@5  test_acc  test_loss  test_mrr  \n",
       "0     0.520533      30.4   3.401548  0.520533  \n",
       "1     0.520533      30.4   3.401548  0.520533  \n",
       "2     0.520533      30.4   3.401548  0.520533  \n",
       "3     0.520533      30.4   3.401548  0.520533  \n",
       "4     0.520533      30.4   3.401548  0.520533  \n",
       "..         ...       ...        ...       ...  \n",
       "85    0.666267      43.2   1.752569  0.666267  \n",
       "86    0.666267      43.2   1.752569  0.666267  \n",
       "87    0.666267      43.2   1.752569  0.666267  \n",
       "88    0.666267      43.2   1.752569  0.666267  \n",
       "89    0.666267      43.2   1.752569  0.666267  \n",
       "\n",
       "[90 rows x 14 columns]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "result_df = pd.read_csv(\"results/image_e5-v_all_seeds.csv\")\n",
    "lens_result_7b_df = pd.read_csv(\"results/image_lens-qwen2.5-vl-7b-instruct_all_seeds.csv\")\n",
    "lens_result_32b_df = pd.read_csv(\"results/image_lens-qwen2.5-vl-32b-instruct_all_seeds.csv\")\n",
    "result_merged_df = pd.merge(lens_result_7b_df, result_df, how='right', on=['seed', 'dataset_name', 'split'])\n",
    "result_merged_df = pd.merge(lens_result_32b_df, result_merged_df, how='right', on=['seed', 'dataset_name', 'split'])\n",
    "result_merged_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>metric</th>\n",
       "      <th>split</th>\n",
       "      <th>average_test_mrr</th>\n",
       "      <th>top3_average_test_mrr</th>\n",
       "      <th>delta</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>32b_lens</td>\n",
       "      <td>first_split</td>\n",
       "      <td>57.15</td>\n",
       "      <td>56.38</td>\n",
       "      <td>-0.77</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>32b_lens</td>\n",
       "      <td>second_split</td>\n",
       "      <td>55.77</td>\n",
       "      <td>56.21</td>\n",
       "      <td>0.44</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>32b_lens</td>\n",
       "      <td>third_split</td>\n",
       "      <td>57.67</td>\n",
       "      <td>60.21</td>\n",
       "      <td>2.54</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>32b_lens_unnormalized</td>\n",
       "      <td>first_split</td>\n",
       "      <td>57.15</td>\n",
       "      <td>53.42</td>\n",
       "      <td>-3.73</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>32b_lens_unnormalized</td>\n",
       "      <td>second_split</td>\n",
       "      <td>55.77</td>\n",
       "      <td>55.99</td>\n",
       "      <td>0.22</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>32b_lens_unnormalized</td>\n",
       "      <td>third_split</td>\n",
       "      <td>57.67</td>\n",
       "      <td>56.98</td>\n",
       "      <td>-0.69</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>7b_lens</td>\n",
       "      <td>first_split</td>\n",
       "      <td>57.15</td>\n",
       "      <td>57.31</td>\n",
       "      <td>0.16</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>7b_lens</td>\n",
       "      <td>second_split</td>\n",
       "      <td>55.77</td>\n",
       "      <td>55.33</td>\n",
       "      <td>-0.44</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>7b_lens</td>\n",
       "      <td>third_split</td>\n",
       "      <td>57.67</td>\n",
       "      <td>59.05</td>\n",
       "      <td>1.38</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>7b_lens_unnormalized</td>\n",
       "      <td>first_split</td>\n",
       "      <td>57.15</td>\n",
       "      <td>56.73</td>\n",
       "      <td>-0.43</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>7b_lens_unnormalized</td>\n",
       "      <td>second_split</td>\n",
       "      <td>55.77</td>\n",
       "      <td>55.78</td>\n",
       "      <td>0.01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>7b_lens_unnormalized</td>\n",
       "      <td>third_split</td>\n",
       "      <td>57.67</td>\n",
       "      <td>58.65</td>\n",
       "      <td>0.98</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>mdm_test</td>\n",
       "      <td>first_split</td>\n",
       "      <td>57.15</td>\n",
       "      <td>57.02</td>\n",
       "      <td>-0.13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>mdm_test</td>\n",
       "      <td>second_split</td>\n",
       "      <td>55.77</td>\n",
       "      <td>54.81</td>\n",
       "      <td>-0.96</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>mdm_test</td>\n",
       "      <td>third_split</td>\n",
       "      <td>57.67</td>\n",
       "      <td>52.16</td>\n",
       "      <td>-5.51</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>mmd_test</td>\n",
       "      <td>first_split</td>\n",
       "      <td>57.15</td>\n",
       "      <td>57.29</td>\n",
       "      <td>0.13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>mmd_test</td>\n",
       "      <td>second_split</td>\n",
       "      <td>55.77</td>\n",
       "      <td>54.49</td>\n",
       "      <td>-1.28</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>mmd_test</td>\n",
       "      <td>third_split</td>\n",
       "      <td>57.67</td>\n",
       "      <td>64.13</td>\n",
       "      <td>6.46</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>pad_test</td>\n",
       "      <td>first_split</td>\n",
       "      <td>57.15</td>\n",
       "      <td>55.90</td>\n",
       "      <td>-1.26</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>pad_test</td>\n",
       "      <td>second_split</td>\n",
       "      <td>55.77</td>\n",
       "      <td>55.38</td>\n",
       "      <td>-0.38</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>pad_test</td>\n",
       "      <td>third_split</td>\n",
       "      <td>57.67</td>\n",
       "      <td>58.22</td>\n",
       "      <td>0.55</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   metric         split  average_test_mrr  \\\n",
       "0                32b_lens   first_split             57.15   \n",
       "1                32b_lens  second_split             55.77   \n",
       "2                32b_lens   third_split             57.67   \n",
       "3   32b_lens_unnormalized   first_split             57.15   \n",
       "4   32b_lens_unnormalized  second_split             55.77   \n",
       "5   32b_lens_unnormalized   third_split             57.67   \n",
       "6                 7b_lens   first_split             57.15   \n",
       "7                 7b_lens  second_split             55.77   \n",
       "8                 7b_lens   third_split             57.67   \n",
       "9    7b_lens_unnormalized   first_split             57.15   \n",
       "10   7b_lens_unnormalized  second_split             55.77   \n",
       "11   7b_lens_unnormalized   third_split             57.67   \n",
       "12               mdm_test   first_split             57.15   \n",
       "13               mdm_test  second_split             55.77   \n",
       "14               mdm_test   third_split             57.67   \n",
       "15               mmd_test   first_split             57.15   \n",
       "16               mmd_test  second_split             55.77   \n",
       "17               mmd_test   third_split             57.67   \n",
       "18               pad_test   first_split             57.15   \n",
       "19               pad_test  second_split             55.77   \n",
       "20               pad_test   third_split             57.67   \n",
       "\n",
       "    top3_average_test_mrr  delta  \n",
       "0                   56.38  -0.77  \n",
       "1                   56.21   0.44  \n",
       "2                   60.21   2.54  \n",
       "3                   53.42  -3.73  \n",
       "4                   55.99   0.22  \n",
       "5                   56.98  -0.69  \n",
       "6                   57.31   0.16  \n",
       "7                   55.33  -0.44  \n",
       "8                   59.05   1.38  \n",
       "9                   56.73  -0.43  \n",
       "10                  55.78   0.01  \n",
       "11                  58.65   0.98  \n",
       "12                  57.02  -0.13  \n",
       "13                  54.81  -0.96  \n",
       "14                  52.16  -5.51  \n",
       "15                  57.29   0.13  \n",
       "16                  54.49  -1.28  \n",
       "17                  64.13   6.46  \n",
       "18                  55.90  -1.26  \n",
       "19                  55.38  -0.38  \n",
       "20                  58.22   0.55  "
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "top3_averages = []\n",
    "for seed in [42, 43, 44, 45, 46]:\n",
    "    seed_df = result_merged_df[result_merged_df['seed'] == seed]\n",
    "    for split in result_merged_df.split.unique():\n",
    "        split_df = seed_df[seed_df['split'] == split]\n",
    "        average_test_acc_mean = split_df['test_mrr'].mean()\n",
    "        for metric in ['pad_test', 'mmd_test', 'mdm_test', '7b_lens_unnormalized', '7b_lens', '32b_lens_unnormalized', '32b_lens']:\n",
    "            # negate pad, mmd\n",
    "            if metric in ['pad_test', 'mmd_test']:\n",
    "                top3_average_test_acc_mean = split_df.sort_values(metric).head(3)['test_mrr'].mean()\n",
    "            else:\n",
    "                top3_average_test_acc_mean = split_df.sort_values(metric).tail(3)['test_mrr'].mean()\n",
    "            delta_percentage = top3_average_test_acc_mean - average_test_acc_mean\n",
    "            top3_averages.append({\n",
    "                \"seed\": seed,\n",
    "                \"split\": split,\n",
    "                \"metric\": metric,\n",
    "                \"average_test_mrr\": average_test_acc_mean,\n",
    "                \"top3_average_test_mrr\": top3_average_test_acc_mean,\n",
    "                \"delta\": delta_percentage,\n",
    "            })\n",
    "top3_average_test_acc_mean_df = pd.DataFrame(top3_averages)\n",
    "top3_average_test_acc_mean_df_agg = top3_average_test_acc_mean_df.groupby([\"metric\", \"split\"]).agg(\"mean\").reset_index()\n",
    "result_cols = ['average_test_mrr', 'top3_average_test_mrr', 'delta']\n",
    "top3_average_test_acc_mean_df_agg.drop(columns=['seed'], inplace=True)\n",
    "top3_average_test_acc_mean_df_agg[['average_test_mrr','top3_average_test_mrr', 'delta']] = top3_average_test_acc_mean_df_agg[['average_test_mrr','top3_average_test_mrr', 'delta']].apply(lambda x: x*100)\n",
    "top3_average_test_acc_mean_df_agg.round(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>metric</th>\n",
       "      <th>delta</th>\n",
       "      <th>top3_average_test_mrr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>32b_lens</td>\n",
       "      <td>0.74</td>\n",
       "      <td>57.60</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>32b_lens_unnormalized</td>\n",
       "      <td>-1.40</td>\n",
       "      <td>55.46</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>7b_lens</td>\n",
       "      <td>0.37</td>\n",
       "      <td>57.23</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>7b_lens_unnormalized</td>\n",
       "      <td>0.19</td>\n",
       "      <td>57.05</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>mdm_test</td>\n",
       "      <td>-2.20</td>\n",
       "      <td>54.66</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>mmd_test</td>\n",
       "      <td>1.77</td>\n",
       "      <td>58.63</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>pad_test</td>\n",
       "      <td>-0.36</td>\n",
       "      <td>56.50</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  metric  delta  top3_average_test_mrr\n",
       "0               32b_lens   0.74                  57.60\n",
       "1  32b_lens_unnormalized  -1.40                  55.46\n",
       "2                7b_lens   0.37                  57.23\n",
       "3   7b_lens_unnormalized   0.19                  57.05\n",
       "4               mdm_test  -2.20                  54.66\n",
       "5               mmd_test   1.77                  58.63\n",
       "6               pad_test  -0.36                  56.50"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Compute average \"delta\" and average \"top3_average_test_mrr\" across splits for each metric\n",
    "avg_across_splits = top3_average_test_acc_mean_df_agg.groupby(\"metric\")[[\"delta\", \"top3_average_test_mrr\"]].mean().reset_index()\n",
    "avg_across_splits = avg_across_splits.round(2)\n",
    "avg_across_splits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Web Navigation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute PAD and MMD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 166,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load embedings\n",
    "import pickle\n",
    "synth_data_map = pickle.load(open(\"../embeddings/web_synthetic_data_embs.pkl\", \"rb\"))\n",
    "real_data_map = pickle.load(open(\"../embeddings/web_real_data_embs.pkl\", \"rb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import pearsonr, spearmanr\n",
    "import numpy as np\n",
    "import random\n",
    "\n",
    "# We'll collect results in a list of dicts\n",
    "corr_rows = []\n",
    "\n",
    "# New: collect raw data for each entry\n",
    "raw_rows = []\n",
    "\n",
    "# For each website (source_model)\n",
    "for website in synth_data_map.keys():\n",
    "    # For each seed in this website\n",
    "    seed_corrs = []\n",
    "    for seed in range(42, 47):\n",
    "        random.seed(seed)\n",
    "        indices = random.sample(range(len(real_data_map[website])), 20)\n",
    "        real_embs = [real_data_map[website][i] for i in indices]\n",
    "        # For each partition/portion\n",
    "        pad_scores = []\n",
    "        mmd_scores = []\n",
    "        mdm_scores = []\n",
    "        success_rates = []\n",
    "        partitions = sorted(synth_data_map[website].keys())\n",
    "        for partition in partitions:\n",
    "            synth_embs = synth_data_map[website][partition]\n",
    "\n",
    "            try:\n",
    "                pad = compute_pad(synth_embs, real_embs)\n",
    "            except Exception:\n",
    "                pad = None\n",
    "            try:\n",
    "                mmd = compute_mmd(synth_embs, real_embs)\n",
    "            except Exception:\n",
    "                mmd = None\n",
    "            try:\n",
    "                mdm = compute_mdm(synth_embs)\n",
    "            except Exception:\n",
    "                mdm = None\n",
    "            # Get success_rate for this website and portion from results DataFrame\n",
    "            sr_row = results[(results['source_model'] == website) & (results['partition'] == partition)]\n",
    "            if len(sr_row) == 0:\n",
    "                continue\n",
    "            success_rate = sr_row['success_rate'].values[0]\n",
    "            pad_scores.append(pad)\n",
    "            mmd_scores.append(mmd)\n",
    "            mdm_scores.append(mdm)\n",
    "            success_rates.append(success_rate)\n",
    "            # Add raw entry\n",
    "            raw_rows.append({\n",
    "                'website': website,\n",
    "                'seed': seed,\n",
    "                'partition': partition,\n",
    "                'pad': pad,\n",
    "                'mmd': mmd,\n",
    "                'mdm': mdm,\n",
    "                'success_rate': success_rate\n",
    "            })\n",
    "        # Only compute correlation if we have at least 2 points\n",
    "        if len(success_rates) < 2:\n",
    "            continue\n",
    "        def safe_corr(x, y, fn):\n",
    "            try:\n",
    "                return fn(x, y)[0]\n",
    "            except Exception:\n",
    "                return None\n",
    "        seed_corrs.append({\n",
    "            'pearson_pad_vs_success_rate': safe_corr(pad_scores, success_rates, pearsonr),\n",
    "            'spearman_pad_vs_success_rate': safe_corr(pad_scores, success_rates, spearmanr),\n",
    "            'pearson_mmd_vs_success_rate': safe_corr(mmd_scores, success_rates, pearsonr),\n",
    "            'spearman_mmd_vs_success_rate': safe_corr(mmd_scores, success_rates, spearmanr),\n",
    "            'pearson_mdm_vs_success_rate': safe_corr(mdm_scores, success_rates, pearsonr),\n",
    "            'spearman_mdm_vs_success_rate': safe_corr(mdm_scores, success_rates, spearmanr),\n",
    "            'n_portions': len(success_rates)\n",
    "        })\n",
    "    # Now, for this website, take the mean of each correlation coefficient across seeds\n",
    "    if len(seed_corrs) == 0:\n",
    "        continue\n",
    "    def mean_ignore_none(lst):\n",
    "        arr = [x for x in lst if x is not None]\n",
    "        return np.mean(arr) if arr else None\n",
    "    corr_rows.append({\n",
    "        'website': website,\n",
    "        'pearson_pad_vs_success_rate': mean_ignore_none([c['pearson_pad_vs_success_rate'] for c in seed_corrs]),\n",
    "        'spearman_pad_vs_success_rate': mean_ignore_none([c['spearman_pad_vs_success_rate'] for c in seed_corrs]),\n",
    "        'pearson_mmd_vs_success_rate': mean_ignore_none([c['pearson_mmd_vs_success_rate'] for c in seed_corrs]),\n",
    "        'spearman_mmd_vs_success_rate': mean_ignore_none([c['spearman_mmd_vs_success_rate'] for c in seed_corrs]),\n",
    "        'pearson_mdm_vs_success_rate': mean_ignore_none([c['pearson_mdm_vs_success_rate'] for c in seed_corrs]),\n",
    "        'spearman_mdm_vs_success_rate': mean_ignore_none([c['spearman_mdm_vs_success_rate'] for c in seed_corrs]),\n",
    "    })\n",
    "\n",
    "corr_df = pd.DataFrame(corr_rows)\n",
    "raw_df = pd.DataFrame(raw_rows)\n",
    "\n",
    "# Print Average correlation across websites for each seed (like in file_context_0)\n",
    "all_websites = list(synth_data_map.keys())\n",
    "all_seeds = list(range(42, 47))\n",
    "\n",
    "print(\"Average correlation across websites for each seed:\")\n",
    "all_websites_corr_pad = []\n",
    "all_websites_corr_s_pad = []\n",
    "all_websites_corr_mmd = []\n",
    "all_websites_corr_s_mmd = []\n",
    "all_websites_corr_mdm = []\n",
    "all_websites_corr_s_mdm = []\n",
    "for seed in all_seeds:\n",
    "    per_website_corrs_pad = []\n",
    "    per_website_corrs_mmd = []\n",
    "    per_website_corrs_mdm = []\n",
    "    per_website_corrs_s_pad = []\n",
    "    per_website_corrs_s_mmd = []\n",
    "    per_website_corrs_s_mdm = []\n",
    "    for website in all_websites:\n",
    "        # Get all rows for this website and seed\n",
    "        rows = [r for r in raw_rows if r['website'] == website and r['seed'] == seed]\n",
    "        if len(rows) < 2:\n",
    "            continue\n",
    "        pad_scores = [r['pad'] for r in rows]\n",
    "        mmd_scores = [r['mmd'] for r in rows]\n",
    "        mdm_scores = [r['mdm'] for r in rows]\n",
    "        success_rates = [r['success_rate'] for r in rows]\n",
    "        def safe_corr(x, y, fn):\n",
    "            try:\n",
    "                return fn(x, y)[0]\n",
    "            except Exception:\n",
    "                return np.nan\n",
    "        p_pad = safe_corr(pad_scores, success_rates, pearsonr)\n",
    "        s_pad = safe_corr(pad_scores, success_rates, spearmanr)\n",
    "        p_mmd = safe_corr(mmd_scores, success_rates, pearsonr)\n",
    "        s_mmd = safe_corr(mmd_scores, success_rates, spearmanr)\n",
    "        p_mdm = safe_corr(mdm_scores, success_rates, pearsonr)\n",
    "        s_mdm = safe_corr(mdm_scores, success_rates, spearmanr)\n",
    "        per_website_corrs_pad.append(p_pad)\n",
    "        per_website_corrs_mmd.append(p_mmd)\n",
    "        per_website_corrs_mdm.append(p_mdm)\n",
    "        per_website_corrs_s_pad.append(s_pad)\n",
    "        per_website_corrs_s_mmd.append(s_mmd)\n",
    "        per_website_corrs_s_mdm.append(s_mdm)\n",
    "    if per_website_corrs_pad:\n",
    "        avg_p_pad = np.nanmean(per_website_corrs_pad)\n",
    "        avg_s_pad = np.nanmean(per_website_corrs_s_pad)\n",
    "        avg_p_mmd = np.nanmean(per_website_corrs_mmd)\n",
    "        avg_s_mmd = np.nanmean(per_website_corrs_s_mmd)\n",
    "        avg_p_mdm = np.nanmean(per_website_corrs_mdm)\n",
    "        avg_s_mdm = np.nanmean(per_website_corrs_s_mdm)\n",
    "        all_websites_corr_pad.append(avg_p_pad)\n",
    "        all_websites_corr_s_pad.append(avg_s_pad)\n",
    "        all_websites_corr_mmd.append(avg_p_mmd)\n",
    "        all_websites_corr_s_mmd.append(avg_s_mmd)\n",
    "        all_websites_corr_mdm.append(avg_p_mdm)\n",
    "        all_websites_corr_s_mdm.append(avg_s_mdm)\n",
    "        print(f\"Seed {seed}:\")\n",
    "        print(f\"  Pearson pad vs success_rate: {avg_p_pad:.6f}\")\n",
    "        print(f\"  Spearman pad vs success_rate: {avg_s_pad:.6f}\")\n",
    "        print(f\"  Pearson mmd vs success_rate: {avg_p_mmd:.6f}\")\n",
    "        print(f\"  Spearman mmd vs success_rate: {avg_s_mmd:.6f}\")\n",
    "        print(f\"  Pearson mdm vs success_rate: {avg_p_mdm:.6f}\")\n",
    "        print(f\"  Spearman mdm vs success_rate: {avg_s_mdm:.6f}\")\n",
    "    else:\n",
    "        print(f\"Seed {seed}: No valid website correlations.\")\n",
    "\n",
    "corr_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute Top-3 Ranked Average Accuracies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>website</th>\n",
       "      <th>seed</th>\n",
       "      <th>partition</th>\n",
       "      <th>32b_lens_unnormalized</th>\n",
       "      <th>32b_lens</th>\n",
       "      <th>success_rate</th>\n",
       "      <th>7b_lens_unnormalized</th>\n",
       "      <th>7b_lens</th>\n",
       "      <th>pad</th>\n",
       "      <th>mmd</th>\n",
       "      <th>mdm</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>allrecipes</td>\n",
       "      <td>42</td>\n",
       "      <td>4</td>\n",
       "      <td>0.393143</td>\n",
       "      <td>0.258009</td>\n",
       "      <td>0.377778</td>\n",
       "      <td>0.491004</td>\n",
       "      <td>0.468937</td>\n",
       "      <td>-0.986128</td>\n",
       "      <td>-0.000125</td>\n",
       "      <td>0.507638</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>allrecipes</td>\n",
       "      <td>43</td>\n",
       "      <td>3</td>\n",
       "      <td>0.401160</td>\n",
       "      <td>0.251210</td>\n",
       "      <td>0.355556</td>\n",
       "      <td>0.450392</td>\n",
       "      <td>0.424337</td>\n",
       "      <td>-1.110462</td>\n",
       "      <td>-0.000159</td>\n",
       "      <td>0.387447</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>allrecipes</td>\n",
       "      <td>43</td>\n",
       "      <td>0</td>\n",
       "      <td>0.370992</td>\n",
       "      <td>0.262381</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.486483</td>\n",
       "      <td>0.478754</td>\n",
       "      <td>-1.112758</td>\n",
       "      <td>-0.000162</td>\n",
       "      <td>0.499306</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>allrecipes</td>\n",
       "      <td>43</td>\n",
       "      <td>1</td>\n",
       "      <td>0.433831</td>\n",
       "      <td>0.298433</td>\n",
       "      <td>0.400000</td>\n",
       "      <td>0.491953</td>\n",
       "      <td>0.486093</td>\n",
       "      <td>-1.108301</td>\n",
       "      <td>-0.000166</td>\n",
       "      <td>0.544760</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>allrecipes</td>\n",
       "      <td>46</td>\n",
       "      <td>0</td>\n",
       "      <td>0.292300</td>\n",
       "      <td>0.121690</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.445705</td>\n",
       "      <td>0.423682</td>\n",
       "      <td>-1.017123</td>\n",
       "      <td>-0.000133</td>\n",
       "      <td>0.499306</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>320</th>\n",
       "      <td>wolframalpha</td>\n",
       "      <td>45</td>\n",
       "      <td>4</td>\n",
       "      <td>0.341667</td>\n",
       "      <td>0.405584</td>\n",
       "      <td>0.195652</td>\n",
       "      <td>0.386905</td>\n",
       "      <td>0.411252</td>\n",
       "      <td>-0.518806</td>\n",
       "      <td>-0.000070</td>\n",
       "      <td>0.555600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>321</th>\n",
       "      <td>wolframalpha</td>\n",
       "      <td>45</td>\n",
       "      <td>2</td>\n",
       "      <td>0.294066</td>\n",
       "      <td>0.321478</td>\n",
       "      <td>0.282609</td>\n",
       "      <td>0.346645</td>\n",
       "      <td>0.384175</td>\n",
       "      <td>-0.483980</td>\n",
       "      <td>-0.000051</td>\n",
       "      <td>0.553595</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>322</th>\n",
       "      <td>wolframalpha</td>\n",
       "      <td>45</td>\n",
       "      <td>3</td>\n",
       "      <td>0.372096</td>\n",
       "      <td>0.328104</td>\n",
       "      <td>0.326087</td>\n",
       "      <td>0.397114</td>\n",
       "      <td>0.413972</td>\n",
       "      <td>-0.483227</td>\n",
       "      <td>-0.000065</td>\n",
       "      <td>0.470202</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>323</th>\n",
       "      <td>wolframalpha</td>\n",
       "      <td>42</td>\n",
       "      <td>0</td>\n",
       "      <td>0.445689</td>\n",
       "      <td>0.437453</td>\n",
       "      <td>0.326087</td>\n",
       "      <td>0.432756</td>\n",
       "      <td>0.435058</td>\n",
       "      <td>-0.500790</td>\n",
       "      <td>-0.000054</td>\n",
       "      <td>0.493407</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>324</th>\n",
       "      <td>wolframalpha</td>\n",
       "      <td>44</td>\n",
       "      <td>2</td>\n",
       "      <td>0.308460</td>\n",
       "      <td>0.357338</td>\n",
       "      <td>0.282609</td>\n",
       "      <td>0.371807</td>\n",
       "      <td>0.410843</td>\n",
       "      <td>-0.465463</td>\n",
       "      <td>-0.000050</td>\n",
       "      <td>0.553595</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>325 rows × 11 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "          website  seed  partition  32b_lens_unnormalized  32b_lens  \\\n",
       "0      allrecipes    42          4               0.393143  0.258009   \n",
       "1      allrecipes    43          3               0.401160  0.251210   \n",
       "2      allrecipes    43          0               0.370992  0.262381   \n",
       "3      allrecipes    43          1               0.433831  0.298433   \n",
       "4      allrecipes    46          0               0.292300  0.121690   \n",
       "..            ...   ...        ...                    ...       ...   \n",
       "320  wolframalpha    45          4               0.341667  0.405584   \n",
       "321  wolframalpha    45          2               0.294066  0.321478   \n",
       "322  wolframalpha    45          3               0.372096  0.328104   \n",
       "323  wolframalpha    42          0               0.445689  0.437453   \n",
       "324  wolframalpha    44          2               0.308460  0.357338   \n",
       "\n",
       "     success_rate  7b_lens_unnormalized   7b_lens       pad       mmd  \\\n",
       "0        0.377778              0.491004  0.468937 -0.986128 -0.000125   \n",
       "1        0.355556              0.450392  0.424337 -1.110462 -0.000159   \n",
       "2        0.333333              0.486483  0.478754 -1.112758 -0.000162   \n",
       "3        0.400000              0.491953  0.486093 -1.108301 -0.000166   \n",
       "4        0.333333              0.445705  0.423682 -1.017123 -0.000133   \n",
       "..            ...                   ...       ...       ...       ...   \n",
       "320      0.195652              0.386905  0.411252 -0.518806 -0.000070   \n",
       "321      0.282609              0.346645  0.384175 -0.483980 -0.000051   \n",
       "322      0.326087              0.397114  0.413972 -0.483227 -0.000065   \n",
       "323      0.326087              0.432756  0.435058 -0.500790 -0.000054   \n",
       "324      0.282609              0.371807  0.410843 -0.465463 -0.000050   \n",
       "\n",
       "          mdm  \n",
       "0    0.507638  \n",
       "1    0.387447  \n",
       "2    0.499306  \n",
       "3    0.544760  \n",
       "4    0.499306  \n",
       "..        ...  \n",
       "320  0.555600  \n",
       "321  0.553595  \n",
       "322  0.470202  \n",
       "323  0.493407  \n",
       "324  0.553595  \n",
       "\n",
       "[325 rows x 11 columns]"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result_df = pd.read_csv(\"results/webvoyager_qte-qwen2-7b-instruct_all_seeds.csv\").drop(columns=['Unnamed: 0'])\n",
    "lens_result_7b_df = pd.read_csv(\"results/webvoyager_lens-qwen2.5-7b-instruct_all_seeds.csv\").drop(columns=['Unnamed: 0'])\n",
    "lens_result_7b_df.rename(columns={'lens_unnormalized': \"7b_lens_unnormalized\", \"lens\": \"7b_lens\"},inplace=True)\n",
    "lens_result_32b_df = pd.read_csv(\"results/webvoyager_lens-qwen2.5-32b-instruct_all_seeds.csv\").drop(columns=['Unnamed: 0'])\n",
    "lens_result_32b_df.rename(columns={'lens_unnormalized': \"32b_lens_unnormalized\", \"lens\": \"32b_lens\"},inplace=True)\n",
    "# we negate pad and mmd as their correlation with accuracy is negative\n",
    "result_df['pad'] = result_df['pad'].apply(lambda x: -x)\n",
    "result_df['mmd'] = result_df['mmd'].apply(lambda x: -x)\n",
    "result_merged_df = pd.merge(lens_result_7b_df, result_df, how='left', on=['seed', 'partition', 'success_rate', 'website'])\n",
    "result_merged_df = pd.merge(lens_result_32b_df, result_merged_df, how='left', on=['seed', 'partition', 'success_rate', 'website'])\n",
    "result_merged_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>metric</th>\n",
       "      <th>average_success_rate</th>\n",
       "      <th>top3_average_success_rate</th>\n",
       "      <th>delta</th>\n",
       "      <th>percentage_improvement</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>32b_lens</td>\n",
       "      <td>25.84</td>\n",
       "      <td>26.32</td>\n",
       "      <td>0.49</td>\n",
       "      <td>2.62</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>32b_lens_unnormalized</td>\n",
       "      <td>25.84</td>\n",
       "      <td>26.02</td>\n",
       "      <td>0.18</td>\n",
       "      <td>0.99</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>7b_lens</td>\n",
       "      <td>25.84</td>\n",
       "      <td>26.54</td>\n",
       "      <td>0.71</td>\n",
       "      <td>2.36</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>7b_lens_unnormalized</td>\n",
       "      <td>25.84</td>\n",
       "      <td>26.28</td>\n",
       "      <td>0.45</td>\n",
       "      <td>1.72</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>mdm</td>\n",
       "      <td>25.84</td>\n",
       "      <td>25.76</td>\n",
       "      <td>-0.08</td>\n",
       "      <td>0.10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>mmd</td>\n",
       "      <td>25.84</td>\n",
       "      <td>26.55</td>\n",
       "      <td>0.72</td>\n",
       "      <td>0.55</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>pad</td>\n",
       "      <td>25.84</td>\n",
       "      <td>25.71</td>\n",
       "      <td>-0.13</td>\n",
       "      <td>-0.50</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  metric  average_success_rate  top3_average_success_rate  \\\n",
       "0               32b_lens                 25.84                      26.32   \n",
       "1  32b_lens_unnormalized                 25.84                      26.02   \n",
       "2                7b_lens                 25.84                      26.54   \n",
       "3   7b_lens_unnormalized                 25.84                      26.28   \n",
       "4                    mdm                 25.84                      25.76   \n",
       "5                    mmd                 25.84                      26.55   \n",
       "6                    pad                 25.84                      25.71   \n",
       "\n",
       "   delta  percentage_improvement  \n",
       "0   0.49                    2.62  \n",
       "1   0.18                    0.99  \n",
       "2   0.71                    2.36  \n",
       "3   0.45                    1.72  \n",
       "4  -0.08                    0.10  \n",
       "5   0.72                    0.55  \n",
       "6  -0.13                   -0.50  "
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "top3_averages = []\n",
    "for seed in [42, 43, 44, 45, 46]:\n",
    "    seed_df = result_merged_df[result_merged_df['seed'] == seed]\n",
    "    for website in result_merged_df.website.unique():\n",
    "        website_df = seed_df[seed_df['website'] == website]\n",
    "        average_test_acc_mean = website_df['success_rate'].mean()\n",
    "        for metric in ['pad', 'mmd', 'mdm', '7b_lens_unnormalized', '7b_lens', '32b_lens_unnormalized', '32b_lens']:\n",
    "            top3_average_test_acc_mean = website_df.sort_values(metric).tail(3)['success_rate'].mean()\n",
    "            delta_percentage = top3_average_test_acc_mean - average_test_acc_mean\n",
    "            percentage_improvement = (delta_percentage / average_test_acc_mean) * 100\n",
    "            top3_averages.append({\n",
    "                \"seed\": seed,\n",
    "                \"website\": website,\n",
    "                \"metric\": metric,\n",
    "                \"average_success_rate\": average_test_acc_mean,\n",
    "                \"top3_average_success_rate\": top3_average_test_acc_mean,\n",
    "                \"delta\": delta_percentage,\n",
    "                \"percentage_improvement\": percentage_improvement\n",
    "            })\n",
    "top3_average_test_acc_mean_df = pd.DataFrame(top3_averages)\n",
    "top3_average_test_acc_mean_df_agg = top3_average_test_acc_mean_df.groupby([\"metric\", \"website\"]).agg(\"mean\").reset_index()\n",
    "result_cols = ['average_success_rate', 'top3_average_success_rate', 'delta', 'percentage_improvement']\n",
    "top3_average_test_acc_mean_df_agg.drop(columns=['seed'], inplace=True)\n",
    "top3_average_test_acc_mean_df_agg[['average_success_rate','top3_average_success_rate', 'delta']] = top3_average_test_acc_mean_df_agg[['average_success_rate','top3_average_success_rate', 'delta']].apply(lambda x: x*100)\n",
    "top3_average_test_acc_mean_df_agg.groupby(\"metric\").mean(numeric_only=True).reset_index().round(2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "data-synthesis",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
