{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba6f9a41",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import statsmodels.formula.api as smf\n",
    "from statsmodels.stats.anova import anova_lm\n",
    "import matplotlib.pyplot as plt\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "caaad844",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_df(dataset, model_name_or_path):\n",
    "    path = f\"../../results/data_results/template_vs_question/{model_name_or_path}/{dataset}_logit.csv\"\n",
    "    df = pd.read_csv(path)\n",
    "    df = df.head(144)\n",
    "    df[\"question\"] = df[\"question_id\"].astype(\"category\")\n",
    "    df[\"prompt\"]   = df[\"prompt_id\"].astype(\"category\")\n",
    "    \n",
    "    model = smf.ols(\"logit ~ C(question) + C(prompt)\", data=df).fit()\n",
    "\n",
    "    anova_tbl = anova_lm(model)\n",
    "    factors = [\"C(question)\", \"C(prompt)\"]\n",
    "    anova_factors = anova_tbl.loc[factors, [\"sum_sq\", \"F\", \"PR(>F)\"]].copy()\n",
    "    resid_ss = anova_tbl.loc[\"Residual\", \"sum_sq\"]\n",
    "    anova_factors[\"prop_of_total_ss\"] = anova_factors[\"sum_sq\"] / (anova_factors[\"sum_sq\"].sum() + resid_ss)\n",
    "\n",
    "    print(anova_factors)\n",
    "\n",
    "   \n",
    "    plt.figure(figsize=(8, 2))\n",
    "\n",
    "    prompt_val   = anova_factors.loc[\"C(prompt)\", \"prop_of_total_ss\"]\n",
    "    question_val = anova_factors.loc[\"C(question)\", \"prop_of_total_ss\"]\n",
    "    resid_val    = 1 - (prompt_val + question_val)\n",
    "\n",
    "    plt.barh(0, prompt_val, color=\"#6abffc\", label=\"prompt\")\n",
    "    plt.barh(0, resid_val, left=prompt_val, color=\"#d3d3d3\", label=\"residual\")\n",
    "    plt.barh(0, question_val, left=prompt_val+resid_val, color=\"#ffa556\", label=\"question\")\n",
    "\n",
    "    plt.text(prompt_val/2, 0, f\"{prompt_val:.1%}\", ha=\"center\", va=\"center\", color=\"white\", fontsize=10)\n",
    "    plt.text(prompt_val + resid_val/2, 0, f\"{resid_val:.1%}\", ha=\"center\", va=\"center\", color=\"black\", fontsize=10)\n",
    "    plt.text(prompt_val + resid_val + question_val/2, 0, f\"{question_val:.1%}\", ha=\"center\", va=\"center\", color=\"white\", fontsize=10)\n",
    "\n",
    "    def get_stars(p):\n",
    "        if p < 0.001:\n",
    "            return \"***\"\n",
    "        elif p < 0.01:\n",
    "            return \"**\"\n",
    "        elif p < 0.05:\n",
    "            return \"*\"\n",
    "        else:\n",
    "            return \"\"\n",
    "\n",
    "    prompt_pval   = anova_factors.loc[\"C(prompt)\", \"PR(>F)\"]\n",
    "    question_pval = anova_factors.loc[\"C(question)\", \"PR(>F)\"]\n",
    "\n",
    "    plt.text(prompt_val/2, 0.2, get_stars(prompt_pval), ha=\"center\", va=\"bottom\", fontsize=12, color=\"black\")\n",
    "    plt.text(prompt_val + resid_val + question_val/2, 0.2, get_stars(question_pval), ha=\"center\", va=\"bottom\", fontsize=12, color=\"black\")\n",
    "\n",
    "    plt.xlim(0, 1)\n",
    "    plt.yticks([])\n",
    "    plt.title(\"Factor Contributions to logit (ANOVA)\")\n",
    "    plt.legend(loc=\"upper center\", bbox_to_anchor=(0.5, -0.3), ncol=3, frameon=False)\n",
    "    plt.tight_layout()\n",
    "    save_path = f\"../../results/figure_results/factor_contribution/{model_name_or_path}/{dataset}_factor_contributions.pdf\"\n",
    "    os.makedirs(os.path.dirname(save_path), exist_ok=True)\n",
    "    \n",
    "    plt.savefig(save_path, dpi=200)\n",
    "    plt.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "984b3fc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_list = [\"ARC_Challenge\", \"CommonSenseQA\", \"MMLU\", \"OpenBookQA\"]\n",
    "model_list = [\"Qwen/Qwen1.5-0.5B\",\n",
    "              \"Qwen/Qwen1.5-4B\",\n",
    "                \"Qwen/Qwen1.5-1.8B\",\n",
    "                \"meta-llama/Llama-3.2-1B\",\n",
    "                \"meta-llama/Llama-3.2-3B\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33575d94",
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset in dataset_list:\n",
    "    for model_name_or_path in model_list:\n",
    "        load_df(dataset, model_name_or_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "l",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
