{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49786f39",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from cauml.utils import CausalDataset\n",
    "import torch\n",
    "import os\n",
    "import torch.nn as nn\n",
    "from scipy.stats import norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2dbc036",
   "metadata": {},
   "outputs": [],
   "source": [
    "def caculate_error(tau1, tau2, t, y, mu0, mu1, e):  #compute relative error\n",
    "    delta = tau1 ** 2 - tau2 ** 2 - 2 * (tau1 - tau2) * (t * (y - mu1) / e + mu1 - (1-t)* (y - mu0) / (1 - e) - mu0)\n",
    "    return delta\n",
    "\n",
    "def set_seed(seed=926):\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    torch.use_deterministic_algorithms(True)\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "set_seed()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a46c42df",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Our network\n",
    "class SharkNet(nn.Module):\n",
    "    def __init__(self, input_dim):\n",
    "        super().__init__()\n",
    "        self.shared = nn.Sequential(\n",
    "            nn.Linear(input_dim, 60), \n",
    "            nn.ELU(),\n",
    "            nn.Linear(60, 30),\n",
    "            nn.ELU(),\n",
    "            nn.Linear(30, 30),\n",
    "            nn.ELU()\n",
    "        )\n",
    "        self.y0_head = nn.Sequential(\n",
    "            nn.Linear(30, 1, bias=False)\n",
    "        )\n",
    "        self.y1_head = nn.Sequential(\n",
    "            nn.Linear(30, 1, bias=False)\n",
    "        )\n",
    "        self.t_head = nn.Sequential(\n",
    "            nn.Linear(30, 1, bias=False),\n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "        self.xi = nn.Parameter(torch.ones(30) * 0.1)\n",
    "        self.eta = nn.Parameter(torch.ones(30) * 0.1)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        rep = self.shared(x)\n",
    "        y0 = self.y0_head(rep)\n",
    "        y1 = self.y1_head(rep)\n",
    "        t = self.t_head(rep)\n",
    "        gamma = self.t_head[0].weight[0]\n",
    "        output = torch.cat([y0, y1, t, rep], dim=1)\n",
    "        return output, gamma\n",
    "\n",
    "    def wls_loss(self, output, Y, A, tau_diff, gamma):\n",
    "        gamma = gamma.detach().squeeze()\n",
    "        Phi = output[:, 3:]\n",
    "        e_pred = torch.sigmoid(Phi @ gamma)\n",
    "        #e_pred = output[:, 2:3].detach().squeeze()\n",
    "        e_pred = torch.clamp(e_pred, 0.3, 0.7)\n",
    "        mu0_pred = output[:, 0:1].squeeze()\n",
    "        mu1_pred = output[:, 1:2].squeeze()\n",
    "        term1 = (1 - A) * e_pred * (Y - mu0_pred)**2 / (1 - e_pred)\n",
    "        term2 = A * (1 - e_pred) * (Y - mu1_pred)**2 / e_pred\n",
    "        loss = (torch.abs(tau_diff) * (term1 + term2)).mean()\n",
    "        return loss\n",
    "\n",
    "    def cross_entropy_loss(self, output, A):\n",
    "        e_pred = output[:, 2:3]\n",
    "        e_pred = torch.clamp(e_pred, 1e-3, 1 - 1e-3)\n",
    "        e_pred = e_pred.squeeze()\n",
    "        log_likelihood = (A * torch.log(e_pred) + (1 - A) * torch.log(1 - e_pred)).mean()\n",
    "        return -log_likelihood\n",
    "    \n",
    "    def constraint_loss(self, output, A, tau_diff, gamma, c = 1, penalty_weight=100.0):\n",
    "        A = A.unsqueeze(1)\n",
    "        tau_diff = tau_diff.unsqueeze(1)\n",
    "        Phi = output[:, 3:].detach()\n",
    "        e_pred = torch.sigmoid(Phi @ gamma)\n",
    "        e_pred = torch.clamp(e_pred, 1e-3, 1 - 1e-3)\n",
    "        e_pred = e_pred.unsqueeze(1)\n",
    "        term_xi = (tau_diff * (1 - A / (e_pred)) * Phi).mean(dim=0)\n",
    "        term_eta = (tau_diff * (1 - (1 - A) / (1 - e_pred)) * Phi).mean(dim=0)\n",
    "        penalty_xi = torch.relu(torch.abs(term_xi) - self.xi)\n",
    "        penalty_eta = torch.relu(torch.abs(term_eta) - self.eta)\n",
    "        constraint_penalty = (penalty_xi).sum() + (penalty_eta).sum()\n",
    "        slack_penalty = self.xi.sum() + self.eta.sum()\n",
    "        neg_penalty = torch.relu(-self.xi).sum() + torch.relu(-self.eta).sum()\n",
    "        #print(slack_penalty, neg_penalty, constraint_penalty)\n",
    "        loss = penalty_weight * (constraint_penalty + neg_penalty) + c * slack_penalty\n",
    "        return loss\n",
    "    \n",
    "    def fit(self, data, tau_diff, epochs=700, verbose=True, sigma = 0.1):\n",
    "        x = torch.tensor(data.x, dtype=torch.float32)\n",
    "        y = torch.tensor(data.y.squeeze(), dtype=torch.float32)\n",
    "        t = torch.tensor(data.t.squeeze(), dtype=torch.float32)\n",
    "        tau_diff = torch.tensor(tau_diff, dtype=torch.float32)\n",
    "        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3,  \n",
    "                                    weight_decay=0)  \n",
    "        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=3e-3,  \n",
    "                                                      total_steps=epochs)\n",
    "        for epoch in range(epochs):\n",
    "            optimizer.zero_grad()\n",
    "            output, gamma1 = self.forward(x)\n",
    "            ite = output[:, 1:2] - output[:, 0:1]\n",
    "            loss1 = self.wls_loss(output, y, t, tau_diff, gamma1)\n",
    "            loss2 = self.cross_entropy_loss(output, t)\n",
    "            loss3 = self.constraint_loss(output, t, tau_diff, gamma1)\n",
    "            total_loss = 1 * loss1 + 0.5 * loss2 + 1 * loss3\n",
    "            total_loss.backward()\n",
    "            torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)\n",
    "            optimizer.step()\n",
    "            scheduler.step()\n",
    "            if verbose and (epoch == epochs - 1):\n",
    "                pass\n",
    "                print(f\"Epoch {epoch}: loss1 = {loss1.item():.4f}, loss2 = {loss2.item():.4f}, loss3 = {loss3.item():.4f}\")\n",
    "\n",
    "    def predict(self, data):\n",
    "        x = torch.tensor(data.x, dtype=torch.float32)\n",
    "        self.eval()\n",
    "        with torch.no_grad():\n",
    "            output, _ = self.forward(x)\n",
    "            y0_pred = output[:, 0]\n",
    "            y1_pred = output[:, 1]\n",
    "            e_pred = output[:, 2]\n",
    "        return y0_pred, y1_pred, e_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb47e0ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "#re-implementation of Causal Forest and the config of TARNet\n",
    "from econml.dml import CausalForestDML\n",
    "from sklearn.linear_model import LassoCV\n",
    "from cauml.inference import TarNet\n",
    "class CausalForest():\n",
    "    def __init__(self):\n",
    "        pass\n",
    "    def fit(self, data):\n",
    "        if not isinstance(data.x, np.ndarray):\n",
    "            self.X = data.x.cpu().numpy()\n",
    "            self.Y = data.y.cpu().numpy().squeeze()\n",
    "            self.T = data.t.cpu().numpy().squeeze()\n",
    "        else:\n",
    "            self.X = data.x\n",
    "            self.Y = data.y.squeeze()\n",
    "            self.T = data.t.squeeze()\n",
    "        self.est = CausalForestDML(\n",
    "            criterion='het',\n",
    "            n_estimators=60,\n",
    "            min_samples_leaf=10,\n",
    "            max_depth=5,\n",
    "            max_samples=0.5,\n",
    "            discrete_treatment=False,\n",
    "            model_t=LassoCV(),\n",
    "            model_y=LassoCV(),\n",
    "            random_state=0\n",
    "        )\n",
    "        self.est.fit(self.Y, self.T, X=self.X)\n",
    "    def ITE(self, data):\n",
    "        if not isinstance(data.x, np.ndarray):\n",
    "            self.X_test = data.x.cpu().numpy()\n",
    "        else:\n",
    "            self.X_test = data.x\n",
    "        effect = self.est.effect(self.X_test)\n",
    "        return np.zeros(len(self.X_test)), effect, np.zeros(len(self.X_test))\n",
    "    \n",
    "config1 = {\n",
    "                        'methodName': 'TarNet',\n",
    "                        'device': 'cpu',\n",
    "                        'epochs': 25,\n",
    "                        'verbose': 10,\n",
    "                        'batch_size': 500,\n",
    "                        'shuffle': 1,\n",
    "                        'wd': 5e-3,\n",
    "                        'tr_wd': 5e-3,\n",
    "                        'momentum': 0.9,\n",
    "                        'cfg_density': [(25, 50, 1, 'relu'), (50, 50, 1, 'relu')],\n",
    "                        'num_grid': 10,\n",
    "                        'cfg': [(50, 50, 1, 'relu'), (50, 1, 1, 'id')],\n",
    "                        'isenhance': 0,\n",
    "                        'isTargetReg': 1, #1\n",
    "                        'init_lr': 0.01, #0.01\n",
    "                        'alpha': 0.5,\n",
    "                        'tr_init_lr': 0.01,\n",
    "                        'beta': 1.0,\n",
    "                        'tr_knots': list(np.arange(0.1, 1, 0.1)),\n",
    "                        'tr_degree': 2,\n",
    "                        'seed': 2022,   \n",
    "                        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b545f4cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "datas = []\n",
    "pehe_test = []\n",
    "pehe_train = []\n",
    "ate_test = []\n",
    "ate_train = []\n",
    "cover1 = 0\n",
    "georgeen = 0\n",
    "\n",
    "for i in range(10):\n",
    "    datas.append(CausalDataset(path=f'~/code/IHDP/66_34_0/{i}/')) #load data\n",
    "s = 0\n",
    "for data in datas:\n",
    "    A_train = data.train.t.squeeze()\n",
    "    Y_train = data.train.y.squeeze()\n",
    "    A_test = data.test.t.squeeze()\n",
    "    Y_test = data.test.y.squeeze()\n",
    "    X_train = data.train.x\n",
    "    X_test = data.test.x\n",
    "    mu0_test_true = data.test.m[:, 0].squeeze()\n",
    "    mu1_test_true = data.test.m[:, 1].squeeze()\n",
    "    mu0_train_true = data.train.m[:, 0].squeeze()\n",
    "    mu1_train_true = data.train.m[:, 1].squeeze()\n",
    "    data.tensor()\n",
    "    data.to('cpu')\n",
    "\n",
    "    model_HTE1 = TarNet()\n",
    "    model_HTE1.set_Configuration(config1)\n",
    "    model_HTE1.fit(data.train)\n",
    "    a, b, _ = model_HTE1.ITE(data.train)\n",
    "    c, d, _ = model_HTE1.ITE(data.test)\n",
    "    tau1_train = (b - a).squeeze()\n",
    "    tau1_test = (d - c).squeeze()\n",
    "\n",
    "    model_HTE2 = CausalForest()\n",
    "    model_HTE2.fit(data.train)\n",
    "    c, d, _ = model_HTE2.ITE(data.train)\n",
    "    a, b, _ = model_HTE2.ITE(data.test)\n",
    "    tau2_test = (b-a).squeeze()\n",
    "    tau2_train = (d - c).squeeze()\n",
    "    tau_diff_train = tau1_train - tau2_train\n",
    "    tau_diff_test = tau1_test - tau2_test\n",
    "    data.to('cpu')\n",
    "    data.numpy()\n",
    "\n",
    "    q_low = np.quantile(tau_diff_train, 0.05)\n",
    "    q_high = np.quantile(tau_diff_train, 0.95)\n",
    "    tau_diff_clipped = np.clip(tau_diff_train, q_low, q_high)\n",
    "    max_abs = np.max(np.abs(tau_diff_clipped))\n",
    "    if max_abs < 1e-6:\n",
    "        tau_diff_train_scaled = np.zeros_like(tau_diff_clipped)\n",
    "    else:\n",
    "        tau_diff_train_scaled = 0.25 * tau_diff_clipped / max_abs\n",
    "    reals_test = (tau1_test - mu1_test_true + mu0_test_true)**2 - (tau2_test - mu1_test_true + mu0_test_true)**2\n",
    "    reals_train = (tau1_train - mu1_train_true + mu0_train_true)**2 - (tau2_train - mu1_train_true + mu0_train_true)**2\n",
    "    reals = np.concatenate([reals_test, reals_train])\n",
    "    \n",
    "    model = SharkNet(data.train.x.shape[1])\n",
    "    model.fit(data.train, tau_diff_train_scaled)\n",
    "    mu0_test, mu1_test, e_test = model.predict(data.test)\n",
    "    mu0_test = mu0_test.cpu().numpy().squeeze()\n",
    "    mu1_test = mu1_test.cpu().numpy().squeeze()\n",
    "    e_test = e_test.cpu().numpy().squeeze()\n",
    "    print(s)\n",
    "    s += 1\n",
    "    print('PEHE test:', np.sqrt(np.mean((mu1_test - mu0_test -mu1_test_true + mu0_test_true)**2)))\n",
    "    print('ATE test:', np.abs(np.mean(mu1_test-mu0_test-mu1_test_true+mu0_test_true)))\n",
    "    pehe_test.append(np.sqrt(np.mean((mu1_test - mu0_test -mu1_test_true + mu0_test_true)**2)))\n",
    "    ate_test.append(np.abs(np.mean(mu1_test-mu0_test-mu1_test_true+mu0_test_true)))\n",
    "    mu0_train, mu1_train, e_train = model.predict(data.train)\n",
    "    mu0_train = mu0_train.cpu().numpy().squeeze()\n",
    "    mu1_train = mu1_train.cpu().numpy().squeeze()\n",
    "    e_train = e_train.cpu().numpy().squeeze()\n",
    "    print('PEHE train:', np.sqrt(np.mean((mu1_train - mu0_train -mu1_train_true + mu0_train_true)**2)))\n",
    "    print('ATE train:', np.abs(np.mean(mu1_train-mu0_train-mu1_train_true+mu0_train_true)))\n",
    "    pehe_train.append(np.sqrt(np.mean((mu1_train - mu0_train -mu1_train_true + mu0_train_true)**2)))\n",
    "    ate_train.append(np.abs(np.mean(mu1_train-mu0_train-mu1_train_true+mu0_train_true)))\n",
    "    \n",
    "    e_test = np.clip(e_test, 0.1, 1- 0.1)\n",
    "    delta_test = caculate_error(tau1_test, tau2_test, A_test, Y_test, mu0_test, mu1_test, e_test)\n",
    "    print(\"the real delta:\", np.mean(reals))\n",
    "    print(\"the estimated delta:\", np.mean(delta_test))\n",
    "    alpha = 0.1\n",
    "    z = norm.ppf(1 - alpha / 2)\n",
    "    lower = delta_test.mean() - z * (delta_test.var() / len(delta_test)) ** 0.5\n",
    "    upper = delta_test.mean() + z * (delta_test.var() / len(delta_test)) ** 0.5\n",
    "    print('delta--sample variance:', delta_test.var())\n",
    "    print('confidence interval:', lower, upper)\n",
    "    if lower <= np.mean(reals) and upper >= np.mean(reals):\n",
    "        cover1 += 1\n",
    "    if lower > 0 and np.mean(reals) >0:\n",
    "        georgeen += 1 \n",
    "    if upper < 0 and np.mean(reals) < 0 :\n",
    "        georgeen += 1\n",
    "        \n",
    "pehe_test = np.array(pehe_test)\n",
    "pehe_train = np.array(pehe_train)\n",
    "ate_test = np.array(ate_test)\n",
    "ate_train = np.array(ate_train)\n",
    "\n",
    "for arr, name in zip([pehe_test, pehe_train, ate_test, ate_train], \n",
    "                     ['pehe_test', 'pehe_train', 'ate_test', 'ate_train']):\n",
    "    print(f\"{name} - Mean: {arr.mean()}, Standard Deviation: {arr.std()}\")\n",
    "\n",
    "print(\"coverage:\", cover1 / len(datas), \"selection accuracy: \", georgeen / len(datas))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tf-metal",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
