{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3ae3dcaf-1ace-4127-a0b7-0eef36ccc671",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import seaborn\n",
    "from scipy import stats\n",
    "\n",
    "import numpy as np\n",
    "import river\n",
    "from river import neural_net\n",
    "from river import preprocessing as pp\n",
    "from river import optim, metrics\n",
    "from river.neural_net import activations as act\n",
    "\n",
    "\n",
    "full_dat = pd.read_csv(\"behaghel.csv\")\n",
    "relevant_columns = [\n",
    "                    'sw',\n",
    "                    'A_public',\n",
    "                    'A_private',\n",
    "                    'A_standard',\n",
    "                    'Y',\n",
    "                    'College_education',\n",
    "                    'nivetude2',\n",
    "                    'Vocational',\n",
    "                    'High_school_dropout',\n",
    "                    'Manager',\n",
    "                    'Technician',\n",
    "                    'Skilled_clerical_worker',\n",
    "                    'Unskilled_clerical_worker',\n",
    "                    'Skilled_blue_colar',\n",
    "                    'Unskilled_blue_colar',\n",
    "                    'Woman',\n",
    "                    'Married',\n",
    "                    'French',\n",
    "                    'African',\n",
    "                    'Other_Nationality',\n",
    "                    'Paris_region',\n",
    "                    'North',\n",
    "                    'Other_regions',\n",
    "                    'Employment_component_level_1',\n",
    "                    'Employment_component_level_2',\n",
    "                    'Employment_component_missing',\n",
    "                    'Economic_Layoff',\n",
    "                    'Personnal_Layoff',\n",
    "                    'End_of_Fixed_Term_Contract',\n",
    "                    'End_of_Temporary_Work',\n",
    "                    'Other_reasons_of_unemployment',\n",
    "                    'Statistical_risk_level_2',\n",
    "                    'Statistical_risk_level_3',\n",
    "                    'Other_Statistical_risk',\n",
    "                    'Search_for_a_full_time_position',\n",
    "                    'Sensitive_suburban_area',\n",
    "                    'Insertion',\n",
    "                    'Interim',\n",
    "                    'Conseil',\n",
    "                    'age',\n",
    "                    'Number_of_children',\n",
    "                    'exper',\n",
    "                    'salaire.num',\n",
    "                    'mois_saisie_occ',\n",
    "                    'ndem'\n",
    "                    ]\n",
    "full_dat = full_dat[relevant_columns]\n",
    "#print((full_dat.head()))\n",
    "\n",
    "# label columns as features, outcome, treatment\n",
    "\n",
    "# numerical features\n",
    "Xnum = [\n",
    "  'age',\n",
    "  'Number_of_children',\n",
    "  'exper', # years experience on the job\n",
    "  'salaire.num', # salary target\n",
    "  'mois_saisie_occ', # when assigned\n",
    "  'ndem' # Num. unemployment spell\n",
    "]\n",
    "\n",
    "\n",
    "# categorical features\n",
    "Xbin = [\n",
    "  'College_education',\n",
    "  'nivetude2',\n",
    "  'Vocational',\n",
    "  'High_school_dropout',\n",
    "  'Manager',\n",
    "  'Technician',\n",
    "  'Skilled_clerical_worker',\n",
    "  'Unskilled_clerical_worker',\n",
    "  'Skilled_blue_colar',\n",
    "  'Unskilled_blue_colar',\n",
    "  'Woman',\n",
    "  'Married',\n",
    "  'French',\n",
    "  'African',\n",
    "  'Other_Nationality',\n",
    "  'Paris_region',\n",
    "  'North',\n",
    "  'Other_regions',\n",
    "  'Employment_component_level_1',\n",
    "  'Employment_component_level_2',\n",
    "  'Employment_component_missing',\n",
    "  'Economic_Layoff',\n",
    "  'Personnal_Layoff',\n",
    "  'End_of_Fixed_Term_Contract',\n",
    "  'End_of_Temporary_Work',\n",
    "  'Other_reasons_of_unemployment',\n",
    "  'Statistical_risk_level_2',\n",
    "  'Statistical_risk_level_3',\n",
    "  'Other_Statistical_risk',\n",
    "  'Search_for_a_full_time_position',\n",
    "  'Sensitive_suburban_area',\n",
    "  'Insertion',\n",
    "  'Interim',\n",
    "  'Conseil'\n",
    "]\n",
    "\n",
    "\n",
    "for col in Xnum:\n",
    "    full_dat[col] = full_dat[col].astype(float)\n",
    "\n",
    "for col in Xbin:\n",
    "    full_dat[col] = full_dat[col].astype(\"category\")\n",
    "\n",
    "\n",
    "other_variables = [\"sw\", \"A_public\", \"A_private\", \"A_standard\", \"Y\"]\n",
    "\n",
    "for col in other_variables:\n",
    "    full_dat[col] = full_dat[col].astype(float)\n",
    "\n",
    "#print(full_dat.dtypes)\n",
    "\n",
    "categorical_indices = []\n",
    "\n",
    "for i in range(full_dat.shape[1]):\n",
    "    if (full_dat.columns[i] in Xbin):\n",
    "        categorical_indices.append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8753b698-c274-42ee-8d85-c42a186e9a9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# construct bootstrap sampling function\n",
    "### here, we return pseudooutcomes + covariate matrix \n",
    "\n",
    "\n",
    "def row_to_dict(x_row: np.ndarray):\n",
    "    return {f\"x{i}\": float(x_row[i]) for i in range(x_row.shape[0])}\n",
    "\n",
    "def clip02(v: float, eps: float = 0.01) -> float:\n",
    "    return float(min(1.0 - eps, max(eps, v)))\n",
    "\n",
    "\n",
    "def bootstrap_sample(n, \n",
    "                     df, \n",
    "                     random_seed = 42, \n",
    "                     k = 4, \n",
    "                     noise_std=1/4, \n",
    "                     bound = 100, \n",
    "                     n_cols = None,\n",
    "                     Xbin = [], \n",
    "                     Xnum = [], \n",
    "                     categorical_indices = []):\n",
    "\n",
    "   # resample based on sample weights\n",
    "    bs_sample =  df.sample(n = n, replace = True, weights = df[\"sw\"], random_state = random_seed)\n",
    "\n",
    "   # fit pseudooutcomes based on training data\n",
    "    A = np.array(bs_sample[\"A_public\"])\n",
    "\n",
    "    ## add noise to Y to preserve difficulty\n",
    "    Y = np.array(bs_sample[\"Y\"])\n",
    "    np.random.seed(random_seed)\n",
    "    noise = np.maximum(-bound, np.minimum(np.random.normal(size=len(Y), loc=0, scale = noise_std), bound) )\n",
    "    Y = np.array(bs_sample[\"Y\"]) + noise\n",
    "\n",
    "\n",
    "    ### randomly sample 15 columns of X\n",
    "    X = np.array(bs_sample[np.concatenate((Xbin, Xnum))])\n",
    "\n",
    "    if n_cols is not None:\n",
    "        rng = np.random.default_rng(seed=random_seed)\n",
    "        idx = np.arange(15)\n",
    "        cols = rng.choice(idx, size=15, replace=False)\n",
    "        X = np.array(X[:, cols])\n",
    "    \n",
    "    X_stream = np.column_stack((A, X)) \n",
    "\n",
    "\n",
    "   ## fit neural network sequentially to get predictions for g_0, g_1\n",
    "    mlp = pp.StandardScaler() | neural_net.MLPRegressor(\n",
    "        hidden_dims=(64, 64, 64, 64),\n",
    "        activations=(act.ReLU(), act.ReLU(), act.ReLU(), act.ReLU(), act.Sigmoid()),\n",
    "        optimizer=optim.Adam(1e-3),\n",
    "        seed=0)\n",
    "    mu_1_est = np.full(n, np.nan)\n",
    "    mu_0_est = np.full(n, np.nan)\n",
    "\n",
    "    for t in range(n):\n",
    "        ## predict based on one unified neural network (S-Learner)\n",
    "\n",
    "        true_x = row_to_dict(X_stream[t])\n",
    "\n",
    "        ## get fake samples with 1 and 0 for the first entry\n",
    "        x_0 = X_stream[t].copy()\n",
    "        x_0[0] = 0\n",
    "        x_0 = row_to_dict(x_0)\n",
    "\n",
    "        x_1 = X_stream[t].copy()\n",
    "        x_1[0] = 1\n",
    "        x_1 = row_to_dict(x_1)\n",
    "                \n",
    "        y = float(Y[t])\n",
    "\n",
    "        yhat_0 = mlp.predict_one(x_0)\n",
    "        yhat_1 = mlp.predict_one(x_1)\n",
    "        y_hat = mlp.predict_one(true_x)\n",
    "\n",
    "        if yhat_0 is not None:\n",
    "            #yhat_0 = clip02(yhat_0, 0)\n",
    "            mu_0_est[t] = yhat_0\n",
    "\n",
    "        if yhat_1 is not None:\n",
    "            #yhat_1 = clip02(yhat_1, 0)\n",
    "            mu_1_est[t] = yhat_1\n",
    "\n",
    "        mlp.learn_one(true_x, y)\n",
    "\n",
    "    #print(mu_1_est)\n",
    "    #print(mu_0_est)\n",
    "\n",
    "    \n",
    "    d = {\"mu_1_est\": mu_1_est,\n",
    "        \"mu_0_est\": mu_0_est,\n",
    "        \"outcome\": Y,\n",
    "        \"treatment\": A,\n",
    "        \"propensity\": np.mean(A) # assuming complete randomization with fixed probability of treatment identical for everyone\n",
    "       }\n",
    "    d = pd.DataFrame(data=d)\n",
    "\n",
    "    ###  construct EIF for each datapoint\n",
    "    pi_1 = d[\"propensity\"] \n",
    "    pi_0 = 1-pi_1\n",
    "    g_1 = d[\"mu_1_est\"]\n",
    "    g_0 = d[\"mu_0_est\"]\n",
    "    Y = d[\"outcome\"]\n",
    "    A = d[\"treatment\"]\n",
    "\n",
    "    np.random.seed(random_seed)\n",
    "    ## generate random noise to eif outcomes\n",
    "    noise = np.maximum(-bound, np.minimum(np.random.normal(size=len(Y), loc=0, scale = noise_std), bound) )\n",
    "    Y_eif = g_1 + (A==1)*(Y-g_1)/pi_1 - (g_0 + (A==0)*(Y-g_0)/pi_0) + noise ## eif\n",
    "    \n",
    "    return X, X_stream, Y, Y_eif"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "83067b74-ff8b-450a-b11e-aab09b3c510f",
   "metadata": {},
   "outputs": [],
   "source": [
    "## assume complete randomization\n",
    "X,X_w_A, Y, Y_eif = bootstrap_sample(n=10000, df=full_dat, Xbin=Xbin, Xnum = Xnum, random_seed = 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "09453cfb-5aed-421b-9fb6-841e7ff7dfc1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 40)\n",
      "(10000, 41)\n",
      "0.0923\n"
     ]
    }
   ],
   "source": [
    "print(X.shape)\n",
    "print(X_w_A.shape)\n",
    "print(np.mean(X_w_A[:,0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88e8af7d-c1df-4ba5-9b96-006e56c5a965",
   "metadata": {},
   "source": [
    "## F Test for Treatment Effect and Interaction Terms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "c92d1f8b-2dae-4970-818b-f00ebaf064e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def _hatvalues(X, XtX_inv=None):\n",
    "    # h_i = x_i^T (X^T X)^{-1} x_i\n",
    "    if XtX_inv is None:\n",
    "        XtX_inv = np.linalg.inv(X.T @ X)\n",
    "    return np.sum((X @ XtX_inv) * X, axis=1), XtX_inv\n",
    "\n",
    "def _hc_weights(e, h, kind=\"HC0\", k=None):\n",
    "    n = e.shape[0]\n",
    "    if kind == \"HC0\":\n",
    "        return e**2\n",
    "    if kind == \"HC1\":\n",
    "        if k is None:\n",
    "            raise ValueError(\"k must be provided for HC1\")\n",
    "        return e**2 * (n / (n - k))\n",
    "    if kind == \"HC2\":\n",
    "        return e**2 / np.clip(1.0 - h, 1e-12, None)\n",
    "    if kind == \"HC3\":\n",
    "        return e**2 / np.clip((1.0 - h) ** 2, 1e-24, None)\n",
    "    raise ValueError(f\"Unsupported HC type: {kind}\")\n",
    "\n",
    "def _avlm_log_G_f(F_value, d, nu, n, g):\n",
    "    \"\"\"\n",
    "    AVLM log_G_f expects an F-statistic (NOT a Wald chi-square).\n",
    "    \"\"\"\n",
    "    r = g / (g + n)\n",
    "\n",
    "    F_value = np.asarray(F_value, dtype=float)\n",
    "    F_value = np.where(np.isfinite(F_value), F_value, 0.0)\n",
    "    F_value = np.maximum(F_value, 0.0)\n",
    "\n",
    "    a = (d / nu) * F_value\n",
    "    x1 = np.maximum(1.0 + a, 1e-300)\n",
    "    x2 = np.maximum(1.0 + r * a, 1e-300)\n",
    "\n",
    "    return 0.5 * d * np.log(r) + 0.5 * (nu + d) * (np.log(x1) - np.log(x2))\n",
    "\n",
    "def _avlm_p_G_f_from_F(F_value, d, nu, n, g):\n",
    "    logG = _avlm_log_G_f(F_value, d, nu, n, g)\n",
    "    return min(1.0, float(np.exp(-logG)))\n",
    "\n",
    "def _normalize_cols_selector(test_cols, k):\n",
    "    \"\"\"\n",
    "    Returns a boolean mask of length k for which coefficients are being tested.\n",
    "    Accepts:\n",
    "      - None => all columns\n",
    "      - list/array of ints\n",
    "      - boolean mask of length k\n",
    "    \"\"\"\n",
    "    if test_cols is None:\n",
    "        return np.ones(k, dtype=bool)\n",
    "\n",
    "    test_cols = np.asarray(test_cols)\n",
    "    if test_cols.dtype == bool:\n",
    "        if test_cols.shape[0] != k:\n",
    "            raise ValueError(\"Boolean test_cols mask must have length k\")\n",
    "        return test_cols.copy()\n",
    "\n",
    "    mask = np.zeros(k, dtype=bool)\n",
    "    mask[test_cols.astype(int)] = True\n",
    "    return mask\n",
    "\n",
    "def avlm_global_F_pvalue_HC(\n",
    "    X, y, g=1.0, hc_type=\"HC0\", ridge=0.0, has_intercept=False, test_cols=None\n",
    "):\n",
    "    \"\"\"\n",
    "    Anytime-valid global F p-value for an lm fit with robust SEs, optionally\n",
    "    testing only a subset of columns (test_cols).\n",
    "\n",
    "    NOTE: No centering is done here.\n",
    "    \"\"\"\n",
    "    X = np.asarray(X, float)\n",
    "    y = np.asarray(y, float)\n",
    "    n, k = X.shape\n",
    "\n",
    "    # OLS\n",
    "    XtX = X.T @ X\n",
    "    ## to help with invertability issues\n",
    "    if ridge and ridge > 0:\n",
    "        XtX = XtX + ridge * np.eye(k)\n",
    "\n",
    "    XtX_inv = np.linalg.inv(XtX)\n",
    "    beta = XtX_inv @ (X.T @ y)\n",
    "    e = y - X @ beta\n",
    "\n",
    "    # hatvalues and robust weights\n",
    "    h, _ = _hatvalues(X, XtX_inv=XtX_inv)\n",
    "    w = _hc_weights(e, h, kind=hc_type, k=k)\n",
    "\n",
    "    # XV_hatX = X^T diag(w) X  (implemented as X^T (X * w))\n",
    "    XV_hatX = X.T @ (X * w[:, None])\n",
    "    if ridge and ridge > 0:\n",
    "        XV_hatX = XV_hatX + ridge * np.eye(k)\n",
    "\n",
    "    XV_hatX_inv = np.linalg.inv(XV_hatX)\n",
    "\n",
    "    # asymp_precision = (X^T X) (XV_hatX)^{-1} (X^T X)\n",
    "    asymp_precision = XtX @ XV_hatX_inv @ XtX\n",
    "\n",
    "    # Choose which coefficients to test\n",
    "    test_mask = _normalize_cols_selector(test_cols, k)\n",
    "\n",
    "    if has_intercept:\n",
    "        # assumes intercept is column 0; do not test it by default\n",
    "        test_mask = test_mask & (np.arange(k) != 0)\n",
    "\n",
    "    idx = np.where(test_mask)[0]\n",
    "    d = int(idx.size)\n",
    "    if d <= 0:\n",
    "        raise ValueError(\"No columns selected for testing (d=0).\")\n",
    "\n",
    "    beta_use = beta[idx]\n",
    "    P_use = asymp_precision[np.ix_(idx, idx)]\n",
    "\n",
    "    # Wald statistic (chi-square-like)\n",
    "    wald = float(beta_use.T @ P_use @ beta_use)\n",
    "\n",
    "    # Convert to an F-statistic on (d, nu) df\n",
    "    F_value = wald / d\n",
    "\n",
    "    # Denominator df\n",
    "    nu = int(n - k)\n",
    "    if nu <= 0:\n",
    "        return np.nan, F_value\n",
    "\n",
    "    p_anytime = _avlm_p_G_f_from_F(F_value, d=d, nu=nu, n=n, g=g)\n",
    "    return p_anytime, F_value\n",
    "\n",
    "def stopping_times_F_testing_avlm(\n",
    "    X, Y, null, alpha=0.1, g=1.0, t_min=200, ridge_XtX=0.0, hc_type=\"HC0\",\n",
    "    test_cols=None, has_intercept=False\n",
    "):\n",
    "    \"\"\"\n",
    "    Anytime-valid stopping time using AVLM-style global F p-values (robust HC0/1/2/3).\n",
    "\n",
    "    NOTE: No centering is done here. You should pass in the design matrix you want\n",
    "    (e.g., [A, X, A*X]) already constructed.\n",
    "    \"\"\"\n",
    "    X = np.asarray(X, float)\n",
    "    Y = np.asarray(Y, float)\n",
    "\n",
    "    y = (Y - null)\n",
    "    n = y.shape[0]\n",
    "\n",
    "    # Ensure we can form nu = t - k > 0\n",
    "    k = X.shape[1]\n",
    "    t_min_eff = max(int(t_min), k + 2)\n",
    "\n",
    "    p_t = np.zeros(n-t_min_eff+1)\n",
    "    F_t = np.zeros(n-t_min_eff+1)\n",
    "    \n",
    "\n",
    "    for t in range(t_min_eff, n + 1):\n",
    "        p_t[t-t_min_eff], F_t[t-t_min_eff] = avlm_global_F_pvalue_HC(\n",
    "            X[:t, :], y[:t],\n",
    "            g=g,\n",
    "            hc_type=hc_type,\n",
    "            ridge=ridge_XtX,\n",
    "            has_intercept=has_intercept,\n",
    "            test_cols=test_cols\n",
    "        )\n",
    "        if np.isfinite(p_t[t-t_min_eff]) and p_t[t-t_min_eff] <= alpha:\n",
    "            return {\"reject_time\": t, \"p_value\": p_t, \"F_stat\": F_t}\n",
    "\n",
    "    return {\"reject_time\": np.nan, \"p_value\": p_t, \"F_stat\": F_t}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ef7579d1-85b2-4916-bec2-58af3f4566eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "X,X_w_A, Y, Y_eif = bootstrap_sample(n=10000, df=full_dat, Xbin=Xbin, Xnum = Xnum, random_seed = 580)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4fd4e226-45bd-43fb-9c48-ae7baccb9b94",
   "metadata": {},
   "outputs": [],
   "source": [
    "#A = A.reshape(-1, 1)\n",
    "A = X_w_A[:,0]\n",
    "A =  A.reshape(-1, 1)\n",
    "#X = X - np.mean(X, axis = 0)\n",
    "#X = np.hstack([np.ones(X.shape[0]).reshape(-1,1), X])\n",
    "Z = np.hstack([A, X, A * X, np.ones(X.shape[0]).reshape(-1,1)])                 # (n, 81)\n",
    "\n",
    "test_cols = [0] + list(range(1 + X.shape[1], 1 + 2*X.shape[1]))  # [0] + 41..80\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "c9c45c8a-32d8-487b-85ad-59004892c33e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'reject_time': 1317,\n",
       " 'p_value': array([1., 1., 1., ..., 0., 0., 0.]),\n",
       " 'F_stat': array([1.08739183, 1.09092228, 1.09118155, ..., 0.        , 0.        ,\n",
       "        0.        ])}"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out = stopping_times_F_testing_avlm(\n",
    "    X=Z,\n",
    "    Y=Y,\n",
    "    null=0,\n",
    "    alpha=0.1,\n",
    "    g=1501,\n",
    "    t_min=1000,\n",
    "    hc_type=\"HC0\",\n",
    "    test_cols=test_cols,\n",
    "    ridge_XtX = 1e-5,\n",
    "    has_intercept=False\n",
    ")\n",
    "out"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aee9ea63-e81d-453c-83f0-5860a4020d1a",
   "metadata": {},
   "source": [
    "## Interpretable Bins for Testing Binned Approach"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "eafee9e3-eca3-453f-b534-00d2885f5214",
   "metadata": {},
   "outputs": [],
   "source": [
    "### get bins in our data for testing procedure\n",
    "def bin_index_woman_age_college(\n",
    "    X: np.ndarray,\n",
    "    Xbin: list,\n",
    "    Xnum: list,\n",
    "    n_age_bins: int = 6,\n",
    "    age_bins: np.ndarray | None = None,\n",
    "    age_strategy: str = \"quantile\",  # \"quantile\" or \"uniform\"\n",
    ") -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Construct bin indices based on (Woman, AgeBin, College_education).\n",
    "\n",
    "    Assumes X columns are ordered as: Xbin + Xnum.\n",
    "\n",
    "    Bin encoding:\n",
    "        idx = woman + 2 * (age_bin + n_age_bins * college)\n",
    "\n",
    "    where:\n",
    "        woman  in {0,1}\n",
    "        college in {0,1}\n",
    "        age_bin in {0,...,n_age_bins-1}\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    X : np.ndarray\n",
    "        Design matrix with columns ordered as Xbin + Xnum.\n",
    "    Xbin, Xnum : list\n",
    "        Feature name lists used to build X.\n",
    "    n_age_bins : int\n",
    "        Number of age bins (>= 4) if age_bins not provided.\n",
    "    age_bins : np.ndarray | None\n",
    "        Optional explicit bin edges (monotone). If provided, n_age_bins ignored.\n",
    "        Example: np.array([18,25,35,45,55,65,100]) -> 6 bins.\n",
    "    age_strategy : str\n",
    "        If age_bins is None, choose how to create bins:\n",
    "        - \"quantile\": equal-count bins (robust to skew)\n",
    "        - \"uniform\": equal-width bins\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    np.ndarray\n",
    "        Vector of integer bin indices, shape (n_samples,).\n",
    "    \"\"\"\n",
    "    if X.ndim != 2:\n",
    "        raise ValueError(\"X must be a 2D array.\")\n",
    "    if age_bins is None and n_age_bins < 1:\n",
    "        raise ValueError(\"n_age_bins must be positive.\")\n",
    "\n",
    "    # Column positions (since X was built from Xbin + Xnum)\n",
    "    colnames = list(Xbin) + list(Xnum)\n",
    "    try:\n",
    "        i_woman = colnames.index(\"Woman\")\n",
    "        i_college = colnames.index(\"College_education\")\n",
    "        i_age = colnames.index(\"age\")\n",
    "    except ValueError as e:\n",
    "        raise ValueError(f\"Missing required column in Xbin/Xnum lists: {e}\")\n",
    "\n",
    "    woman = X[:, i_woman].astype(int)\n",
    "    college = X[:, i_college].astype(int)\n",
    "    age = X[:, i_age].astype(float)\n",
    "\n",
    "    # Basic sanity: clamp binary-like columns to {0,1}\n",
    "    woman = (woman != 0).astype(int)\n",
    "    college = (college != 0).astype(int)\n",
    "\n",
    "    # Build age bins\n",
    "    if age_bins is not None:\n",
    "        age_bins = np.asarray(age_bins, dtype=float)\n",
    "        if age_bins.ndim != 1 or age_bins.size < 0:\n",
    "            raise ValueError(\"age_bins should be positive\")\n",
    "        if not np.all(np.diff(age_bins) > 0):\n",
    "            raise ValueError(\"age_bins must be strictly increasing.\")\n",
    "        n_age_bins_eff = age_bins.size - 1\n",
    "        # digitize returns 1..n_bins; convert to 0..n_bins-1\n",
    "        age_bin = np.digitize(age, age_bins[1:-1], right=False)\n",
    "    else:\n",
    "        if age_strategy not in {\"quantile\", \"uniform\"}:\n",
    "            raise ValueError(\"age_strategy must be 'quantile' or 'uniform'.\")\n",
    "        n_age_bins_eff = int(n_age_bins)\n",
    "\n",
    "        if age_strategy == \"quantile\":\n",
    "            # Use unique quantile edges; if ties collapse bins, we fall back to uniform.\n",
    "            qs = np.linspace(0, 1, n_age_bins_eff + 1)\n",
    "            edges = np.quantile(age[~np.isnan(age)], qs)\n",
    "            edges = np.unique(edges)\n",
    "            if edges.size < 5:  # fewer than 4 bins possible\n",
    "                edges = np.linspace(np.nanmin(age), np.nanmax(age), n_age_bins_eff + 1)\n",
    "            # Map to bins\n",
    "            age_bin = np.digitize(age, edges[1:-1], right=False)\n",
    "            n_age_bins_eff = edges.size - 1\n",
    "        else:  # uniform\n",
    "            edges = np.linspace(np.nanmin(age), np.nanmax(age), n_age_bins_eff + 1)\n",
    "            age_bin = np.digitize(age, edges[1:-1], right=False)\n",
    "\n",
    "    # Ensure age_bin is within range (handles NaNs or out-of-range gracefully)\n",
    "    age_bin = np.nan_to_num(age_bin, nan=0.0).astype(int)\n",
    "    age_bin = np.clip(age_bin, 0, n_age_bins_eff - 1)\n",
    "\n",
    "    # Combine into single index\n",
    "    idx = woman + 2 * (age_bin + n_age_bins_eff * college)\n",
    "    return idx.astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "ef4b72b4-023b-4516-860b-de96a59a0486",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Last method we use will be running two-sided confidence intervals using asymptotic AV inference. \n",
    "\n",
    "def upper_bound_for_ci(means, vars, alpha=0.1, rho=0.06):\n",
    "    t_vector = np.arange(len(means)) + 1\n",
    "    term = t_vector * vars *(rho)**2 + 1\n",
    "    return means + np.sqrt( 2*term/t_vector**2/rho**2 * np.log(1+np.sqrt(term)/alpha) )\n",
    "\n",
    "## implementation of lower bound function\n",
    "def lower_bound_for_ci(means, vars, alpha=0.1, rho=0.06):\n",
    "    t_vector = np.arange(len(means)) + 1\n",
    "    term = t_vector * vars *(rho)**2 + 1\n",
    "    return means - np.sqrt( 2*term/t_vector**2/rho**2 * np.log(1+np.sqrt(term)/alpha) )\n",
    "\n",
    "\n",
    "def conditional_mean_CIs_real_world(X, Y, Xbin, Xnum, alpha=0.1, rho = 0.06, null = 0., t_0 = 200 , n_age_bins = 2):\n",
    "    # for each bin, construct AV confidence interval\n",
    "    T = X.shape[0]\n",
    "    rejection_at_t = np.zeros(T)\n",
    "    bin_id = bin_index_woman_age_college(X=X, Xbin=Xbin, Xnum=Xnum, n_age_bins=n_age_bins, age_strategy=\"quantile\")\n",
    "    n_bins = n_age_bins * 4\n",
    "    \n",
    "    for n in range(n_bins):\n",
    "        ## compute the CI for bin n\n",
    "        X_rel = X[bin_id == n,]\n",
    "        Y_rel = Y[bin_id == n]\n",
    "        T_n = X_rel.shape[0]\n",
    "        \n",
    "        running_mean = np.cumsum(Y_rel)/(np.arange(T_n)+1)\n",
    "        running_var = np.cumsum((Y_rel - running_mean)**2)/(np.arange(T_n)+1)\n",
    "        upper_bounds_n = upper_bound_for_ci(running_mean, running_var, alpha = alpha/n_bins, rho = rho)\n",
    "        lower_bounds_n = lower_bound_for_ci(running_mean, running_var, alpha = alpha/n_bins, rho = rho)\n",
    "        #print((np.max(lower_bounds_n)-0.5, np.min(upper_bounds_n)-0.5))\n",
    "        reject_at_t_n = (upper_bounds_n < null) + (lower_bounds_n > null)\n",
    "\n",
    "        rejection_at_t[bin_id == n] += reject_at_t_n\n",
    "        ### print time of rejection\n",
    "        #print( ((np.arange(T)+1)[bin_id == n])[reject_at_t_n != 0] )\n",
    "    if np.sum(rejection_at_t) != 0:\n",
    "        return np.min(np.where(rejection_at_t[int(t_0): ]!= 0))+t_0\n",
    "    else:\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "36d5d53c-c695-4636-ad72-0dcf12152bb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "X,X_w_A, Y, Y_eif = bootstrap_sample(n=10000, df=full_dat, Xbin=Xbin, Xnum = Xnum, random_seed = 580)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "0a751a35-fff8-4425-8834-946dc98ff250",
   "metadata": {},
   "outputs": [],
   "source": [
    "time = conditional_mean_CIs_real_world(X=np.array(X), Y=np.array(Y_eif), Xbin=Xbin, Xnum = Xnum, null = 0., n_age_bins = 1, t_0 =1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "15bb8c08-70b7-4bc5-8f1d-727ed34609de",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "None\n"
     ]
    }
   ],
   "source": [
    "print(time)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2755b043-d373-4dff-88eb-f5eceb04a5a5",
   "metadata": {},
   "source": [
    "## Code for our Approach"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "57e4b367-2015-4d40-84ba-3d4dbc974491",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "river version: 0.22.0\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import river\n",
    "from river import neural_net\n",
    "from river import preprocessing as pp\n",
    "from river import optim, metrics\n",
    "from river.neural_net import activations as act\n",
    "\n",
    "print(\"river version:\", river.__version__)\n",
    "\n",
    "def row_to_dict(x_row: np.ndarray):\n",
    "    return {f\"x{i}\": float(x_row[i]) for i in range(x_row.shape[0])}\n",
    "\n",
    "\n",
    "## implementation of lower bound function\n",
    "def lower_bound(means, vars, alpha=0.1, rho=0.06):\n",
    "    t_vector = np.arange(len(means)) + 1\n",
    "    term = t_vector * vars *(rho)**2 + 1\n",
    "    return means - np.sqrt( 2*term/t_vector**2/rho**2 * np.log(1+np.sqrt(term)/2/alpha) )\n",
    "\n",
    "\n",
    "def our_approach(X_stream, Y_stream, null = 0.5,alpha=0.1, rho = 0.06, t_0 = 200, l = 0.01, weight_decay = 0.49):\n",
    "    ## first, train regressor for $\\hat\\tau$\n",
    "    T = X_stream.shape[0]\n",
    "\n",
    "    ## for conditional_means\n",
    "    mlp = pp.StandardScaler() | neural_net.MLPRegressor(\n",
    "        hidden_dims=(64, 64, 64, 64, 32),\n",
    "        activations=(act.ReLU(), act.ReLU(), act.ReLU(), act.ReLU(), act.ReLU, act.Identity()),\n",
    "        optimizer=optim.Adam(1e-3),\n",
    "        seed=0)\n",
    "\n",
    "    ## for variance\n",
    "    mlp_2 = pp.StandardScaler() | neural_net.MLPRegressor(\n",
    "        hidden_dims=(64, 64, 64, 64, 32),\n",
    "        activations=(act.ReLU(), act.ReLU(), act.ReLU(), act.ReLU(), act.ReLU(), act.Identity()),\n",
    "        optimizer=optim.Adam(1e-3),\n",
    "        seed=0)\n",
    "    \n",
    "    pred = np.full(T, np.nan)\n",
    "\n",
    "    for t in range(T):\n",
    "        x = row_to_dict(X_stream[t])\n",
    "        y = float(Y_stream[t])\n",
    "\n",
    "        yhat = mlp.predict_one(x)\n",
    "        if yhat is not None:\n",
    "            yhat = clip02(yhat, 0)\n",
    "            pred[t] = yhat\n",
    "            #rmse.update(y, yhat)\n",
    "            #mae.update(y, yhat)\n",
    "\n",
    "        mlp.learn_one(x, y)\n",
    "\n",
    "    Y_stream_2 = (Y_stream - pred)**2\n",
    "    pred_vars = np.full(T, np.nan)\n",
    "\n",
    "    for t in range(T):\n",
    "        x = row_to_dict(X_stream[t])\n",
    "        y = float(Y_stream_2[t])\n",
    "\n",
    "        yhat = mlp_2.predict_one(x)\n",
    "        if yhat is not None:\n",
    "            yhat = clip02(yhat, l)\n",
    "            pred_vars[t] = yhat\n",
    "\n",
    "        mlp_2.learn_one(x, y)\n",
    "\n",
    "    ## compute sequential weights for our procedure\n",
    "    weights = (pred-null)/pred_vars\n",
    "    weights = np.sign(weights) * np.maximum(np.abs(weights), 0.05*((np.arange(T)+1)**(-1*weight_decay)) )\n",
    "\n",
    "    #print(weights)\n",
    "\n",
    "    \n",
    "    scores = weights * (Y_stream - null) \n",
    "    running_mean = np.cumsum(scores)/(np.arange(len(Y_stream))+1) #\\psi_t\n",
    "    w_sq = weights**2 \n",
    "    resid = (Y_stream-pred)**2\n",
    "    var = np.cumsum(w_sq * resid)/(np.arange(len(Y_stream))+1) # running variance estimate\n",
    "\n",
    "    lbs = lower_bound(running_mean, var, alpha = alpha)\n",
    "\n",
    "    ## get rejection time for our procedure\n",
    "    if np.max(lbs[int(t_0):]) <= 0:\n",
    "        return None\n",
    "    else:\n",
    "        return np.min(np.where(lbs[int(t_0):] > 0)) + t_0\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "33ddd5e9-83d2-448b-b330-2cf9a47d9819",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1327\n"
     ]
    }
   ],
   "source": [
    "X_stream, X_w_A_stream, Y_stream, Y_stream_eif = bootstrap_sample(n=10000, \n",
    "                                                                      df=full_dat, \n",
    "                                                                      Xbin=Xbin, \n",
    "                                                                      Xnum = Xnum, \n",
    "                                                                      random_seed = 70, \n",
    "                                                                      noise_std = 0)\n",
    "\n",
    "x=our_approach(np.array(X_stream), np.array(Y_stream_eif), null = 0,alpha=0.1, rho = 0.06, t_0 = 800, l = 0.01, weight_decay = 0.24)\n",
    "print(x)\n",
    "#plot(x)\n",
    "#plt.axhline(0, color=\"red\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa654c5a-13bf-45c3-9783-9f334376abda",
   "metadata": {},
   "source": [
    "## Testing our Approaches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "5aec8426-81b4-4493-92f3-6cced62d6e51",
   "metadata": {},
   "outputs": [],
   "source": [
    "null = 0.\n",
    "alpha = 0.1\n",
    "rho = 0.06\n",
    "t_0 = 2000\n",
    "num_experiments = 100\n",
    "num_methods = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "e48532d0-1ad1-4e8d-aa40-f006cf03ff85",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[5776.   nan   nan   nan]\n",
      "[2450. 2000.   nan   nan]\n",
      "[3858.   nan   nan   nan]\n",
      "[  nan 2000.   nan   nan]\n",
      "[5741. 2000.   nan   nan]\n",
      "[3123. 2000.   nan   nan]\n",
      "[4243. 2134.   nan   nan]\n",
      "[5375. 2000.   nan   nan]\n",
      "[2865. 2000.   nan   nan]\n",
      "[5613. 2000.   nan   nan]\n",
      "[  nan 2000.   nan   nan]\n",
      "[9028. 2000.   nan   nan]\n",
      "[7814. 2000.   nan   nan]\n",
      "[2446.   nan   nan   nan]\n",
      "[8921. 2000.   nan   nan]\n",
      "[8784. 2000.   nan   nan]\n",
      "[8335. 2000.   nan   nan]\n",
      "[8386. 2000.   nan   nan]\n",
      "[4602.   nan   nan   nan]\n",
      "[4635. 2000.   nan 5357.]\n",
      "[9397. 2000.   nan   nan]\n",
      "[5133.   nan   nan   nan]\n",
      "[  nan 2000.   nan   nan]\n",
      "[2265. 2000.   nan   nan]\n",
      "[  nan 2000.   nan   nan]\n",
      "[9578. 2000.   nan   nan]\n",
      "[2000. 2000.   nan   nan]\n",
      "[3845. 2000.   nan   nan]\n",
      "[  nan 2000.   nan   nan]\n",
      "[3623. 2000.   nan 6173.]\n",
      "[  nan 2000.   nan   nan]\n",
      "[9044. 2000.   nan   nan]\n",
      "[6771. 2000.   nan 5134.]\n",
      "[3669. 2000.   nan   nan]\n",
      "[4633.   nan   nan   nan]\n",
      "[  nan 2000.   nan   nan]\n",
      "[7618. 2141.   nan   nan]\n",
      "[4538. 2000.   nan   nan]\n",
      "[3319. 6127.   nan   nan]\n",
      "[2336. 2000.   nan   nan]\n",
      "[3251. 2247.   nan 5738.]\n",
      "[3104. 2000.   nan   nan]\n",
      "[6250. 2000.   nan   nan]\n",
      "[  nan 2000.   nan   nan]\n",
      "[2318. 2255.   nan   nan]\n",
      "[3853. 2000.   nan   nan]\n",
      "[  nan 2000.   nan   nan]\n",
      "[4270.   nan   nan   nan]\n",
      "[7460. 2000.   nan   nan]\n",
      "[3647. 2000.   nan   nan]\n",
      "[4129. 2000.   nan   nan]\n",
      "[2843. 2000.   nan   nan]\n",
      "[6013. 2095.   nan   nan]\n",
      "[4289. 2000.   nan   nan]\n",
      "[7410. 2000.   nan   nan]\n",
      "[5784. 2000.   nan   nan]\n",
      "[3547. 2000.   nan   nan]\n",
      "[5982. 2018.   nan   nan]\n",
      "[6964. 2000.   nan   nan]\n",
      "[3320.   nan   nan 7672.]\n",
      "[nan nan nan nan]\n",
      "[9775. 2000.   nan   nan]\n",
      "[  nan 2000.   nan   nan]\n",
      "[7253. 2000.   nan   nan]\n",
      "[6016. 2000.   nan   nan]\n",
      "[2007. 2000.   nan   nan]\n",
      "[3333. 2000.   nan   nan]\n",
      "[6883. 2000.   nan   nan]\n",
      "[6312. 2000.   nan   nan]\n",
      "[4518. 2000.   nan   nan]\n",
      "[3836. 2000.   nan   nan]\n",
      "[9440. 2000.   nan   nan]\n",
      "[3263.   nan   nan   nan]\n",
      "[  nan 2000.   nan   nan]\n",
      "[  nan 2000.   nan   nan]\n",
      "[  nan 2000.   nan   nan]\n",
      "[  nan 2000.   nan   nan]\n",
      "[  nan 2000.   nan   nan]\n",
      "[8164.   nan   nan   nan]\n",
      "[6457. 2000.   nan   nan]\n",
      "[2692. 2000.   nan   nan]\n",
      "[7587. 2000.   nan   nan]\n",
      "[5213. 2000.   nan   nan]\n",
      "[7440. 2000.   nan   nan]\n",
      "[  nan 2000.   nan   nan]\n",
      "[2748.   nan   nan   nan]\n",
      "[  nan 2000.   nan   nan]\n",
      "[4137. 2000.   nan   nan]\n",
      "[4777. 2000.   nan   nan]\n",
      "[  nan 2000.   nan   nan]\n",
      "[4161. 2000.   nan   nan]\n",
      "[5734. 2000.   nan   nan]\n",
      "[9672.   nan   nan   nan]\n",
      "[nan nan nan nan]\n",
      "[4053. 2000.   nan   nan]\n",
      "[4846. 2000.   nan   nan]\n",
      "[3034. 2000.   nan   nan]\n",
      "[  nan 2000.   nan   nan]\n",
      "[3094. 2000.   nan   nan]\n",
      "[6689. 2169.   nan   nan]\n"
     ]
    }
   ],
   "source": [
    "## Testing Methods\n",
    "\n",
    "stopping_times_jobs = np.full((num_experiments, num_methods), np.nan)\n",
    "\n",
    "for i in range(num_experiments):\n",
    "    X_stream, X_w_A_stream, Y_stream, Y_stream_eif = bootstrap_sample(n=10000, \n",
    "                                                                      df=full_dat, \n",
    "                                                                      Xbin=Xbin, \n",
    "                                                                      Xnum = Xnum, \n",
    "                                                                      random_seed = 600+i, \n",
    "                                                                      noise_std = 0)\n",
    "    \n",
    "    #X_stream = np.array(X_stream)\n",
    "\n",
    "    #X_w_A_stream = np.array(X_w_A_stream)\n",
    "    \n",
    "    #Y_stream_eif = np.array(Y_stream_eif)\n",
    "    \n",
    "    #Y_stream = np.array(Y_stream)\n",
    "\n",
    "    ### our approach\n",
    "    stopping_times_jobs[i, 0] = our_approach(X_stream, \n",
    "                                             Y_stream_eif, \n",
    "                                             null = null,\n",
    "                                             alpha=alpha, \n",
    "                                             rho = rho, \n",
    "                                             t_0 = t_0, \n",
    "                                             l = 0.01, \n",
    "                                             weight_decay = 0.24)\n",
    "\n",
    "    ### preprocessing for F_test\n",
    "    A = X_w_A_stream[:,0]\n",
    "    A =  A.reshape(-1, 1)\n",
    "    Z = np.hstack([A, X_stream, A * X_stream, np.ones(X_stream.shape[0]).reshape(-1,1)])                 # (n, 82)\n",
    "\n",
    "    test_cols = [0] + list(range(1 + X_stream.shape[1], 1 + 2*X_stream.shape[1]))  # [0] + 41..80\n",
    "    #print(test_cols)\n",
    "\n",
    "    \n",
    "    stopping_times_jobs[i, 1] = stopping_times_F_testing_avlm(X=Z,\n",
    "                                                                    Y=Y_stream,\n",
    "                                                                    null=null,\n",
    "                                                                    alpha=alpha,\n",
    "                                                                    g=1501,\n",
    "                                                                    t_min=2000,\n",
    "                                                                    hc_type=\"HC0\",\n",
    "                                                                    test_cols=test_cols,\n",
    "                                                                    ridge_XtX = 1e-5,\n",
    "                                                                    has_intercept=False\n",
    "                                                                   )[\"reject_time\"]\n",
    "\n",
    "    \n",
    "    stopping_times_jobs[i, 2] = conditional_mean_CIs_real_world(X=(X_stream), \n",
    "                                                                Y=(Y_stream_eif), \n",
    "                                                                Xbin = Xbin,\n",
    "                                                                Xnum = Xnum,\n",
    "                                                                alpha = alpha/10,\n",
    "                                                                rho = rho,\n",
    "                                                                null = null,\n",
    "                                                                t_0 = t_0,\n",
    "                                                                n_age_bins = 2)\n",
    "\n",
    "    stopping_times_jobs[i, 3] = conditional_mean_CIs_real_world(X=(X_stream), \n",
    "                                                                Y=(Y_stream_eif), \n",
    "                                                                Xbin = Xbin,\n",
    "                                                                Xnum = Xnum,\n",
    "                                                                alpha = alpha,\n",
    "                                                                rho = rho,\n",
    "                                                                null = null,\n",
    "                                                                t_0 = t_0,\n",
    "                                                                n_age_bins = 4)\n",
    "                                                                \n",
    "\n",
    "    print(stopping_times_jobs[i, :])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "id": "7b16d3f9-d3a1-41c5-b16f-074e996cf881",
   "metadata": {},
   "outputs": [],
   "source": [
    "### save matrix\n",
    "np.savez(\"stopping_times_real_world.npz\",\n",
    "         jobs=stopping_times_jobs,)\n",
    "#print(stopping_times_jobs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
