{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f9f46a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from scipy import stats\n",
    "import numpy as np\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ea33635d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "import decision_infovalue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0756efb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = json.load(open('label.json'))\n",
    "df = pd.DataFrame(data)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cfdbad7",
   "metadata": {},
   "outputs": [],
   "source": [
    "review_df = pd.read_csv('review_model_pred_and_explanation_cluster.csv')\n",
    "accuracy = (review_df[\"predicted_label\"] == review_df[\"actual_label\"]).mean()\n",
    "print(f\"Prediction accuracy: {accuracy:.4f}\")\n",
    "review_df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f76d55f8",
   "metadata": {},
   "source": [
    "# Rational agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "422492e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "human_df = df.loc[df[\"experiment\"]==\"machine\", :]\n",
    "human_df = pd.merge(human_df.drop(columns=[\"predicted_label\"]), review_df.drop(columns=[\"actual_label\"]), on=\"review_text\", how=\"left\")\n",
    "human_df.head()\n",
    "human_df.groupby(\"user_label\").agg({\"actual_label\": \"mean\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36532e31",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_condition = [cond for cond in df[\"experiment\"].unique() if \"machine\" in cond or \"control\" in cond]\n",
    "\n",
    "result = []\n",
    "\n",
    "info_model = decision_infovalue.DecisionInfoModel(human_df, \"actual_label\", \n",
    "                                    signals=[\"x_cluster_id\", \"user_label\", \"nn_pos_cluster\", \"nn_neg_cluster\", \n",
    "                                    \"cluster_heatmap\", \"cluster_random\", \"predicted_label\"], \n",
    "                                    scoring_rule=\"threshold_accuracy\",\n",
    "                                    binning_method=None)\n",
    "\n",
    "result.append({\n",
    "    \"ra_null\": info_model.complement_info_value([], None, ret_orig_value=True),\n",
    "    \"ra_x\": info_model.complement_info_value([\"x_cluster_id\"], None, ret_orig_value=True),\n",
    "    # \"ra_yai\": info_model.complement_info_value([\"predicted_label\"], None, ret_orig_value=True),\n",
    "    # \"ra_expl_yai\": info_model.complement_info_value([\"explanation_col\", \"predicted_label\"], None, ret_orig_value=True),\n",
    "    \n",
    "    # \"ra_x_yh\": info_model.complement_info_value([\"x_cluster_id\", \"user_label\"], None, ret_orig_value=True),\n",
    "    \"ra_yh\": info_model.complement_info_value([\"user_label\"], None, ret_orig_value=True),\n",
    "    # \"ra_yh_yai\": info_model.complement_info_value([\"user_label\", \"predicted_label\"], None, ret_orig_value=True),\n",
    "    # \"ra_yh_expl_yai\": info_model.complement_info_value([\"user_label\", \"explanation_col\", \"predicted_label\"], None, ret_orig_value=True),\n",
    "    # \"ra_x_yh\": info_model.complement_info_value([\"x_cluster_id\", \"user_label\"], None, ret_orig_value=True)\n",
    "    # \"ra_expl\": info_model.complement_info_value([\"explanation_col\"], None, ret_orig_value=True)\n",
    "})\n",
    "\n",
    "result.append({\n",
    "    \"ra_expl\": info_model.complement_info_value([\"cluster_heatmap\"], None, ret_orig_value=True),\n",
    "    \"ra_yh_expl\": info_model.complement_info_value([\"user_label\", \"cluster_heatmap\"], None, ret_orig_value=True)\n",
    "})\n",
    "\n",
    "result.append({\n",
    "    \"ra_expl\": info_model.complement_info_value([\"cluster_random\"], None, ret_orig_value=True),\n",
    "    \"ra_yh_expl\": info_model.complement_info_value([\"user_label\", \"cluster_random\"], None, ret_orig_value=True)\n",
    "})\n",
    "\n",
    "result.append({\n",
    "    \"ra_expl\": info_model.complement_info_value([\"nn_pos_cluster\", \"nn_neg_cluster\"], None, ret_orig_value=True),\n",
    "    \"ra_yh_expl\": info_model.complement_info_value([\"user_label\", \"nn_pos_cluster\", \"nn_neg_cluster\"], None, ret_orig_value=True)\n",
    "})\n",
    "\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b0a4090",
   "metadata": {},
   "source": [
    "# Behavioral agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a6c4723",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from dataclasses import dataclass\n",
    "from typing import Any, Dict, List, Optional, Tuple\n",
    "\n",
    "from sklearn.base import clone\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class ATEResult:\n",
    "    ate: float\n",
    "    se: float\n",
    "    ci95: Tuple[float, float]\n",
    "    details: Dict[str, Any]\n",
    "\n",
    "\n",
    "def _clip_ps(ps: np.ndarray, eps: float = 1e-3) -> np.ndarray:\n",
    "    return np.clip(ps, eps, 1 - eps)\n",
    "\n",
    "\n",
    "def _robust_se(phi: np.ndarray) -> float:\n",
    "    n = phi.shape[0]\n",
    "    return float(np.sqrt(np.var(phi, ddof=1) / n))\n",
    "\n",
    "\n",
    "def _ci95(est: float, se: float) -> Tuple[float, float]:\n",
    "    z = 1.96\n",
    "    return (est - z * se, est + z * se)\n",
    "\n",
    "\n",
    "def aipw_ate_text_binary(\n",
    "    y: np.ndarray,\n",
    "    t: np.ndarray,\n",
    "    texts: List[str],\n",
    "    *,\n",
    "    outcome_model: Optional[Any] = None,\n",
    "    ps_model: Optional[Any] = None,\n",
    "    n_splits: int = 5,\n",
    "    clip_eps: float = 1e-3,\n",
    "    random_state: int = 0,\n",
    ") -> ATEResult:\n",
    "    \"\"\"\n",
    "    Doubly robust AIPW estimate of ATE for binary outcome Y and binary treatment T,\n",
    "    with text covariates X (raw texts). Uses cross-fitting by default.\n",
    "\n",
    "    Returns ATE on the probability scale: E[Y(1) - Y(0)].\n",
    "    \"\"\"\n",
    "    y = np.asarray(y).reshape(-1).astype(int)\n",
    "    t = np.asarray(t).reshape(-1).astype(int)\n",
    "\n",
    "    if len(texts) != len(y):\n",
    "        raise ValueError(\"texts must have the same length as y.\")\n",
    "    if not set(np.unique(y)).issubset({0, 1}):\n",
    "        raise ValueError(\"y must be binary in {0,1}.\")\n",
    "    if not set(np.unique(t)).issubset({0, 1}):\n",
    "        raise ValueError(\"t must be binary in {0,1}.\")\n",
    "\n",
    "    # Default: TF-IDF -> LogisticRegression for both nuisance models\n",
    "    if outcome_model is None:\n",
    "        outcome_model = Pipeline([\n",
    "            (\"tfidf\", TfidfVectorizer(ngram_range=(1, 2), min_df=2, max_df=0.95)),\n",
    "            (\"clf\", LogisticRegression(max_iter=5000))\n",
    "        ])\n",
    "    if ps_model is None:\n",
    "        ps_model = Pipeline([\n",
    "            (\"tfidf\", TfidfVectorizer(ngram_range=(1, 2), min_df=2, max_df=0.95)),\n",
    "            (\"clf\", LogisticRegression(max_iter=5000))\n",
    "        ])\n",
    "\n",
    "    n = len(y)\n",
    "    mu1_hat = np.zeros(n, dtype=float)  # P(Y=1 | T=1, X)\n",
    "    mu0_hat = np.zeros(n, dtype=float)  # P(Y=1 | T=0, X)\n",
    "    e_hat = np.zeros(n, dtype=float)    # P(T=1 | X)\n",
    "\n",
    "    kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)\n",
    "    idx = np.arange(n)\n",
    "\n",
    "    for tr, te in kf.split(idx):\n",
    "        texts_tr = [texts[i] for i in tr]\n",
    "        texts_te = [texts[i] for i in te]\n",
    "        y_tr, t_tr = y[tr], t[tr]\n",
    "\n",
    "        # 1) propensity model e(x)\n",
    "        ps = clone(ps_model).fit(texts_tr, t_tr)\n",
    "        e_hat[te] = ps.predict_proba(texts_te)[:, 1]\n",
    "\n",
    "        # 2) outcome models mu1(x), mu0(x) fitted on their respective groups\n",
    "        tr1 = tr[t_tr == 1]\n",
    "        tr0 = tr[t_tr == 0]\n",
    "        if len(tr1) == 0 or len(tr0) == 0:\n",
    "            raise ValueError(\"A fold has no treated or no control samples. Reduce n_splits or check data balance.\")\n",
    "\n",
    "        om1 = clone(outcome_model).fit([texts[i] for i in tr1], y[tr1])\n",
    "        om0 = clone(outcome_model).fit([texts[i] for i in tr0], y[tr0])\n",
    "\n",
    "        mu1_hat[te] = om1.predict_proba(texts_te)[:, 1]\n",
    "        mu0_hat[te] = om0.predict_proba(texts_te)[:, 1]\n",
    "\n",
    "    e_hat = _clip_ps(e_hat, clip_eps)\n",
    "\n",
    "    # AIPW influence function\n",
    "    phi = (mu1_hat - mu0_hat) + (t * (y - mu1_hat) / e_hat) - ((1 - t) * (y - mu0_hat) / (1 - e_hat))\n",
    "    ate = float(np.mean(phi))\n",
    "    se = _robust_se(phi - ate)\n",
    "    return ATEResult(ate=ate, se=se, ci95=_ci95(ate, se),\n",
    "                     details={\"e_hat\": e_hat, \"mu1_hat\": mu1_hat, \"mu0_hat\": mu0_hat, \"phi\": phi})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47f3ff54",
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"experiment\"].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "343e7660",
   "metadata": {},
   "outputs": [],
   "source": [
    "conditions = ['machine_and_heatmap',\n",
    "       'machine_and_examples', 'machine_and_random_heatmap']\n",
    "\n",
    "ate_result = {}\n",
    "\n",
    "for condition in conditions:\n",
    "    df_condition = df.loc[(df[\"experiment\"]==condition) | (df[\"experiment\"]==\"machine\"), :]\n",
    "    x = df_condition[\"review_text\"].tolist()\n",
    "    y = (df_condition[\"actual_label\"] == df_condition[\"user_label\"]).tolist()\n",
    "    t = df_condition[\"experiment\"].apply(lambda x: 1 if x == condition else 0).tolist()\n",
    "    res = aipw_ate_text_binary(y, t, x, n_splits=5)\n",
    "    ate_result[condition] = res\n",
    "ate_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55469158",
   "metadata": {},
   "outputs": [],
   "source": [
    "result[0][\"ba_no_explanation\"] = (df.loc[(df[\"experiment\"]==\"machine\"), \"actual_label\"] == df.loc[(df[\"experiment\"]==\"machine\"), \"user_label\"]).to_numpy()\n",
    "result[1][\"ba_explanation\"] = (df.loc[(df[\"experiment\"]==\"machine_and_heatmap\"), \"actual_label\"] == df.loc[(df[\"experiment\"]==\"machine_and_heatmap\"), \"user_label\"]).to_numpy()\n",
    "result[2][\"ba_explanation\"] = (df.loc[(df[\"experiment\"]==\"machine_and_random_heatmap\"), \"actual_label\"] == df.loc[(df[\"experiment\"]==\"machine_and_random_heatmap\"), \"user_label\"]).to_numpy()\n",
    "result[3][\"ba_explanation\"] = (df.loc[(df[\"experiment\"]==\"machine_and_examples\"), \"actual_label\"] == df.loc[(df[\"experiment\"]==\"machine_and_examples\"), \"user_label\"]).to_numpy()\n",
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54495033",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from scipy.stats import bootstrap\n",
    "\n",
    "# Extract data for plotting and compute 95% CI with bootstrap\n",
    "\n",
    "def mean_and_95ci(arr):\n",
    "    # Convert boolean to int if needed\n",
    "    arr = np.array(arr).astype(float)\n",
    "    res = bootstrap((arr,), np.mean, confidence_level=0.95, n_resamples=1000, method=\"basic\")\n",
    "    lower = res.confidence_interval.low\n",
    "    upper = res.confidence_interval.high\n",
    "    mean = np.mean(arr)\n",
    "    # Return mean, lower, upper\n",
    "    return mean, lower, upper\n",
    "\n",
    "ra_x = [mean_and_95ci(result[0][\"ra_x\"])]\n",
    "ra_yh = [mean_and_95ci(result[0][\"ra_yh\"])]\n",
    "# ra_yh_yai = [mean_and_95ci(result[0][\"ra_yh_yai\"])]\n",
    "ra_yh_expl = [mean_and_95ci(result[i][\"ra_yh_expl\"]) if i > 0 else None for i in range(len(result))]\n",
    "ra_expl = [mean_and_95ci(result[i][\"ra_expl\"]) if i > 0 else None for i in range(len(result))]\n",
    "# ra_yai = [mean_and_95ci(result[0][\"ra_yai\"])]\n",
    "ba_no_explanation = [mean_and_95ci(result[0][\"ba_no_explanation\"])]\n",
    "ba_explanation = [mean_and_95ci(result[i][\"ba_explanation\"]) if i > 0 else None for i in range(len(result))]\n",
    "\n",
    "ra_null = [mean_and_95ci(result[0][\"ra_null\"])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c92394e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import norm\n",
    "import numpy as np\n",
    "\n",
    "# Prepare data for plotting\n",
    "# We want to plot: ra_x_yh, ra_x, ra_yh, ba_explanation, ba_no_explanation, ra_null\n",
    "legend_order = [\n",
    "    # \"Human\",\n",
    "    \"RA (w/ prior)\",\n",
    "    \"RA (w/ y^H)\",\n",
    "    \"RA (w/ y^AI)\",\n",
    "    \"RA (w/ y^H + y^AI)\",\n",
    "    \"RA (w/ y^AI + expl)\",\n",
    "    \"RA (w/ y^H + y^AI + expl)\",\n",
    "    \"RA (w/ x)\",\n",
    "    # \"RA (w/ x+y^H)\"\n",
    "]\n",
    "\n",
    "values_for_plot = [\n",
    "    # (\"RA (w/ x+y^H)\", ra_x_yh),\n",
    "    (\"RA (w/ x)\", ra_x),\n",
    "    # (\"RA (w/ y^H)\", ra_yh),\n",
    "    # (\"RA (w/ y^AI)\", ra_yai),\n",
    "    (\"RA (w/ y^AI + expl)\", ra_expl),\n",
    "    # (\"RA (w/ y^H + expl)\", ra_yh_expl),\n",
    "    # (\"Human\", ba),\n",
    "    (\"RA (w/ prior)\", ra_null),\n",
    "    (\"BA (w/o expl)\", ba_no_explanation),\n",
    "    (\"BA (w/ expl)\", ba_explanation)\n",
    "    # (\"RA (w/ y^H + y^AI)\", ra_yai_yh),\n",
    "    # (\"RA (w/ y^H + y^AI + expl)\", ra_yh_expl_yai)\n",
    "]\n",
    "\n",
    "\n",
    "values_for_plot2 = [\n",
    "    # (\"RA (w/ x+y^H)\", ra_x_yh),\n",
    "    # (\"RA (w/ x)\", ra_x),\n",
    "    (\"RA (w/ x)\", ra_x),\n",
    "    (\"RA (w/ y^H)\", ra_yh),\n",
    "    # (\"RA (w/ y^AI)\", ra_yai),\n",
    "    # (\"RA (w/ y^AI + expl)\", ra_expl_yai),\n",
    "    # (\"RA (w/ y^H + expl)\", ra_yh_expl),\n",
    "    # (\"Human\", ba),\n",
    "    # (\"RA (w/ prior)\", ra_null),\n",
    "    # (\"RA (w/ y^H + y^AI)\", ra_yh),\n",
    "    (\"RA (w/ y^H + y^AI + expl)\", ra_yh_expl),\n",
    "    # (\"BA (w/o expl)\", ba_no_explanation),\n",
    "    # (\"BA (w/ expl)\", ba_explanation)\n",
    "    # (\"BA (w.o. expl)\", ba_no_explanation),\n",
    "    # (\"BA (w. expl)\", ba_explanation)\n",
    "]\n",
    "\n",
    "# Use ColorBrewer colors for consistency and accessibility.\n",
    "# From ColorBrewer: Set1 (colorblind safe/sequential)\n",
    "color_map = {\n",
    "    # \"RA (w/ x+y^H)\": \"#377eb8\",          # blue\n",
    "    \"RA (w/ x)\": \"#e41a1c\",              # red\n",
    "    \"RA (w/ x+y^H)\": \"#999999\",              # grey\n",
    "    \"RA (w/ y^H)\": \"#4daf4a\",            # green\n",
    "    \"RA (w/ y^H + y^AI)\": \"#377eb8\",     # blue\n",
    "    \"RA (w/ y^H + y^AI + expl)\": \"#ff7f00\",     # orange\n",
    "    # \"Human\": \"#984ea3\",            # purple\n",
    "    \"RA (w/ prior)\": \"#a65628\",      # brown\n",
    "    \"RA (w/ y^AI)\": \"#f781bf\",     # pink\n",
    "    \"RA (w/ y^AI + expl)\": \"#984ea3\",     # grey\n",
    "    \"BA (w/o expl)\": \"#999999\",              # grey\n",
    "    \"BA (w/ expl)\": \"#999900\"              # yellow\n",
    "}\n",
    "num_conditions = 4\n",
    "num_rows = len(values_for_plot)\n",
    "\n",
    "fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(5, 5))\n",
    "\n",
    "# axes is a 1D array of the two Axes objects\n",
    "ax1 = axes[0]\n",
    "ax2 = axes[1]\n",
    "\n",
    "# fig, ax = plt.subplots(figsize=(6,3))\n",
    "\n",
    "y_labels = []\n",
    "y_ticks = []\n",
    "y_offset = 0\n",
    "\n",
    "# Offset for points and errorbars within each group along y-axis\n",
    "group_width = 0.7\n",
    "single_offset = group_width / 2\n",
    "for y_tick in range(num_conditions):\n",
    "    ax1.axhline(y=y_tick, color='#101010', linestyle='-', linewidth=1, alpha=0.5)\n",
    "\n",
    "ax1.axvline(x=np.mean(ra_null[0]), color=color_map[\"RA (w/ prior)\"], linestyle='--', linewidth=1)\n",
    "ax1.axvline(x=np.mean(ba_no_explanation[0]), color=color_map[\"BA (w/o expl)\"], linestyle='--', linewidth=1)\n",
    "# ax.axvline(x=np.mean(ra_yh[0]), color=color_map[\"RA (w/ y^H)\"], linestyle='--', linewidth=1)\n",
    "# ax.axvline(x=np.mean(ra_yai_yh[0]), color=color_map[\"RA (w/ y^H + y^AI)\"], linestyle='--', linewidth=1)\n",
    "ax1.axvline(x=np.mean(ra_x[0]), color=color_map[\"RA (w/ x)\"], linestyle='--', linewidth=1)\n",
    "# ax1.axvline(x=np.mean(ra_yai[0]), color=color_map[\"RA (w/ y^AI)\"], linestyle='--', linewidth=1)\n",
    "\n",
    "for y_tick in range(num_conditions):\n",
    "    ax2.axhline(y=y_tick, color='#101010', linestyle='-', linewidth=1, alpha=0.5)\n",
    "\n",
    "# ax2.axvline(x=np.mean(ra_null[0]), color=color_map[\"RA (w/ prior)\"], linestyle='--', linewidth=1)\n",
    "ax2.axvline(x=np.mean(ba_no_explanation[0]), color=color_map[\"BA (w/o expl)\"], linestyle='--', linewidth=1)\n",
    "ax2.axvline(x=np.mean(ra_yh[0]), color=color_map[\"RA (w/ y^H)\"], linestyle='--', linewidth=1)\n",
    "# ax2.axvline(x=np.mean(ra_yh_yai[0]), color=color_map[\"RA (w/ y^H + y^AI)\"], linestyle='--', linewidth=1)\n",
    "# ax2.axvline(x=np.mean(ra_x[0]), color=color_map[\"RA (w/ x)\"], linestyle='--', linewidth=1)\n",
    "ax2.axvline(x=np.mean(ra_x[0]), color=color_map[\"RA (w/ x)\"], linestyle='--', linewidth=1)\n",
    "\n",
    "for i, (name, data) in enumerate(values_for_plot):\n",
    "    # for each metric to plot\n",
    "    y = []\n",
    "    y_lowers = []\n",
    "    y_uppers = []\n",
    "    x = []\n",
    "    # is_ci = len(data[0]) == 3 if not name.startswith(\"RA (prior only)\") else False\n",
    "\n",
    "    for j in range(num_conditions):\n",
    "        if j >= len(data) or data[j] is None:\n",
    "            continue\n",
    "        mean, lower, upper = data[j]\n",
    "        y_val = j\n",
    "        x.append(mean)\n",
    "        y.append(y_val)\n",
    "        y_labels.append(f\"{ name }\")\n",
    "        y_ticks.append(y_val)\n",
    "        # Plot the normal distribution centered at mean with std estimated from CI width\n",
    "        \n",
    "\n",
    "        ci_width = upper - lower\n",
    "        if ci_width > 1e-3:\n",
    "            std = ci_width / (2 * 1.96)  # approximate std assuming 95% CI\n",
    "\n",
    "            x_range = np.linspace(max(0, mean - 4*std), min(1.0, mean + 4*std), 200)\n",
    "            y_norm = norm.pdf(x_range, loc=mean, scale=std)\n",
    "            y_norm = y_norm / y_norm.max() * (group_width / 2.5)  # scale height for visual clarity\n",
    "            ax1.fill_between(x_range, y_val, y_val + y_norm, alpha=0.7, color=color_map[name], label=None)\n",
    "        \n",
    "\n",
    "        ax1.errorbar([mean], [y_val], xerr=[[mean - lower], [upper - mean]], fmt='o', label=name, capsize=0, color=color_map[name], elinewidth=3)\n",
    "\n",
    "# Set y-ticks and labels\n",
    "ax1.set_yticks(range(num_conditions))\n",
    "\n",
    "\n",
    "\n",
    "for i, (name, data) in enumerate(values_for_plot2):\n",
    "    # for each metric to plot\n",
    "    y = []\n",
    "    y_lowers = []\n",
    "    y_uppers = []\n",
    "    x = []\n",
    "    # is_ci = len(data[0]) == 3 if not name.startswith(\"RA (prior only)\") else False\n",
    "\n",
    "    for j in range(num_conditions):\n",
    "        if j >= len(data) or data[j] is None:\n",
    "            continue\n",
    "        mean, lower, upper = data[j]\n",
    "        y_val = j\n",
    "        x.append(mean)\n",
    "        y.append(y_val)\n",
    "        y_labels.append(f\"{ name }\")\n",
    "        y_ticks.append(y_val)\n",
    "        # Plot the normal distribution centered at mean with std estimated from CI width\n",
    "        \n",
    "\n",
    "        ci_width = upper - lower\n",
    "        if ci_width > 1e-3:\n",
    "            std = ci_width / (2 * 1.96)  # approximate std assuming 95% CI\n",
    "\n",
    "            x_range = np.linspace(max(0, mean - 4*std), min(1.0, mean + 4*std), 200)\n",
    "            y_norm = norm.pdf(x_range, loc=mean, scale=std)\n",
    "            y_norm = y_norm / y_norm.max() * (group_width / 2.5)  # scale height for visual clarity\n",
    "            ax2.fill_between(x_range, y_val, y_val + y_norm, alpha=0.7, color=color_map[name], label=None)\n",
    "        \n",
    "\n",
    "        ax2.errorbar([mean], [y_val], xerr=[[mean - lower], [upper - mean]], fmt='o', label=name, capsize=0, color=color_map[name], elinewidth=3)\n",
    "\n",
    "# Set y-ticks and labels\n",
    "ax2.set_yticks(range(num_conditions))\n",
    "\n",
    "# Create a legend (avoid duplicate labels)\n",
    "handles, labels = ax1.get_legend_handles_labels()\n",
    "seen = set()\n",
    "unique = [(h,l) for h,l in zip(handles, labels) if not (l in seen or seen.add(l))]\n",
    "# Desired legend order\n",
    "\n",
    "# Reorder unique to match legend_order\n",
    "label_to_handle = {l: h for h, l in unique}\n",
    "ordered_handles_labels = [(label_to_handle[l], l) for l in legend_order if l in label_to_handle]\n",
    "ordered_handles, ordered_labels = zip(*ordered_handles_labels)\n",
    "\n",
    "# ax1.legend(\n",
    "#     ordered_handles,\n",
    "#     ordered_labels,\n",
    "#     loc=\"upper center\",\n",
    "#     bbox_to_anchor=(0.5, 1.1),\n",
    "#     ncol=6,\n",
    "#     frameon=False,\n",
    "#     borderaxespad=0.01,\n",
    "#     handletextpad=0.2,\n",
    "#     labelspacing=0.01\n",
    "# )\n",
    "\n",
    "\n",
    "# Create a legend (avoid duplicate labels)\n",
    "handles, labels = ax2.get_legend_handles_labels()\n",
    "seen = set()\n",
    "unique = [(h,l) for h,l in zip(handles, labels) if not (l in seen or seen.add(l))]\n",
    "# Desired legend order\n",
    "\n",
    "# Reorder unique to match legend_order\n",
    "label_to_handle = {l: h for h, l in unique}\n",
    "ordered_handles_labels = [(label_to_handle[l], l) for l in legend_order if l in label_to_handle]\n",
    "ordered_handles, ordered_labels = zip(*ordered_handles_labels)\n",
    "# ax2.legend(\n",
    "#     ordered_handles,\n",
    "#     ordered_labels,\n",
    "#     loc=\"upper center\",\n",
    "#     bbox_to_anchor=(0.5, 1.1),\n",
    "#     ncol=6,\n",
    "#     frameon=False,\n",
    "#     borderaxespad=0.01,\n",
    "#     handletextpad=0.2,\n",
    "#     labelspacing=0.01\n",
    "# )\n",
    "\n",
    "# Remove the top, left, and right border lines (spines) of the plot\n",
    "ax1.spines['top'].set_visible(False)\n",
    "ax1.spines['left'].set_visible(False)\n",
    "ax1.spines['right'].set_visible(False)\n",
    "ax2.spines['top'].set_visible(False)\n",
    "ax2.spines['left'].set_visible(False)\n",
    "ax2.spines['right'].set_visible(False)\n",
    "\n",
    "ax2.set_xlim(0.45, 1)\n",
    "\n",
    "ax2.set_xlabel(\"Accuracy\")\n",
    "ax2.set_yticklabels([\"W/o expl.\", \"Heatmap\", \"Random\\nheatmap\", \"Example\"])\n",
    "# ax.set_title(\"Comparison of Metrics Across Experiment Conditions (with 95% CI)\")\n",
    "\n",
    "ax1.sharex(ax2)\n",
    "ax1.set_yticklabels([\"W/o expl.\", \"Heatmap\", \"Random\\nheatmap\", \"Example\"])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('laiandtan_explanation_means.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
