{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "65eda12e-7ce7-4129-985a-ff2e508c9aed",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from utils.kernel_utils import BinaryKernel, RBF, ColumnwiseRBF\n",
    "from causal_models.causal_learning import KernelCATE\n",
    "from causal_models.proxy_causal_learning import KernelProxyVariableCATE\n",
    "from utils.experimental_data_functions import generate_synthetic_CATE_data\n",
    "from utils.ml_utils import data_transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "83f0184a-a8a7-4fe8-b6f6-7c8040999072",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpu\n"
     ]
    }
   ],
   "source": [
    "from jax.lib import xla_bridge\n",
    "print(xla_bridge.get_backend().platform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bf507347-5bd4-4988-8058-6c71b1a9f627",
   "metadata": {},
   "outputs": [],
   "source": [
    "sigma = 0.5\n",
    "uniform_noise_upper_bound = 0.5,\n",
    "uniform_noise_lower_bound = -0.5,\n",
    "seed = np.random.randint(1000000)\n",
    "U, W, Z, V, A, Y, covariate_v_test, do_A, EY_do_A_CATE = generate_synthetic_CATE_data(1000, \n",
    "                                                                                      sigma,\n",
    "                                                                                      uniform_noise_upper_bound,\n",
    "                                                                                      uniform_noise_lower_bound,\n",
    "                                                                                      seed)\n",
    "\n",
    "# Z, W = U + np.random.normal(size = U.shape), U + np.random.normal(size = U.shape)\n",
    "U, Z, W, V, A, Y = jnp.array(U), jnp.array(Z), jnp.array(W), jnp.array(V), jnp.array(A), jnp.array(Y)\n",
    "\n",
    "covariate_v_test, do_A, EY_do_A_CATE = jnp.array(covariate_v_test), jnp.array(do_A), jnp.array(EY_do_A_CATE)\n",
    "\n",
    "A_transformed, A_transformer = data_transform(A)\n",
    "U_transformed, U_transformer = data_transform(U)\n",
    "Z_transformed, Z_transformer = data_transform(Z)\n",
    "W_transformed, W_transformer = data_transform(W)\n",
    "V_transformed, V_transformer = data_transform(V)\n",
    "Y_transformed, Y_transformer = data_transform(Y)\n",
    "\n",
    "data_size = A_transformed.shape[0]\n",
    "A_transformed = jnp.array(A_transformed).reshape(data_size, -1)\n",
    "U_transformed = jnp.array(U_transformed).reshape(data_size, -1)\n",
    "Z_transformed = jnp.array(Z_transformed).reshape(data_size, -1)\n",
    "W_transformed = jnp.array(W_transformed).reshape(data_size, -1)\n",
    "V_transformed = jnp.array(V_transformed).reshape(data_size, -1)\n",
    "Y_transformed = jnp.array(Y_transformed).reshape(data_size, -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bb1603d2-84c5-45ae-840e-446fe36d4873",
   "metadata": {},
   "outputs": [],
   "source": [
    "RBF_Kernel_Z = RBF(use_length_scale_heuristic = True, use_jit_call = True)\n",
    "RBF_Kernel_A = BinaryKernel(use_length_scale_heuristic = True, use_jit_call = True)\n",
    "RBF_Kernel_W = ColumnwiseRBF(use_length_scale_heuristic = True, use_jit_call = True)\n",
    "RBF_Kernel_V = RBF(use_length_scale_heuristic = True, use_jit_call = True)\n",
    "RBF_Kernel_X = RBF(use_length_scale_heuristic = True, use_jit_call = True)\n",
    "\n",
    "lambda1_ = 0.01\n",
    "lambda2_ = 1.2*1e-1\n",
    "zeta = 0.01\n",
    "optimize_lambda1_parameter = True\n",
    "optimize_zeta_parameter = True\n",
    "optimize_lambda2_parameter = True\n",
    "lambda1_optimization_range = (1e-5, 1.0)\n",
    "lambda2_optimization_range = (1e-5, 1.0)\n",
    "zeta_optimization_range = (1e-5, 1.0)\n",
    "stage1_perc = 0.5\n",
    "regularization_grid_points = 25\n",
    "make_psd_eps = 5*1e-9\n",
    "\n",
    "model = KernelProxyVariableCATE(\n",
    "                                     kernel_A = RBF_Kernel_A,\n",
    "                                     kernel_W = RBF_Kernel_W, \n",
    "                                     kernel_Z = RBF_Kernel_Z,\n",
    "                                     kernel_V = RBF_Kernel_V,\n",
    "                                     kernel_X = RBF_Kernel_X,\n",
    "                                     lambda1_ = lambda1_,\n",
    "                                     lambda2_ = lambda2_,\n",
    "                                     zeta = zeta,\n",
    "                                     optimize_lambda1_parameter = optimize_lambda1_parameter,\n",
    "                                     optimize_lambda2_parameter = optimize_lambda2_parameter,\n",
    "                                     optimize_zeta_parameter = optimize_zeta_parameter,\n",
    "                                     lambda1_optimization_range = lambda1_optimization_range,\n",
    "                                     lambda2_optimization_range = lambda2_optimization_range,\n",
    "                                     zeta_optimization_range = zeta_optimization_range,\n",
    "                                     stage1_perc = stage1_perc,\n",
    "                                     regularization_grid_points = regularization_grid_points, \n",
    "                                     make_psd_eps = make_psd_eps,\n",
    "                                     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2155f1f7-910d-4447-a2a1-4689e18b2646",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Structured function test set MSE: 0.002683326766009367\n",
      "Structured function test set MAE: 0.04122952962062915\n"
     ]
    }
   ],
   "source": [
    "model.fit((A_transformed, W_transformed, Z_transformed, V_transformed), Y_transformed)\n",
    "\n",
    "do_A_size = do_A.shape[0]\n",
    "covariate_v_shape = covariate_v_test.shape[0]\n",
    "do_A_transformed = (A_transformer.transform(do_A)).reshape(do_A_size, -1)\n",
    "covariate_v_transformed = (V_transformer.transform(covariate_v_test)).reshape(covariate_v_shape, -1)\n",
    "\n",
    "f_struct_pred_transformed = model.predict(do_A_transformed, covariate_v_transformed)\n",
    "f_struct_pred = Y_transformer.inverse_transform(f_struct_pred_transformed.reshape(-1, 1))\n",
    "\n",
    "structured_pred_mse = (np.mean((f_struct_pred.reshape(-1, 1) - EY_do_A_CATE.reshape(-1, 1)) ** 2))\n",
    "structured_pred_mae = (np.mean(np.abs(f_struct_pred.reshape(-1, 1) - EY_do_A_CATE.reshape(-1, 1))))\n",
    "print(\"Structured function test set MSE: {}\".format(structured_pred_mse))\n",
    "print(\"Structured function test set MAE: {}\".format(structured_pred_mae))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3ace490e-e5a8-43b0-957a-fcf29ca62c38",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[-0.03499967],\n",
       "        [-0.09023152],\n",
       "        [-0.04884291],\n",
       "        [ 0.20320549],\n",
       "        [ 0.55910208]]),\n",
       " Array([[-0.03136],\n",
       "        [-0.10368],\n",
       "        [ 0.     ],\n",
       "        [ 0.25088],\n",
       "        [ 0.46656]], dtype=float64))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f_struct_pred, EY_do_A_CATE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dcadddd-6e2b-4c82-8f84-6c2deac3a76a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
