{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "098a4a72",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "================================================================================\n",
      "MODEL: Gemma2-2B L12\n",
      "================================================================================\n",
      "\n",
      "============ Gemma2-2B L12 | PAIRWISE Kendall tau-b: AutoInterp vs SteeringBase ============\n",
      "- Overall: n=30 pairs=435  tau_b=0.2184  (C=265, D=170, T_x=0, T_y=0, T_xy=0)\n",
      "- Psi on axis: Sparsity-within-Architecture: psi=0.2533  (SE≈0.1718, 95% boot CI=(-0.04000000000000001, 0.5466666666666666))\n",
      "    • Sparsity=Batch TopK: n=6 tau_b=0.7333 (pairs=15) | perm p=0.0570, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=Gated: n=6 tau_b=-0.2000 (pairs=15) | perm p=0.7259, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=Jump ReLU: n=6 tau_b=0.4667 (pairs=15) | perm p=0.2689, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=ReLU: n=6 tau_b=-0.0667 (pairs=15) | perm p=1.0000, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=TopK: n=6 tau_b=0.3333 (pairs=15) | perm p=0.4669, perm 95% CI=[-0.7333, 0.7333]\n",
      "- Psi on axis: Architecture-at-matched-Slot: psi=0.1333  (SE≈0.1606, 95% boot CI=(-0.16666666666666666, 0.4000000000000001))\n",
      "    • Architecture=1: n=5 tau_b=0.4000 (pairs=10) | perm p=0.4733, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=2: n=5 tau_b=0.4000 (pairs=10) | perm p=0.4867, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=3: n=5 tau_b=0.6000 (pairs=10) | perm p=0.2240, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=4: n=5 tau_b=-0.2000 (pairs=10) | perm p=0.8114, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=5: n=5 tau_b=-0.4000 (pairs=10) | perm p=0.4829, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=6: n=5 tau_b=0.0000 (pairs=10) | perm p=1.0000, perm 95% CI=[-0.8000, 0.8000]\n",
      "- Overall Psi (mean of the two axes A/B) = 0.1933\n",
      "\n",
      "============ Gemma2-2B L12 | PAIRWISE Kendall tau-b: AutoInterp vs Lift ============\n",
      "- Overall: n=30 pairs=435  tau_b=-0.1402  (C=187, D=248, T_x=0, T_y=0, T_xy=0)\n",
      "- Psi on axis: Sparsity-within-Architecture: psi=-0.2800  (SE≈0.1718, 95% boot CI=(-0.6266666666666667, -0.013333333333333332))\n",
      "    • Sparsity=Batch TopK: n=6 tau_b=-0.8667 (pairs=15) | perm p=0.0188, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=Gated: n=6 tau_b=-0.0667 (pairs=15) | perm p=1.0000, perm 95% CI=[-0.6000, 0.6000]\n",
      "    • Sparsity=Jump ReLU: n=6 tau_b=-0.0667 (pairs=15) | perm p=1.0000, perm 95% CI=[-0.7333, 0.6000]\n",
      "    • Sparsity=ReLU: n=6 tau_b=-0.4667 (pairs=15) | perm p=0.2575, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=TopK: n=6 tau_b=0.0667 (pairs=15) | perm p=1.0000, perm 95% CI=[-0.6000, 0.7333]\n",
      "- Psi on axis: Architecture-at-matched-Slot: psi=-0.0000  (SE≈0.1549, 95% boot CI=(-0.26666666666666666, 0.26666666666666666))\n",
      "    • Architecture=1: n=5 tau_b=0.6000 (pairs=10) | perm p=0.2254, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=2: n=5 tau_b=-0.4000 (pairs=10) | perm p=0.4611, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=3: n=5 tau_b=0.0000 (pairs=10) | perm p=1.0000, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=4: n=5 tau_b=0.0000 (pairs=10) | perm p=1.0000, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=5: n=5 tau_b=-0.4000 (pairs=10) | perm p=0.4681, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=6: n=5 tau_b=0.2000 (pairs=10) | perm p=0.8134, perm 95% CI=[-0.8000, 0.8000]\n",
      "- Overall Psi (mean of the two axes A/B) = -0.1400\n",
      "\n",
      "============ Gemma2-2B L12 | PAIRWISE Kendall tau-b: AutoInterp vs Delta ============\n",
      "- Overall: n=30 pairs=435  tau_b=-0.0989  (C=196, D=239, T_x=0, T_y=0, T_xy=0)\n",
      "- Psi on axis: Sparsity-within-Architecture: psi=-0.3067  (SE≈0.1809, 95% boot CI=(-0.6266666666666667, -0.013333333333333319))\n",
      "    • Sparsity=Batch TopK: n=6 tau_b=-0.8667 (pairs=15) | perm p=0.0188, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=Gated: n=6 tau_b=-0.3333 (pairs=15) | perm p=0.4695, perm 95% CI=[-0.6000, 0.6000]\n",
      "    • Sparsity=Jump ReLU: n=6 tau_b=-0.0667 (pairs=15) | perm p=1.0000, perm 95% CI=[-0.7333, 0.6000]\n",
      "    • Sparsity=ReLU: n=6 tau_b=-0.4667 (pairs=15) | perm p=0.2575, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=TopK: n=6 tau_b=0.2000 (pairs=15) | perm p=0.7213, perm 95% CI=[-0.6033, 0.7333]\n",
      "- Psi on axis: Architecture-at-matched-Slot: psi=0.0667  (SE≈0.1116, 95% boot CI=(-0.13333333333333333, 0.23333333333333336))\n",
      "    • Architecture=1: n=5 tau_b=0.4000 (pairs=10) | perm p=0.4703, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=2: n=5 tau_b=0.0000 (pairs=10) | perm p=1.0000, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=3: n=5 tau_b=0.2000 (pairs=10) | perm p=0.8142, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=4: n=5 tau_b=0.0000 (pairs=10) | perm p=1.0000, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=5: n=5 tau_b=-0.4000 (pairs=10) | perm p=0.4681, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=6: n=5 tau_b=0.2000 (pairs=10) | perm p=0.8134, perm 95% CI=[-0.8000, 0.8000]\n",
      "- Overall Psi (mean of the two axes A/B) = -0.1200\n",
      "\n",
      "================================================================================\n",
      "MODEL: Qwen2.5-3B L17\n",
      "================================================================================\n",
      "\n",
      "============ Qwen2.5-3B L17 | PAIRWISE Kendall tau-b: AutoInterp vs SteeringBase ============\n",
      "- Overall: n=30 pairs=435  tau_b=0.4575  (C=317, D=118, T_x=0, T_y=0, T_xy=0)\n",
      "- Psi on axis: Sparsity-within-Architecture: psi=0.5200  (SE≈0.2004, 95% boot CI=(0.11999999999999997, 0.8400000000000001))\n",
      "    • Sparsity=Batch TopK: n=6 tau_b=1.0000 (pairs=15) | perm p=0.0032, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=Gated: n=6 tau_b=-0.2000 (pairs=15) | perm p=0.7277, perm 95% CI=[-0.7333, 0.6000]\n",
      "    • Sparsity=Jump ReLU: n=6 tau_b=0.4667 (pairs=15) | perm p=0.2849, perm 95% CI=[-0.6000, 0.7333]\n",
      "    • Sparsity=ReLU: n=6 tau_b=0.7333 (pairs=15) | perm p=0.0572, perm 95% CI=[-0.7333, 0.6000]\n",
      "    • Sparsity=TopK: n=6 tau_b=0.6000 (pairs=15) | perm p=0.1396, perm 95% CI=[-0.7333, 0.7333]\n",
      "- Psi on axis: Architecture-at-matched-Slot: psi=-0.1333  (SE≈0.1978, 95% boot CI=(-0.46666666666666673, 0.2333333333333333))\n",
      "    • Architecture=1: n=5 tau_b=0.6000 (pairs=10) | perm p=0.2302, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=2: n=5 tau_b=-0.2000 (pairs=10) | perm p=0.8216, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=3: n=5 tau_b=-0.8000 (pairs=10) | perm p=0.0818, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=4: n=5 tau_b=0.2000 (pairs=10) | perm p=0.8156, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=5: n=5 tau_b=-0.2000 (pairs=10) | perm p=0.8170, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=6: n=5 tau_b=-0.4000 (pairs=10) | perm p=0.4775, perm 95% CI=[-0.8000, 0.8000]\n",
      "- Overall Psi (mean of the two axes A/B) = 0.1933\n",
      "\n",
      "============ Qwen2.5-3B L17 | PAIRWISE Kendall tau-b: AutoInterp vs Lift ============\n",
      "- Overall: n=30 pairs=435  tau_b=-0.1448  (C=186, D=249, T_x=0, T_y=0, T_xy=0)\n",
      "- Psi on axis: Sparsity-within-Architecture: psi=-0.2533  (SE≈0.1083, 95% boot CI=(-0.4666666666666666, -0.06666666666666668))\n",
      "    • Sparsity=Batch TopK: n=6 tau_b=-0.3333 (pairs=15) | perm p=0.4603, perm 95% CI=[-0.7333, 0.6000]\n",
      "    • Sparsity=Gated: n=6 tau_b=-0.2000 (pairs=15) | perm p=0.7301, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=Jump ReLU: n=6 tau_b=0.0667 (pairs=15) | perm p=1.0000, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=ReLU: n=6 tau_b=-0.6000 (pairs=15) | perm p=0.1478, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=TopK: n=6 tau_b=-0.2000 (pairs=15) | perm p=0.7211, perm 95% CI=[-0.7333, 0.7333]\n",
      "- Psi on axis: Architecture-at-matched-Slot: psi=-0.0333  (SE≈0.1308, 95% boot CI=(-0.26666666666666666, 0.16666666666666666))\n",
      "    • Architecture=1: n=5 tau_b=-0.2000 (pairs=10) | perm p=0.8114, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=2: n=5 tau_b=0.2000 (pairs=10) | perm p=0.8066, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=3: n=5 tau_b=0.2000 (pairs=10) | perm p=0.8136, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=4: n=5 tau_b=0.0000 (pairs=10) | perm p=1.0000, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=5: n=5 tau_b=-0.6000 (pairs=10) | perm p=0.2354, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=6: n=5 tau_b=0.2000 (pairs=10) | perm p=0.8080, perm 95% CI=[-0.8000, 0.8000]\n",
      "- Overall Psi (mean of the two axes A/B) = -0.1433\n",
      "\n",
      "============ Qwen2.5-3B L17 | PAIRWISE Kendall tau-b: AutoInterp vs Delta ============\n",
      "- Overall: n=30 pairs=435  tau_b=-0.0345  (C=210, D=225, T_x=0, T_y=0, T_xy=0)\n",
      "- Psi on axis: Sparsity-within-Architecture: psi=0.0133  (SE≈0.0998, 95% boot CI=(-0.17333333333333334, 0.17333333333333334))\n",
      "    • Sparsity=Batch TopK: n=6 tau_b=0.0667 (pairs=15) | perm p=1.0000, perm 95% CI=[-0.6000, 0.6000]\n",
      "    • Sparsity=Gated: n=6 tau_b=-0.0667 (pairs=15) | perm p=1.0000, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=Jump ReLU: n=6 tau_b=0.2000 (pairs=15) | perm p=0.7111, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=ReLU: n=6 tau_b=-0.3333 (pairs=15) | perm p=0.4729, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=TopK: n=6 tau_b=0.2000 (pairs=15) | perm p=0.7173, perm 95% CI=[-0.7333, 0.7333]\n",
      "- Psi on axis: Architecture-at-matched-Slot: psi=-0.1667  (SE≈0.1406, 95% boot CI=(-0.4000000000000001, 0.06666666666666668))\n",
      "    • Architecture=1: n=5 tau_b=-0.4000 (pairs=10) | perm p=0.4687, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=2: n=5 tau_b=0.0000 (pairs=10) | perm p=1.0000, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=3: n=5 tau_b=0.2000 (pairs=10) | perm p=0.8136, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=4: n=5 tau_b=-0.4000 (pairs=10) | perm p=0.4887, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=5: n=5 tau_b=-0.6000 (pairs=10) | perm p=0.2354, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=6: n=5 tau_b=0.2000 (pairs=10) | perm p=0.8080, perm 95% CI=[-0.8000, 0.8000]\n",
      "- Overall Psi (mean of the two axes A/B) = -0.0767\n",
      "\n",
      "================================================================================\n",
      "MODEL: Gemma2-9B L20\n",
      "================================================================================\n",
      "\n",
      "============ Gemma2-9B L20 | PAIRWISE Kendall tau-b: AutoInterp vs SteeringBase ============\n",
      "- Overall: n=30 pairs=435  tau_b=0.3057  (C=284, D=151, T_x=0, T_y=0, T_xy=0)\n",
      "- Psi on axis: Sparsity-within-Architecture: psi=0.3600  (SE≈0.1543, 95% boot CI=(0.06666666666666667, 0.6))\n",
      "    • Sparsity=Batch TopK: n=6 tau_b=0.3333 (pairs=15) | perm p=0.4775, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=Gated: n=6 tau_b=-0.2000 (pairs=15) | perm p=0.7163, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=Jump ReLU: n=6 tau_b=0.4667 (pairs=15) | perm p=0.2615, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=ReLU: n=6 tau_b=0.4667 (pairs=15) | perm p=0.2801, perm 95% CI=[-0.7333, 0.6000]\n",
      "    • Sparsity=TopK: n=6 tau_b=0.7333 (pairs=15) | perm p=0.0570, perm 95% CI=[-0.7333, 0.7333]\n",
      "- Psi on axis: Architecture-at-matched-Slot: psi=0.3333  (SE≈0.1687, 95% boot CI=(0.03333333333333335, 0.6333333333333333))\n",
      "    • Architecture=1: n=5 tau_b=0.8000 (pairs=10) | perm p=0.0776, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=2: n=5 tau_b=0.8000 (pairs=10) | perm p=0.0866, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=3: n=5 tau_b=0.2000 (pairs=10) | perm p=0.8214, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=4: n=5 tau_b=0.4000 (pairs=10) | perm p=0.4739, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=5: n=5 tau_b=-0.2000 (pairs=10) | perm p=0.8110, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=6: n=5 tau_b=0.0000 (pairs=10) | perm p=1.0000, perm 95% CI=[-0.8000, 0.8000]\n",
      "- Overall Psi (mean of the two axes A/B) = 0.3467\n",
      "\n",
      "============ Gemma2-9B L20 | PAIRWISE Kendall tau-b: AutoInterp vs Lift ============\n",
      "- Overall: n=30 pairs=435  tau_b=-0.0483  (C=207, D=228, T_x=0, T_y=0, T_xy=0)\n",
      "- Psi on axis: Sparsity-within-Architecture: psi=0.0133  (SE≈0.0533, 95% boot CI=(-0.09333333333333335, 0.06666666666666667))\n",
      "    • Sparsity=Batch TopK: n=6 tau_b=-0.2000 (pairs=15) | perm p=0.7153, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=Gated: n=6 tau_b=0.0667 (pairs=15) | perm p=1.0000, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=Jump ReLU: n=6 tau_b=0.0667 (pairs=15) | perm p=1.0000, perm 95% CI=[-0.6000, 0.7333]\n",
      "    • Sparsity=ReLU: n=6 tau_b=0.0667 (pairs=15) | perm p=1.0000, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=TopK: n=6 tau_b=0.0667 (pairs=15) | perm p=1.0000, perm 95% CI=[-0.6000, 0.7333]\n",
      "- Psi on axis: Architecture-at-matched-Slot: psi=-0.0667  (SE≈0.2044, 95% boot CI=(-0.43333333333333335, 0.3))\n",
      "    • Architecture=1: n=5 tau_b=-0.2000 (pairs=10) | perm p=0.8200, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=2: n=5 tau_b=-0.6000 (pairs=10) | perm p=0.2390, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=3: n=5 tau_b=0.6000 (pairs=10) | perm p=0.2354, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=4: n=5 tau_b=0.4000 (pairs=10) | perm p=0.4893, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=5: n=5 tau_b=0.0000 (pairs=10) | perm p=1.0000, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=6: n=5 tau_b=-0.6000 (pairs=10) | perm p=0.2356, perm 95% CI=[-0.8000, 0.8000]\n",
      "- Overall Psi (mean of the two axes A/B) = -0.0267\n",
      "\n",
      "============ Gemma2-9B L20 | PAIRWISE Kendall tau-b: AutoInterp vs Delta ============\n",
      "- Overall: n=30 pairs=435  tau_b=-0.0069  (C=216, D=219, T_x=0, T_y=0, T_xy=0)\n",
      "- Psi on axis: Sparsity-within-Architecture: psi=0.0133  (SE≈0.0533, 95% boot CI=(-0.09333333333333335, 0.06666666666666667))\n",
      "    • Sparsity=Batch TopK: n=6 tau_b=-0.2000 (pairs=15) | perm p=0.7079, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=Gated: n=6 tau_b=0.0667 (pairs=15) | perm p=1.0000, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=Jump ReLU: n=6 tau_b=0.0667 (pairs=15) | perm p=1.0000, perm 95% CI=[-0.6000, 0.7333]\n",
      "    • Sparsity=ReLU: n=6 tau_b=0.0667 (pairs=15) | perm p=1.0000, perm 95% CI=[-0.7333, 0.7333]\n",
      "    • Sparsity=TopK: n=6 tau_b=0.0667 (pairs=15) | perm p=1.0000, perm 95% CI=[-0.6000, 0.7333]\n",
      "- Psi on axis: Architecture-at-matched-Slot: psi=0.0000  (SE≈0.2066, 95% boot CI=(-0.39999999999999997, 0.36666666666666664))\n",
      "    • Architecture=1: n=5 tau_b=0.0000 (pairs=10) | perm p=1.0000, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=2: n=5 tau_b=-0.6000 (pairs=10) | perm p=0.2390, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=3: n=5 tau_b=0.6000 (pairs=10) | perm p=0.2200, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=4: n=5 tau_b=0.4000 (pairs=10) | perm p=0.4893, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=5: n=5 tau_b=0.2000 (pairs=10) | perm p=0.8190, perm 95% CI=[-0.8000, 0.8000]\n",
      "    • Architecture=6: n=5 tau_b=-0.6000 (pairs=10) | perm p=0.2356, perm 95% CI=[-0.8000, 0.8000]\n",
      "- Overall Psi (mean of the two axes A/B) = 0.0067\n",
      "\n",
      "============ Combined (All Models) | PAIRWISE Kendall tau-b: AutoInterp vs SteeringBase ============\n",
      "- Overall: n=90 pairs=4005  tau_b=0.2979  BCa95%=[0.1590, 0.4191]  (C=2599, D=1406, T_x=0, T_y=0, T_xy=0)\n",
      "- Psi on axis: Sparsity-within-Architecture: psi=0.2575  (SE≈0.1163, 95% boot CI=(0.02222222222222222, 0.39607843137254906))\n",
      "    • Sparsity=Batch TopK: n=18 tau_b=0.3203 (pairs=153) | perm p=0.0712, perm 95% CI=[-0.3464, 0.3464]\n",
      "    • Sparsity=Gated: n=18 tau_b=-0.2026 (pairs=153) | perm p=0.2577, perm 95% CI=[-0.3464, 0.3333]\n",
      "    • Sparsity=Jump ReLU: n=18 tau_b=0.4248 (pairs=153) | perm p=0.0160, perm 95% CI=[-0.3333, 0.3333]\n",
      "    • Sparsity=ReLU: n=18 tau_b=0.3595 (pairs=153) | perm p=0.0392, perm 95% CI=[-0.3333, 0.3333]\n",
      "    • Sparsity=TopK: n=18 tau_b=0.3856 (pairs=153) | perm p=0.0272, perm 95% CI=[-0.3333, 0.3333]\n",
      "- Psi on axis: Architecture-at-matched-Slot: psi=0.1651  (SE≈0.1112, 95% boot CI=(-0.02857142857142857, 0.35873015873015873))\n",
      "    • Architecture=1: n=15 tau_b=0.5429 (pairs=105) | perm p=0.0034, perm 95% CI=[-0.3714, 0.3714]\n",
      "    • Architecture=2: n=15 tau_b=0.3524 (pairs=105) | perm p=0.0740, perm 95% CI=[-0.3714, 0.3714]\n",
      "    • Architecture=3: n=15 tau_b=0.1810 (pairs=105) | perm p=0.3821, perm 95% CI=[-0.3905, 0.3714]\n",
      "    • Architecture=4: n=15 tau_b=0.1810 (pairs=105) | perm p=0.3673, perm 95% CI=[-0.3714, 0.3714]\n",
      "    • Architecture=5: n=15 tau_b=-0.2190 (pairs=105) | perm p=0.2837, perm 95% CI=[-0.3905, 0.3714]\n",
      "    • Architecture=6: n=15 tau_b=-0.0476 (pairs=105) | perm p=0.8484, perm 95% CI=[-0.3905, 0.3714]\n",
      "- Psi on axis: Within-Model: psi=0.3272  (SE≈0.0698, 95% boot CI=(0.21839080459770113, 0.4574712643678161))\n",
      "    • Within=Gemma2-2B L12: n=30 tau_b=0.2184 (pairs=435) | perm p=0.0980, perm 95% CI=[-0.2598, 0.2552]\n",
      "    • Within=Qwen2.5-3B L17: n=30 tau_b=0.4575 (pairs=435) | perm p=0.0008, perm 95% CI=[-0.2506, 0.2506]\n",
      "    • Within=Gemma2-9B L20: n=30 tau_b=0.3057 (pairs=435) | perm p=0.0166, perm 95% CI=[-0.2506, 0.2461]\n",
      "- Overall Psi (mean of A/B/C axes) = 0.2499\n",
      "\n",
      "============ Combined (All Models) | PAIRWISE Kendall tau-b: AutoInterp vs Lift ============\n",
      "- Overall: n=90 pairs=4005  tau_b=-0.0692  BCa95%=[-0.2019, 0.0666]  (C=1864, D=2141, T_x=0, T_y=0, T_xy=0)\n",
      "- Psi on axis: Sparsity-within-Architecture: psi=-0.0719  (SE≈0.0781, 95% boot CI=(-0.20784313725490194, 0.06143790849673203))\n",
      "    • Sparsity=Batch TopK: n=18 tau_b=-0.2288 (pairs=153) | perm p=0.2004, perm 95% CI=[-0.3333, 0.3464]\n",
      "    • Sparsity=Gated: n=18 tau_b=0.0327 (pairs=153) | perm p=0.8792, perm 95% CI=[-0.3464, 0.3333]\n",
      "    • Sparsity=Jump ReLU: n=18 tau_b=-0.0065 (pairs=153) | perm p=1.0000, perm 95% CI=[-0.3464, 0.3203]\n",
      "    • Sparsity=ReLU: n=18 tau_b=-0.2810 (pairs=153) | perm p=0.1096, perm 95% CI=[-0.3333, 0.3464]\n",
      "    • Sparsity=TopK: n=18 tau_b=0.1242 (pairs=153) | perm p=0.4985, perm 95% CI=[-0.3464, 0.3337]\n",
      "- Psi on axis: Architecture-at-matched-Slot: psi=0.0127  (SE≈0.0457, 95% boot CI=(-0.07619047619047618, 0.0888888888888889))\n",
      "    • Architecture=1: n=15 tau_b=0.0476 (pairs=105) | perm p=0.8466, perm 95% CI=[-0.3714, 0.3714]\n",
      "    • Architecture=2: n=15 tau_b=0.1619 (pairs=105) | perm p=0.4437, perm 95% CI=[-0.3905, 0.3714]\n",
      "    • Architecture=3: n=15 tau_b=0.0095 (pairs=105) | perm p=1.0000, perm 95% CI=[-0.3714, 0.3714]\n",
      "    • Architecture=4: n=15 tau_b=0.0476 (pairs=105) | perm p=0.8452, perm 95% CI=[-0.3714, 0.3714]\n",
      "    • Architecture=5: n=15 tau_b=-0.1810 (pairs=105) | perm p=0.3797, perm 95% CI=[-0.3714, 0.3714]\n",
      "    • Architecture=6: n=15 tau_b=-0.0095 (pairs=105) | perm p=1.0000, perm 95% CI=[-0.3714, 0.3905]\n",
      "- Psi on axis: Within-Model: psi=-0.1111  (SE≈0.0314, 95% boot CI=(-0.14482758620689656, -0.048275862068965524))\n",
      "    • Within=Gemma2-2B L12: n=30 tau_b=-0.1402 (pairs=435) | perm p=0.2911, perm 95% CI=[-0.2552, 0.2598]\n",
      "    • Within=Qwen2.5-3B L17: n=30 tau_b=-0.1448 (pairs=435) | perm p=0.2781, perm 95% CI=[-0.2507, 0.2460]\n",
      "    • Within=Gemma2-9B L20: n=30 tau_b=-0.0483 (pairs=435) | perm p=0.7157, perm 95% CI=[-0.2506, 0.2552]\n",
      "- Overall Psi (mean of A/B/C axes) = -0.0568\n",
      "\n",
      "============ Combined (All Models) | PAIRWISE Kendall tau-b: AutoInterp vs Delta ============\n",
      "- Overall: n=90 pairs=4005  tau_b=-0.0097  BCa95%=[-0.1408, 0.1233]  (C=1983, D=2022, T_x=0, T_y=0, T_xy=0)\n",
      "- Psi on axis: Sparsity-within-Architecture: psi=0.0065  (SE≈0.0891, 95% boot CI=(-0.13725490196078433, 0.15294117647058822))\n",
      "    • Sparsity=Batch TopK: n=18 tau_b=-0.1765 (pairs=153) | perm p=0.3255, perm 95% CI=[-0.3333, 0.3464]\n",
      "    • Sparsity=Gated: n=18 tau_b=0.0588 (pairs=153) | perm p=0.7592, perm 95% CI=[-0.3333, 0.3333]\n",
      "    • Sparsity=Jump ReLU: n=18 tau_b=0.0458 (pairs=153) | perm p=0.8260, perm 95% CI=[-0.3464, 0.3203]\n",
      "    • Sparsity=ReLU: n=18 tau_b=-0.1895 (pairs=153) | perm p=0.2825, perm 95% CI=[-0.3333, 0.3464]\n",
      "    • Sparsity=TopK: n=18 tau_b=0.2941 (pairs=153) | perm p=0.1006, perm 95% CI=[-0.3464, 0.3464]\n",
      "- Psi on axis: Architecture-at-matched-Slot: psi=0.0444  (SE≈0.0493, 95% boot CI=(-0.044444444444444446, 0.13015873015873017))\n",
      "    • Architecture=1: n=15 tau_b=0.1048 (pairs=105) | perm p=0.6269, perm 95% CI=[-0.3714, 0.3714]\n",
      "    • Architecture=2: n=15 tau_b=0.2190 (pairs=105) | perm p=0.2921, perm 95% CI=[-0.3714, 0.3905]\n",
      "    • Architecture=3: n=15 tau_b=0.0667 (pairs=105) | perm p=0.7656, perm 95% CI=[-0.3714, 0.3714]\n",
      "    • Architecture=4: n=15 tau_b=-0.0095 (pairs=105) | perm p=1.0000, perm 95% CI=[-0.3714, 0.3714]\n",
      "    • Architecture=5: n=15 tau_b=-0.1429 (pairs=105) | perm p=0.4887, perm 95% CI=[-0.3714, 0.3714]\n",
      "    • Architecture=6: n=15 tau_b=0.0286 (pairs=105) | perm p=0.9210, perm 95% CI=[-0.3714, 0.3714]\n",
      "- Psi on axis: Within-Model: psi=-0.0467  (SE≈0.0272, 95% boot CI=(-0.09885057471264368, -0.006896551724137931))\n",
      "    • Within=Gemma2-2B L12: n=30 tau_b=-0.0989 (pairs=435) | perm p=0.4571, perm 95% CI=[-0.2598, 0.2552]\n",
      "    • Within=Qwen2.5-3B L17: n=30 tau_b=-0.0345 (pairs=435) | perm p=0.8050, perm 95% CI=[-0.2506, 0.2460]\n",
      "    • Within=Gemma2-9B L20: n=30 tau_b=-0.0069 (pairs=435) | perm p=0.9668, perm 95% CI=[-0.2506, 0.2552]\n",
      "- Overall Psi (mean of A/B/C axes) = 0.0014\n"
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "Pairwise (Kendall's tau-b) + Granulated Kendall (Psi) analysis\n",
    "between interpretability (AutoInterp) and utility (Steering) for three models:\n",
    "\n",
    "  1) Gemma 2 (2B)  - Layer 12\n",
    "  2) Qwen 2.5 (3B) - Layer 17\n",
    "  3) Gemma 2 (9B)  - Layer 20\n",
    "\n",
    "What this script does:\n",
    "- For each model, constructs the 30 SAEs from the provided MAPS (trainer->sparsity).\n",
    "- Loads AutoInterp scores from the model-specific AUTOINTERP_DIR (robust keying).\n",
    "- Loads Steering Base/After from model-specific EVAL_ROOT/ENTROPY_ROOT, with robust\n",
    "  directory resolution across alias naming like: jump_relu vs jumprelu, relu vs\n",
    "  standard_april_update, topk vs top_k, batch_topk vs batch_top_k, etc.\n",
    "- Runs Kendall’s tau-b and Psi along three axes:\n",
    "    (A) Sparsity-within-Architecture\n",
    "    (B) Architecture-at-matched-Slot\n",
    "    (C) Within-Model (new)\n",
    "- For each *group* tau-b, performs a permutation test (shuffle AutoInterp vs target)\n",
    "  to obtain a p-value and an empirical 95% CI.\n",
    "- For each axis-level Psi, performs a group-level bootstrap (resample groups) to get\n",
    "  a 95% CI in addition to the SE.\n",
    "- For the Combined dataset's Overall tau-b, reports a BCa bootstrap 95% CI.\n",
    "\n",
    "Utility target:\n",
    "- Primary analyses use Lift/Delta derived from SteeringAfter and SteeringBase:\n",
    "    Lift  = (After - Base) / max(Base, eps)\n",
    "    Delta = After - Base\n",
    "- We also keep SteeringBase as a control comparison.\n",
    "\n",
    "Notes:\n",
    "- No preview table is printed.\n",
    "- If fewer than 2 valid SAEs remain for a comparison, that comparison is skipped.\n",
    "\"\"\"\n",
    "\n",
    "import os\n",
    "import re\n",
    "import math\n",
    "import json\n",
    "import itertools\n",
    "import random\n",
    "from dataclasses import dataclass\n",
    "from typing import Dict, Any, List, Optional, Tuple, Callable\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from statistics import NormalDist\n",
    "\n",
    "# =============================================================================\n",
    "# ----- Repro -----\n",
    "# =============================================================================\n",
    "\n",
    "GLOBAL_RANDOM_SEED = 42\n",
    "random.seed(GLOBAL_RANDOM_SEED)\n",
    "np.random.seed(GLOBAL_RANDOM_SEED)\n",
    "\n",
    "# Standard normal helpers (stable across Python versions)\n",
    "_NORM = NormalDist(0, 1)\n",
    "def norm_ppf(p: float) -> float:\n",
    "    p = float(np.clip(p, 1e-12, 1 - 1e-12))\n",
    "    return _NORM.inv_cdf(p)\n",
    "\n",
    "def norm_cdf(z: float) -> float:\n",
    "    return _NORM.cdf(z)\n",
    "\n",
    "# =============================================================================\n",
    "# ----- Model configs (paths + maps) -----\n",
    "# =============================================================================\n",
    "\n",
    "@dataclass\n",
    "class ModelConfig:\n",
    "    name: str\n",
    "    # AutoInterp locations\n",
    "    autointerp_base: str\n",
    "    layer_dir: str\n",
    "    # Steering locations\n",
    "    eval_root: str\n",
    "    entropy_root: str\n",
    "    # trainer -> sparsity mapping per architecture\n",
    "    maps: Dict[str, Dict[int, int]]\n",
    "\n",
    "# ----- Sparsity mapping from Gemma2-2b L12 -----\n",
    "MAPS_GEMMA2_2B_L12 = {\n",
    "    \"gated\": {\n",
    "        0: 948, 1: 547, 2: 340, 3: 148, 4: 78, 5: 49,\n",
    "    },\n",
    "    \"jumprelu\": {\n",
    "        4: 52, 5: 330, 6: 538, 7: 779, 8: 83, 9: 165,\n",
    "    },\n",
    "    \"standard_april_update\": {\n",
    "        0: 733, 1: 507, 2: 309, 3: 156, 4: 99, 5: 54,\n",
    "    },\n",
    "    \"batch_topk\": {\n",
    "        0: 50, 1: 320, 2: 520, 3: 820, 6: 80, 7: 160,\n",
    "    },\n",
    "    \"topk\": {\n",
    "        0: 50, 1: 320, 2: 520, 3: 820, 6: 80, 7: 160,\n",
    "    },\n",
    "}\n",
    "\n",
    "# ----- Sparsity mapping from Qwen-3b L17 -----\n",
    "MAPS_QWEN2_5_3B_L17 = {\n",
    "    \"gated\": {\n",
    "        0: 999, 1: 565, 2: 338, 3: 141, 4: 72, 5: 46,\n",
    "    },\n",
    "    \"jumprelu\": {\n",
    "        6: 51, 7: 82, 8: 166, 9: 323, 10: 494, 11: 754,\n",
    "    },\n",
    "    \"standard_april_update\": {\n",
    "        0: 762, 1: 523, 2: 321, 3: 167, 4: 108, 5: 61,\n",
    "    },\n",
    "    \"batch_topk\": {\n",
    "        12: 50, 13: 80, 14: 160, 15: 320, 16: 520, 17: 820,\n",
    "    },\n",
    "    \"topk\": {\n",
    "        6: 50, 7: 80, 8: 160, 9: 320, 10: 520, 11: 820,\n",
    "    },\n",
    "}\n",
    "\n",
    "# ----- Sparsity mapping from Gemma2-9b L20 -----\n",
    "MAPS_GEMMA2_9B_L20 = {\n",
    "    \"gated\": {\n",
    "        0: 1070, 1: 573, 2: 360, 3: 160, 4: 85, 5: 53,\n",
    "    },\n",
    "    \"jumprelu\": {\n",
    "        12: 51, 13: 82, 14: 164, 15: 327, 16: 529, 17: 786,\n",
    "    },\n",
    "    \"standard_april_update\": {\n",
    "        0: 762, 1: 523, 2: 321, 3: 167, 4: 108, 5: 61,\n",
    "    },\n",
    "    \"batch_topk\": {\n",
    "        6: 50, 7: 80, 8: 160, 9: 320, 10: 520, 11: 820,\n",
    "    },\n",
    "    \"topk\": {\n",
    "        6: 50, 7: 80, 8: 160, 9: 320, 10: 520, 11: 820,\n",
    "    },\n",
    "}\n",
    "\n",
    "MODELS: List[ModelConfig] = [\n",
    "    ModelConfig(\n",
    "        name=\"Gemma2-2B L12\",\n",
    "        autointerp_base=\"/home/dslabra5/sae4steer/SAEBench/sae_bench/custom_saes/eval_results/gemma2_2b\",\n",
    "        layer_dir=\"l12\",\n",
    "        eval_root=\"/home/dslabra5/sae4steer/saes-are-good-for-steering/cache/results_sae_eval_openai/gemma2_2b_it/layer12\",\n",
    "        entropy_root=\"/home/dslabra5/sae4steer/saes-are-good-for-steering/cache/results_entropy_score/amp10_top_1/gemma2_2b/layer12\",\n",
    "        maps=MAPS_GEMMA2_2B_L12,\n",
    "    ),\n",
    "    ModelConfig(\n",
    "        name=\"Qwen2.5-3B L17\",\n",
    "        autointerp_base=\"/home/dslabra5/sae4steer/SAEBench/sae_bench/custom_saes/eval_results/qwen2.5_3b\",\n",
    "        layer_dir=\"l17\",\n",
    "        eval_root=\"/home/dslabra5/sae4steer/saes-are-good-for-steering/cache/results_sae_eval_openai/Qwen2.5-3B-Instruct/layer17\",\n",
    "        entropy_root=\"/home/dslabra5/sae4steer/saes-are-good-for-steering/cache/results_entropy_score/amp10_top_1/Qwen2.5-3B/layer17\",\n",
    "        maps=MAPS_QWEN2_5_3B_L17,\n",
    "    ),\n",
    "    ModelConfig(\n",
    "        name=\"Gemma2-9B L20\",\n",
    "        autointerp_base=\"/home/dslabra5/sae4steer/SAEBench/sae_bench/custom_saes/eval_results/gemma2_9b\",\n",
    "        layer_dir=\"l20\",\n",
    "        eval_root=\"/home/dslabra5/sae4steer/saes-are-good-for-steering/cache/results_sae_eval_openai/gemma2_9b_it/layer20\",\n",
    "        entropy_root=\"/home/dslabra5/sae4steer/saes-are-good-for-steering/cache/results_entropy_score/amp10_top_1/gemma2_9b/layer20\",\n",
    "        maps=MAPS_GEMMA2_9B_L20,\n",
    "    ),\n",
    "]\n",
    "\n",
    "# =============================================================================\n",
    "# Common helpers\n",
    "# =============================================================================\n",
    "\n",
    "def list_jsons_recursive(root_dir: str) -> List[str]:\n",
    "    if not os.path.isdir(root_dir):\n",
    "        return []\n",
    "    out: List[str] = []\n",
    "    for dpath, _dnames, fnames in os.walk(root_dir):\n",
    "        for f in fnames:\n",
    "            if f.endswith(\".json\"):\n",
    "                out.append(os.path.join(dpath, f))\n",
    "    return out\n",
    "\n",
    "def safe_get(d: Dict[str, Any], path: List[str], default=None):\n",
    "    cur = d\n",
    "    for k in path:\n",
    "        if not isinstance(cur, dict) or k not in cur:\n",
    "            return default\n",
    "        cur = cur[k]\n",
    "    return cur\n",
    "\n",
    "# =============================================================================\n",
    "# Architecture canonicalization and labels\n",
    "# =============================================================================\n",
    "\n",
    "ARCH_LABELS = {\n",
    "    \"batch_topk\": \"Batch TopK\",\n",
    "    \"gated\": \"Gated\",\n",
    "    \"jumprelu\": \"Jump ReLU\",\n",
    "    \"standard_april_update\": \"ReLU\",\n",
    "    \"topk\": \"TopK\",\n",
    "    # aliases (for display fallback if needed)\n",
    "    \"relu\": \"ReLU\",\n",
    "}\n",
    "\n",
    "_ARCH_CANON_MAP = {\n",
    "    \"batch_topk\": \"batch_topk\",\n",
    "    \"batch_top_k\": \"batch_topk\",\n",
    "    \"gated\": \"gated\",\n",
    "    \"gated_top_k\": \"gated\",\n",
    "    \"gated_top_k_resid_post\": \"gated\",\n",
    "    \"jumprelu\": \"jumprelu\",\n",
    "    \"jump_relu\": \"jumprelu\",\n",
    "    \"relu\": \"standard_april_update\",\n",
    "    \"standard_april_update\": \"standard_april_update\",\n",
    "    \"topk\": \"topk\",\n",
    "    \"top_k\": \"topk\",\n",
    "}\n",
    "\n",
    "def canonicalize_arch(name: str) -> Optional[str]:\n",
    "    if not name:\n",
    "        return None\n",
    "    s = name.lower().replace(\"-\", \"_\")\n",
    "    s = re.sub(r\"__+\", \"_\", s)\n",
    "    if s in _ARCH_CANON_MAP:\n",
    "        return _ARCH_CANON_MAP[s]\n",
    "    if \"batch_top_k\" in s or \"batch_topk\" in s:\n",
    "        return \"batch_topk\"\n",
    "    if \"gated\" in s:\n",
    "        return \"gated\"\n",
    "    if \"jump\" in s and \"relu\" in s:\n",
    "        return \"jumprelu\"\n",
    "    if \"standard_april_update\" in s or re.search(r\"(?:^|_)relu(?:_|$)\", s):\n",
    "        return \"standard_april_update\"\n",
    "    if \"top_k\" in s or re.search(r\"(?:^|_)topk(?:_|$)\", s):\n",
    "        return \"topk\"\n",
    "    return None\n",
    "\n",
    "# =============================================================================\n",
    "# AutoInterp loading (robust keying)\n",
    "# =============================================================================\n",
    "\n",
    "def extract_trainer_idx_from_release_or_name(release_id: str, filename: str) -> Optional[str]:\n",
    "    for text in (release_id or \"\", filename or \"\"):\n",
    "        m = re.search(r\"trainer[_-]?(\\d+)\", text)\n",
    "        if m:\n",
    "            return m.group(1)\n",
    "    return None\n",
    "\n",
    "def key_from_json_robust(data: Dict[str, Any], filepath: str) -> Optional[str]:\n",
    "    arch_cfg = safe_get(data, [\"sae_cfg_dict\", \"architecture\"])\n",
    "    arch_canon = canonicalize_arch(str(arch_cfg) if arch_cfg is not None else \"\")\n",
    "    release_id = data.get(\"sae_lens_release_id\", \"\") or data.get(\"sae_lens_id\", \"\")\n",
    "    idx = extract_trainer_idx_from_release_or_name(release_id, os.path.basename(filepath))\n",
    "\n",
    "    if arch_canon is None:\n",
    "        fname = os.path.basename(filepath).lower()\n",
    "        fname_noext = os.path.splitext(fname)[0]\n",
    "        tokens = re.split(r\"[_\\-]+\", fname_noext)\n",
    "        for w in range(5, 0, -1):\n",
    "            for i in range(0, max(1, len(tokens) - w + 1)):\n",
    "                part = \"_\".join(tokens[i:i+w])\n",
    "                arch_canon = canonicalize_arch(part)\n",
    "                if arch_canon:\n",
    "                    break\n",
    "            if arch_canon:\n",
    "                break\n",
    "\n",
    "    if arch_canon is None or idx is None:\n",
    "        return None\n",
    "    return f\"{arch_canon}_trainer{idx}\"\n",
    "\n",
    "def load_autointerp(autointerp_dir: str) -> Dict[str, float]:\n",
    "    out: Dict[str, float] = {}\n",
    "    for p in list_jsons_recursive(autointerp_dir):\n",
    "        try:\n",
    "            with open(p, \"r\", encoding=\"utf-8\") as f:\n",
    "                data = json.load(f)\n",
    "        except Exception:\n",
    "            continue\n",
    "        key = key_from_json_robust(data, p)\n",
    "        if key is None:\n",
    "            continue\n",
    "        score = safe_get(data, [\"eval_result_metrics\", \"autointerp\", \"autointerp_score\"])\n",
    "        if score is not None:\n",
    "            out[key] = float(score)\n",
    "    return out\n",
    "\n",
    "# =============================================================================\n",
    "# Build SAE list from MAPS (trainer->sparsity)\n",
    "# =============================================================================\n",
    "\n",
    "ARCH_ORDER = [\"batch_topk\", \"gated\", \"jumprelu\", \"standard_april_update\", \"topk\"]\n",
    "\n",
    "def build_sae_list_from_maps(maps: Dict[str, Dict[int, int]]) -> List[Tuple[str, str]]:\n",
    "    \"\"\"\n",
    "    Return list of (display_name, autointerp_key_prefix) of length 30 for a model.\n",
    "\n",
    "    display_name: \"{arch}_trainer{idx}_l0_{sparsity}\"\n",
    "    autointerp_key_prefix: \"{arch}_trainer{idx}\"\n",
    "    The list is ordered by ARCH_ORDER and within each architecture by ascending sparsity.\n",
    "    \"\"\"\n",
    "    rows: List[Tuple[str, str]] = []\n",
    "    for arch in ARCH_ORDER:\n",
    "        if arch not in maps:\n",
    "            continue\n",
    "        pairs = sorted(maps[arch].items(), key=lambda kv: kv[1])  # by sparsity ascending\n",
    "        for trainer_idx, sparsity in pairs:\n",
    "            disp = f\"{arch}_trainer{trainer_idx}_l0_{int(sparsity)}\"\n",
    "            key_prefix = f\"{arch}_trainer{trainer_idx}\"\n",
    "            rows.append((disp, key_prefix))\n",
    "    return rows\n",
    "\n",
    "# =============================================================================\n",
    "# Steering loading and selection (per-model roots)\n",
    "# =============================================================================\n",
    "\n",
    "def safe_load_json(path: str) -> Optional[Dict[str, Any]]:\n",
    "    if not os.path.isfile(path):\n",
    "        return None\n",
    "    try:\n",
    "        with open(path, \"r\") as f:\n",
    "            return json.load(f)\n",
    "    except Exception:\n",
    "        return None\n",
    "\n",
    "def get_overall_from_eval_item(item: Dict[str, Any]) -> Optional[float]:\n",
    "    try:\n",
    "        return float(item[\"holdout\"][\"mean\"][\"overall\"])\n",
    "    except Exception:\n",
    "        return None\n",
    "\n",
    "def compute_base_stats(eval_data: Dict[str, Any]) -> Tuple[int, float]:\n",
    "    num_features = len(eval_data)\n",
    "    total = 0.0\n",
    "    for v in eval_data.values():\n",
    "        s = get_overall_from_eval_item(v)\n",
    "        if s is not None:\n",
    "            total += s\n",
    "    avg = (total / num_features) if num_features > 0 else 0.0\n",
    "    return num_features, avg\n",
    "\n",
    "def build_delta_map(entropy_data: Dict[str, Any]) -> Dict[str, float]:\n",
    "    out: Dict[str, float] = {}\n",
    "    for f, info in entropy_data.items():\n",
    "        if isinstance(info, dict) and (\"delta_confidence\" in info):\n",
    "            try:\n",
    "                out[f] = float(info[\"delta_confidence\"])\n",
    "            except Exception:\n",
    "                pass\n",
    "    return out\n",
    "\n",
    "def avg_for_selected(eval_data: Dict[str, Any], selected_ids: List[str]) -> Tuple[int, float]:\n",
    "    vals: List[float] = []\n",
    "    for fid in selected_ids:\n",
    "        s = get_overall_from_eval_item(eval_data.get(fid, {}))\n",
    "        if s is not None:\n",
    "            vals.append(s)\n",
    "    n = len(vals)\n",
    "    avg = (sum(vals) / n) if n > 0 else 0.0\n",
    "    return n, avg\n",
    "\n",
    "# Selection levels (default: TOP-K only; enable quantiles if desired)\n",
    "USE_QUANTILES = False\n",
    "QUANTILES: List[float] = [0.99, 0.95, 0.90, 0.80]\n",
    "USE_TOPK = True\n",
    "TOPK_LIST: List[int] = [1, 2, 5, 10]\n",
    "\n",
    "def selection_levels_for_up(delta_map, quantiles, topk_list, use_quantiles, use_topk):\n",
    "    levels: List[Tuple[str, List[str]]] = []\n",
    "    if not delta_map:\n",
    "        return levels\n",
    "    fids = list(delta_map.keys())\n",
    "    vals = np.array([delta_map[f] for f in fids], dtype=float)\n",
    "    if use_quantiles and quantiles:\n",
    "        qs = sorted(set(quantiles), reverse=True)\n",
    "        for q in qs:\n",
    "            cut = float(np.quantile(vals, q))\n",
    "            sel = [fid for fid in fids if delta_map[fid] >= cut]\n",
    "            levels.append((f\"Q>= {q:.2f}\", sel))\n",
    "    if use_topk and topk_list:\n",
    "        sorted_up = sorted(fids, key=lambda x: delta_map[x], reverse=True)\n",
    "        for k in sorted(set(topk_list)):\n",
    "            sel = sorted_up[:max(0, min(k, len(sorted_up)))]\n",
    "            levels.append((f\"TOPK_UP= {k}\", sel))\n",
    "    return levels\n",
    "\n",
    "def selection_levels_for_down(delta_map, quantiles, topk_list, use_quantiles, use_topk):\n",
    "    levels: List[Tuple[str, List[str]]] = []\n",
    "    if not delta_map:\n",
    "        return levels\n",
    "    fids = list(delta_map.keys())\n",
    "    vals = np.array([delta_map[f] for f in fids], dtype=float)\n",
    "    if use_quantiles and quantiles:\n",
    "        qs = sorted(set(quantiles), reverse=True)\n",
    "        for q in qs:\n",
    "            low_q = 1.0 - q\n",
    "            low_cut = float(np.quantile(vals, low_q))\n",
    "            sel = [fid for fid in fids if delta_map[fid] <= low_cut]\n",
    "            levels.append((f\"Q<= {low_q:.2f}\", sel))\n",
    "    if use_topk and topk_list:\n",
    "        sorted_down = sorted(fids, key=lambda x: delta_map[x])  # most negative first\n",
    "        for k in sorted(set(topk_list)):\n",
    "            sel = sorted_down[:max(0, min(k, len(sorted_down)))]\n",
    "            levels.append((f\"TOPK_DOWN= {k}\", sel))\n",
    "    return levels\n",
    "\n",
    "# Aliases used for steering directory naming within each model root\n",
    "STEERING_ARCH_DIR_ALIASES: Dict[str, List[str]] = {\n",
    "    \"batch_topk\": [\"batch_topk\", \"batch_top_k\"],\n",
    "    \"gated\": [\"gated\", \"gated_top_k\"],\n",
    "    \"jumprelu\": [\"jumprelu\", \"jump_relu\"],\n",
    "    \"standard_april_update\": [\"standard_april_update\", \"relu\"],\n",
    "    \"topk\": [\"topk\", \"top_k\"],\n",
    "}\n",
    "\n",
    "def steering_dir_candidates(arch_key: str, sparsity_val: int) -> List[str]:\n",
    "    aliases = STEERING_ARCH_DIR_ALIASES.get(arch_key, [arch_key])\n",
    "    return [f\"{alias}_{int(sparsity_val)}\" for alias in aliases]\n",
    "\n",
    "def resolve_steering_dir(eval_root: str, arch_key: str, sparsity_val: int) -> Optional[str]:\n",
    "    candidates = steering_dir_candidates(arch_key, sparsity_val)\n",
    "    for cand in candidates:\n",
    "        eval_path = os.path.join(eval_root, cand, \"eval.json\")\n",
    "        if os.path.isfile(eval_path):\n",
    "            return cand\n",
    "    # fall back to the first candidate (so caller can still try entropy path etc.)\n",
    "    return candidates[0] if candidates else None\n",
    "\n",
    "def compute_steering_scores_for_sae_dirname(\n",
    "    eval_root: str,\n",
    "    entropy_root: str,\n",
    "    sae_dirname: str\n",
    ") -> Tuple[Optional[float], Optional[float], Optional[str]]:\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "        base_avg, best_avg_after, best_label\n",
    "    \"\"\"\n",
    "    if sae_dirname is None:\n",
    "        return None, None, None\n",
    "    eval_path = os.path.join(eval_root, sae_dirname, \"eval.json\")\n",
    "    entropy_path = os.path.join(entropy_root, sae_dirname, \"output_scores_plus_top20_kconf1.json\")\n",
    "\n",
    "    eval_data = safe_load_json(eval_path)\n",
    "    if eval_data is None:\n",
    "        return None, None, None\n",
    "\n",
    "    _, base_avg = compute_base_stats(eval_data)\n",
    "\n",
    "    entropy_data = safe_load_json(entropy_path)\n",
    "    if entropy_data is None:\n",
    "        return base_avg, None, None\n",
    "\n",
    "    delta_map = build_delta_map(entropy_data)\n",
    "    up_levels = selection_levels_for_up(delta_map, QUANTILES, TOPK_LIST, USE_QUANTILES, USE_TOPK)\n",
    "    down_levels = selection_levels_for_down(delta_map, QUANTILES, TOPK_LIST, USE_QUANTILES, USE_TOPK)\n",
    "\n",
    "    best_avg = float(\"-inf\")\n",
    "    best_label = None\n",
    "\n",
    "    for lab, ids in up_levels + down_levels:\n",
    "        n, avg_sub = avg_for_selected(eval_data, ids)\n",
    "        if n > 0 and avg_sub > best_avg:\n",
    "            best_avg = avg_sub\n",
    "            best_label = lab\n",
    "\n",
    "    if best_label is None:\n",
    "        return base_avg, None, None\n",
    "    return base_avg, best_avg, best_label\n",
    "\n",
    "# =============================================================================\n",
    "# Parse arch/sparsity + slot/axes helpers\n",
    "# =============================================================================\n",
    "\n",
    "def parse_arch_and_sparsity_from_display(display_name: str) -> Tuple[str, int]:\n",
    "    \"\"\"\n",
    "    \"{arch}_trainer{idx}_l0_{sparsity}\" -> (\"{arch}\", sparsity)\n",
    "    \"\"\"\n",
    "    m = re.search(r\"^([a-zA-Z_]+)_trainer\\d+_l0_(\\d+)$\", display_name)\n",
    "    if not m:\n",
    "        raise ValueError(f\"Unexpected SAE display name: {display_name}\")\n",
    "    return m.group(1), int(m.group(2))\n",
    "\n",
    "def add_axis_columns(df_in: pd.DataFrame) -> pd.DataFrame:\n",
    "    arch_key, arch_label, spars_val = [], [], []\n",
    "    for name in df_in[\"SAEs\"]:\n",
    "        ak, sval = parse_arch_and_sparsity_from_display(str(name))\n",
    "        arch_key.append(ak)\n",
    "        arch_label.append(ARCH_LABELS.get(ak, ak))\n",
    "        spars_val.append(int(sval))\n",
    "    out = df_in.copy()\n",
    "    out[\"arch_key\"] = arch_key\n",
    "    out[\"Architecture\"] = arch_label\n",
    "    out[\"SparsityVal\"] = spars_val\n",
    "    # Slot is rank within *architecture* by ascending sparsity (per MODEL).\n",
    "    out[\"Slot\"] = (\n",
    "        out.groupby([\"Model\", \"arch_key\"])[\"SparsityVal\"]\n",
    "        .rank(method=\"first\", ascending=True)\n",
    "        .astype(int)\n",
    "    )\n",
    "    return out\n",
    "\n",
    "# =============================================================================\n",
    "# Kendall tau-b + inference (permutation test, bootstrap CIs)\n",
    "# =============================================================================\n",
    "\n",
    "def cmp_sign(a: float, b: float, eps: float = 0.0) -> int:\n",
    "    if eps > 0.0 and abs(a - b) <= eps:\n",
    "        return 0\n",
    "    if a > b: return 1\n",
    "    if a < b: return -1\n",
    "    return 0\n",
    "\n",
    "def kendall_tau_b(x: np.ndarray, y: np.ndarray):\n",
    "    n = len(x)\n",
    "    assert n == len(y) and n >= 2\n",
    "    C = D = T_x = T_y = T_xy = 0\n",
    "    for i, j in itertools.combinations(range(n), 2):\n",
    "        xi, xj = float(x[i]), float(x[j])\n",
    "        yi, yj = float(y[i]), float(y[j])\n",
    "        dx = cmp_sign(xi, xj)\n",
    "        dy = cmp_sign(yi, yj)\n",
    "        if dx == 0 and dy == 0:\n",
    "            T_xy += 1\n",
    "        elif dx == 0 and dy != 0:\n",
    "            T_x += 1\n",
    "        elif dx != 0 and dy == 0:\n",
    "            T_y += 1\n",
    "        else:\n",
    "            if dx == dy: C += 1\n",
    "            else:        D += 1\n",
    "    denom = math.sqrt((C + D + T_x) * (C + D + T_y))\n",
    "    tau_b = 0.0 if denom == 0 else (C - D) / denom\n",
    "    return tau_b, {\"C\": C, \"D\": D, \"T_x\": T_x, \"T_y\": T_y, \"T_xy\": T_xy, \"N_pairs\": n*(n-1)//2}\n",
    "\n",
    "def pairwise_summary(df_sub: pd.DataFrame, x_col: str, y_col: str, label: str):\n",
    "    x = df_sub[x_col].to_numpy()\n",
    "    y = df_sub[y_col].to_numpy()\n",
    "    tau, stats = kendall_tau_b(x, y)\n",
    "    return {\n",
    "        \"where\": label,\n",
    "        \"n\": len(df_sub),\n",
    "        \"pairs\": stats[\"N_pairs\"],\n",
    "        \"tau_b\": tau,\n",
    "        \"C\": stats[\"C\"], \"D\": stats[\"D\"],\n",
    "        \"T_x\": stats[\"T_x\"], \"T_y\": stats[\"T_y\"], \"T_xy\": stats[\"T_xy\"],\n",
    "    }\n",
    "\n",
    "def permutation_test_kendall(x: np.ndarray, y: np.ndarray, n_perm: int = 5000, seed: int = GLOBAL_RANDOM_SEED) -> Tuple[float, Tuple[float, float]]:\n",
    "    \"\"\"\n",
    "    Permutation test by shuffling the pairing between x and y.\n",
    "    Returns:\n",
    "        p_value (two-sided), (ci_lo, ci_hi) empirical 95% CI of the permuted distribution.\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    x = np.asarray(x, dtype=float)\n",
    "    y = np.asarray(y, dtype=float)\n",
    "    obs, _ = kendall_tau_b(x, y)\n",
    "    perms = []\n",
    "    for _ in range(n_perm):\n",
    "        y_perm = rng.permutation(y)\n",
    "        t, _ = kendall_tau_b(x, y_perm)\n",
    "        perms.append(t)\n",
    "    perms = np.array(perms, dtype=float)\n",
    "    # two-sided p: proportion of |perm| >= |obs|\n",
    "    p = float((np.sum(np.abs(perms) >= abs(obs)) + 1) / (n_perm + 1))\n",
    "    ci_lo, ci_hi = float(np.percentile(perms, 2.5)), float(np.percentile(perms, 97.5))\n",
    "    return p, (ci_lo, ci_hi)\n",
    "\n",
    "def psi_along_axis(df_aug: pd.DataFrame, x_col: str, y_col: str, axis: str,\n",
    "                   do_group_permutation: bool = True,\n",
    "                   n_perm: int = 5000) -> Tuple[float, float, Tuple[float, float], List[Dict[str, Any]]]:\n",
    "    \"\"\"\n",
    "    Compute Psi along a given axis by averaging group-wise tau_b.\n",
    "    Also returns SE, and a bootstrap 95% CI over groups.\n",
    "\n",
    "    axis in {\"Sparsity-within-Architecture\", \"Architecture-at-matched-Slot\", \"Within-Model\"}\n",
    "    \"\"\"\n",
    "    per_group: List[Dict[str, Any]] = []\n",
    "    if axis == \"Sparsity-within-Architecture\":\n",
    "        group_iter = df_aug.groupby(\"Architecture\", sort=False)\n",
    "    elif axis == \"Architecture-at-matched-Slot\":\n",
    "        group_iter = df_aug.groupby(\"Slot\", sort=True)\n",
    "    elif axis == \"Within-Model\":\n",
    "        group_iter = df_aug.groupby(\"Model\", sort=False)\n",
    "    else:\n",
    "        raise ValueError(\"Unknown axis.\")\n",
    "\n",
    "    for key, g in group_iter:\n",
    "        if len(g) < 2:\n",
    "            continue\n",
    "        d = pairwise_summary(g, x_col, y_col, label=f\"{axis.split('-')[0]}={key}\")\n",
    "        if do_group_permutation:\n",
    "            px, py = g[x_col].to_numpy(), g[y_col].to_numpy()\n",
    "            pval, (ci_lo, ci_hi) = permutation_test_kendall(px, py, n_perm=n_perm)\n",
    "            d[\"perm_p\"] = pval\n",
    "            d[\"perm_ci_lo\"] = ci_lo\n",
    "            d[\"perm_ci_hi\"] = ci_hi\n",
    "        per_group.append(d)\n",
    "\n",
    "    tau_vals = [d[\"tau_b\"] for d in per_group]\n",
    "    psi_mean = float(np.mean(tau_vals)) if tau_vals else float(\"nan\")\n",
    "    psi_se = float(np.std(tau_vals, ddof=1) / math.sqrt(len(tau_vals))) if len(tau_vals) > 1 else float(\"nan\")\n",
    "\n",
    "    # group-level bootstrap (resample groups with replacement)\n",
    "    def _psi_from_indices(idxs: np.ndarray) -> float:\n",
    "        if len(idxs) == 0:\n",
    "            return float(\"nan\")\n",
    "        return float(np.mean([tau_vals[i] for i in idxs]))\n",
    "\n",
    "    B = 5000\n",
    "    rng = np.random.default_rng(GLOBAL_RANDOM_SEED)\n",
    "    if len(tau_vals) >= 2:\n",
    "        boot = []\n",
    "        m = len(tau_vals)\n",
    "        for _ in range(B):\n",
    "            idxs = rng.integers(0, m, size=m)\n",
    "            boot.append(_psi_from_indices(idxs))\n",
    "        psi_ci_lo = float(np.nanpercentile(boot, 2.5))\n",
    "        psi_ci_hi = float(np.nanpercentile(boot, 97.5))\n",
    "        psi_ci = (psi_ci_lo, psi_ci_hi)\n",
    "    else:\n",
    "        psi_ci = (float(\"nan\"), float(\"nan\"))\n",
    "\n",
    "    return psi_mean, psi_se, psi_ci, per_group\n",
    "\n",
    "# -------------------- BCa bootstrap for overall tau-b --------------------\n",
    "\n",
    "def _tau_stat(indices: np.ndarray, x: np.ndarray, y: np.ndarray) -> float:\n",
    "    t, _ = kendall_tau_b(x[indices], y[indices])\n",
    "    return t\n",
    "\n",
    "def bca_ci_for_tau(x: np.ndarray, y: np.ndarray, B: int = 5000, seed: int = GLOBAL_RANDOM_SEED) -> Tuple[float, float, float]:\n",
    "    \"\"\"\n",
    "    Bias-Corrected and Accelerated (BCa) bootstrap CI for Kendall's tau-b.\n",
    "\n",
    "    Returns:\n",
    "        (tau_hat, ci_lo, ci_hi)\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    x = np.asarray(x, dtype=float)\n",
    "    y = np.asarray(y, dtype=float)\n",
    "    n = len(x)\n",
    "    assert n == len(y) and n >= 2\n",
    "\n",
    "    # observed\n",
    "    tau_hat, _ = kendall_tau_b(x, y)\n",
    "\n",
    "    # bootstrap\n",
    "    boot_stats = []\n",
    "    for _ in range(B):\n",
    "        idxs = rng.integers(0, n, size=n)\n",
    "        boot_stats.append(_tau_stat(idxs, x, y))\n",
    "    boot_stats = np.array(boot_stats, dtype=float)\n",
    "\n",
    "    # bias-correction z0 = Φ^{-1}(proportion of boot < observed)\n",
    "    prop = float(np.mean(boot_stats < tau_hat))\n",
    "    z0 = norm_ppf(prop)\n",
    "\n",
    "    # jackknife for acceleration a\n",
    "    jack = []\n",
    "    for i in range(n):\n",
    "        idxs = np.array([j for j in range(n) if j != i], dtype=int)\n",
    "        jack.append(_tau_stat(idxs, x, y))\n",
    "    jack = np.array(jack, dtype=float)\n",
    "    jack_mean = float(np.mean(jack))\n",
    "    num = np.sum((jack_mean - jack) ** 3)\n",
    "    den = 6.0 * (np.sum((jack_mean - jack) ** 2) ** 1.5 + 1e-12)\n",
    "    a = float(num / (den + 1e-12))\n",
    "\n",
    "    # BCa percentiles\n",
    "    alpha_lo, alpha_hi = 0.025, 0.975\n",
    "    z_lo = norm_ppf(alpha_lo)\n",
    "    z_hi = norm_ppf(alpha_hi)\n",
    "\n",
    "    z_alpha_lo = z0 + (z0 + z_lo) / (1 - a * (z0 + z_lo) + 1e-12)\n",
    "    z_alpha_hi = z0 + (z0 + z_hi) / (1 - a * (z0 + z_hi) + 1e-12)\n",
    "\n",
    "    pct_lo = 100 * norm_cdf(z_alpha_lo)\n",
    "    pct_hi = 100 * norm_cdf(z_alpha_hi)\n",
    "\n",
    "    ci_lo = float(np.percentile(boot_stats, pct_lo))\n",
    "    ci_hi = float(np.percentile(boot_stats, pct_hi))\n",
    "    return tau_hat, ci_lo, ci_hi\n",
    "\n",
    "# =============================================================================\n",
    "# Per-model processing\n",
    "# =============================================================================\n",
    "\n",
    "def process_model(cfg: ModelConfig) -> pd.DataFrame:\n",
    "    print(\"\\n\" + \"=\" * 80)\n",
    "    print(f\"MODEL: {cfg.name}\")\n",
    "    print(\"=\" * 80)\n",
    "\n",
    "    autointerp_dir = os.path.join(cfg.autointerp_base, cfg.layer_dir, \"autointerp\")\n",
    "    autointerp_map = load_autointerp(autointerp_dir)\n",
    "\n",
    "    sae_list = build_sae_list_from_maps(cfg.maps)  # 30 items\n",
    "    rows = []\n",
    "    missing = []\n",
    "\n",
    "    for display_name, key_prefix in sae_list:\n",
    "        ai = autointerp_map.get(key_prefix, None)\n",
    "        arch_key, sparsity_val = parse_arch_and_sparsity_from_display(display_name)\n",
    "        sdir = resolve_steering_dir(cfg.eval_root, arch_key, sparsity_val)\n",
    "        base, after, best_label = compute_steering_scores_for_sae_dirname(cfg.eval_root, cfg.entropy_root, sdir)\n",
    "\n",
    "        if ai is None or base is None:\n",
    "            missing.append((display_name, key_prefix, sdir, ai, base, after))\n",
    "            continue\n",
    "\n",
    "        rows.append({\n",
    "            \"Model\": cfg.name,\n",
    "            \"SAEs\": display_name,\n",
    "            \"steering_name\": sdir,\n",
    "            \"AutoInterp\": float(ai),\n",
    "            \"SteeringBase\": float(base),\n",
    "            \"SteeringAfter\": float(after) if after is not None else np.nan,\n",
    "            \"BestLevel\": best_label if best_label is not None else \"N/A\",\n",
    "        })\n",
    "\n",
    "    if missing:\n",
    "        print(\"\\n[WARNING] The following SAEs are missing AutoInterp and/or SteeringBase (excluded):\")\n",
    "        for disp, akey, sdir, ai, base, after in missing:\n",
    "            print(f\"  - {disp} | autointerp_key={akey} -> AutoInterp={ai} ; \"\n",
    "                  f\"steering_dir={sdir} -> Base={base} After={after}\")\n",
    "\n",
    "    df = pd.DataFrame(rows)\n",
    "\n",
    "    # Add Lift and Delta columns (primary utility targets)\n",
    "    if not df.empty:\n",
    "        eps = 1e-12\n",
    "        df[\"Lift\"] = np.where(\n",
    "            df[\"SteeringAfter\"].notna(),\n",
    "            (df[\"SteeringAfter\"] - df[\"SteeringBase\"]) / np.maximum(df[\"SteeringBase\"], eps),\n",
    "            np.nan,\n",
    "        )\n",
    "        df[\"Delta\"] = df[\"SteeringAfter\"] - df[\"SteeringBase\"]\n",
    "\n",
    "    return df\n",
    "\n",
    "# =============================================================================\n",
    "# Reporting helpers\n",
    "# =============================================================================\n",
    "\n",
    "def add_axes_and_analyze(df: pd.DataFrame, title: str, do_within_model_axis: bool = True, is_combined: bool = False):\n",
    "    if df.empty:\n",
    "        print(f\"\\n[SKIP] {title}: no valid rows.\")\n",
    "        return\n",
    "\n",
    "    df_aug = add_axis_columns(df)\n",
    "\n",
    "    def run_analysis(y_col: str, label: str):\n",
    "        df_use = df_aug.dropna(subset=[y_col]).copy()\n",
    "        if len(df_use) < 2:\n",
    "            print(f\"\\n[SKIP] {title}: AutoInterp vs {label} -> valid samples < 2\")\n",
    "            return\n",
    "\n",
    "        # Overall tau-b + BCa CI (only for combined to keep output compact)\n",
    "        overall = pairwise_summary(df_use, \"AutoInterp\", y_col, label=f\"Overall (n={len(df_use)})\")\n",
    "        print(\"\\n\" + \"=\"*12 + f\" {title} | PAIRWISE Kendall tau-b: AutoInterp vs {label} \" + \"=\"*12)\n",
    "        if is_combined:\n",
    "            tau_hat, ci_lo, ci_hi = bca_ci_for_tau(df_use[\"AutoInterp\"].to_numpy(),\n",
    "                                                   df_use[y_col].to_numpy(),\n",
    "                                                   B=5000)\n",
    "            print(f\"- Overall: n={overall['n']} pairs={overall['pairs']}  \"\n",
    "                  f\"tau_b={overall['tau_b']:.4f}  BCa95%=[{ci_lo:.4f}, {ci_hi:.4f}]  \"\n",
    "                  f\"(C={overall['C']}, D={overall['D']}, T_x={overall['T_x']}, T_y={overall['T_y']}, T_xy={overall['T_xy']})\")\n",
    "        else:\n",
    "            print(f\"- Overall: n={overall['n']} pairs={overall['pairs']}  tau_b={overall['tau_b']:.4f}  \"\n",
    "                  f\"(C={overall['C']}, D={overall['D']}, T_x={overall['T_x']}, T_y={overall['T_y']}, T_xy={overall['T_xy']})\")\n",
    "\n",
    "        # Axis A: Sparsity-within-Architecture\n",
    "        psi_A, psi_A_se, psi_A_ci, groups_A = psi_along_axis(\n",
    "            df_use, \"AutoInterp\", y_col, \"Sparsity-within-Architecture\", do_group_permutation=True\n",
    "        )\n",
    "        print(f\"- Psi on axis: Sparsity-within-Architecture: psi={psi_A:.4f}  (SE≈{psi_A_se:.4f}, 95% boot CI={psi_A_ci})\")\n",
    "        for g in groups_A:\n",
    "            extra = f\" | perm p={g.get('perm_p', float('nan')):.4f}, perm 95% CI=[{g.get('perm_ci_lo', float('nan')):.4f}, {g.get('perm_ci_hi', float('nan')):.4f}]\"\n",
    "            print(f\"    • {g['where']}: n={g['n']} tau_b={g['tau_b']:.4f} (pairs={g['pairs']}){extra}\")\n",
    "\n",
    "        # Axis B: Architecture-at-matched-Slot\n",
    "        psi_B, psi_B_se, psi_B_ci, groups_B = psi_along_axis(\n",
    "            df_use, \"AutoInterp\", y_col, \"Architecture-at-matched-Slot\", do_group_permutation=True\n",
    "        )\n",
    "        print(f\"- Psi on axis: Architecture-at-matched-Slot: psi={psi_B:.4f}  (SE≈{psi_B_se:.4f}, 95% boot CI={psi_B_ci})\")\n",
    "        for g in groups_B:\n",
    "            extra = f\" | perm p={g.get('perm_p', float('nan')):.4f}, perm 95% CI=[{g.get('perm_ci_lo', float('nan')):.4f}, {g.get('perm_ci_hi', float('nan')):.4f}]\"\n",
    "            print(f\"    • {g['where']}: n={g['n']} tau_b={g['tau_b']:.4f} (pairs={g['pairs']}){extra}\")\n",
    "\n",
    "        # Axis C: Within-Model (new)\n",
    "        if do_within_model_axis:\n",
    "            psi_C, psi_C_se, psi_C_ci, groups_C = psi_along_axis(\n",
    "                df_use, \"AutoInterp\", y_col, \"Within-Model\", do_group_permutation=True\n",
    "            )\n",
    "            print(f\"- Psi on axis: Within-Model: psi={psi_C:.4f}  (SE≈{psi_C_se:.4f}, 95% boot CI={psi_C_ci})\")\n",
    "            for g in groups_C:\n",
    "                extra = f\" | perm p={g.get('perm_p', float('nan')):.4f}, perm 95% CI=[{g.get('perm_ci_lo', float('nan')):.4f}, {g.get('perm_ci_hi', float('nan')):.4f}]\"\n",
    "                print(f\"    • {g['where']}: n={g['n']} tau_b={g['tau_b']:.4f} (pairs={g['pairs']}){extra}\")\n",
    "\n",
    "            overall_psi = float(np.nanmean([psi_A, psi_B, psi_C]))\n",
    "            print(f\"- Overall Psi (mean of A/B/C axes) = {overall_psi:.4f}\")\n",
    "        else:\n",
    "            overall_psi = float(np.nanmean([psi_A, psi_B]))\n",
    "            print(f\"- Overall Psi (mean of the two axes A/B) = {overall_psi:.4f}\")\n",
    "\n",
    "    # Analyses:\n",
    "    # 1) Control: AutoInterp vs SteeringBase\n",
    "    run_analysis(\"SteeringBase\", label=\"SteeringBase\")\n",
    "\n",
    "    # 2) Primary: AutoInterp vs Lift\n",
    "    if \"Lift\" in df_aug.columns:\n",
    "        run_analysis(\"Lift\", label=\"Lift\")\n",
    "\n",
    "    # 3) Supplement: AutoInterp vs Delta\n",
    "    if \"Delta\" in df_aug.columns:\n",
    "        run_analysis(\"Delta\", label=\"Delta\")\n",
    "\n",
    "# =============================================================================\n",
    "# Main\n",
    "# =============================================================================\n",
    "\n",
    "def main():\n",
    "    per_model_frames: List[pd.DataFrame] = []\n",
    "\n",
    "    # 1) Per-model analysis\n",
    "    for cfg in MODELS:\n",
    "        df_model = process_model(cfg)\n",
    "        per_model_frames.append(df_model)\n",
    "        # For single-model outputs we omit the Within-Model axis (it's degenerate).\n",
    "        add_axes_and_analyze(df_model, title=f\"{cfg.name}\", do_within_model_axis=False, is_combined=False)\n",
    "\n",
    "    # 2) Combined analysis (all SAEs from all models)\n",
    "    combined = pd.concat(per_model_frames, ignore_index=True) if per_model_frames else pd.DataFrame()\n",
    "    add_axes_and_analyze(combined, title=\"Combined (All Models)\", do_within_model_axis=True, is_combined=True)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sae4steer",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
