{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5f42da29",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of concepts processed: 3\n",
      "Average scores for each method across all concepts:\n",
      "  PromptSteering: 1.300\n",
      "  GemmaScopeSAE: 0.147\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import random\n",
    "\n",
    "# Set random seed for reproducibility\n",
    "random.seed(42)\n",
    "\n",
    "# 1. File path and list of methods\n",
    "in_path = \"/home/dslabra5/sae4steer/axbench/axbench/concept10_gemma2_2b_L12_batch_topk_80_0.8357/evaluate/steering.jsonl\"\n",
    "methods = [\"PromptSteering\", \"GemmaScopeSAE\"]\n",
    "\n",
    "# 2. Iterate over each concept and collect lm_judge_rating for each method\n",
    "scores = {m: [] for m in methods}\n",
    "\n",
    "with open(in_path, 'r') as f:\n",
    "    for line in f:\n",
    "        data = json.loads(line)\n",
    "        lmres = data[\"results\"][\"LMJudgeEvaluator\"]\n",
    "        for m in methods:\n",
    "            ratings = lmres[m][\"lm_judge_rating\"]\n",
    "            if m == \"PromptSteering\":\n",
    "                # PromptSteering: randomly choose one rating\n",
    "                chosen_score = random.choice(ratings)\n",
    "                scores[m].append(chosen_score)\n",
    "            else:\n",
    "                # GemmaScopeSAE: take the maximum rating\n",
    "                scores[m].append(max(ratings))\n",
    "\n",
    "# 3. Calculate average score (arithmetic mean)\n",
    "num_concepts = len(scores[methods[0]])\n",
    "avg_scores = {m: sum(scores[m]) / num_concepts for m in methods}\n",
    "\n",
    "# 4. Output results\n",
    "print(f\"Total number of concepts processed: {num_concepts}\")\n",
    "print(\"Average scores for each method across all concepts:\")\n",
    "for m, avg in avg_scores.items():\n",
    "    print(f\"  {m}: {avg:.3f}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sae4steer",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
