{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71519e52",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "from itertools import combinations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1343cf89",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "import decision_infovalue\n",
    "import importlib\n",
    "importlib.reload(decision_infovalue)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3c519cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"toxic_training_evaluation_clustered.csv\")\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54c7ad28",
   "metadata": {},
   "outputs": [],
   "source": [
    "comment_df = df.groupby(\"id\").agg({ \"is_toxic_training_data_without_sae\": \"first\", \"toxic_model\": \"first\", \"generated_text_cluster_id\": \"first\", \"input_text_cluster_id\": \"first\", \"features_with_interpretations_cluster_id\": \"first\"}).reset_index()\n",
    "comment_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e863545e",
   "metadata": {},
   "outputs": [],
   "source": [
    "result = []\n",
    "\n",
    "info_model = decision_infovalue.DecisionInfoModel(df, \"toxic_model\", \n",
    "                                    signals=[\"input_text_cluster_id\", \"is_toxic_training_data_without_sae\", \"features_with_interpretations_cluster_id\",\n",
    "                                    \"generated_text_cluster_id\"], \n",
    "                                    scoring_rule=\"threshold_accuracy\",\n",
    "                                    binning_method=None)\n",
    "\n",
    "result.append({\n",
    "    \"ra_null\": info_model.complement_info_value([], None, ret_orig_value=True),\n",
    "    \"ra_x\": info_model.complement_info_value([\"input_text_cluster_id\"], None, ret_orig_value=True),\n",
    "    # \"ra_yai\": info_model.complement_info_value([\"generated_text_cluster_id\"], None, ret_orig_value=True),\n",
    "    # \"ra_expl_yai\": info_model.complement_info_value([\"explanation_col\", \"generated_text_cluster_idicted_label\"], None, ret_orig_value=True),\n",
    "    \n",
    "    # \"ra_x_yh\": info_model.complement_info_value([\"input_text_cluster_id\", \"user_label\"], None, ret_orig_value=True),\n",
    "    \"ra_yh\": info_model.complement_info_value([\"is_toxic_training_data_without_sae\"], None, ret_orig_value=True),\n",
    "    # \"ra_yh_yai\": info_model.complement_info_value([\"is_toxic_training_data_without_sae\", \"generated_text_cluster_id\"], None, ret_orig_value=True),\n",
    "    # \"ra_yh_expl_yai\": info_model.complement_info_value([\"user_label\", \"explanation_col\", \"generated_text_cluster_idicted_label\"], None, ret_orig_value=True),\n",
    "    # \"ra_x_yh\": info_model.complement_info_value([\"input_text_cluster_id\", \"is_toxic_training_data_without_sae\"], None, ret_orig_value=True)\n",
    "    # \"ra_expl\": info_model.complement_info_value([\"explanation_col\"], None, ret_orig_value=True)\n",
    "})\n",
    "result[0][\"ba_no_explanation\"] = (comment_df[\"is_toxic_training_data_without_sae\"] == comment_df[\"toxic_model\"]).to_numpy()\n",
    "\n",
    "unique_number_of_features = set(df[\"number_of_features\"])\n",
    "for i in unique_number_of_features:\n",
    "    number_df = df[df[\"number_of_features\"] == i]\n",
    "    info_model = decision_infovalue.DecisionInfoModel(number_df, \"toxic_model\", \n",
    "                                    signals=[\"features_with_interpretations_cluster_id\",\n",
    "                                    \"is_toxic_training_data_without_sae\", \n",
    "                                    \"generated_text_cluster_id\"], \n",
    "                                    scoring_rule=\"threshold_accuracy\",\n",
    "                                    binning_method=None)\n",
    "    result.append({\n",
    "        \"ra_expl\": info_model.complement_info_value([\"features_with_interpretations_cluster_id\"], None, ret_orig_value=True),\n",
    "        \"ra_yh_expl\": info_model.complement_info_value([\"is_toxic_training_data_without_sae\", \"features_with_interpretations_cluster_id\"], None, ret_orig_value=True)\n",
    "    })\n",
    "    result[len(result) - 1][\"ba_explanation\"] = (number_df[\"is_toxic_training_data_with_sae\"] == number_df[\"toxic_model\"]).to_numpy()\n",
    "\n",
    "\n",
    "result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4bafd77",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from scipy.stats import bootstrap\n",
    "\n",
    "# Extract data for plotting and compute 95% CI with bootstrap\n",
    "\n",
    "def mean_and_95ci(arr):\n",
    "    # Convert boolean to int if needed\n",
    "    arr = np.array(arr).astype(float)\n",
    "    res = bootstrap((arr,), np.mean, confidence_level=0.95, n_resamples=1000, method=\"basic\")\n",
    "    lower = res.confidence_interval.low\n",
    "    upper = res.confidence_interval.high\n",
    "    mean = np.mean(arr)\n",
    "    # Return mean, lower, upper\n",
    "    return mean, lower, upper\n",
    "\n",
    "# conditions = [\"control\", \"machine\", \"machine_and_random_heatmap\", \"machine_and_examples\", \"machine_and_heatmap\", \"machine_with_accuracy\"] # list(result.keys())\n",
    "# ra_x_yh = [mean_and_95ci(result[0][\"ra_x_yh\"])]\n",
    "ra_x = [mean_and_95ci(result[0][\"ra_x\"])]\n",
    "ra_yh = [mean_and_95ci(result[0][\"ra_yh\"])]\n",
    "# ra_yh_yai = [mean_and_95ci(result[0][\"ra_yh_yai\"])]\n",
    "ra_yh_expl = [mean_and_95ci(result[i][\"ra_yh_expl\"]) if i > 0 else None for i in range(len(result))]\n",
    "ra_expl = [mean_and_95ci(result[i][\"ra_expl\"]) if i > 0 else None for i in range(len(result))]\n",
    "# ra_yai = [mean_and_95ci(result[0][\"ra_yai\"])]\n",
    "ba_explanation = [mean_and_95ci(result[i][\"ba_explanation\"]) if i > 0 else None for i in range(len(result))]\n",
    "ba_no_explanation = [mean_and_95ci(result[i][\"ba_no_explanation\"]) if i == 0 else None for i in range(len(result))]\n",
    "\n",
    "ra_null = [mean_and_95ci(result[0][\"ra_null\"])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddef1e0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import norm\n",
    "import numpy as np\n",
    "\n",
    "# Prepare data for plotting\n",
    "# We want to plot: ra_x_yh, ra_x, ra_yh, ba_explanation, ba_no_explanation, ra_null\n",
    "legend_order = [\n",
    "    # \"Human\",\n",
    "    \"RA (w/ prior)\",\n",
    "    \"RA (w/ y^H)\",\n",
    "    \"RA (w/ y^AI)\",\n",
    "    \"RA (w/ y^H + y^AI)\",\n",
    "    \"RA (w/ y^AI + expl)\",\n",
    "    \"RA (w/ y^H + y^AI + expl)\",\n",
    "    \"RA (w/ x)\",\n",
    "    # \"RA (w/ x+y^H)\"\n",
    "]\n",
    "\n",
    "values_for_plot = [\n",
    "    # (\"RA (w/ x+y^H)\", ra_x_yh),\n",
    "    (\"RA (w/ x)\", ra_x),\n",
    "    # (\"RA (w/ y^H)\", ra_yh),\n",
    "    # (\"RA (w/ y^AI)\", ra_yai),\n",
    "    (\"RA (w/ y^AI + expl)\", ra_expl),\n",
    "    # (\"RA (w/ y^H + expl)\", ra_yh_expl),\n",
    "    # (\"Human\", ba),\n",
    "    (\"RA (w/ prior)\", ra_null),\n",
    "    (\"BA (w/ expl)\", ba_explanation),\n",
    "    (\"BA (w/o expl)\", ba_no_explanation)\n",
    "    # (\"RA (w/ y^H + y^AI)\", ra_yai_yh),\n",
    "    # (\"RA (w/ y^H + y^AI + expl)\", ra_yh_expl_yai)\n",
    "]\n",
    "\n",
    "\n",
    "values_for_plot2 = [\n",
    "    # (\"RA (w/ x+y^H)\", ra_x_yh),\n",
    "    # (\"RA (w/ x)\", ra_x),\n",
    "    (\"RA (w/ x)\", ra_x),\n",
    "    (\"RA (w/ y^H)\", ra_yh),\n",
    "    # (\"RA (w/ y^AI)\", ra_yai),\n",
    "    # (\"RA (w/ y^AI + expl)\", ra_expl_yai),\n",
    "    # (\"RA (w/ y^H + expl)\", ra_yh_expl),\n",
    "    # (\"Human\", ba),\n",
    "    # (\"RA (w/ prior)\", ra_null),\n",
    "    # (\"RA (w/ y^H + y^AI)\", ra_yh_yai),\n",
    "    (\"RA (w/ y^H + y^AI + expl)\", ra_yh_expl),\n",
    "    # (\"BA (w/ expl)\", ba_explanation),\n",
    "    # (\"BA (w/o expl)\", ba_no_explanation)\n",
    "]\n",
    "\n",
    "# Use ColorBrewer colors for consistency and accessibility.\n",
    "# From ColorBrewer: Set1 (colorblind safe/sequential)\n",
    "color_map = {\n",
    "    # \"RA (w/ x+y^H)\": \"#377eb8\",          # blue\n",
    "    \"RA (w/ x)\": \"#e41a1c\",              # red\n",
    "    \"RA (w/ x+y^H)\": \"#999999\",              # grey\n",
    "    \"RA (w/ y^H)\": \"#4daf4a\",            # green\n",
    "    \"RA (w/ y^H + y^AI)\": \"#377eb8\",     # blue\n",
    "    \"RA (w/ y^H + y^AI + expl)\": \"#ff7f00\",     # orange\n",
    "    # \"Human\": \"#984ea3\",            # purple\n",
    "    \"RA (w/ prior)\": \"#a65628\",      # brown\n",
    "    \"RA (w/ y^AI)\": \"#f781bf\",     # pink\n",
    "    \"RA (w/ y^AI + expl)\": \"#984ea3\",     # grey\n",
    "    \"BA (w/o expl)\": \"#999999\",              # grey\n",
    "    \"BA (w/ expl)\": \"#999900\"              # yellow\n",
    "}\n",
    "num_conditions = 6\n",
    "num_rows = len(values_for_plot)\n",
    "\n",
    "fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(6, 4))\n",
    "\n",
    "# axes is a 1D array of the two Axes objects\n",
    "ax1 = axes[0]\n",
    "ax2 = axes[1]\n",
    "\n",
    "# fig, ax = plt.subplots(figsize=(6,3))\n",
    "\n",
    "y_labels = []\n",
    "y_ticks = []\n",
    "y_offset = 0\n",
    "\n",
    "# Offset for points and errorbars within each group along y-axis\n",
    "group_width = 0.7\n",
    "single_offset = group_width / 2\n",
    "for y_tick in range(num_conditions):\n",
    "    ax1.axhline(y=y_tick, color='#101010', linestyle='-', linewidth=1, alpha=0.5)\n",
    "\n",
    "ax1.axvline(x=np.mean(ra_null[0]), color=color_map[\"RA (w/ prior)\"], linestyle='--', linewidth=1)\n",
    "# ax.axvline(x=np.mean(ra_yh[0]), color=color_map[\"RA (w/ y^H)\"], linestyle='--', linewidth=1)\n",
    "# ax.axvline(x=np.mean(ra_yai_yh[0]), color=color_map[\"RA (w/ y^H + y^AI)\"], linestyle='--', linewidth=1)\n",
    "ax1.axvline(x=np.mean(ra_x[0]), color=color_map[\"RA (w/ x)\"], linestyle='--', linewidth=1)\n",
    "# ax1.axvline(x=np.mean(ra_yai[0]), color=color_map[\"RA (w/ y^AI)\"], linestyle='--', linewidth=1)\n",
    "ax1.axvline(x=np.mean(ba_no_explanation[0]), color=color_map[\"BA (w/o expl)\"], linestyle='--', linewidth=1)\n",
    "\n",
    "for y_tick in range(num_conditions):\n",
    "    ax2.axhline(y=y_tick, color='#101010', linestyle='-', linewidth=1, alpha=0.5)\n",
    "\n",
    "# ax2.axvline(x=np.mean(ra_null[0]), color=color_map[\"RA (w/ prior)\"], linestyle='--', linewidth=1)\n",
    "ax2.axvline(x=np.mean(ra_yh[0]), color=color_map[\"RA (w/ y^H)\"], linestyle='--', linewidth=1)\n",
    "# ax2.axvline(x=np.mean(ra_yh_yai[0]), color=color_map[\"RA (w/ y^H + y^AI)\"], linestyle='--', linewidth=1)\n",
    "# ax2.axvline(x=np.mean(ra_x[0]), color=color_map[\"RA (w/ x)\"], linestyle='--', linewidth=1)\n",
    "ax2.axvline(x=np.mean(ra_x[0]), color=color_map[\"RA (w/ x)\"], linestyle='--', linewidth=1)\n",
    "ax2.axvline(x=np.mean(ba_no_explanation[0]), color=color_map[\"BA (w/o expl)\"], linestyle='--', linewidth=1)\n",
    "\n",
    "for i, (name, data) in enumerate(values_for_plot):\n",
    "    # for each metric to plot\n",
    "    y = []\n",
    "    y_lowers = []\n",
    "    y_uppers = []\n",
    "    x = []\n",
    "    # is_ci = len(data[0]) == 3 if not name.startswith(\"RA (prior only)\") else False\n",
    "\n",
    "    for j in range(num_conditions):\n",
    "        if j >= len(data) or data[j] is None:\n",
    "            continue\n",
    "        mean, lower, upper = data[j]\n",
    "        y_val = j\n",
    "        x.append(mean)\n",
    "        y.append(y_val)\n",
    "        y_labels.append(f\"{ name }\")\n",
    "        y_ticks.append(y_val)\n",
    "        # Plot the normal distribution centered at mean with std estimated from CI width\n",
    "        \n",
    "\n",
    "        ci_width = upper - lower\n",
    "        if ci_width > 1e-3:\n",
    "            std = ci_width / (2 * 1.96)  # approximate std assuming 95% CI\n",
    "\n",
    "            x_range = np.linspace(max(0, mean - 4*std), min(1.0, mean + 4*std), 200)\n",
    "            y_norm = norm.pdf(x_range, loc=mean, scale=std)\n",
    "            y_norm = y_norm / y_norm.max() * (group_width / 2.5)  # scale height for visual clarity\n",
    "            ax1.fill_between(x_range, y_val, y_val + y_norm, alpha=0.7, color=color_map[name], label=None)\n",
    "        \n",
    "\n",
    "        ax1.errorbar([mean], [y_val], xerr=[[mean - lower], [upper - mean]], fmt='o', label=name, capsize=0, color=color_map[name], elinewidth=2, ms=3)\n",
    "\n",
    "# Set y-ticks and labels\n",
    "ax1.set_yticks(range(num_conditions))\n",
    "\n",
    "\n",
    "\n",
    "for i, (name, data) in enumerate(values_for_plot2):\n",
    "    # for each metric to plot\n",
    "    y = []\n",
    "    y_lowers = []\n",
    "    y_uppers = []\n",
    "    x = []\n",
    "    # is_ci = len(data[0]) == 3 if not name.startswith(\"RA (prior only)\") else False\n",
    "\n",
    "    for j in range(num_conditions):\n",
    "        if j >= len(data) or data[j] is None:\n",
    "            continue\n",
    "        mean, lower, upper = data[j]\n",
    "        y_val = j\n",
    "        x.append(mean)\n",
    "        y.append(y_val)\n",
    "        y_labels.append(f\"{ name }\")\n",
    "        y_ticks.append(y_val)\n",
    "        # Plot the normal distribution centered at mean with std estimated from CI width\n",
    "        \n",
    "\n",
    "        ci_width = upper - lower\n",
    "        if ci_width > 1e-3:\n",
    "            std = ci_width / (2 * 1.96)  # approximate std assuming 95% CI\n",
    "\n",
    "            x_range = np.linspace(max(0, mean - 4*std), min(1.0, mean + 4*std), 200)\n",
    "            y_norm = norm.pdf(x_range, loc=mean, scale=std)\n",
    "            y_norm = y_norm / y_norm.max() * (group_width / 2.5)  # scale height for visual clarity\n",
    "            ax2.fill_between(x_range, y_val, y_val + y_norm, alpha=0.7, color=color_map[name], label=None)\n",
    "        \n",
    "\n",
    "        ax2.errorbar([mean], [y_val], xerr=[[mean - lower], [upper - mean]], fmt='o', label=name, capsize=0, color=color_map[name], elinewidth=2, ms=3)\n",
    "\n",
    "# Set y-ticks and labels\n",
    "ax2.set_yticks(range(num_conditions))\n",
    "\n",
    "# Create a legend (avoid duplicate labels)\n",
    "handles, labels = ax1.get_legend_handles_labels()\n",
    "seen = set()\n",
    "unique = [(h,l) for h,l in zip(handles, labels) if not (l in seen or seen.add(l))]\n",
    "# Desired legend order\n",
    "\n",
    "# Reorder unique to match legend_order\n",
    "label_to_handle = {l: h for h, l in unique}\n",
    "ordered_handles_labels = [(label_to_handle[l], l) for l in legend_order if l in label_to_handle]\n",
    "ordered_handles, ordered_labels = zip(*ordered_handles_labels)\n",
    "\n",
    "# ax1.legend(\n",
    "#     ordered_handles,\n",
    "#     ordered_labels,\n",
    "#     loc=\"upper center\",\n",
    "#     bbox_to_anchor=(0.5, 1.1),\n",
    "#     ncol=6,\n",
    "#     frameon=False,\n",
    "#     borderaxespad=0.01,\n",
    "#     handletextpad=0.2,\n",
    "#     labelspacing=0.01\n",
    "# )\n",
    "\n",
    "\n",
    "# Create a legend (avoid duplicate labels)\n",
    "handles, labels = ax2.get_legend_handles_labels()\n",
    "seen = set()\n",
    "unique = [(h,l) for h,l in zip(handles, labels) if not (l in seen or seen.add(l))]\n",
    "# Desired legend order\n",
    "\n",
    "# Reorder unique to match legend_order\n",
    "label_to_handle = {l: h for h, l in unique}\n",
    "ordered_handles_labels = [(label_to_handle[l], l) for l in legend_order if l in label_to_handle]\n",
    "ordered_handles, ordered_labels = zip(*ordered_handles_labels)\n",
    "# ax2.legend(\n",
    "#     ordered_handles,\n",
    "#     ordered_labels,\n",
    "#     loc=\"upper center\",\n",
    "#     bbox_to_anchor=(0.5, 1.1),\n",
    "#     ncol=6,\n",
    "#     frameon=False,\n",
    "#     borderaxespad=0.01,\n",
    "#     handletextpad=0.2,\n",
    "#     labelspacing=0.01\n",
    "# )\n",
    "\n",
    "# Remove the top, left, and right border lines (spines) of the plot\n",
    "ax1.spines['top'].set_visible(False)\n",
    "ax1.spines['left'].set_visible(False)\n",
    "ax1.spines['right'].set_visible(False)\n",
    "ax2.spines['top'].set_visible(False)\n",
    "ax2.spines['left'].set_visible(False)\n",
    "ax2.spines['right'].set_visible(False)\n",
    "\n",
    "ax2.set_xlim(0.495, 1)\n",
    "\n",
    "ax2.set_xlabel(\"Accuracy\")\n",
    "ax2.set_yticklabels([\"No expl.\", \"SAE-top-1\", \"SAE-top-2\", \"SAE-top-3\", \"SAE-top-4\", \"SAE-top-5\"])\n",
    "# ax.set_title(\"Comparison of Metrics Across Experiment Conditions (with 95% CI)\")\n",
    "\n",
    "ax1.sharex(ax2)\n",
    "ax1.set_yticklabels([\"No expl.\", \"SAE-top-1\", \"SAE-top-2\", \"SAE-top-3\", \"SAE-top-4\", \"SAE-top-5\"])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('alignment_audit_explanation_means.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b78db72",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from dataclasses import dataclass\n",
    "from typing import Any, Dict, List, Optional, Tuple\n",
    "\n",
    "from sklearn.base import clone\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class ATEResult:\n",
    "    ate: float\n",
    "    se: float\n",
    "    ci95: Tuple[float, float]\n",
    "    details: Dict[str, Any]\n",
    "\n",
    "\n",
    "def _clip_ps(ps: np.ndarray, eps: float = 1e-3) -> np.ndarray:\n",
    "    return np.clip(ps, eps, 1 - eps)\n",
    "\n",
    "\n",
    "def _robust_se(phi: np.ndarray) -> float:\n",
    "    n = phi.shape[0]\n",
    "    return float(np.sqrt(np.var(phi, ddof=1) / n))\n",
    "\n",
    "\n",
    "def _ci95(est: float, se: float) -> Tuple[float, float]:\n",
    "    z = 1.96\n",
    "    return (est - z * se, est + z * se)\n",
    "\n",
    "\n",
    "def aipw_ate_text_binary(\n",
    "    y: np.ndarray,\n",
    "    t: np.ndarray,\n",
    "    texts: List[str],\n",
    "    *,\n",
    "    outcome_model: Optional[Any] = None,\n",
    "    ps_model: Optional[Any] = None,\n",
    "    n_splits: int = 5,\n",
    "    clip_eps: float = 1e-3,\n",
    "    random_state: int = 0,\n",
    ") -> ATEResult:\n",
    "    \"\"\"\n",
    "    Doubly robust AIPW estimate of ATE for binary outcome Y and binary treatment T,\n",
    "    with text covariates X (raw texts). Uses cross-fitting by default.\n",
    "\n",
    "    Returns ATE on the probability scale: E[Y(1) - Y(0)].\n",
    "    \"\"\"\n",
    "    y = np.asarray(y).reshape(-1).astype(int)\n",
    "    t = np.asarray(t).reshape(-1).astype(int)\n",
    "\n",
    "    if len(texts) != len(y):\n",
    "        raise ValueError(\"texts must have the same length as y.\")\n",
    "    if not set(np.unique(y)).issubset({0, 1}):\n",
    "        raise ValueError(\"y must be binary in {0,1}.\")\n",
    "    if not set(np.unique(t)).issubset({0, 1}):\n",
    "        raise ValueError(\"t must be binary in {0,1}.\")\n",
    "\n",
    "    # Default: TF-IDF -> LogisticRegression for both nuisance models\n",
    "    if outcome_model is None:\n",
    "        outcome_model = Pipeline([\n",
    "            (\"tfidf\", TfidfVectorizer(ngram_range=(1, 2), min_df=2, max_df=0.95)),\n",
    "            (\"clf\", LogisticRegression(max_iter=5000))\n",
    "        ])\n",
    "    if ps_model is None:\n",
    "        ps_model = Pipeline([\n",
    "            (\"tfidf\", TfidfVectorizer(ngram_range=(1, 2), min_df=2, max_df=0.95)),\n",
    "            (\"clf\", LogisticRegression(max_iter=5000))\n",
    "        ])\n",
    "\n",
    "    n = len(y)\n",
    "    mu1_hat = np.zeros(n, dtype=float)  # P(Y=1 | T=1, X)\n",
    "    mu0_hat = np.zeros(n, dtype=float)  # P(Y=1 | T=0, X)\n",
    "    e_hat = np.zeros(n, dtype=float)    # P(T=1 | X)\n",
    "\n",
    "    kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)\n",
    "    idx = np.arange(n)\n",
    "\n",
    "    for tr, te in kf.split(idx):\n",
    "        texts_tr = [texts[i] for i in tr]\n",
    "        texts_te = [texts[i] for i in te]\n",
    "        y_tr, t_tr = y[tr], t[tr]\n",
    "\n",
    "        # 1) propensity model e(x)\n",
    "        ps = clone(ps_model).fit(texts_tr, t_tr)\n",
    "        e_hat[te] = ps.predict_proba(texts_te)[:, 1]\n",
    "\n",
    "        # 2) outcome models mu1(x), mu0(x) fitted on their respective groups\n",
    "        tr1 = tr[t_tr == 1]\n",
    "        tr0 = tr[t_tr == 0]\n",
    "        if len(tr1) == 0 or len(tr0) == 0:\n",
    "            raise ValueError(\"A fold has no treated or no control samples. Reduce n_splits or check data balance.\")\n",
    "\n",
    "        om1 = clone(outcome_model).fit([texts[i] for i in tr1], y[tr1])\n",
    "        om0 = clone(outcome_model).fit([texts[i] for i in tr0], y[tr0])\n",
    "\n",
    "        mu1_hat[te] = om1.predict_proba(texts_te)[:, 1]\n",
    "        mu0_hat[te] = om0.predict_proba(texts_te)[:, 1]\n",
    "\n",
    "    e_hat = _clip_ps(e_hat, clip_eps)\n",
    "\n",
    "    # AIPW influence function\n",
    "    phi = (mu1_hat - mu0_hat) + (t * (y - mu1_hat) / e_hat) - ((1 - t) * (y - mu0_hat) / (1 - e_hat))\n",
    "    ate = float(np.mean(phi))\n",
    "    se = _robust_se(phi - ate)\n",
    "    return ATEResult(ate=ate, se=se, ci95=_ci95(ate, se),\n",
    "                     details={\"e_hat\": e_hat, \"mu1_hat\": mu1_hat, \"mu0_hat\": mu0_hat, \"phi\": phi})\n",
    "\n",
    "\n",
    "\n",
    "ate_result = {}\n",
    "\n",
    "for i in unique_number_of_features:\n",
    "    number_df = df[df[\"number_of_features\"] == i]\n",
    "    x = number_df[\"input_text\"].tolist() * 2\n",
    "    y = (number_df[\"toxic_model\"] == number_df[\"is_toxic_training_data_without_sae\"]).tolist() + (number_df[\"toxic_model\"] == number_df[\"is_toxic_training_data_with_sae\"]).tolist()\n",
    "    t = [0] * len(number_df) + [1] * len(number_df)\n",
    "    ate_result[i] = aipw_ate_text_binary(y, t, x, n_splits=5)\n",
    "ate_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b284b8e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
