{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "14d2f594-03a8-4878-9454-42cc5ea3a81f",
   "metadata": {},
   "source": [
    "### Import necessary packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f57ecd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import pyreadstat\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "import xgboost as xgb\n",
    "import scipy\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import os, sys\n",
    "from scipy.stats import norm, bernoulli\n",
    "from ppi_py.datasets import load_dataset\n",
    "import matplotlib.patheffects as pe\n",
    "from utils import make_width_coverage_plot, make_budget_plot\n",
    "import warnings; warnings.simplefilter('ignore')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5cf90ae6",
   "metadata": {},
   "source": [
    "### Import the AlphaFold data set\n",
    "\n",
    "Load the data. The data set contains true indicators of disorder (```Y```), predicted indicators of disorder (```Yhat```), and indicators of a PTM (```phosphorylated```, ```ubiquitinated```, or ```acetylated```). Predictions of disorder are made based on AlphaFold predictions of structure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6da3138",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_folder = \"./alphafold/data/\"\n",
    "data = load_dataset(dataset_folder, \"alphafold\")\n",
    "Y_total = data[\"Y\"]\n",
    "Yhat_total = data[\"Yhat\"]\n",
    "Z = data[\"phosphorylated\"].astype(bool)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8969f9db",
   "metadata": {},
   "source": [
    "### Problem setup\n",
    "\n",
    "Compute ground-truth value of the odds ratio. Specify range of budgets, error level $\\alpha$, and number of trials."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b3c8f29",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split into two subgroups depending on whether phosphorylated \n",
    "Y0, Y1 = Y_total[~Z], Y_total[Z]\n",
    "Yhat0, Yhat1 = Yhat_total[~Z], Yhat_total[Z]\n",
    "n0 = Y0.shape[0]\n",
    "n1 = Y1.shape[0]\n",
    "n = len(Y_total)\n",
    "\n",
    "# True odds ratio\n",
    "mu0 = Y0.mean()\n",
    "mu1 = Y1.mean()\n",
    "odds_ratio = (mu1 / (1 - mu1)) / (mu0 / (1 - mu0))\n",
    "\n",
    "# Effort control parameters\n",
    "e_min = 0.8  # Target minimum effort level\n",
    "q_e = lambda e: e  # Correction probability q(e) = e\n",
    "c_e = lambda e: 0.5 * e**2  # Effort cost c(e) = e^2/2\n",
    "\n",
    "# Compute per-instance error probabilities from model predictions\n",
    "# For binary outcome: p_i = 1 - (Yhat * Y + (1-Yhat) * (1-Y))\n",
    "Y_binary = Y_total.astype(int)  # Ensure Y is binary {0, 1}\n",
    "p_correct = Yhat_total * Y_binary + (1 - Yhat_total) * (1 - Y_binary)\n",
    "p_i_total = 1.0 - p_correct\n",
    "p_i_total = np.clip(p_i_total, 0.01, 0.99)  # Clip to avoid extreme values\n",
    "\n",
    "# Compute per-instance payments needed to sustain e_min effort\n",
    "# w_i = e_min / p_i (based on linear utility assumption)\n",
    "w_i_total = e_min / p_i_total\n",
    "\n",
    "# Subgroup specific costs and difficulties\n",
    "p_i0 = p_i_total[~Z]\n",
    "p_i1 = p_i_total[Z]\n",
    "w_i0 = w_i_total[~Z]\n",
    "w_i1 = w_i_total[Z]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5419048-6d75-46fd-9166-89a2c811cf6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "budgets = np.linspace(0.01, 0.2, 20)\n",
    "alpha = 0.1\n",
    "num_trials = 200"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51177a3c",
   "metadata": {},
   "source": [
    "### Main experiment\n",
    "\n",
    "Forms dataframe `df` with experiment results. The columns in the dataframe are:\n",
    "\n",
    "- `lb` - interval lower bound\n",
    "- `ub` - interval upper bound\n",
    "- `interval width` - equal to `ub` - `lb`\n",
    "- `coverage` - 0/1 indicator of whether or not interval covered target\n",
    "- `estimator` - one of `classical`, `uniform`, `active`, or `incentive_robust`\n",
    "- `budget` - budget size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fb6cb6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "# Ensure data is numpy array to avoid pandas indexing issues in loops\n",
    "Y0, Y1 = np.array(Y0), np.array(Y1)\n",
    "p_i0, p_i1 = np.array(p_i0), np.array(p_i1)\n",
    "w_i0, w_i1 = np.array(w_i0), np.array(w_i1)\n",
    "Yhat0, Yhat1 = np.array(Yhat0), np.array(Yhat1)\n",
    "\n",
    "results = []\n",
    "columns = [\"lb\", \"ub\", \"interval width\", \"coverage\", \"estimator\", \"budget\"]\n",
    "# Initialize as empty DataFrame with object dtype to allow mixing numeric and strings\n",
    "temp_df = pd.DataFrame(columns=columns)\n",
    "\n",
    "# Robust Parameters\n",
    "rho_sentinel = 0.1\n",
    "k_sentinel = 1.0\n",
    "b_optimal = e_min / rho_sentinel\n",
    "w_0_sentinel = (b_optimal * rho_sentinel) ** 2\n",
    "cost_per_sample_robust = w_0_sentinel + rho_sentinel * k_sentinel\n",
    "e_robust = np.sqrt(w_0_sentinel)\n",
    "\n",
    "for j in tqdm(range(len(budgets)), desc=\"Budgets\"):\n",
    "    budget = budgets[j]\n",
    "    \n",
    "    # Baseline setup: Standard Active Learning (Mixed with Uniform)\n",
    "    uncertainty0 = np.minimum(Yhat0, 1-Yhat0)\n",
    "    uncertainty1 = np.minimum(Yhat1, 1-Yhat1)\n",
    "    \n",
    "    eta0 = budget / np.maximum(np.mean(uncertainty0), 1e-8)\n",
    "    probs0_active_base = eta0 * uncertainty0\n",
    "    \n",
    "    eta1 = budget / np.maximum(np.mean(uncertainty1), 1e-8)\n",
    "    probs1_active_base = eta1 * uncertainty1\n",
    "    \n",
    "    # Mixing parameter\n",
    "    tau = 0.5\n",
    "    probs0_active = np.clip((1-tau)*probs0_active_base + tau*budget, 0, 1)\n",
    "    probs1_active = np.clip((1-tau)*probs1_active_base + tau*budget, 0, 1)\n",
    "    \n",
    "    # Calculate target spending for fairness (based on active sampling cost)\n",
    "    target_spending = np.sum(probs0_active * w_i0) + np.sum(probs1_active * w_i1)\n",
    "    \n",
    "    # Robust setup: Variance-based sampling weights\n",
    "    ell_hat0 = Yhat0 * (1 - Yhat0)\n",
    "    ell_hat1 = Yhat1 * (1 - Yhat1)\n",
    "    pi_robust_unscaled0 = np.sqrt(ell_hat0)\n",
    "    pi_robust_unscaled1 = np.sqrt(ell_hat1)\n",
    "    \n",
    "    expected_cost_robust_base = np.sum(pi_robust_unscaled0 * cost_per_sample_robust) + \\\n",
    "                                np.sum(pi_robust_unscaled1 * cost_per_sample_robust)\n",
    "    \n",
    "    if expected_cost_robust_base > 0:\n",
    "        scaling_factor_robust = target_spending / expected_cost_robust_base\n",
    "    else:\n",
    "        scaling_factor_robust = 1.0\n",
    "        \n",
    "    probs_robust0 = np.clip(pi_robust_unscaled0 * scaling_factor_robust, 0, 1)\n",
    "    probs_robust1 = np.clip(pi_robust_unscaled1 * scaling_factor_robust, 0, 1)\n",
    "    \n",
    "    for i in range(num_trials):\n",
    "        # We start with a clean temp_df structure for this trial\n",
    "        trial_data = []\n",
    "        \n",
    "        # Correction probability\n",
    "        q_val = q_e(e_min)\n",
    "\n",
    "        # 1. Active Estimator\n",
    "        xi0 = bernoulli.rvs(probs0_active)\n",
    "        xi1 = bernoulli.rvs(probs1_active)\n",
    "        \n",
    "        # Simulate Labels (Group 0) - Active\n",
    "        Y_label0 = np.zeros_like(Y0, dtype=float)\n",
    "        # Vectorized simulation for speed and safety\n",
    "        idx_sampled0 = np.where(xi0 == 1)[0]\n",
    "        if len(idx_sampled0) > 0:\n",
    "             err_prob = p_i0[idx_sampled0]\n",
    "             corr_prob = 1 - err_prob * (1 - q_e(e_min))\n",
    "             is_correct = bernoulli.rvs(corr_prob)\n",
    "             Y_label0[idx_sampled0] = np.where(is_correct, Y0[idx_sampled0], 1 - Y0[idx_sampled0])\n",
    "        \n",
    "        # Simulate Labels (Group 1) - Active\n",
    "        Y_label1 = np.zeros_like(Y1, dtype=float)\n",
    "        idx_sampled1 = np.where(xi1 == 1)[0]\n",
    "        if len(idx_sampled1) > 0:\n",
    "             err_prob = p_i1[idx_sampled1]\n",
    "             corr_prob = 1 - err_prob * (1 - q_e(e_min))\n",
    "             is_correct = bernoulli.rvs(corr_prob)\n",
    "             Y_label1[idx_sampled1] = np.where(is_correct, Y1[idx_sampled1], 1 - Y1[idx_sampled1])\n",
    "\n",
    "        # PPI estimator with control variate (unbiased by q_val division)\n",
    "        \n",
    "        term0 = (Y_label0 - Yhat0) * xi0 / (np.maximum(probs0_active, 1e-8) * q_val)\n",
    "        mu0_hat = np.mean(Yhat0 + term0)\n",
    "        var_mu0_hat = np.var(Yhat0 + term0) / n0\n",
    "        \n",
    "        term1 = (Y_label1 - Yhat1) * xi1 / (np.maximum(probs1_active, 1e-8) * q_val)\n",
    "        mu1_hat = np.mean(Yhat1 + term1)\n",
    "        var_mu1_hat = np.var(Yhat1 + term1) / n1\n",
    "        \n",
    "        pointest_log = np.log(mu1_hat/(1-mu1_hat)) - np.log(mu0_hat/(1-mu0_hat))\n",
    "        var0_term = var_mu0_hat / ((mu0_hat * (1-mu0_hat))**2)\n",
    "        var1_term = var_mu1_hat / ((mu1_hat * (1-mu1_hat))**2)\n",
    "        width_log = norm.ppf(1-alpha/2) * np.sqrt(var0_term + var1_term)\n",
    "        \n",
    "        l, u = np.exp(pointest_log - width_log), np.exp(pointest_log + width_log)\n",
    "        cov = (odds_ratio >= l) * (odds_ratio <= u)\n",
    "        trial_data.append([l, u, u-l, cov, \"active\", int(budget*n)])\n",
    "\n",
    "        # 2. Uniform Estimator\n",
    "        # Scale to same target spending\n",
    "        prob_unif_val = target_spending / (np.sum(w_i0) + np.sum(w_i1))\n",
    "        prob_unif_val = np.clip(prob_unif_val, 0, 1)\n",
    "        \n",
    "        prob_unif0 = np.ones(n0) * prob_unif_val\n",
    "        prob_unif1 = np.ones(n1) * prob_unif_val\n",
    "        \n",
    "        xi0_unif = bernoulli.rvs(prob_unif0)\n",
    "        xi1_unif = bernoulli.rvs(prob_unif1)\n",
    "        \n",
    "        # Label Simulation - Uniform\n",
    "        Y_label0_unif = np.zeros_like(Y0, dtype=float)\n",
    "        idx_sampled0 = np.where(xi0_unif == 1)[0]\n",
    "        if len(idx_sampled0) > 0:\n",
    "             corr_prob = 1 - p_i0[idx_sampled0]*(1-q_val)\n",
    "             is_correct = bernoulli.rvs(corr_prob)\n",
    "             Y_label0_unif[idx_sampled0] = np.where(is_correct, Y0[idx_sampled0], 1 - Y0[idx_sampled0])\n",
    "\n",
    "        Y_label1_unif = np.zeros_like(Y1, dtype=float)\n",
    "        idx_sampled1 = np.where(xi1_unif == 1)[0]\n",
    "        if len(idx_sampled1) > 0:\n",
    "             corr_prob = 1 - p_i1[idx_sampled1]*(1-q_val)\n",
    "             is_correct = bernoulli.rvs(corr_prob)\n",
    "             Y_label1_unif[idx_sampled1] = np.where(is_correct, Y1[idx_sampled1], 1 - Y1[idx_sampled1])\n",
    "\n",
    "        # Uniform estimator: divide by both prob and q_val\n",
    "        term0 = (Y_label0_unif - Yhat0) * xi0_unif / (np.maximum(prob_unif0, 1e-8) * q_val)\n",
    "        mu0_hat = np.mean(Yhat0 + term0)\n",
    "        var_mu0_hat = np.var(Yhat0 + term0) / n0\n",
    "        \n",
    "        term1 = (Y_label1_unif - Yhat1) * xi1_unif / (np.maximum(prob_unif1, 1e-8) * q_val)\n",
    "        mu1_hat = np.mean(Yhat1 + term1)\n",
    "        var_mu1_hat = np.var(Yhat1 + term1) / n1\n",
    "        \n",
    "        pointest_log = np.log(mu1_hat/(1-mu1_hat)) - np.log(mu0_hat/(1-mu0_hat))\n",
    "        width_log = norm.ppf(1-alpha/2) * np.sqrt(var_mu0_hat/((mu0_hat*(1-mu0_hat))**2) + var_mu1_hat/((mu1_hat*(1-mu1_hat))**2))\n",
    "        l, u = np.exp(pointest_log - width_log), np.exp(pointest_log + width_log)\n",
    "        cov = (odds_ratio >= l) * (odds_ratio <= u)\n",
    "        trial_data.append([l, u, u-l, cov, \"uniform\", int(budget*n)])\n",
    "\n",
    "        # 3. Classical Estimator (Horvitz-Thompson without control variate)\n",
    "        \n",
    "        # Debias labels for Classical  \n",
    "        Y0_class_debiased = np.zeros(n0)\n",
    "        idx_sampled0_class = np.where(xi0_unif == 1)[0]\n",
    "        if len(idx_sampled0_class) > 0:\n",
    "            acc = 1 - p_i0[idx_sampled0_class]*(1-q_val)\n",
    "            denom = np.maximum(2*acc - 1, 0.01)  # Increased threshold to prevent extreme values\n",
    "            debiased = (Y_label0_unif[idx_sampled0_class] - (1-acc)) / denom\n",
    "            Y0_class_debiased[idx_sampled0_class] = np.clip(debiased, -10, 10)  # Clip to prevent explosion\n",
    "        \n",
    "        Y1_class_debiased = np.zeros(n1)\n",
    "        idx_sampled1_class = np.where(xi1_unif == 1)[0]\n",
    "        if len(idx_sampled1_class) > 0:\n",
    "            acc = 1 - p_i1[idx_sampled1_class]*(1-q_val)\n",
    "            denom = np.maximum(2*acc - 1, 0.1)  # Increased threshold to prevent extreme values\n",
    "            debiased = (Y_label1_unif[idx_sampled1_class] - (1-acc)) / denom\n",
    "            Y1_class_debiased[idx_sampled1_class] = np.clip(debiased, -10, 10)  # Clip to prevent explosion\n",
    "        \n",
    "        # Pure Horvitz-Thompson (no control variate)\n",
    "        mu0_hat = np.mean(Y0_class_debiased * xi0_unif / np.maximum(prob_unif0, 1e-8))\n",
    "        var_mu0_hat = np.var(Y0_class_debiased * xi0_unif / np.maximum(prob_unif0, 1e-8)) / n0\n",
    "        \n",
    "        mu1_hat = np.mean(Y1_class_debiased * xi1_unif / np.maximum(prob_unif1, 1e-8))\n",
    "        var_mu1_hat = np.var(Y1_class_debiased * xi1_unif / np.maximum(prob_unif1, 1e-8)) / n1\n",
    "        \n",
    "        pointest_log = np.log(mu1_hat/(1-mu1_hat)) - np.log(mu0_hat/(1-mu0_hat))\n",
    "        width_log = norm.ppf(1-alpha/2) * np.sqrt(var_mu0_hat/((mu0_hat*(1-mu0_hat))**2) + var_mu1_hat/((mu1_hat*(1-mu1_hat))**2))\n",
    "        l, u = np.exp(pointest_log - width_log), np.exp(pointest_log + width_log)\n",
    "        cov = (odds_ratio >= l) * (odds_ratio <= u)\n",
    "        trial_data.append([l, u, u-l, cov, \"classical\", int(budget*n)])\n",
    "\n",
    "        # 4. Incentive Robust\n",
    "        xi0_rob = bernoulli.rvs(probs_robust0)\n",
    "        xi1_rob = bernoulli.rvs(probs_robust1)\n",
    "        zeta0 = 1 - bernoulli.rvs(rho_sentinel, size=n0)\n",
    "        zeta1 = 1 - bernoulli.rvs(rho_sentinel, size=n1)\n",
    "        q_rob = q_e(e_robust)\n",
    "        \n",
    "        # Label Sim (e_robust) - Group 0\n",
    "        Y_label0_rob = np.zeros_like(Y0, dtype=float)\n",
    "        idx_sampled0 = np.where((xi0_rob == 1) & (zeta0 == 1))[0]\n",
    "        if len(idx_sampled0) > 0:\n",
    "            acc = 1 - p_i0[idx_sampled0]*(1 - q_rob)\n",
    "            is_correct = bernoulli.rvs(acc)\n",
    "            Y_label0_rob[idx_sampled0] = np.where(is_correct, Y0[idx_sampled0], 1 - Y0[idx_sampled0])\n",
    "        \n",
    "        # Label Sim (e_robust) - Group 1\n",
    "        Y_label1_rob = np.zeros_like(Y1, dtype=float)\n",
    "        idx_sampled1 = np.where((xi1_rob == 1) & (zeta1 == 1))[0]\n",
    "        if len(idx_sampled1) > 0:\n",
    "            acc = 1 - p_i1[idx_sampled1]*(1 - q_rob)\n",
    "            is_correct = bernoulli.rvs(acc)\n",
    "            Y_label1_rob[idx_sampled1] = np.where(is_correct, Y1[idx_sampled1], 1 - Y1[idx_sampled1])\n",
    "        \n",
    "        # Estimator\n",
    "        term0 = (Y_label0_rob - Yhat0) * xi0_rob * zeta0 / (np.maximum(probs_robust0, 1e-8) * (1-rho_sentinel) * q_rob)\n",
    "        mu0_hat = np.mean(Yhat0 + term0)\n",
    "        var_mu0_hat = np.var(Yhat0 + term0) / n0\n",
    "        \n",
    "        term1 = (Y_label1_rob - Yhat1) * xi1_rob * zeta1 / (np.maximum(probs_robust1, 1e-8) * (1-rho_sentinel) * q_rob)\n",
    "        mu1_hat = np.mean(Yhat1 + term1)\n",
    "        var_mu1_hat = np.var(Yhat1 + term1) / n1\n",
    "        \n",
    "        pointest_log = np.log(mu1_hat/(1-mu1_hat)) - np.log(mu0_hat/(1-mu0_hat))\n",
    "        width_log = norm.ppf(1-alpha/2) * np.sqrt(var_mu0_hat/((mu0_hat*(1-mu0_hat))**2) + var_mu1_hat/((mu1_hat*(1-mu1_hat))**2))\n",
    "        l, u = np.exp(pointest_log - width_log), np.exp(pointest_log + width_log)\n",
    "        cov = (odds_ratio >= l) * (odds_ratio <= u)\n",
    "        trial_data.append([l, u, u-l, cov, \"incentive_robust\", int(budget*n)])\n",
    "\n",
    "        # Append all results for this trial\n",
    "        # Note: Concatenating strictly compatible types or lists\n",
    "        results.append(pd.DataFrame(trial_data, columns=columns))\n",
    "\n",
    "df = pd.concat(results, ignore_index=True)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c37901e6-9751-4302-aa99-47351481b880",
   "metadata": {},
   "source": [
    "### Plot coverage and interval width"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d30aa8d-2b4c-40a6-9150-aa96c79fe15b",
   "metadata": {},
   "outputs": [],
   "source": [
    "make_width_coverage_plot(df, \"odds ratio\", \"widths_and_coverage_alphafold.pdf\", odds_ratio, n_example_ind = 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdc3568a-8b42-4af2-9be3-3f094a443ca9",
   "metadata": {},
   "outputs": [],
   "source": [
    "make_budget_plot(df, \"AlphaFold\", \"budget_alphafold.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "active-inference-experiments",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
