{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import plotly.express as px\n",
    "import plotly.graph_objects as go\n",
    "from plotly.subplots import make_subplots\n",
    "import matplotlib.pyplot as plt\n",
    "from glob import glob\n",
    "import json\n",
    "import re\n",
    "from sklearn.metrics import auc, roc_auc_score, roc_curve, accuracy_score, balanced_accuracy_score, f1_score\n",
    "from statistics import mean, stdev, variance\n",
    "from scipy import stats\n",
    "import numpy as np\n",
    "from scipy import stats\n",
    "from math import sqrt\n",
    "import plotly\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Calculating the FMS score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import re\n",
    "from glob import glob\n",
    "import pandas as pd\n",
    "from glob import glob\n",
    "from typing import List, Dict, Optional\n",
    "\n",
    "def load_local_tree_stats_data(\n",
    "    file_info_list: List[Dict],\n",
    ") -> pd.DataFrame:\n",
    "    \"\"\"\n",
    "    Loads and processes a local tree stats CSV using provided metadata.\n",
    "\n",
    "    Parameters:\n",
    "    - file_info: Dict with:\n",
    "        {\n",
    "            \"file\": str,\n",
    "            \"model_name\": str,\n",
    "            \"concept\": str,\n",
    "            \"model_type\": str,\n",
    "        }\n",
    "    - model_type_fn: Optional function to determine model type per row\n",
    "\n",
    "    Returns:\n",
    "    - pd.DataFrame with local model shift (MS_local)\n",
    "    \"\"\"\n",
    "    all_data = pd.DataFrame()\n",
    "    for file_info in file_info_list:\n",
    "        df = pd.read_csv(file_info[\"file\"])\n",
    "\n",
    "        # Inject metadata\n",
    "        df[\"model_name\"] = file_info[\"model_name\"]\n",
    "        df[\"concept\"] = file_info[\"concept\"]\n",
    "        df[\"model type\"] = file_info[\"model_type\"]\n",
    "\n",
    "        # Compute MS_local\n",
    "        df[\"MS_local\"] = df.apply(\n",
    "            lambda x: 2\n",
    "            * (\n",
    "                df[\n",
    "                    (df[\"Nodes\"] == 1)\n",
    "                    & (df[\"num_cuts\"] == 0)\n",
    "                ][\"Accuracy\"].mean()\n",
    "                - x[\"Accuracy\"]\n",
    "            )\n",
    "            if x[\"num_cuts\"] != 0\n",
    "            else None,\n",
    "            axis=1,\n",
    "        )\n",
    "        all_data = pd.concat([all_data, df], ignore_index=True)\n",
    "\n",
    "    return all_data[all_data[\"Nodes\"] == 1][\n",
    "        [\n",
    "            \"num_cuts\",\n",
    "            \"concept\",\n",
    "            \"model type\",\n",
    "            \"MS_local\",\n",
    "        ]\n",
    "    ]\n",
    "\n",
    "\n",
    "def load_tree_stats_data(\n",
    "    file_info_list: List[Dict],\n",
    ") -> pd.DataFrame:\n",
    "    \"\"\"\n",
    "    Loads and combines tree stats CSVs using provided metadata.\n",
    "\n",
    "    Parameters:\n",
    "    - file_info_list: List of dicts with:\n",
    "        {\n",
    "            \"file\": str,                  # Path to CSV file\n",
    "            \"model_name\": str,\n",
    "            \"concept\": str,              # e.g., \"sp\", \"rtp\", or \"pii\"\n",
    "            \"model_type\": str,           # e.g., \"G-SAE\", \"Baseline\", etc.\n",
    "        }\n",
    "\n",
    "    Returns:\n",
    "    - pd.DataFrame: Combined and annotated dataset.\n",
    "    \"\"\"\n",
    "    all_depths = pd.DataFrame()\n",
    "\n",
    "    for info in file_info_list:\n",
    "        file = info[\"file\"]\n",
    "\n",
    "        if \"cut\" in file:\n",
    "            continue\n",
    "\n",
    "        df = pd.read_csv(file)\n",
    "\n",
    "        df[\"model_name\"] = info[\"model_name\"]\n",
    "        df[\"concept\"] = info[\"concept\"]\n",
    "        df[\"model type\"] = info[\"model_type\"]\n",
    "\n",
    "        all_depths = pd.concat([all_depths, df], ignore_index=True)\n",
    "        all_depths[\"MS_global\"] = all_depths.apply(\n",
    "            lambda x: 1\n",
    "            - (\n",
    "                sum(\n",
    "                    all_depths[\n",
    "                        (all_depths[\"Nodes\"] != 1)\n",
    "                        & (all_depths[\"model type\"] == x[\"model type\"])\n",
    "                        & (all_depths[\"concept\"] == x[\"concept\"])\n",
    "                    ][\"Accuracy\"]\n",
    "                    - all_depths[\n",
    "                        (all_depths[\"Nodes\"] == 1)\n",
    "                        & (all_depths[\"model type\"] == x[\"model type\"])\n",
    "                        & (all_depths[\"concept\"] == x[\"concept\"])\n",
    "                    ][\"Accuracy\"].item()\n",
    "                )\n",
    "                / len(\n",
    "                    all_depths[\n",
    "                        (all_depths[\"Nodes\"] != 1)\n",
    "                        & (all_depths[\"model type\"] == x[\"model type\"])\n",
    "                        & (all_depths[\"concept\"] == x[\"concept\"])\n",
    "                    ][\"Accuracy\"]\n",
    "                )\n",
    "            )\n",
    "            if x[\"Nodes\"] == 1\n",
    "            else None,\n",
    "            axis=1,\n",
    "        )\n",
    "\n",
    "    return all_depths[all_depths[\"Nodes\"] == 1][\n",
    "        [\n",
    "            \"Accuracy\",\n",
    "            \"concept\",\n",
    "            \"model type\",\n",
    "            \"MS_global\",\n",
    "        ]\n",
    "    ]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_info_list = [\n",
    "    {\n",
    "        \"file\": \"./llama3_SAE/SAE_eval/SP-Block_v2/sp_tree_valid_llama3-l24576-b03-k2048_s1_statsV2.csv\",\n",
    "        \"model_name\": \"LLaMA3\",\n",
    "        \"concept\": \"Shakespeare\",\n",
    "        \"model_type\": \"G-SAE\",\n",
    "    },\n",
    "]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_global = load_tree_stats_data(file_info_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_info_list = [\n",
    "    {\n",
    "        \"file\": \"./llama3_SAE/SAE_eval/SP-Block_v2/sp_tree_valid_llama3-l24576-b03-k2048_s1_cut.csv\",\n",
    "        \"model_name\": \"LLaMA3\",\n",
    "        \"concept\": \"Shakespeare\",\n",
    "        \"model_type\": \"G-SAE\",\n",
    "    },\n",
    "]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_local = load_local_tree_stats_data(file_info_list)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.merge(df_local, df_global)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"FMS\"] = df.apply(lambda x: x[\"Accuracy\"] * ((x[\"MS_local\"] + x[\"MS_global\"]) / 2), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df[(df[\"num_cuts\"] == 1) | (df[\"num_cuts\"] == 5)].round(2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "SCAR",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
