{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fb1ea1b2-6e38-4dd8-b355-cbd875a10033",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from utils.kernel_utils import BinaryKernel, RBF, ColumnwiseRBF\n",
    "from causal_models.causal_learning import KernelCATE\n",
    "from causal_models.proxy_causal_learning import KernelAlternativeProxyCATE\n",
    "from utils.experimental_data_functions import generate_synthetic_CATE_data\n",
    "from utils.ml_utils import data_transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3f34a5f7-9c6e-4df1-b0ab-0ddd787adc0f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpu\n"
     ]
    }
   ],
   "source": [
    "from jax.lib import xla_bridge\n",
    "print(xla_bridge.get_backend().platform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dfc2f8d6-1afa-444e-96f6-979941777d5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "sigma = 0.5\n",
    "uniform_noise_upper_bound = 0.5,\n",
    "uniform_noise_lower_bound = -0.5,\n",
    "seed = np.random.randint(1000000)\n",
    "U, W, Z, V, A, Y, covariate_v_test, do_A, EY_do_A_CATE = generate_synthetic_CATE_data(1000, \n",
    "                                                                                      sigma,\n",
    "                                                                                      uniform_noise_upper_bound,\n",
    "                                                                                      uniform_noise_lower_bound,\n",
    "                                                                                      seed)\n",
    "\n",
    "# Z, W = U + np.random.normal(size = U.shape), U + np.random.normal(size = U.shape)\n",
    "U, Z, W, V, A, Y = jnp.array(U), jnp.array(Z), jnp.array(W), jnp.array(V), jnp.array(A), jnp.array(Y)\n",
    "\n",
    "covariate_v_test, do_A, EY_do_A_CATE = jnp.array(covariate_v_test), jnp.array(do_A), jnp.array(EY_do_A_CATE)\n",
    "\n",
    "A_transformed, A_transformer = data_transform(A)\n",
    "U_transformed, U_transformer = data_transform(U)\n",
    "Z_transformed, Z_transformer = data_transform(Z)\n",
    "W_transformed, W_transformer = data_transform(W)\n",
    "V_transformed, V_transformer = data_transform(V)\n",
    "Y_transformed, Y_transformer = data_transform(Y)\n",
    "\n",
    "data_size = A_transformed.shape[0]\n",
    "A_transformed = jnp.array(A_transformed).reshape(data_size, -1)\n",
    "U_transformed = jnp.array(U_transformed).reshape(data_size, -1)\n",
    "Z_transformed = jnp.array(Z_transformed).reshape(data_size, -1)\n",
    "W_transformed = jnp.array(W_transformed).reshape(data_size, -1)\n",
    "V_transformed = jnp.array(V_transformed).reshape(data_size, -1)\n",
    "Y_transformed = jnp.array(Y_transformed).reshape(data_size, -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1e392638-f0b8-452f-9553-48535e07bfc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "RBF_Kernel_Z = RBF(use_length_scale_heuristic = True, use_jit_call = True)\n",
    "RBF_Kernel_V = RBF(use_length_scale_heuristic = True, use_jit_call = True)\n",
    "RBF_Kernel_W = ColumnwiseRBF(use_length_scale_heuristic = True, use_jit_call = True)\n",
    "RBF_Kernel_A = BinaryKernel()\n",
    "\n",
    "lambda_ = 1e-2\n",
    "eta = 2*1e-3\n",
    "lambda2_ = 1e-3\n",
    "zeta = 0.1\n",
    "optimize_lambda_parameters = True\n",
    "optimize_eta_parameter = True\n",
    "optimize_zeta_parameter = True\n",
    "lambda_optimization_range = (1e-5, 1.0)\n",
    "zeta_optimization_range = (1e-5, 1.0)\n",
    "eta_optimization_range = (1e-5, 1.0)\n",
    "stage1_perc = 0.5\n",
    "regularization_grid_points = 25\n",
    "\n",
    "\n",
    "model = KernelAlternativeProxyCATE(\n",
    "                                     kernel_A = RBF_Kernel_A,\n",
    "                                     kernel_W = RBF_Kernel_W, \n",
    "                                     kernel_Z = RBF_Kernel_Z,\n",
    "                                     kernel_V = RBF_Kernel_V,\n",
    "                                     lambda_ = lambda_,\n",
    "                                     eta = eta,\n",
    "                                     lambda2_ = lambda2_,\n",
    "                                     zeta = zeta,\n",
    "                                     optimize_lambda_parameters = optimize_lambda_parameters,\n",
    "                                     optimize_eta_parameter = optimize_eta_parameter,\n",
    "                                     optimize_zeta_parameter = optimize_zeta_parameter,\n",
    "                                     lambda_optimization_range = lambda_optimization_range,\n",
    "                                     zeta_optimization_range = zeta_optimization_range,\n",
    "                                     eta_optimization_range = eta_optimization_range,\n",
    "                                     stage1_perc = stage1_perc,\n",
    "                                     regularization_grid_points = regularization_grid_points, \n",
    "                                     label_variance_in_eta_opt = 1.,\n",
    "                                  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e8079aec-08d1-4d10-bd21-f35a667f19bf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Structured function test set MSE: 0.0020104863002525773\n",
      "Structured function test set MAE: 0.04202315863354388\n"
     ]
    }
   ],
   "source": [
    "model.fit((A_transformed, W_transformed, Z_transformed, V_transformed), Y_transformed)\n",
    "\n",
    "do_A_size = do_A.shape[0]\n",
    "covariate_v_shape = covariate_v_test.shape[0]\n",
    "do_A_transformed = (A_transformer.transform(do_A)).reshape(do_A_size, -1)\n",
    "covariate_v_transformed = (V_transformer.transform(covariate_v_test)).reshape(covariate_v_shape, -1)\n",
    "\n",
    "f_struct_pred_transformed = model.predict(do_A_transformed, covariate_v_transformed)\n",
    "f_struct_pred = Y_transformer.inverse_transform(f_struct_pred_transformed.reshape(-1, 1))\n",
    "\n",
    "structured_pred_mse = (np.mean((f_struct_pred.reshape(-1, 1) - EY_do_A_CATE.reshape(-1, 1)) ** 2))\n",
    "structured_pred_mae = (np.mean(np.abs(f_struct_pred.reshape(-1, 1) - EY_do_A_CATE.reshape(-1, 1))))\n",
    "print(\"Structured function test set MSE: {}\".format(structured_pred_mse))\n",
    "print(\"Structured function test set MAE: {}\".format(structured_pred_mae))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "76c95d2c-96a8-47ca-9f44-bfb3223f4ff2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[ 0.02034653],\n",
       "        [-0.06939048],\n",
       "        [ 0.02836105],\n",
       "        [ 0.27858656],\n",
       "        [ 0.39850787]]),\n",
       " Array([[-0.03136],\n",
       "        [-0.10368],\n",
       "        [ 0.     ],\n",
       "        [ 0.25088],\n",
       "        [ 0.46656]], dtype=float64))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f_struct_pred, EY_do_A_CATE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0108a3c-4001-4653-a10b-a023308ca29c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
