{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "671c2de2-e29f-4d1a-adf4-16185dabfe9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from utils.kernel_utils import ColumnwiseRBF, RBF, FourthOrderGaussianKernel, BinaryKernel\n",
    "from causal_models.proxy_causal_learning import KernelNegativeControlCATE\n",
    "from utils.ml_utils import data_transform\n",
    "from utils.experimental_data_functions import generate_synthetic_CATE_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2f257861-10fb-4236-a822-fb71b2160b75",
   "metadata": {},
   "outputs": [],
   "source": [
    "sigma = 1.\n",
    "uniform_noise_upper_bound = 1.\n",
    "uniform_noise_lower_bound = -1.\n",
    "seed = np.random.randint(1000000)\n",
    "\n",
    "U, W, Z, V, A, Y, covariate_v_test, do_A, EY_do_A_CATE = generate_synthetic_CATE_data(1000, \n",
    "                                                                                      sigma,\n",
    "                                                                                      uniform_noise_upper_bound,\n",
    "                                                                                      uniform_noise_lower_bound,\n",
    "                                                                                      seed = seed)\n",
    "\n",
    "U, W, Z, V, A, Y = jnp.array(U), jnp.array(W), jnp.array(Z), jnp.array(V), jnp.array(A), jnp.array(Y)\n",
    "covariate_v_test, do_A, EY_do_A_CATE = jnp.array(covariate_v_test), jnp.array(do_A), jnp.array(EY_do_A_CATE)\n",
    "\n",
    "A_transformed, A_transformer = data_transform(A)\n",
    "Z_transformed, Z_transformer = data_transform(Z)\n",
    "W_transformed, W_transformer = data_transform(W)\n",
    "V_transformed, V_transformer = data_transform(V)\n",
    "Y_transformed, Y_transformer = data_transform(Y)\n",
    "\n",
    "data_size = A_transformed.shape[0]\n",
    "A_transformed = jnp.array(A_transformed).reshape(data_size, -1)\n",
    "Z_transformed = jnp.array(Z_transformed).reshape(data_size, -1)\n",
    "W_transformed = jnp.array(W_transformed).reshape(data_size, -1)\n",
    "V_transformed = jnp.array(V_transformed).reshape(data_size, -1)\n",
    "Y_transformed = jnp.array(Y_transformed).reshape(data_size, -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d57dd641-6e9c-4114-a87b-77ee393c788d",
   "metadata": {},
   "outputs": [],
   "source": [
    "RBF_Kernel_Z = RBF(use_length_scale_heuristic = True, use_jit_call = True)\n",
    "RBF_Kernel_W = RBF(use_length_scale_heuristic = True, use_jit_call = True)\n",
    "RBF_Kernel_A = BinaryKernel()\n",
    "# RBF_Kernel_A = RBF(length_scale = 0.08, use_median_length_scale_heuristic = False)\n",
    "RBF_Kernel_V = RBF(use_length_scale_heuristic = True, use_jit_call = True)\n",
    "\n",
    "lambda_ = 1e-3\n",
    "zeta = 1*1e-3\n",
    "lambda2_ = 1e-3\n",
    "optimize_regularization_parameters = True\n",
    "lambda_optimization_range = (1e-5, 1.0)\n",
    "zeta_optimization_range = (1e-5, 1.0)\n",
    "stage1_perc = 0.5\n",
    "regularization_grid_points = 25\n",
    "\n",
    "model = KernelNegativeControlCATE(\n",
    "                                 kernel_A = RBF_Kernel_A,\n",
    "                                 kernel_W = RBF_Kernel_W,\n",
    "                                 kernel_Z = RBF_Kernel_Z,\n",
    "                                 kernel_V = RBF_Kernel_V,\n",
    "                                 lambda_ = lambda_,\n",
    "                                 zeta = zeta, \n",
    "                                 lambda2_ = lambda2_,\n",
    "                                 optimize_regularization_parameters = optimize_regularization_parameters,\n",
    "                                 lambda_optimization_range = lambda_optimization_range,\n",
    "                                 zeta_optimization_range = zeta_optimization_range,\n",
    "                                )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a2126811-1eeb-43d3-98c6-60be32e07722",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Structured function test set MSE: 0.0010774624006446947\n",
      "Structured function test set MAE: 0.02691996025730274\n"
     ]
    }
   ],
   "source": [
    "model.fit((A, W_transformed, Z_transformed, V_transformed), Y_transformed)\n",
    "\n",
    "do_A_size = do_A.shape[0]\n",
    "covariate_v_shape = covariate_v_test.shape[0]\n",
    "do_A_transformed = (A_transformer.transform(do_A)).reshape(do_A_size, -1)\n",
    "covariate_v_transformed = (V_transformer.transform(covariate_v_test)).reshape(covariate_v_shape, -1)\n",
    "\n",
    "f_struct_pred_transformed = model.predict(do_A, covariate_v_transformed)\n",
    "f_struct_pred = Y_transformer.inverse_transform(f_struct_pred_transformed.reshape(do_A_size, -1)).reshape(do_A_size, -1)\n",
    "\n",
    "structured_pred_mse = (np.mean((f_struct_pred.reshape(-1, 1) - EY_do_A_CATE.reshape(-1, 1)) ** 2))\n",
    "structured_pred_mae = (np.mean(np.abs(f_struct_pred.reshape(-1, 1) - EY_do_A_CATE.reshape(-1, 1))))\n",
    "print(\"Structured function test set MSE: {}\".format(structured_pred_mse))\n",
    "print(\"Structured function test set MAE: {}\".format(structured_pred_mae))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f8288b52-aa60-4a47-a774-8be31a92127b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Array(0.00564518, dtype=float64),\n",
       " Array(0.00609867, dtype=float64),\n",
       " Array(0.00564518, dtype=float64))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.lambda2_, model.lambda_, model.zeta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0c841cd8-3738-4857-b54c-e158a19cea65",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.06650785],\n",
       "       [ 0.06246303],\n",
       "       [ 0.05808673],\n",
       "       [ 0.05337154],\n",
       "       [ 0.04831491],\n",
       "       [ 0.04291961],\n",
       "       [ 0.03719424],\n",
       "       [ 0.0311535 ],\n",
       "       [ 0.02481836],\n",
       "       [ 0.01821606],\n",
       "       [ 0.01137998],\n",
       "       [ 0.00434929],\n",
       "       [-0.0028315 ],\n",
       "       [-0.01011317],\n",
       "       [-0.01744251],\n",
       "       [-0.02476319],\n",
       "       [-0.03201663],\n",
       "       [-0.03914307],\n",
       "       [-0.04608245],\n",
       "       [-0.05277554],\n",
       "       [-0.05916482],\n",
       "       [-0.06519544],\n",
       "       [-0.07081607],\n",
       "       [-0.07597968],\n",
       "       [-0.08064413],\n",
       "       [-0.08477275],\n",
       "       [-0.08833472],\n",
       "       [-0.09130532],\n",
       "       [-0.09366609],\n",
       "       [-0.09540479],\n",
       "       [-0.0965153 ],\n",
       "       [-0.09699732],\n",
       "       [-0.09685601],\n",
       "       [-0.09610154],\n",
       "       [-0.09474849],\n",
       "       [-0.09281522],\n",
       "       [-0.09032318],\n",
       "       [-0.08729617],\n",
       "       [-0.08375953],\n",
       "       [-0.07973946],\n",
       "       [-0.07526216],\n",
       "       [-0.07035319],\n",
       "       [-0.06503672],\n",
       "       [-0.05933494],\n",
       "       [-0.05326754],\n",
       "       [-0.0468512 ],\n",
       "       [-0.04009927],\n",
       "       [-0.03302149],\n",
       "       [-0.02562385],\n",
       "       [-0.01790854],\n",
       "       [-0.00987399],\n",
       "       [-0.00151501],\n",
       "       [ 0.00717694],\n",
       "       [ 0.01621344],\n",
       "       [ 0.02560878],\n",
       "       [ 0.03537941],\n",
       "       [ 0.04554356],\n",
       "       [ 0.05612061],\n",
       "       [ 0.06713054],\n",
       "       [ 0.07859335],\n",
       "       [ 0.09052837],\n",
       "       [ 0.10295364],\n",
       "       [ 0.11588512],\n",
       "       [ 0.12933602],\n",
       "       [ 0.14331591],\n",
       "       [ 0.15782994],\n",
       "       [ 0.17287794],\n",
       "       [ 0.18845341],\n",
       "       [ 0.20454262],\n",
       "       [ 0.22112356],\n",
       "       [ 0.23816498],\n",
       "       [ 0.25562543],\n",
       "       [ 0.27345237],\n",
       "       [ 0.29158142],\n",
       "       [ 0.3099358 ],\n",
       "       [ 0.32842593],\n",
       "       [ 0.3469494 ],\n",
       "       [ 0.36539121],\n",
       "       [ 0.38362442],\n",
       "       [ 0.40151117],\n",
       "       [ 0.41890419],\n",
       "       [ 0.43564867],\n",
       "       [ 0.45158463],\n",
       "       [ 0.46654966],\n",
       "       [ 0.48038192],\n",
       "       [ 0.4929235 ],\n",
       "       [ 0.50402389],\n",
       "       [ 0.51354355],\n",
       "       [ 0.52135738],\n",
       "       [ 0.52735807],\n",
       "       [ 0.53145917],\n",
       "       [ 0.53359763],\n",
       "       [ 0.53373601],\n",
       "       [ 0.53186395],\n",
       "       [ 0.52799901],\n",
       "       [ 0.52218685],\n",
       "       [ 0.51450058],\n",
       "       [ 0.50503951],\n",
       "       [ 0.49392712],\n",
       "       [ 0.48130848]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f_struct_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9331fe9d-b83f-4549-b9df-a0d8f7ac1c39",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[-0.00000000e+00],\n",
       "       [-4.43822993e-04],\n",
       "       [-1.71519250e-03],\n",
       "       [-3.72653725e-03],\n",
       "       [-6.39368288e-03],\n",
       "       [-9.63580140e-03],\n",
       "       [-1.33753608e-02],\n",
       "       [-1.75380744e-02],\n",
       "       [-2.20528507e-02],\n",
       "       [-2.68517426e-02],\n",
       "       [-3.18698971e-02],\n",
       "       [-3.70455046e-02],\n",
       "       [-4.23197488e-02],\n",
       "       [-4.76367559e-02],\n",
       "       [-5.29435443e-02],\n",
       "       [-5.81899740e-02],\n",
       "       [-6.33286962e-02],\n",
       "       [-6.83151030e-02],\n",
       "       [-7.31072766e-02],\n",
       "       [-7.76659388e-02],\n",
       "       [-8.19544011e-02],\n",
       "       [-8.59385135e-02],\n",
       "       [-8.95866145e-02],\n",
       "       [-9.28694805e-02],\n",
       "       [-9.57602751e-02],\n",
       "       [-9.82344992e-02],\n",
       "       [-1.00269940e-01],\n",
       "       [-1.01846620e-01],\n",
       "       [-1.02946749e-01],\n",
       "       [-1.03554669e-01],\n",
       "       [-1.03656810e-01],\n",
       "       [-1.03241634e-01],\n",
       "       [-1.02299586e-01],\n",
       "       [-1.00823045e-01],\n",
       "       [-9.88062744e-02],\n",
       "       [-9.62453669e-02],\n",
       "       [-9.31381985e-02],\n",
       "       [-8.94843761e-02],\n",
       "       [-8.52851874e-02],\n",
       "       [-8.05435504e-02],\n",
       "       [-7.52639631e-02],\n",
       "       [-6.94524525e-02],\n",
       "       [-6.31165247e-02],\n",
       "       [-5.62651143e-02],\n",
       "       [-4.89085336e-02],\n",
       "       [-4.10584225e-02],\n",
       "       [-3.27276978e-02],\n",
       "       [-2.39305030e-02],\n",
       "       [-1.46821574e-02],\n",
       "       [-4.99910598e-03],\n",
       "       [ 5.10113118e-03],\n",
       "       [ 1.56000094e-02],\n",
       "       [ 2.64780099e-02],\n",
       "       [ 3.77146902e-02],\n",
       "       [ 4.92887346e-02],\n",
       "       [ 6.11780047e-02],\n",
       "       [ 7.33595898e-02],\n",
       "       [ 8.58098576e-02],\n",
       "       [ 9.85045042e-02],\n",
       "       [ 1.11418605e-01],\n",
       "       [ 1.24526666e-01],\n",
       "       [ 1.37802671e-01],\n",
       "       [ 1.51220136e-01],\n",
       "       [ 1.64752159e-01],\n",
       "       [ 1.78371468e-01],\n",
       "       [ 1.92050473e-01],\n",
       "       [ 2.05761317e-01],\n",
       "       [ 2.19475926e-01],\n",
       "       [ 2.33166059e-01],\n",
       "       [ 2.46803360e-01],\n",
       "       [ 2.60359407e-01],\n",
       "       [ 2.73805760e-01],\n",
       "       [ 2.87114020e-01],\n",
       "       [ 3.00255868e-01],\n",
       "       [ 3.13203126e-01],\n",
       "       [ 3.25927799e-01],\n",
       "       [ 3.38402132e-01],\n",
       "       [ 3.50598655e-01],\n",
       "       [ 3.62490240e-01],\n",
       "       [ 3.74050144e-01],\n",
       "       [ 3.85252064e-01],\n",
       "       [ 3.96070189e-01],\n",
       "       [ 4.06479244e-01],\n",
       "       [ 4.16454549e-01],\n",
       "       [ 4.25972061e-01],\n",
       "       [ 4.35008430e-01],\n",
       "       [ 4.43541050e-01],\n",
       "       [ 4.51548104e-01],\n",
       "       [ 4.59008620e-01],\n",
       "       [ 4.65902519e-01],\n",
       "       [ 4.72210666e-01],\n",
       "       [ 4.77914920e-01],\n",
       "       [ 4.82998185e-01],\n",
       "       [ 4.87444460e-01],\n",
       "       [ 4.91238890e-01],\n",
       "       [ 4.94367816e-01],\n",
       "       [ 4.96818826e-01],\n",
       "       [ 4.98580805e-01],\n",
       "       [ 4.99643986e-01],\n",
       "       [ 5.00000000e-01]], dtype=float64)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "EY_do_A_CATE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa54bcb0-192a-4b69-a8f7-b1beaa20e9e9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
