{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f610e7ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from xgboost import XGBRegressor\n",
    "import xgboost as xgb\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.metrics import mean_squared_error, mean_absolute_percentage_error, r2_score\n",
    "from tqdm import tqdm\n",
    "from itertools import combinations\n",
    "import shap\n",
    "import statsmodels.api as sm\n",
    "import statsmodels.formula.api as smf\n",
    "from sklearn.linear_model import LinearRegression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc834ee9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "import decision_infovalue\n",
    "import importlib\n",
    "importlib.reload(decision_infovalue)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "186358ff",
   "metadata": {},
   "source": [
    "# Amazon Books Review"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a6929c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "review_df = pd.read_csv(\"save_df_amazon.csv\")\n",
    "review_df.loc[:, \"review_id\"] = range(len(review_df))\n",
    "review_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "715ae966",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"bansal-et-al-decision-result-filter.csv\")\n",
    "df = df.loc[df[\"task\"] == \"amzbook\"]\n",
    "human_df = df.loc[df[\"condition\"] == \"Conf.\"]\n",
    "human_df = pd.merge(human_df, review_df.drop(columns=[\"pred\", \"conf\"]), left_on=\"questionId\", right_on=\"review_id\", how=\"left\")\n",
    "df = pd.merge(df, review_df.drop(columns=[\"pred\", \"conf\"]), left_on=\"questionId\", right_on=\"review_id\", how=\"left\")\n",
    "human_df.loc[:, \"choice\"] = human_df[\"choice\"].astype(int)\n",
    "human_df.loc[:, \"conf\"] = human_df[\"conf\"].apply(lambda x: int(x * 10) / 10)\n",
    "human_df[\"y\"] = human_df[\"y\"].astype(int)\n",
    "human_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6aac0ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "result = []\n",
    "\n",
    "info_model = decision_infovalue.DecisionInfoModel(human_df, \"y\", \n",
    "                                    signals=[\"x_cluster_id\", \"choice\", \"conf\", \n",
    "                                    \"explanation_top1_cluster_id\", \"explanation_top2_cluster_id\", \"explanation_adaptive_cluster_id\", \"explanation_expert_cluster_id\",\n",
    "                                    \"pred\"], \n",
    "                                    scoring_rule=\"threshold_accuracy\",\n",
    "                                    binning_method=None)\n",
    "\n",
    "result.append({\n",
    "    \"ra_null\": info_model.complement_info_value([], None, ret_orig_value=True),\n",
    "    \"ra_x\": info_model.complement_info_value([\"x_cluster_id\"], None, ret_orig_value=True),\n",
    "    # \"ra_yai\": info_model.complement_info_value([\"pred\"], None, ret_orig_value=True),\n",
    "    # \"ra_expl_yai\": info_model.complement_info_value([\"explanation_col\", \"predicted_label\"], None, ret_orig_value=True),\n",
    "    \n",
    "    # \"ra_x_yh\": info_model.complement_info_value([\"x_cluster_id\", \"user_label\"], None, ret_orig_value=True),\n",
    "    \"ra_yh\": info_model.complement_info_value([\"choice\"], None, ret_orig_value=True),\n",
    "    # \"ra_yh_yai\": info_model.complement_info_value([\"choice\", \"pred\"], None, ret_orig_value=True),\n",
    "    # \"ra_yh_expl_yai\": info_model.complement_info_value([\"user_label\", \"explanation_col\", \"predicted_label\"], None, ret_orig_value=True),\n",
    "    # \"ra_x_yh\": info_model.complement_info_value([\"x_cluster_id\", \"choice\"], None, ret_orig_value=True)\n",
    "    # \"ra_expl\": info_model.complement_info_value([\"explanation_col\"], None, ret_orig_value=True)\n",
    "})\n",
    "\n",
    "# result.append({\n",
    "#     \"ra_expl\": info_model.complement_info_value([\"conf\"], None, ret_orig_value=True),\n",
    "#     \"ra_yh_expl\": info_model.complement_info_value([\"choice\", \"conf\"], None, ret_orig_value=True)\n",
    "# })\n",
    "\n",
    "result.append({\n",
    "    \"ra_expl\": info_model.complement_info_value([\"explanation_top1_cluster_id\"], None, ret_orig_value=True),\n",
    "    \"ra_yh_expl\": info_model.complement_info_value([\"choice\", \"explanation_top1_cluster_id\"], None, ret_orig_value=True)\n",
    "})\n",
    "result.append({\n",
    "    \"ra_expl\": info_model.complement_info_value([\"explanation_top2_cluster_id\"], None, ret_orig_value=True),\n",
    "    \"ra_yh_expl\": info_model.complement_info_value([\"choice\", \"explanation_top2_cluster_id\"], None, ret_orig_value=True)\n",
    "})\n",
    "result.append({\n",
    "    \"ra_expl\": info_model.complement_info_value([\"explanation_adaptive_cluster_id\"], None, ret_orig_value=True),\n",
    "    \"ra_yh_expl\": info_model.complement_info_value([\"choice\", \"explanation_adaptive_cluster_id\"], None, ret_orig_value=True)\n",
    "})\n",
    "result.append({\n",
    "    \"ra_expl\": info_model.complement_info_value([\"explanation_expert_cluster_id\"], None, ret_orig_value=True),\n",
    "    \"ra_yh_expl\": info_model.complement_info_value([\"choice\", \"explanation_expert_cluster_id\"], None, ret_orig_value=True)\n",
    "})\n",
    "\n",
    "result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db28e1bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from dataclasses import dataclass\n",
    "from typing import Any, Dict, List, Optional, Tuple\n",
    "\n",
    "from sklearn.base import clone\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class ATEResult:\n",
    "    ate: float\n",
    "    se: float\n",
    "    ci95: Tuple[float, float]\n",
    "    details: Dict[str, Any]\n",
    "\n",
    "\n",
    "def _clip_ps(ps: np.ndarray, eps: float = 1e-3) -> np.ndarray:\n",
    "    return np.clip(ps, eps, 1 - eps)\n",
    "\n",
    "\n",
    "def _robust_se(phi: np.ndarray) -> float:\n",
    "    n = phi.shape[0]\n",
    "    return float(np.sqrt(np.var(phi, ddof=1) / n))\n",
    "\n",
    "\n",
    "def _ci95(est: float, se: float) -> Tuple[float, float]:\n",
    "    z = 1.96\n",
    "    return (est - z * se, est + z * se)\n",
    "\n",
    "\n",
    "def aipw_ate_text_binary(\n",
    "    y: np.ndarray,\n",
    "    t: np.ndarray,\n",
    "    texts: List[str],\n",
    "    *,\n",
    "    outcome_model: Optional[Any] = None,\n",
    "    ps_model: Optional[Any] = None,\n",
    "    n_splits: int = 5,\n",
    "    clip_eps: float = 1e-3,\n",
    "    random_state: int = 0,\n",
    ") -> ATEResult:\n",
    "    \"\"\"\n",
    "    Doubly robust AIPW estimate of ATE for binary outcome Y and binary treatment T,\n",
    "    with text covariates X (raw texts). Uses cross-fitting by default.\n",
    "\n",
    "    Returns ATE on the probability scale: E[Y(1) - Y(0)].\n",
    "    \"\"\"\n",
    "    y = np.asarray(y).reshape(-1).astype(int)\n",
    "    t = np.asarray(t).reshape(-1).astype(int)\n",
    "\n",
    "    if len(texts) != len(y):\n",
    "        raise ValueError(\"texts must have the same length as y.\")\n",
    "    if not set(np.unique(y)).issubset({0, 1}):\n",
    "        raise ValueError(\"y must be binary in {0,1}.\")\n",
    "    if not set(np.unique(t)).issubset({0, 1}):\n",
    "        raise ValueError(\"t must be binary in {0,1}.\")\n",
    "\n",
    "    # Default: TF-IDF -> LogisticRegression for both nuisance models\n",
    "    if outcome_model is None:\n",
    "        outcome_model = Pipeline([\n",
    "            (\"tfidf\", TfidfVectorizer(ngram_range=(1, 2), min_df=2, max_df=0.95)),\n",
    "            (\"clf\", LogisticRegression(max_iter=5000))\n",
    "        ])\n",
    "    if ps_model is None:\n",
    "        ps_model = Pipeline([\n",
    "            (\"tfidf\", TfidfVectorizer(ngram_range=(1, 2), min_df=2, max_df=0.95)),\n",
    "            (\"clf\", LogisticRegression(max_iter=5000))\n",
    "        ])\n",
    "\n",
    "    n = len(y)\n",
    "    mu1_hat = np.zeros(n, dtype=float)  # P(Y=1 | T=1, X)\n",
    "    mu0_hat = np.zeros(n, dtype=float)  # P(Y=1 | T=0, X)\n",
    "    e_hat = np.zeros(n, dtype=float)    # P(T=1 | X)\n",
    "\n",
    "    kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)\n",
    "    idx = np.arange(n)\n",
    "\n",
    "    for tr, te in kf.split(idx):\n",
    "        texts_tr = [texts[i] for i in tr]\n",
    "        texts_te = [texts[i] for i in te]\n",
    "        y_tr, t_tr = y[tr], t[tr]\n",
    "\n",
    "        # 1) propensity model e(x)\n",
    "        ps = clone(ps_model).fit(texts_tr, t_tr)\n",
    "        e_hat[te] = ps.predict_proba(texts_te)[:, 1]\n",
    "\n",
    "        # 2) outcome models mu1(x), mu0(x) fitted on their respective groups\n",
    "        tr1 = tr[t_tr == 1]\n",
    "        tr0 = tr[t_tr == 0]\n",
    "        if len(tr1) == 0 or len(tr0) == 0:\n",
    "            raise ValueError(\"A fold has no treated or no control samples. Reduce n_splits or check data balance.\")\n",
    "\n",
    "        om1 = clone(outcome_model).fit([texts[i] for i in tr1], y[tr1])\n",
    "        om0 = clone(outcome_model).fit([texts[i] for i in tr0], y[tr0])\n",
    "\n",
    "        mu1_hat[te] = om1.predict_proba(texts_te)[:, 1]\n",
    "        mu0_hat[te] = om0.predict_proba(texts_te)[:, 1]\n",
    "\n",
    "    e_hat = _clip_ps(e_hat, clip_eps)\n",
    "\n",
    "    # AIPW influence function\n",
    "    phi = (mu1_hat - mu0_hat) + (t * (y - mu1_hat) / e_hat) - ((1 - t) * (y - mu0_hat) / (1 - e_hat))\n",
    "    ate = float(np.mean(phi))\n",
    "    se = _robust_se(phi - ate)\n",
    "    return ATEResult(ate=ate, se=se, ci95=_ci95(ate, se),\n",
    "                     details={\"e_hat\": e_hat, \"mu1_hat\": mu1_hat, \"mu0_hat\": mu0_hat, \"phi\": phi})\n",
    "\n",
    "conditions = ['Conf.+Double', 'Conf.+Single', 'Conf.+Adaptive',\n",
    "       'Conf.+Adaptive (Expert)']\n",
    "\n",
    "ate_result = {}\n",
    "\n",
    "for condition in conditions:\n",
    "    df_condition = df.loc[(df[\"condition\"]==condition) | (df[\"condition\"]==\"Conf.\"), :]\n",
    "    x = df_condition[\"text\"].tolist()\n",
    "    y = (df_condition[\"y\"] == df_condition[\"choice\"]).tolist()\n",
    "    t = df_condition[\"condition\"].apply(lambda x: 1 if x == condition else 0).tolist()\n",
    "    res = aipw_ate_text_binary(y, t, x, n_splits=5)\n",
    "    ate_result[condition] = res\n",
    "ate_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8acf8c94",
   "metadata": {},
   "outputs": [],
   "source": [
    "result[0][\"ba_no_explanation\"] = (df.loc[(df[\"condition\"]==\"Conf.\"), \"y\"] == df.loc[(df[\"condition\"]==\"Conf.\"), \"choice\"]).to_numpy()\n",
    "result[1][\"ba_explanation\"] = (df.loc[(df[\"condition\"]==\"Conf.+Single\"), \"y\"] == df.loc[(df[\"condition\"]==\"Conf.+Single\"), \"choice\"]).to_numpy()\n",
    "result[2][\"ba_explanation\"] = (df.loc[(df[\"condition\"]==\"Conf.+Double\"), \"y\"] == df.loc[(df[\"condition\"]==\"Conf.+Double\"), \"choice\"]).to_numpy()\n",
    "result[3][\"ba_explanation\"] = (df.loc[(df[\"condition\"]==\"Conf.+Adaptive\"), \"y\"] == df.loc[(df[\"condition\"]==\"Conf.+Adaptive\"), \"choice\"]).to_numpy()\n",
    "result[4][\"ba_explanation\"] = (df.loc[(df[\"condition\"]==\"Conf.+Adaptive (Expert)\"), \"y\"] == df.loc[(df[\"condition\"]==\"Conf.+Adaptive (Expert)\"), \"choice\"]).to_numpy()\n",
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d232e80f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from scipy.stats import bootstrap\n",
    "\n",
    "# Extract data for plotting and compute 95% CI with bootstrap\n",
    "\n",
    "def mean_and_95ci(arr):\n",
    "    # Convert boolean to int if needed\n",
    "    arr = np.array(arr).astype(float)\n",
    "    res = bootstrap((arr,), np.mean, confidence_level=0.95, n_resamples=1000, method=\"basic\")\n",
    "    lower = res.confidence_interval.low\n",
    "    upper = res.confidence_interval.high\n",
    "    mean = np.mean(arr)\n",
    "    # Return mean, lower, upper\n",
    "    return mean, lower, upper\n",
    "\n",
    "# conditions = [\"control\", \"machine\", \"machine_and_random_heatmap\", \"machine_and_examples\", \"machine_and_heatmap\", \"machine_with_accuracy\"] # list(result.keys())\n",
    "# ra_x_yh = [mean_and_95ci(result[0][\"ra_x_yh\"])]\n",
    "ra_x = [mean_and_95ci(result[0][\"ra_x\"])]\n",
    "ra_yh = [mean_and_95ci(result[0][\"ra_yh\"])]\n",
    "# ra_yh_yai = [mean_and_95ci(result[0][\"ra_yh_yai\"])]\n",
    "ra_yh_expl = [mean_and_95ci(result[i][\"ra_yh_expl\"]) if i > 0 else None for i in range(len(result))]\n",
    "ra_expl = [mean_and_95ci(result[i][\"ra_expl\"]) if i > 0 else None for i in range(len(result))]\n",
    "# ra_yai = [mean_and_95ci(result[0][\"ra_yai\"])]\n",
    "ba_explanation = [mean_and_95ci(result[i][\"ba_explanation\"]) if i > 0 else None for i in range(len(result))]\n",
    "ba_no_explanation = [mean_and_95ci(result[i][\"ba_no_explanation\"]) if i == 0 else None for i in range(len(result))]\n",
    "\n",
    "ra_null = [mean_and_95ci(result[0][\"ra_null\"])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a9237a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import norm\n",
    "import numpy as np\n",
    "\n",
    "# Prepare data for plotting\n",
    "# We want to plot: ra_x_yh, ra_x, ra_yh, ba_explanation, ba_no_explanation, ra_null\n",
    "legend_order = [\n",
    "    # \"Human\",\n",
    "    \"RA (w/ prior)\",\n",
    "    \"RA (w/ y^H)\",\n",
    "    \"RA (w/ y^AI)\",\n",
    "    \"RA (w/ y^H + y^AI)\",\n",
    "    \"RA (w/ y^AI + expl)\",\n",
    "    \"RA (w/ y^H + y^AI + expl)\",\n",
    "    \"RA (w/ x)\",\n",
    "    # \"RA (w/ x+y^H)\"\n",
    "]\n",
    "\n",
    "values_for_plot = [\n",
    "    # (\"RA (w/ x+y^H)\", ra_x_yh),\n",
    "    (\"RA (w/ x)\", ra_x),\n",
    "    # (\"RA (w/ y^H)\", ra_yh),\n",
    "    # (\"RA (w/ y^AI)\", ra_yai),\n",
    "    (\"RA (w/ y^AI + expl)\", ra_expl),\n",
    "    # (\"RA (w/ y^H + expl)\", ra_yh_expl),\n",
    "    # (\"Human\", ba),\n",
    "    (\"RA (w/ prior)\", ra_null),\n",
    "    (\"BA (w/ expl)\", ba_explanation),\n",
    "    (\"BA (w/o expl)\", ba_no_explanation)\n",
    "    # (\"RA (w/ y^H + y^AI)\", ra_yai_yh),\n",
    "    # (\"RA (w/ y^H + y^AI + expl)\", ra_yh_expl_yai)\n",
    "]\n",
    "\n",
    "\n",
    "values_for_plot2 = [\n",
    "    # (\"RA (w/ x+y^H)\", ra_x_yh),\n",
    "    # (\"RA (w/ x)\", ra_x),\n",
    "    (\"RA (w/ x)\", ra_x),\n",
    "    (\"RA (w/ y^H)\", ra_yh),\n",
    "    # (\"RA (w/ y^AI)\", ra_yai),\n",
    "    # (\"RA (w/ y^AI + expl)\", ra_expl_yai),\n",
    "    # (\"RA (w/ y^H + expl)\", ra_yh_expl),\n",
    "    # (\"Human\", ba),\n",
    "    # (\"RA (w/ prior)\", ra_null),\n",
    "    # (\"RA (w/ y^H + y^AI)\", ra_yh),\n",
    "    (\"RA (w/ y^H + y^AI + expl)\", ra_yh_expl)\n",
    "    # (\"BA (w/ expl)\", ba_explanation),\n",
    "    # (\"BA (w/o expl)\", ba_no_explanation)\n",
    "]\n",
    "\n",
    "# Use ColorBrewer colors for consistency and accessibility.\n",
    "# From ColorBrewer: Set1 (colorblind safe/sequential)\n",
    "color_map = {\n",
    "    # \"RA (w/ x+y^H)\": \"#377eb8\",          # blue\n",
    "    \"RA (w/ x)\": \"#e41a1c\",              # red\n",
    "    \"RA (w/ x+y^H)\": \"#999999\",              # grey\n",
    "    \"RA (w/ y^H)\": \"#4daf4a\",            # green\n",
    "    \"RA (w/ y^H + y^AI)\": \"#377eb8\",     # blue\n",
    "    \"RA (w/ y^H + y^AI + expl)\": \"#ff7f00\",     # orange\n",
    "    # \"Human\": \"#984ea3\",            # purple\n",
    "    \"RA (w/ prior)\": \"#a65628\",      # brown\n",
    "    \"RA (w/ y^AI)\": \"#f781bf\",     # pink\n",
    "    \"RA (w/ y^AI + expl)\": \"#984ea3\",     # grey\n",
    "    \"BA (w/o expl)\": \"#999999\",              # grey\n",
    "    \"BA (w/ expl)\": \"#999900\"              # yellow\n",
    "}\n",
    "num_conditions = 5\n",
    "num_rows = len(values_for_plot)\n",
    "\n",
    "fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(5, 5))\n",
    "\n",
    "# axes is a 1D array of the two Axes objects\n",
    "ax1 = axes[0]\n",
    "ax2 = axes[1]\n",
    "\n",
    "# fig, ax = plt.subplots(figsize=(6,3))\n",
    "\n",
    "y_labels = []\n",
    "y_ticks = []\n",
    "y_offset = 0\n",
    "\n",
    "# Offset for points and errorbars within each group along y-axis\n",
    "group_width = 0.7\n",
    "single_offset = group_width / 2\n",
    "for y_tick in range(num_conditions):\n",
    "    ax1.axhline(y=y_tick, color='#101010', linestyle='-', linewidth=1, alpha=0.5)\n",
    "\n",
    "ax1.axvline(x=np.mean(ra_null[0]), color=color_map[\"RA (w/ prior)\"], linestyle='--', linewidth=1)\n",
    "# ax.axvline(x=np.mean(ra_yh[0]), color=color_map[\"RA (w/ y^H)\"], linestyle='--', linewidth=1)\n",
    "# ax.axvline(x=np.mean(ra_yai_yh[0]), color=color_map[\"RA (w/ y^H + y^AI)\"], linestyle='--', linewidth=1)\n",
    "ax1.axvline(x=np.mean(ra_x[0]), color=color_map[\"RA (w/ x)\"], linestyle='--', linewidth=1)\n",
    "# ax1.axvline(x=np.mean(ra_yai[0]), color=color_map[\"RA (w/ y^AI)\"], linestyle='--', linewidth=1)\n",
    "ax1.axvline(x=np.mean(ba_no_explanation[0]), color=color_map[\"BA (w/o expl)\"], linestyle='--', linewidth=1)\n",
    "\n",
    "for y_tick in range(num_conditions):\n",
    "    ax2.axhline(y=y_tick, color='#101010', linestyle='-', linewidth=1, alpha=0.5)\n",
    "\n",
    "# ax2.axvline(x=np.mean(ra_null[0]), color=color_map[\"RA (w/ prior)\"], linestyle='--', linewidth=1)\n",
    "ax2.axvline(x=np.mean(ra_yh[0]), color=color_map[\"RA (w/ y^H)\"], linestyle='--', linewidth=1)\n",
    "# ax2.axvline(x=np.mean(ra_yh[0]), color=color_map[\"RA (w/ y^H + y^AI)\"], linestyle='--', linewidth=1)\n",
    "# ax2.axvline(x=np.mean(ra_x[0]), color=color_map[\"RA (w/ x)\"], linestyle='--', linewidth=1)\n",
    "ax2.axvline(x=np.mean(ra_x[0]), color=color_map[\"RA (w/ x)\"], linestyle='--', linewidth=1)\n",
    "# ax2.axvline(x=np.mean(ba_no_explanation[0]), color=color_map[\"BA (w/o expl)\"], linestyle='--', linewidth=1)\n",
    "\n",
    "for i, (name, data) in enumerate(values_for_plot):\n",
    "    # for each metric to plot\n",
    "    y = []\n",
    "    y_lowers = []\n",
    "    y_uppers = []\n",
    "    x = []\n",
    "    # is_ci = len(data[0]) == 3 if not name.startswith(\"RA (prior only)\") else False\n",
    "\n",
    "    for j in range(num_conditions):\n",
    "        if j >= len(data) or data[j] is None:\n",
    "            continue\n",
    "        mean, lower, upper = data[j]\n",
    "        y_val = j\n",
    "        x.append(mean)\n",
    "        y.append(y_val)\n",
    "        y_labels.append(f\"{ name }\")\n",
    "        y_ticks.append(y_val)\n",
    "        # Plot the normal distribution centered at mean with std estimated from CI width\n",
    "        \n",
    "\n",
    "        ci_width = upper - lower\n",
    "        if ci_width > 1e-3:\n",
    "            std = ci_width / (2 * 1.96)  # approximate std assuming 95% CI\n",
    "\n",
    "            x_range = np.linspace(max(0, mean - 4*std), min(1.0, mean + 4*std), 200)\n",
    "            y_norm = norm.pdf(x_range, loc=mean, scale=std)\n",
    "            y_norm = y_norm / y_norm.max() * (group_width / 2.5)  # scale height for visual clarity\n",
    "            ax1.fill_between(x_range, y_val, y_val + y_norm, alpha=0.7, color=color_map[name], label=None)\n",
    "        \n",
    "\n",
    "        ax1.errorbar([mean], [y_val], xerr=[[mean - lower], [upper - mean]], fmt='o', label=name, capsize=0, color=color_map[name], elinewidth=3)\n",
    "\n",
    "# Set y-ticks and labels\n",
    "ax1.set_yticks(range(num_conditions))\n",
    "\n",
    "\n",
    "\n",
    "for i, (name, data) in enumerate(values_for_plot2):\n",
    "    # for each metric to plot\n",
    "    y = []\n",
    "    y_lowers = []\n",
    "    y_uppers = []\n",
    "    x = []\n",
    "    # is_ci = len(data[0]) == 3 if not name.startswith(\"RA (prior only)\") else False\n",
    "\n",
    "    for j in range(num_conditions):\n",
    "        if j >= len(data) or data[j] is None:\n",
    "            continue\n",
    "        mean, lower, upper = data[j]\n",
    "        y_val = j\n",
    "        x.append(mean)\n",
    "        y.append(y_val)\n",
    "        y_labels.append(f\"{ name }\")\n",
    "        y_ticks.append(y_val)\n",
    "        # Plot the normal distribution centered at mean with std estimated from CI width\n",
    "        \n",
    "\n",
    "        ci_width = upper - lower\n",
    "        if ci_width > 1e-3:\n",
    "            std = ci_width / (2 * 1.96)  # approximate std assuming 95% CI\n",
    "\n",
    "            x_range = np.linspace(max(0, mean - 4*std), min(1.0, mean + 4*std), 200)\n",
    "            y_norm = norm.pdf(x_range, loc=mean, scale=std)\n",
    "            y_norm = y_norm / y_norm.max() * (group_width / 2.5)  # scale height for visual clarity\n",
    "            ax2.fill_between(x_range, y_val, y_val + y_norm, alpha=0.7, color=color_map[name], label=None)\n",
    "        \n",
    "\n",
    "        ax2.errorbar([mean], [y_val], xerr=[[mean - lower], [upper - mean]], fmt='o', label=name, capsize=0, color=color_map[name], elinewidth=3)\n",
    "\n",
    "# Set y-ticks and labels\n",
    "ax2.set_yticks(range(num_conditions))\n",
    "\n",
    "# Create a legend (avoid duplicate labels)\n",
    "handles, labels = ax1.get_legend_handles_labels()\n",
    "seen = set()\n",
    "unique = [(h,l) for h,l in zip(handles, labels) if not (l in seen or seen.add(l))]\n",
    "# Desired legend order\n",
    "\n",
    "# Reorder unique to match legend_order\n",
    "label_to_handle = {l: h for h, l in unique}\n",
    "ordered_handles_labels = [(label_to_handle[l], l) for l in legend_order if l in label_to_handle]\n",
    "ordered_handles, ordered_labels = zip(*ordered_handles_labels)\n",
    "\n",
    "# ax1.legend(\n",
    "#     ordered_handles,\n",
    "#     ordered_labels,\n",
    "#     loc=\"upper center\",\n",
    "#     bbox_to_anchor=(0.5, 1.1),\n",
    "#     ncol=6,\n",
    "#     frameon=False,\n",
    "#     borderaxespad=0.01,\n",
    "#     handletextpad=0.2,\n",
    "#     labelspacing=0.01\n",
    "# )\n",
    "\n",
    "\n",
    "# Create a legend (avoid duplicate labels)\n",
    "handles, labels = ax2.get_legend_handles_labels()\n",
    "seen = set()\n",
    "unique = [(h,l) for h,l in zip(handles, labels) if not (l in seen or seen.add(l))]\n",
    "# Desired legend order\n",
    "\n",
    "# Reorder unique to match legend_order\n",
    "label_to_handle = {l: h for h, l in unique}\n",
    "ordered_handles_labels = [(label_to_handle[l], l) for l in legend_order if l in label_to_handle]\n",
    "ordered_handles, ordered_labels = zip(*ordered_handles_labels)\n",
    "# ax2.legend(\n",
    "#     ordered_handles,\n",
    "#     ordered_labels,\n",
    "#     loc=\"upper center\",\n",
    "#     bbox_to_anchor=(0.5, 1.1),\n",
    "#     ncol=6,\n",
    "#     frameon=False,\n",
    "#     borderaxespad=0.01,\n",
    "#     handletextpad=0.2,\n",
    "#     labelspacing=0.01\n",
    "# )\n",
    "\n",
    "# Remove the top, left, and right border lines (spines) of the plot\n",
    "ax1.spines['top'].set_visible(False)\n",
    "ax1.spines['left'].set_visible(False)\n",
    "ax1.spines['right'].set_visible(False)\n",
    "ax2.spines['top'].set_visible(False)\n",
    "ax2.spines['left'].set_visible(False)\n",
    "ax2.spines['right'].set_visible(False)\n",
    "\n",
    "ax2.set_xlim(0.45, 1)\n",
    "\n",
    "ax2.set_xlabel(\"Accuracy\")\n",
    "ax2.set_yticklabels([\"No expl.\", \"Explain-top-1\", \"Explain-top-2\", \"Adaptive\", \"Expert\"])\n",
    "# ax.set_title(\"Comparison of Metrics Across Experiment Conditions (with 95% CI)\")\n",
    "\n",
    "ax1.sharex(ax2)\n",
    "ax1.set_yticklabels([\"No expl.\", \"Explain-top-1\", \"Explain-top-2\", \"Adaptive\", \"Expert\"])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('bansal_amzbook_explanation_means.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94ef6ade",
   "metadata": {},
   "source": [
    "# Beer Review"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73cf1e57",
   "metadata": {},
   "outputs": [],
   "source": [
    "review_df = pd.read_csv(\"save_df_beer.csv\")\n",
    "review_df.loc[:, \"review_id\"] = range(len(review_df))\n",
    "\n",
    "df = pd.read_csv(\"bansal-et-al-decision-result-filter.csv\")\n",
    "df = df.loc[df[\"task\"] == \"beer\"]\n",
    "human_df = df.loc[df[\"condition\"] == \"Conf.\"]\n",
    "df = pd.merge(df, review_df.drop(columns=[\"pred\", \"conf\"]), left_on=\"questionId\", right_on=\"review_id\", how=\"left\")\n",
    "human_df = pd.merge(human_df, review_df.drop(columns=[\"pred\", \"conf\"]), left_on=\"questionId\", right_on=\"review_id\", how=\"left\")\n",
    "human_df.loc[:, \"choice\"] = human_df[\"choice\"].astype(int)\n",
    "human_df.loc[:, \"conf\"] = human_df[\"conf\"].apply(lambda x: int(x * 10) / 10)\n",
    "human_df[\"y\"] = human_df[\"y\"].astype(int)\n",
    "\n",
    "result = []\n",
    "\n",
    "info_model = decision_infovalue.DecisionInfoModel(human_df, \"y\", \n",
    "                                    signals=[\"x_cluster_id\", \"choice\", \"conf\", \n",
    "                                    \"explanation_top1_cluster_id\", \"explanation_top2_cluster_id\", \"explanation_adaptive_cluster_id\", \"explanation_expert_cluster_id\",\n",
    "                                    \"pred\"], \n",
    "                                    scoring_rule=\"threshold_accuracy\",\n",
    "                                    binning_method=None)\n",
    "\n",
    "result.append({\n",
    "    \"ra_null\": info_model.complement_info_value([], None, ret_orig_value=True),\n",
    "    \"ra_x\": info_model.complement_info_value([\"x_cluster_id\"], None, ret_orig_value=True),\n",
    "    # \"ra_yai\": info_model.complement_info_value([\"pred\"], None, ret_orig_value=True),\n",
    "    # \"ra_expl_yai\": info_model.complement_info_value([\"explanation_col\", \"predicted_label\"], None, ret_orig_value=True),\n",
    "    \n",
    "    # \"ra_x_yh\": info_model.complement_info_value([\"x_cluster_id\", \"user_label\"], None, ret_orig_value=True),\n",
    "    \"ra_yh\": info_model.complement_info_value([\"choice\"], None, ret_orig_value=True),\n",
    "    # \"ra_yh_yai\": info_model.complement_info_value([\"choice\", \"pred\"], None, ret_orig_value=True),\n",
    "    # \"ra_yh_expl_yai\": info_model.complement_info_value([\"user_label\", \"explanation_col\", \"predicted_label\"], None, ret_orig_value=True),\n",
    "    # \"ra_x_yh\": info_model.complement_info_value([\"x_cluster_id\", \"choice\"], None, ret_orig_value=True)\n",
    "    # \"ra_expl\": info_model.complement_info_value([\"explanation_col\"], None, ret_orig_value=True)\n",
    "})\n",
    "\n",
    "# result.append({\n",
    "#     \"ra_expl_yai\": info_model.complement_info_value([\"conf\", \"pred\"], None, ret_orig_value=True),\n",
    "#     \"ra_yh_expl_yai\": info_model.complement_info_value([\"choice\", \"conf\", \"pred\"], None, ret_orig_value=True)\n",
    "# })\n",
    "\n",
    "result.append({\n",
    "    \"ra_expl\": info_model.complement_info_value([\"explanation_top1_cluster_id\"], None, ret_orig_value=True),\n",
    "    \"ra_yh_expl\": info_model.complement_info_value([\"choice\", \"explanation_top1_cluster_id\"], None, ret_orig_value=True)\n",
    "})\n",
    "result.append({\n",
    "    \"ra_expl\": info_model.complement_info_value([\"explanation_top2_cluster_id\"], None, ret_orig_value=True),\n",
    "    \"ra_yh_expl\": info_model.complement_info_value([\"choice\", \"explanation_top2_cluster_id\"], None, ret_orig_value=True)\n",
    "})\n",
    "result.append({\n",
    "    \"ra_expl\": info_model.complement_info_value([\"explanation_adaptive_cluster_id\"], None, ret_orig_value=True),\n",
    "    \"ra_yh_expl\": info_model.complement_info_value([\"choice\", \"explanation_adaptive_cluster_id\"], None, ret_orig_value=True)\n",
    "})\n",
    "result.append({\n",
    "    \"ra_expl\": info_model.complement_info_value([\"explanation_expert_cluster_id\"], None, ret_orig_value=True),\n",
    "    \"ra_yh_expl\": info_model.complement_info_value([\"choice\", \"explanation_expert_cluster_id\"], None, ret_orig_value=True)\n",
    "})\n",
    "\n",
    "result[0][\"ba_no_explanation\"] = (df.loc[(df[\"condition\"]==\"Conf.\"), \"y\"] == df.loc[(df[\"condition\"]==\"Conf.\"), \"choice\"]).to_numpy()\n",
    "result[1][\"ba_explanation\"] = (df.loc[(df[\"condition\"]==\"Conf.+Single\"), \"y\"] == df.loc[(df[\"condition\"]==\"Conf.+Single\"), \"choice\"]).to_numpy()\n",
    "result[2][\"ba_explanation\"] = (df.loc[(df[\"condition\"]==\"Conf.+Double\"), \"y\"] == df.loc[(df[\"condition\"]==\"Conf.+Double\"), \"choice\"]).to_numpy()\n",
    "result[3][\"ba_explanation\"] = (df.loc[(df[\"condition\"]==\"Conf.+Adaptive\"), \"y\"] == df.loc[(df[\"condition\"]==\"Conf.+Adaptive\"), \"choice\"]).to_numpy()\n",
    "result[4][\"ba_explanation\"] = (df.loc[(df[\"condition\"]==\"Conf.+Adaptive (Expert)\"), \"y\"] == df.loc[(df[\"condition\"]==\"Conf.+Adaptive (Expert)\"), \"choice\"]).to_numpy()\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from scipy.stats import bootstrap\n",
    "\n",
    "# Extract data for plotting and compute 95% CI with bootstrap\n",
    "\n",
    "def mean_and_95ci(arr):\n",
    "    # Convert boolean to int if needed\n",
    "    arr = np.array(arr).astype(float)\n",
    "    res = bootstrap((arr,), np.mean, confidence_level=0.95, n_resamples=1000, method=\"basic\")\n",
    "    lower = res.confidence_interval.low\n",
    "    upper = res.confidence_interval.high\n",
    "    mean = np.mean(arr)\n",
    "    # Return mean, lower, upper\n",
    "    return mean, lower, upper\n",
    "\n",
    "# conditions = [\"control\", \"machine\", \"machine_and_random_heatmap\", \"machine_and_examples\", \"machine_and_heatmap\", \"machine_with_accuracy\"] # list(result.keys())\n",
    "# ra_x_yh = [mean_and_95ci(result[0][\"ra_x_yh\"])]\n",
    "ra_x = [mean_and_95ci(result[0][\"ra_x\"])]\n",
    "ra_yh = [mean_and_95ci(result[0][\"ra_yh\"])]\n",
    "# ra_yh_yai = [mean_and_95ci(result[0][\"ra_yh_yai\"])]\n",
    "ra_yh_expl = [mean_and_95ci(result[i][\"ra_yh_expl\"]) if i > 0 else None for i in range(len(result))]\n",
    "ra_expl = [mean_and_95ci(result[i][\"ra_expl\"]) if i > 0 else None for i in range(len(result))]\n",
    "# ra_yai = [mean_and_95ci(result[0][\"ra_yai\"])]\n",
    "ba_explanation = [mean_and_95ci(result[i][\"ba_explanation\"]) if i > 0 else None for i in range(len(result))]\n",
    "ba_no_explanation = [mean_and_95ci(result[i][\"ba_no_explanation\"]) if i == 0 else None for i in range(len(result))]\n",
    "# ba = [mean_and_95ci(result[condition][\"ba\"]) for condition in conditions]\n",
    "\n",
    "ra_null = [mean_and_95ci(result[0][\"ra_null\"])]\n",
    "\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import norm\n",
    "import numpy as np\n",
    "\n",
    "# Prepare data for plotting\n",
    "# We want to plot: ra_x_yh, ra_x, ra_yh, ba_explanation, ba_no_explanation, ra_null\n",
    "legend_order = [\n",
    "    # \"Human\",\n",
    "    \"RA (w/ prior)\",\n",
    "    \"RA (w/ y^H)\",\n",
    "    \"RA (w/ y^AI)\",\n",
    "    \"RA (w/ y^H + y^AI)\",\n",
    "    \"RA (w/ y^AI + expl)\",\n",
    "    \"RA (w/ y^H + y^AI + expl)\",\n",
    "    \"RA (w/ x)\",\n",
    "    # \"RA (w/ x+y^H)\"\n",
    "]\n",
    "\n",
    "values_for_plot = [\n",
    "    # (\"RA (w/ x+y^H)\", ra_x_yh),\n",
    "    (\"RA (w/ x)\", ra_x),\n",
    "    # (\"RA (w/ y^H)\", ra_yh),\n",
    "    # (\"RA (w/ y^AI)\", ra_yai),\n",
    "    (\"RA (w/ y^AI + expl)\", ra_expl),\n",
    "    # (\"RA (w/ y^H + expl)\", ra_yh_expl),\n",
    "    # (\"Human\", ba),\n",
    "    (\"RA (w/ prior)\", ra_null),\n",
    "    (\"BA (w/ expl)\", ba_explanation),\n",
    "    (\"BA (w/o expl)\", ba_no_explanation),\n",
    "    # (\"RA (w/ y^H + y^AI)\", ra_yai_yh),\n",
    "    # (\"RA (w/ y^H + y^AI + expl)\", ra_yh_expl_yai)\n",
    "]\n",
    "\n",
    "\n",
    "values_for_plot2 = [\n",
    "    # (\"RA (w/ x+y^H)\", ra_x_yh),\n",
    "    # (\"RA (w/ x)\", ra_x),\n",
    "    (\"RA (w/ x)\", ra_x),\n",
    "    (\"RA (w/ y^H)\", ra_yh),\n",
    "    # (\"RA (w/ y^AI)\", ra_yai),\n",
    "    # (\"RA (w/ y^AI + expl)\", ra_expl_yai),\n",
    "    # (\"RA (w/ y^H + expl)\", ra_yh_expl),\n",
    "    # (\"Human\", ba),\n",
    "    # (\"RA (w/ prior)\", ra_null),\n",
    "    # (\"RA (w/ y^H + y^AI)\", ra_yh),\n",
    "    (\"RA (w/ y^H + y^AI + expl)\", ra_yh_expl),\n",
    "    # (\"BA (w/ expl)\", ba_explanation),\n",
    "    # (\"BA (w/o expl)\", ba_no_explanation),\n",
    "]\n",
    "\n",
    "# Use ColorBrewer colors for consistency and accessibility.\n",
    "# From ColorBrewer: Set1 (colorblind safe/sequential)\n",
    "color_map = {\n",
    "    # \"RA (w/ x+y^H)\": \"#377eb8\",          # blue\n",
    "    \"RA (w/ x)\": \"#e41a1c\",              # red\n",
    "    \"RA (w/ x+y^H)\": \"#999999\",              # grey\n",
    "    \"RA (w/ y^H)\": \"#4daf4a\",            # green\n",
    "    \"RA (w/ y^H + y^AI)\": \"#377eb8\",     # blue\n",
    "    \"RA (w/ y^H + y^AI + expl)\": \"#ff7f00\",     # orange\n",
    "    # \"Human\": \"#984ea3\",            # purple\n",
    "    \"RA (w/ prior)\": \"#a65628\",      # brown\n",
    "    \"RA (w/ y^AI)\": \"#f781bf\",     # pink\n",
    "    \"RA (w/ y^AI + expl)\": \"#984ea3\",     # grey\n",
    "    \"BA (w/o expl)\": \"#999999\",              # grey\n",
    "    \"BA (w/ expl)\": \"#999900\"              # yellow\n",
    "}\n",
    "num_conditions = 5\n",
    "num_rows = len(values_for_plot)\n",
    "\n",
    "fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(5, 5))\n",
    "\n",
    "# axes is a 1D array of the two Axes objects\n",
    "ax1 = axes[0]\n",
    "ax2 = axes[1]\n",
    "\n",
    "# fig, ax = plt.subplots(figsize=(6,3))\n",
    "\n",
    "y_labels = []\n",
    "y_ticks = []\n",
    "y_offset = 0\n",
    "\n",
    "# Offset for points and errorbars within each group along y-axis\n",
    "group_width = 0.7\n",
    "single_offset = group_width / 2\n",
    "for y_tick in range(num_conditions):\n",
    "    ax1.axhline(y=y_tick, color='#101010', linestyle='-', linewidth=1, alpha=0.5)\n",
    "\n",
    "ax1.axvline(x=np.mean(ra_null[0]), color=color_map[\"RA (w/ prior)\"], linestyle='--', linewidth=1)\n",
    "# ax.axvline(x=np.mean(ra_yh[0]), color=color_map[\"RA (w/ y^H)\"], linestyle='--', linewidth=1)\n",
    "# ax.axvline(x=np.mean(ra_yai_yh[0]), color=color_map[\"RA (w/ y^H + y^AI)\"], linestyle='--', linewidth=1)\n",
    "ax1.axvline(x=np.mean(ra_x[0]), color=color_map[\"RA (w/ x)\"], linestyle='--', linewidth=1)\n",
    "# ax1.axvline(x=np.mean(ra_yai[0]), color=color_map[\"RA (w/ y^AI)\"], linestyle='--', linewidth=1)\n",
    "ax1.axvline(x=np.mean(ba_no_explanation[0]), color=color_map[\"BA (w/o expl)\"], linestyle='--', linewidth=1)\n",
    "\n",
    "for y_tick in range(num_conditions):\n",
    "    ax2.axhline(y=y_tick, color='#101010', linestyle='-', linewidth=1, alpha=0.5)\n",
    "\n",
    "# ax2.axvline(x=np.mean(ra_null[0]), color=color_map[\"RA (w/ prior)\"], linestyle='--', linewidth=1)\n",
    "ax2.axvline(x=np.mean(ra_yh[0]), color=color_map[\"RA (w/ y^H)\"], linestyle='--', linewidth=1)\n",
    "# ax2.axvline(x=np.mean(ra_yh_yai[0]), color=color_map[\"RA (w/ y^H + y^AI)\"], linestyle='--', linewidth=1)\n",
    "# ax2.axvline(x=np.mean(ra_x[0]), color=color_map[\"RA (w/ x)\"], linestyle='--', linewidth=1)\n",
    "ax2.axvline(x=np.mean(ra_x[0]), color=color_map[\"RA (w/ x)\"], linestyle='--', linewidth=1)\n",
    "# ax2.axvline(x=np.mean(ba_no_explanation[0]), color=color_map[\"BA (w/o expl)\"], linestyle='--', linewidth=1)\n",
    "\n",
    "for i, (name, data) in enumerate(values_for_plot):\n",
    "    # for each metric to plot\n",
    "    y = []\n",
    "    y_lowers = []\n",
    "    y_uppers = []\n",
    "    x = []\n",
    "    # is_ci = len(data[0]) == 3 if not name.startswith(\"RA (prior only)\") else False\n",
    "\n",
    "    for j in range(num_conditions):\n",
    "        if j >= len(data) or data[j] is None:\n",
    "            continue\n",
    "        mean, lower, upper = data[j]\n",
    "        y_val = j\n",
    "        x.append(mean)\n",
    "        y.append(y_val)\n",
    "        y_labels.append(f\"{ name }\")\n",
    "        y_ticks.append(y_val)\n",
    "        # Plot the normal distribution centered at mean with std estimated from CI width\n",
    "        \n",
    "\n",
    "        ci_width = upper - lower\n",
    "        if ci_width > 1e-3:\n",
    "            std = ci_width / (2 * 1.96)  # approximate std assuming 95% CI\n",
    "\n",
    "            x_range = np.linspace(max(0, mean - 4*std), min(1.0, mean + 4*std), 200)\n",
    "            y_norm = norm.pdf(x_range, loc=mean, scale=std)\n",
    "            y_norm = y_norm / y_norm.max() * (group_width / 2.5)  # scale height for visual clarity\n",
    "            ax1.fill_between(x_range, y_val, y_val + y_norm, alpha=0.7, color=color_map[name], label=None)\n",
    "        \n",
    "\n",
    "        ax1.errorbar([mean], [y_val], xerr=[[mean - lower], [upper - mean]], fmt='o', label=name, capsize=0, color=color_map[name], elinewidth=2, markersize=4)\n",
    "\n",
    "# Set y-ticks and labels\n",
    "ax1.set_yticks(range(num_conditions))\n",
    "\n",
    "\n",
    "\n",
    "for i, (name, data) in enumerate(values_for_plot2):\n",
    "    # for each metric to plot\n",
    "    y = []\n",
    "    y_lowers = []\n",
    "    y_uppers = []\n",
    "    x = []\n",
    "    # is_ci = len(data[0]) == 3 if not name.startswith(\"RA (prior only)\") else False\n",
    "\n",
    "    for j in range(num_conditions):\n",
    "        if j >= len(data) or data[j] is None:\n",
    "            continue\n",
    "        mean, lower, upper = data[j]\n",
    "        y_val = j\n",
    "        x.append(mean)\n",
    "        y.append(y_val)\n",
    "        y_labels.append(f\"{ name }\")\n",
    "        y_ticks.append(y_val)\n",
    "        # Plot the normal distribution centered at mean with std estimated from CI width\n",
    "        \n",
    "\n",
    "        ci_width = upper - lower\n",
    "        if ci_width > 1e-3:\n",
    "            std = ci_width / (2 * 1.96)  # approximate std assuming 95% CI\n",
    "\n",
    "            x_range = np.linspace(max(0, mean - 4*std), min(1.0, mean + 4*std), 200)\n",
    "            y_norm = norm.pdf(x_range, loc=mean, scale=std)\n",
    "            y_norm = y_norm / y_norm.max() * (group_width / 2.5)  # scale height for visual clarity\n",
    "            ax2.fill_between(x_range, y_val, y_val + y_norm, alpha=0.7, color=color_map[name], label=None)\n",
    "        \n",
    "        ax2.errorbar([mean], [y_val], xerr=[[mean - lower], [upper - mean]], fmt='o', label=name, capsize=0, color=color_map[name], elinewidth=2, markersize=4)\n",
    "\n",
    "# Set y-ticks and labels\n",
    "ax2.set_yticks(range(num_conditions))\n",
    "\n",
    "# Create a legend (avoid duplicate labels)\n",
    "handles, labels = ax1.get_legend_handles_labels()\n",
    "seen = set()\n",
    "unique = [(h,l) for h,l in zip(handles, labels) if not (l in seen or seen.add(l))]\n",
    "# Desired legend order\n",
    "\n",
    "# Reorder unique to match legend_order\n",
    "label_to_handle = {l: h for h, l in unique}\n",
    "ordered_handles_labels = [(label_to_handle[l], l) for l in legend_order if l in label_to_handle]\n",
    "ordered_handles, ordered_labels = zip(*ordered_handles_labels)\n",
    "\n",
    "# ax1.legend(\n",
    "#     ordered_handles,\n",
    "#     ordered_labels,\n",
    "#     loc=\"upper center\",\n",
    "#     bbox_to_anchor=(0.5, 1.1),\n",
    "#     ncol=6,\n",
    "#     frameon=False,\n",
    "#     borderaxespad=0.01,\n",
    "#     handletextpad=0.2,\n",
    "#     labelspacing=0.01\n",
    "# )\n",
    "\n",
    "\n",
    "# Create a legend (avoid duplicate labels)\n",
    "handles, labels = ax2.get_legend_handles_labels()\n",
    "seen = set()\n",
    "unique = [(h,l) for h,l in zip(handles, labels) if not (l in seen or seen.add(l))]\n",
    "# Desired legend order\n",
    "\n",
    "# Reorder unique to match legend_order\n",
    "label_to_handle = {l: h for h, l in unique}\n",
    "ordered_handles_labels = [(label_to_handle[l], l) for l in legend_order if l in label_to_handle]\n",
    "ordered_handles, ordered_labels = zip(*ordered_handles_labels)\n",
    "# ax2.legend(\n",
    "#     ordered_handles,\n",
    "#     ordered_labels,\n",
    "#     loc=\"upper center\",\n",
    "#     bbox_to_anchor=(0.5, 1.1),\n",
    "#     ncol=6,\n",
    "#     frameon=False,\n",
    "#     borderaxespad=0.01,\n",
    "#     handletextpad=0.2,\n",
    "#     labelspacing=0.01\n",
    "# )\n",
    "\n",
    "# Remove the top, left, and right border lines (spines) of the plot\n",
    "ax1.spines['top'].set_visible(False)\n",
    "ax1.spines['left'].set_visible(False)\n",
    "ax1.spines['right'].set_visible(False)\n",
    "ax2.spines['top'].set_visible(False)\n",
    "ax2.spines['left'].set_visible(False)\n",
    "ax2.spines['right'].set_visible(False)\n",
    "\n",
    "ax2.set_xlim(0.45, 1)\n",
    "\n",
    "ax2.set_xlabel(\"Accuracy\")\n",
    "ax2.set_yticklabels([\"No expl.\", \"Explain-top-1\", \"Explain-top-2\", \"Adaptive\", \"Expert\"])\n",
    "# ax.set_title(\"Comparison of Metrics Across Experiment Conditions (with 95% CI)\")\n",
    "\n",
    "ax1.sharex(ax2)\n",
    "ax1.set_yticklabels([\"No expl.\", \"Explain-top-1\", \"Explain-top-2\", \"Adaptive\", \"Expert\"])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('bansal_beer_explanation_means.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4511fd29",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from dataclasses import dataclass\n",
    "from typing import Any, Dict, List, Optional, Tuple\n",
    "\n",
    "from sklearn.base import clone\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class ATEResult:\n",
    "    ate: float\n",
    "    se: float\n",
    "    ci95: Tuple[float, float]\n",
    "    details: Dict[str, Any]\n",
    "\n",
    "\n",
    "def _clip_ps(ps: np.ndarray, eps: float = 1e-3) -> np.ndarray:\n",
    "    return np.clip(ps, eps, 1 - eps)\n",
    "\n",
    "\n",
    "def _robust_se(phi: np.ndarray) -> float:\n",
    "    n = phi.shape[0]\n",
    "    return float(np.sqrt(np.var(phi, ddof=1) / n))\n",
    "\n",
    "\n",
    "def _ci95(est: float, se: float) -> Tuple[float, float]:\n",
    "    z = 1.96\n",
    "    return (est - z * se, est + z * se)\n",
    "\n",
    "\n",
    "def aipw_ate_text_binary(\n",
    "    y: np.ndarray,\n",
    "    t: np.ndarray,\n",
    "    texts: List[str],\n",
    "    *,\n",
    "    outcome_model: Optional[Any] = None,\n",
    "    ps_model: Optional[Any] = None,\n",
    "    n_splits: int = 5,\n",
    "    clip_eps: float = 1e-3,\n",
    "    random_state: int = 0,\n",
    ") -> ATEResult:\n",
    "    \"\"\"\n",
    "    Doubly robust AIPW estimate of ATE for binary outcome Y and binary treatment T,\n",
    "    with text covariates X (raw texts). Uses cross-fitting by default.\n",
    "\n",
    "    Returns ATE on the probability scale: E[Y(1) - Y(0)].\n",
    "    \"\"\"\n",
    "    y = np.asarray(y).reshape(-1).astype(int)\n",
    "    t = np.asarray(t).reshape(-1).astype(int)\n",
    "\n",
    "    if len(texts) != len(y):\n",
    "        raise ValueError(\"texts must have the same length as y.\")\n",
    "    if not set(np.unique(y)).issubset({0, 1}):\n",
    "        raise ValueError(\"y must be binary in {0,1}.\")\n",
    "    if not set(np.unique(t)).issubset({0, 1}):\n",
    "        raise ValueError(\"t must be binary in {0,1}.\")\n",
    "\n",
    "    # Default: TF-IDF -> LogisticRegression for both nuisance models\n",
    "    if outcome_model is None:\n",
    "        outcome_model = Pipeline([\n",
    "            (\"tfidf\", TfidfVectorizer(ngram_range=(1, 2), min_df=2, max_df=0.95)),\n",
    "            (\"clf\", LogisticRegression(max_iter=5000))\n",
    "        ])\n",
    "    if ps_model is None:\n",
    "        ps_model = Pipeline([\n",
    "            (\"tfidf\", TfidfVectorizer(ngram_range=(1, 2), min_df=2, max_df=0.95)),\n",
    "            (\"clf\", LogisticRegression(max_iter=5000))\n",
    "        ])\n",
    "\n",
    "    n = len(y)\n",
    "    mu1_hat = np.zeros(n, dtype=float)  # P(Y=1 | T=1, X)\n",
    "    mu0_hat = np.zeros(n, dtype=float)  # P(Y=1 | T=0, X)\n",
    "    e_hat = np.zeros(n, dtype=float)    # P(T=1 | X)\n",
    "\n",
    "    kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)\n",
    "    idx = np.arange(n)\n",
    "\n",
    "    for tr, te in kf.split(idx):\n",
    "        texts_tr = [texts[i] for i in tr]\n",
    "        texts_te = [texts[i] for i in te]\n",
    "        y_tr, t_tr = y[tr], t[tr]\n",
    "\n",
    "        # 1) propensity model e(x)\n",
    "        ps = clone(ps_model).fit(texts_tr, t_tr)\n",
    "        e_hat[te] = ps.predict_proba(texts_te)[:, 1]\n",
    "\n",
    "        # 2) outcome models mu1(x), mu0(x) fitted on their respective groups\n",
    "        tr1 = tr[t_tr == 1]\n",
    "        tr0 = tr[t_tr == 0]\n",
    "        if len(tr1) == 0 or len(tr0) == 0:\n",
    "            raise ValueError(\"A fold has no treated or no control samples. Reduce n_splits or check data balance.\")\n",
    "\n",
    "        om1 = clone(outcome_model).fit([texts[i] for i in tr1], y[tr1])\n",
    "        om0 = clone(outcome_model).fit([texts[i] for i in tr0], y[tr0])\n",
    "\n",
    "        mu1_hat[te] = om1.predict_proba(texts_te)[:, 1]\n",
    "        mu0_hat[te] = om0.predict_proba(texts_te)[:, 1]\n",
    "\n",
    "    e_hat = _clip_ps(e_hat, clip_eps)\n",
    "\n",
    "    # AIPW influence function\n",
    "    phi = (mu1_hat - mu0_hat) + (t * (y - mu1_hat) / e_hat) - ((1 - t) * (y - mu0_hat) / (1 - e_hat))\n",
    "    ate = float(np.mean(phi))\n",
    "    se = _robust_se(phi - ate)\n",
    "    return ATEResult(ate=ate, se=se, ci95=_ci95(ate, se),\n",
    "                     details={\"e_hat\": e_hat, \"mu1_hat\": mu1_hat, \"mu0_hat\": mu0_hat, \"phi\": phi})\n",
    "\n",
    "conditions = ['Conf.+Double', 'Conf.+Single', 'Conf.+Adaptive',\n",
    "       'Conf.+Adaptive (Expert)']\n",
    "\n",
    "ate_result = {}\n",
    "\n",
    "for condition in conditions:\n",
    "    df_condition = df.loc[(df[\"condition\"]==condition) | (df[\"condition\"]==\"Conf.\"), :]\n",
    "    x = df_condition[\"text\"].tolist()\n",
    "    y = (df_condition[\"y\"] == df_condition[\"choice\"]).tolist()\n",
    "    t = df_condition[\"condition\"].apply(lambda x: 1 if x == condition else 0).tolist()\n",
    "    res = aipw_ate_text_binary(y, t, x, n_splits=5)\n",
    "    ate_result[condition] = res\n",
    "ate_result"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
