{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "21f27749-00ed-47e0-9912-80aedc092dcc",
   "metadata": {},
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'generate_experiment_data'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 11\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mkernel_utils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m RBF, ColumnwiseRBF\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mcausal_models\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcausal_learning\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m KernelATE, KernelATT\n\u001b[0;32m---> 11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mgenerate_experiment_data\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m generate_train_jobcorp\n\u001b[1;32m     13\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m config\n\u001b[1;32m     14\u001b[0m config\u001b[38;5;241m.\u001b[39mupdate(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mjax_enable_x64\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mTrue\u001b[39;00m)\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'generate_experiment_data'"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from utils.kernel_utils import RBF, ColumnwiseRBF\n",
    "from causal_models.causal_learning import KernelATE, KernelATT\n",
    "from generate_experiment_data import generate_train_jobcorp\n",
    "\n",
    "from jax import config\n",
    "config.update(\"jax_enable_x64\", True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bb77d92-0d97-4035-a337-adb91d6254b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = '../../data/JCdata.csv'\n",
    "\n",
    "U, Y, A = generate_train_jobcorp(data_path)\n",
    "U = jnp.array(U, dtype = jnp.float64)\n",
    "Y = jnp.array(Y, dtype = jnp.float64)\n",
    "A = jnp.array(A, dtype = jnp.float64)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "adefa0a9-1445-4d78-b05c-26e94c9785b6",
   "metadata": {},
   "source": [
    "# Average Treatment Effect Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee13bce2-f887-44ed-bb73-8fe321490884",
   "metadata": {},
   "outputs": [],
   "source": [
    "kernel_X = RBF(use_length_scale_heuristic = True, use_jit_call = True)\n",
    "kernel_A = RBF(use_length_scale_heuristic = True, )\n",
    "optimize_regularization_parameters = True\n",
    "lambda_optimization_range = (1e-9, 1.0)\n",
    "regularization_grid_points = 150\n",
    "\n",
    "KernelATE_model = KernelATE(kernel_X = kernel_X,\n",
    "                            kernel_A = kernel_A, \n",
    "                            optimize_regularization_parameters = optimize_regularization_parameters, \n",
    "                            lambda_optimization_range = lambda_optimization_range, \n",
    "                            regularization_grid_points = regularization_grid_points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48d38e93-67d5-4595-afea-7b0585e9d38d",
   "metadata": {},
   "outputs": [],
   "source": [
    "KernelATE_model.fit((A, U), Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62d4edfd-5092-4257-bd93-26ee163c69da",
   "metadata": {},
   "outputs": [],
   "source": [
    "do_A = jnp.linspace(40, 2000, 1000)[:, jnp.newaxis]\n",
    "A_linspace = do_A\n",
    "\n",
    "f_struct_pred = KernelATE_model.predict(do_A.reshape(-1, 1))\n",
    "\n",
    "plt.plot(do_A, f_struct_pred, label = \"Pred\")\n",
    "# plt.plot(A_linspace, EY_do_A, linewidth = 4, color = \"red\", alpha=0.7, linestyle = \"dashed\", label = \"f-struct\")\n",
    "plt.grid()\n",
    "plt.xlabel(\"a (treatment)\")\n",
    "plt.ylabel(r\"$Y^{(a)}$ (treatment effect)\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6efe120-e1e3-4527-b7a4-c2fe1c67cf2b",
   "metadata": {},
   "source": [
    "# Average Treatment on the Treated Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aed51b52-6de7-4701-9fc4-d43eff5ab1a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "RBF_Kernel_A_ = RBF(use_length_scale_heuristic = True, use_jit_call = True)\n",
    "RBF_Kernel_X_ = RBF(use_length_scale_heuristic = True, use_jit_call = True)\n",
    "optimize_regularization_parameters = True\n",
    "lambda_optimization_range = (1e-9, 1.0)\n",
    "regularization_grid_points = 150\n",
    "\n",
    "model_KernelATT = KernelATT(\n",
    "    kernel_A = RBF_Kernel_A_,\n",
    "    kernel_X = RBF_Kernel_X_,\n",
    "    optimize_regularization_parameters = optimize_regularization_parameters, \n",
    "    lambda_optimization_range = lambda_optimization_range, \n",
    "    regularization_grid_points = regularization_grid_points\n",
    ")\n",
    "\n",
    "model_KernelATT.fit((A, U), Y)\n",
    "\n",
    "a_prime = jnp.array([2000])\n",
    "f_struct_pred_katt = model_KernelATT.predict(do_A, a_prime)\n",
    "\n",
    "plt.plot(do_A, f_struct_pred_katt, label = \"Pred\")\n",
    "# plt.plot(A_linspace, EY_do_A, linewidth = 4, color = \"red\", alpha=0.7, linestyle = \"dashed\", label = \"f-struct\")\n",
    "plt.grid()\n",
    "plt.xlabel(\"a (treatment)\")\n",
    "plt.ylabel(r\"$Y^{(a)}$ (treatment effect)\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
