{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "368957b9-8e53-48a4-beb4-59d046340db6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "\n",
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from utils.kernel_utils import ColumnwiseRBF, RBF, BinaryKernel, StackOfKernels, MaternKernel, ColumnwiseMaternKernel\n",
    "from utils.linalg_utils import cartesian_product, make_psd\n",
    "from causal_models.doubly_robust_pcl import DoublyRobustKernelProxyATE\n",
    "from utils.ml_utils import data_transform\n",
    "from utils.linalg_utils import pairwise_squared_distance\n",
    "from utils.experimental_data_functions import dSprite_ProxyVariable_DatasetV2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "daea0636-142c-4bc9-ae31-3177f934ac3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = '../../data/dsprite'\n",
    "seed = np.random.randint(1000000)\n",
    "np.random.seed(seed)\n",
    "\n",
    "dsprite_data_generator = dSprite_ProxyVariable_DatasetV2()\n",
    "A, Y, Z, W, do_A, EY_do_A = dsprite_data_generator.generate_dsprite_pv(data_path, n_sample = 2000, generate_test = True, rand_seed = seed)\n",
    "## Alternatively, one can use this data class generator as follows to generate the test data separately\n",
    "# A, Y, Z, W, _, _ = dsprite_data_generator.generate_dsprite_pv(data_path, n_sample = 1000, generate_test = False, rand_seed = seed)\n",
    "# do_A, EY_do_A = dsprite_data_generator.generate_test_dsprite(data_path)\n",
    "transform_data = True\n",
    "\n",
    "if transform_data:\n",
    "    A_transformed, A_transformer = data_transform(A)\n",
    "    Z_transformed, Z_transformer = data_transform(Z)\n",
    "    W_transformed, W_transformer = data_transform(W)\n",
    "    Y_transformed, Y_transformer = data_transform(Y)\n",
    "    \n",
    "    data_size = A_transformed.shape[0]\n",
    "    A_transformed = jnp.array(A_transformed).reshape(data_size, -1)\n",
    "    Z_transformed = jnp.array(Z_transformed).reshape(data_size, -1)\n",
    "    W_transformed = jnp.array(W_transformed).reshape(data_size, -1)\n",
    "    Y_transformed = jnp.array(Y_transformed).reshape(data_size, -1)\n",
    "\n",
    "else:\n",
    "    W_transformed, Z_transformed, A_transformed, Y_transformed, do_A, EY_do_A = jnp.array(W), jnp.array(Z), jnp.array(A), jnp.array(Y), jnp.array(do_A), jnp.array(EY_do_A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8e6a1cea-e9fb-43a0-af91-b1c1e32e056a",
   "metadata": {},
   "outputs": [],
   "source": [
    "treatment_bridge_algo_param_dict_default = {\n",
    "                                            \"kernel_A\" : RBF(use_length_scale_heuristic = True, use_jit_call = True),\n",
    "                                            \"kernel_W\" : RBF(use_length_scale_heuristic = True, use_jit_call = True), \n",
    "                                            \"kernel_Z\" : RBF(use_length_scale_heuristic = True, use_jit_call = True),\n",
    "                                            # \"kernel_X\" : RBF(use_length_scale_heuristic = True, use_jit_call = True),\n",
    "                                            \"lambda_\" : 1e-3,\n",
    "                                            \"eta\" : 1e-3,\n",
    "                                            \"lambda2_\" : 1e-3,\n",
    "                                            \"optimize_lambda_parameters\" : True,\n",
    "                                            \"optimize_eta_parameter\" : True,\n",
    "                                            \"lambda_optimization_range\" : (1e-6, 1.0),\n",
    "                                            \"eta_optimization_range\" : (1e-6, 1.0),\n",
    "                                            \"stage1_perc\" : 0.5,\n",
    "                                            \"regularization_grid_points\" : 50, \n",
    "                                            \"make_psd_eps\" : 1e-9,\n",
    "                                            \"label_variance_in_lambda_opt\" : 0.0,\n",
    "                                            \"label_variance_in_eta_opt\" : 3.0,\n",
    "                                            }\n",
    "outcome_bridge_kpv_algo_param_dict_default = {\n",
    "                                            \"algorithm_name\" : \"Kernel_Proxy_Variable\",\n",
    "                                            \"kernel_A\" : RBF(use_length_scale_heuristic = True, use_jit_call = True),\n",
    "                                            \"kernel_W\" : RBF(use_length_scale_heuristic = True, use_jit_call = True),\n",
    "                                            \"kernel_Z\" : RBF(use_length_scale_heuristic = True, use_jit_call = True),\n",
    "                                            # \"kernel_X\" : RBF(use_length_scale_heuristic = True, use_jit_call = True),      \n",
    "                                            \"lambda1_\" : 0.1,\n",
    "                                            \"lambda2_\" : 0.1,\n",
    "                                            \"optimize_lambda1_parameter\" : True,\n",
    "                                            \"optimize_lambda2_parameter\" : True,\n",
    "                                            \"lambda1_optimization_range\" : (1e-5, 1.0),\n",
    "                                            \"lambda2_optimization_range\" : (1e-5, 1.0),\n",
    "                                            \"stage1_perc\" : 0.5,\n",
    "                                            \"regularization_grid_points\" : 50, \n",
    "                                            \"make_psd_eps\" : 1e-9,\n",
    "                                            }\n",
    "\n",
    "model_DR = DoublyRobustKernelProxyATE(\n",
    "                                      treatment_bridge_algo_param_dict = treatment_bridge_algo_param_dict_default,\n",
    "                                      outcome_bridge_algo_param_dict = outcome_bridge_kpv_algo_param_dict_default\n",
    "                                     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bd9179d2-c79c-4591-9485-735d5d8c335f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Structured function test set MSE: 23.270760136140602\n",
      "Structured function test set MAE: 3.1542561161084532\n"
     ]
    }
   ],
   "source": [
    "model_DR.fit((A_transformed, W_transformed, Z_transformed), Y_transformed)\n",
    "do_A_size = do_A.shape[0]\n",
    "do_A_transformed = (A_transformer.transform(do_A)).reshape(do_A_size, -1)\n",
    "f_struct_pred_transformed = model_DR.predict(do_A_transformed)\n",
    "f_struct_pred = Y_transformer.inverse_transform(f_struct_pred_transformed.reshape(do_A_size, -1)).reshape(do_A_size, -1)\n",
    "\n",
    "structured_pred_mse = (np.mean((f_struct_pred.reshape(-1, 1) - EY_do_A.reshape(-1, 1)) ** 2))\n",
    "structured_pred_mae = (np.mean(np.abs(f_struct_pred.reshape(-1, 1) - EY_do_A.reshape(-1, 1))))\n",
    "print(\"Structured function test set MSE: {}\".format(structured_pred_mse))\n",
    "print(\"Structured function test set MAE: {}\".format(structured_pred_mae))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "45e3d8d2-991b-4cd2-a9dc-1fd2fac169b2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Array(1.e-06, dtype=float64),\n",
       " 0.03393221771895326,\n",
       " Array(1.e-06, dtype=float64))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_DR.treatment_bridge_algo.lambda_, model_DR.treatment_bridge_algo.eta, model_DR.treatment_bridge_algo.lambda2_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "899c544b-7757-4c5d-88e2-d0e0416108dd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Array(0.0013895, dtype=float64), Array(1.e-05, dtype=float64))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_DR.outcome_bridge_algo.lambda1_, model_DR.outcome_bridge_algo.lambda2_ "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3b69381-7192-4220-9ab8-7747e36614b1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
