{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dc20743",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "from run_method import run_method_multiple_times\n",
    "from configs import Subcon_Config\n",
    "import torch\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "545c05d0",
   "metadata": {},
   "source": [
    "## IHDP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddf17f69",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use local files OR replace with raw URLs; this path expects local CSVs:\n",
    "base_url = \"data/IHDP/csv/ihdp_npci_\"\n",
    "N_REALIZATIONS = 10  # set to 100 if applicable\n",
    "\n",
    "all_results = []\n",
    "torch.manual_seed(0)\n",
    "np.random.seed(0)\n",
    "mse_in = []\n",
    "mse_out = []\n",
    "sg_masks = []\n",
    "sg_masks_ablation = []\n",
    "\n",
    "for i in range(1, N_REALIZATIONS + 1):\n",
    "    url = f\"{base_url}{i}.csv\"\n",
    "    data_df = pd.read_csv(url, header=None, sep=\",\")\n",
    "    col = ['treatment', 'y_factual', 'y_cfactual', 'mu0', 'mu1'] + [f'x{j}' for j in range(1, 26)]\n",
    "    data_df.columns = col\n",
    "\n",
    "    treatment = data_df['treatment'].values.astype(int)\n",
    "    y = data_df['y_factual'].values.astype(float)\n",
    "    X = data_df.filter(regex=r'^x\\d+$').values.astype(float)\n",
    "    tau = (data_df['mu1'] - data_df['mu0']).values.astype(float)\n",
    "\n",
    "    feature_names = [f\"X{k}\" for k in range(X.shape[1])]\n",
    "    scaler_X = StandardScaler()\n",
    "    X_scaled = scaler_X.fit_transform(X)\n",
    "    scaler_Y = StandardScaler()\n",
    "    y_scaled = scaler_Y.fit_transform(y.reshape(-1, 1))[:,0]\n",
    "    is_discrete = [False] * X.shape[1]\n",
    "\n",
    "    our_config = Subcon_Config().get_setting_config('observational')\n",
    "    subgroups, _, _ = run_method_multiple_times(\n",
    "        X_scaled, X_scaled[treatment == 0], X_scaled[treatment == 1],\n",
    "        y_scaled, y_scaled[treatment == 0], y_scaled[treatment == 1],\n",
    "        scaler_X, scaler_Y, feature_names, is_discrete, our_config,\n",
    "        discrete_target=False, maximize=True, plot=False, max_reps=1\n",
    "    )\n",
    "    s0_mask, s1_mask = subgroups[0]\n",
    "    subgroup_mask = np.zeros(X.shape[0], dtype=bool)\n",
    "    subgroup_mask[np.where(treatment == 0)[0][s0_mask]] = True\n",
    "    subgroup_mask[np.where(treatment == 1)[0][s1_mask]] = True\n",
    "    sg_masks.append(subgroup_mask)\n",
    "    cate_pred_subcon = np.zeros_like(tau, dtype=float)\n",
    "    mask_in, mask_out = subgroup_mask, ~subgroup_mask\n",
    "    ate_in  = np.mean(y[mask_in & (treatment == 1)]) - np.mean(y[mask_in & (treatment == 0)])\n",
    "    ate_out = np.mean(y[mask_out & (treatment == 1)]) - np.mean(y[mask_out & (treatment == 0)])\n",
    "    cate_pred_subcon[mask_in]  = ate_in\n",
    "    cate_pred_subcon[mask_out] = ate_out\n",
    "\n",
    "    # compute mse in subgroup and out of subgroup\n",
    "    m_in  = np.mean((tau[mask_in] - cate_pred_subcon[mask_in])**2)\n",
    "    m_out = np.mean((tau[mask_out] - cate_pred_subcon[mask_out])**2)\n",
    "    mse_in.append(m_in)\n",
    "    mse_out.append(m_out)\n",
    "\n",
    "    our_config[\"lambd\"] = 0.0  # ablation: no regularization\n",
    "    subgroups, _, _ = run_method_multiple_times(\n",
    "        X_scaled, X_scaled[treatment == 0], X_scaled[treatment == 1],\n",
    "        y_scaled, y_scaled[treatment == 0], y_scaled[treatment == 1],\n",
    "        scaler_X, scaler_Y, feature_names, is_discrete, our_config,\n",
    "        discrete_target=False, maximize=True, plot=True, max_reps=1\n",
    "    )\n",
    "    s0_mask, s1_mask = subgroups[0]\n",
    "    subgroup_mask = np.zeros(X.shape[0], dtype=bool)\n",
    "    subgroup_mask[np.where(treatment == 0)[0][s0_mask]] = True\n",
    "    subgroup_mask[np.where(treatment == 1)[0][s1_mask]] = True\n",
    "    sg_masks_ablation.append(subgroup_mask)\n",
    "print(\"MSE in subgroup: \", np.mean(mse_in), np.std(mse_in))\n",
    "print(\"MSE out of subgroup: \", np.mean(mse_out), np.std(mse_out))\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71f02b98",
   "metadata": {},
   "outputs": [],
   "source": [
    "# now repeat with Causal Forest and Causal Tree\n",
    "from causalml.inference.tree import CausalTreeRegressor\n",
    "from econml.grf import CausalForest\n",
    "from econml.metalearners import XLearner\n",
    "from sklearn.ensemble import HistGradientBoostingRegressor\n",
    "from configs import HonestTree_Config\n",
    "\n",
    "torch.manual_seed(0)\n",
    "np.random.seed(0)\n",
    "mse_in = []\n",
    "mse_out = []\n",
    "results = pd.DataFrame(columns=['realization', \"mse_in_subcon\", \"mse_total_subcon\",\"mse_in_ablation\", \"mse_total_ablation\", \"mse_in_ct\", \"mse_total_ct\", \"mse_in_cf\", \"mse_total_cf\", \"mse_in_xl\", \"mse_total_xl\"])\n",
    "\n",
    "for i in range(1, N_REALIZATIONS + 1):\n",
    "    url = f\"{base_url}{i}.csv\"\n",
    "    data_df = pd.read_csv(url, header=None, sep=\",\")\n",
    "    col = ['treatment', 'y_factual', 'y_cfactual', 'mu0', 'mu1'] + [f'x{j}' for j in range(1, 26)]\n",
    "    data_df.columns = col\n",
    "\n",
    "    treatment = data_df['treatment'].values.astype(int)\n",
    "    y = data_df['y_factual'].values.astype(float)\n",
    "    X = data_df.filter(regex=r'^x\\d+$').values.astype(float)\n",
    "    tau = (data_df['mu1'] - data_df['mu0']).values.astype(float)\n",
    "\n",
    "    # fit causal tree\n",
    "    tree_config = HonestTree_Config().get_setting_config('observational')\n",
    "    ct = CausalTreeRegressor(**tree_config)\n",
    "    ct.fit(X, treatment, y)\n",
    "    cate_pred_ct = ct.predict(X)\n",
    "    # fit causal forest\n",
    "    cf = CausalForest()\n",
    "    cf.fit(X, treatment, y)\n",
    "    cate_pred_cf = cf.predict(X)    \n",
    "    xlearner = XLearner(models=HistGradientBoostingRegressor())\n",
    "    xlearner.fit(y, treatment, X=X)\n",
    "    cate_pred_xl = xlearner.effect(X)\n",
    "    # compute mse in subgroup and out of subgroup\n",
    "    subgroup_mask = sg_masks[i-1]\n",
    "    cate_pred_subcon = np.zeros_like(tau, dtype=float)\n",
    "    mask_in, mask_out = subgroup_mask, ~subgroup_mask\n",
    "    ate_in  = np.mean(y[mask_in & (treatment == 1)]) - np.mean(y[mask_in & (treatment == 0)])\n",
    "    ate_out = np.mean(y[mask_out & (treatment == 1)]) - np.mean(y[mask_out & (treatment == 0)])\n",
    "    cate_pred_subcon[mask_in]  = ate_in\n",
    "    cate_pred_subcon[mask_out] = ate_out\n",
    "    mse_in_subcon  = np.sqrt(np.mean((tau[mask_in] - cate_pred_subcon[mask_in])**2))\n",
    "    mse_total_subcon = np.sqrt(np.mean((tau - cate_pred_subcon)**2))\n",
    "    mse_in_ct  = np.sqrt(np.mean((tau[mask_in] - cate_pred_ct[mask_in])**2))\n",
    "    mse_total_ct = np.sqrt(np.mean((tau - cate_pred_ct)**2))\n",
    "    mse_in_cf  = np.sqrt(np.mean((tau[mask_in] - cate_pred_cf[mask_in])**2))\n",
    "    mse_total_cf = np.sqrt(np.mean((tau - cate_pred_cf)**2))\n",
    "    mse_in_xl  = np.sqrt(np.mean((tau[mask_in] - cate_pred_xl[mask_in])**2))\n",
    "    mse_total_xl = np.sqrt(np.mean((tau - cate_pred_xl)**2))\n",
    "\n",
    "\n",
    "    # check without regularizer\n",
    "    subgroup_mask = sg_masks_ablation[i-1]\n",
    "    cate_pred_subcon = np.zeros_like(tau, dtype=float)\n",
    "    mask_in, mask_out = subgroup_mask, ~subgroup_mask\n",
    "    ate_in  = np.mean(y[mask_in & (treatment == 1)]) - np.mean(y[mask_in & (treatment == 0)])\n",
    "    ate_out = np.mean(y[mask_out & (treatment == 1)]) - np.mean(y[mask_out & (treatment == 0)])\n",
    "    cate_pred_subcon[mask_in]  = ate_in\n",
    "    cate_pred_subcon[mask_out] = ate_out\n",
    "    mse_in_subcon_ablation  = np.sqrt(np.mean((tau[mask_in] - cate_pred_subcon[mask_in])**2))\n",
    "    mse_total_subcon_ablation = np.sqrt(np.mean((tau - cate_pred_subcon)**2))\n",
    "\n",
    "\n",
    "    results.loc[len(results)] = {'realization': i, \"mse_in_subcon\": mse_in_subcon, \"mse_total_subcon\": mse_total_subcon,\"mse_in_ablation\": mse_in_subcon_ablation, \"mse_total_ablation\": mse_total_subcon_ablation, \"mse_in_ct\": mse_in_ct, \"mse_total_ct\": mse_total_ct, \"mse_in_cf\": mse_in_cf, \"mse_total_cf\": mse_total_cf, \"mse_in_xl\": mse_in_xl, \"mse_total_xl\": mse_total_xl}\n",
    "\n",
    "# compute average\n",
    "results.loc[len(results)] = {'realization': 'mean', \"mse_in_subcon\": results['mse_in_subcon'].mean(), \"mse_total_subcon\": results['mse_total_subcon'].mean(), \"mse_in_ablation\": results['mse_in_ablation'].mean(), \"mse_total_ablation\": results['mse_total_ablation'].mean(),  \"mse_in_ct\": results['mse_in_ct'].mean(), \"mse_total_ct\": results['mse_total_ct'].mean(), \"mse_in_cf\": results['mse_in_cf'].mean(), \"mse_total_cf\": results['mse_total_cf'].mean(), \"mse_in_xl\": results['mse_in_xl'].mean(), \"mse_total_xl\": results['mse_total_xl'].mean()}\n",
    "results.to_csv(\"results/results_ihdp.csv\", index=False)\n",
    "supports = []\n",
    "\n",
    "suports_ablation = []\n",
    "for i in range(10):\n",
    "    mask_subcon = sg_masks[i]\n",
    "    mask_ablation = sg_masks_ablation[i]\n",
    "    support_subcon = np.mean(mask_subcon)\n",
    "    support_ablation = np.mean(mask_ablation)\n",
    "    supports.append(support_subcon)\n",
    "    suports_ablation.append(support_ablation)\n",
    "print(\"Support subcon: \", np.mean(supports), np.std(supports))\n",
    "print(\"Support ablation: \", np.mean(suports_ablation), np.std(suports_ablation))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "subcon",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
