{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "368957b9-8e53-48a4-beb4-59d046340db6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "\n",
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from utils.kernel_utils import ColumnwiseRBF, RBF, BinaryKernel, StackOfKernels, MaternKernel, ColumnwiseMaternKernel\n",
    "from utils.linalg_utils import cartesian_product, make_psd\n",
    "from causal_models.approximate_proxy_causal_learning import DoublyRobustKernelProxyATE_Nystorm\n",
    "from utils.ml_utils import data_transform\n",
    "from utils.linalg_utils import pairwise_squared_distance\n",
    "from utils.experimental_data_functions import dSprite_ProxyVariable_DatasetV2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "daea0636-142c-4bc9-ae31-3177f934ac3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = '../../data/dsprite'\n",
    "seed = np.random.randint(1000000)\n",
    "np.random.seed(seed)\n",
    "\n",
    "dsprite_data_generator = dSprite_ProxyVariable_DatasetV2()\n",
    "A, Y, Z, W, do_A, EY_do_A = dsprite_data_generator.generate_dsprite_pv(data_path, n_sample = 2000, generate_test = True, rand_seed = seed)\n",
    "## Alternatively, one can use this data class generator as follows to generate the test data separately\n",
    "# A, Y, Z, W, _, _ = dsprite_data_generator.generate_dsprite_pv(data_path, n_sample = 1000, generate_test = False, rand_seed = seed)\n",
    "# do_A, EY_do_A = dsprite_data_generator.generate_test_dsprite(data_path)\n",
    "transform_data = True\n",
    "\n",
    "if transform_data:\n",
    "    A_transformed, A_transformer = data_transform(A)\n",
    "    Z_transformed, Z_transformer = data_transform(Z)\n",
    "    W_transformed, W_transformer = data_transform(W)\n",
    "    Y_transformed, Y_transformer = data_transform(Y)\n",
    "    \n",
    "    data_size = A_transformed.shape[0]\n",
    "    A_transformed = jnp.array(A_transformed).reshape(data_size, -1)\n",
    "    Z_transformed = jnp.array(Z_transformed).reshape(data_size, -1)\n",
    "    W_transformed = jnp.array(W_transformed).reshape(data_size, -1)\n",
    "    Y_transformed = jnp.array(Y_transformed).reshape(data_size, -1)\n",
    "\n",
    "else:\n",
    "    W_transformed, Z_transformed, A_transformed, Y_transformed, do_A, EY_do_A = jnp.array(W), jnp.array(Z), jnp.array(A), jnp.array(Y), jnp.array(do_A), jnp.array(EY_do_A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f226abad-d8b1-4cfe-8ec4-75c8f2d7aa1e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0, 0, 0, ..., 0, 0, 0],\n",
       "       [0, 0, 0, ..., 0, 0, 0],\n",
       "       [0, 0, 0, ..., 0, 0, 0],\n",
       "       ...,\n",
       "       [0, 0, 0, ..., 0, 0, 0],\n",
       "       [0, 0, 0, ..., 0, 0, 0],\n",
       "       [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "dataset_zip = np.load(os.path.join(data_path, \"dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz\"),\n",
    "                              allow_pickle=True, encoding=\"bytes\")\n",
    "dataset_zip['imgs'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "dad98d76-2fc7-4972-b866-82e17a90b1cb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(737280, 64, 64)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset_zip['imgs'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c72f44f-498c-45f3-a4e4-5e3e6a21aa51",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8e6a1cea-e9fb-43a0-af91-b1c1e32e056a",
   "metadata": {},
   "outputs": [],
   "source": [
    "treatment_bridge_algo_param_dict_default = {\n",
    "                                            \"kernel_A\" : RBF(use_length_scale_heuristic = True, use_jit_call = True),\n",
    "                                            \"kernel_W\" : RBF(use_length_scale_heuristic = True, use_jit_call = True), \n",
    "                                            \"kernel_Z\" : RBF(use_length_scale_heuristic = True, use_jit_call = True),\n",
    "                                            \"kernel_V\" : RBF(use_length_scale_heuristic = True, use_jit_call = True), # Only required for CATE algorithm\n",
    "                                            \"kernel_X\" : RBF(use_length_scale_heuristic = True, use_jit_call = True),\n",
    "                                            \"lambda_\" : 1e-3,\n",
    "                                            \"eta\" : 1e-3,\n",
    "                                            \"lambda2_\" : 1e-3,\n",
    "                                            \"zeta\" : 1e-3, # Only required for ATT or CATE algorithm\n",
    "                                            \"nystrom_first_stage_m\": 500,\n",
    "                                            \"nystrom_third_stage_m\": 500,\n",
    "                                            \"stage1_perc\" : 0.5,\n",
    "                                            \"model_seed\": 0,\n",
    "                                            \"make_psd_eps\" : 1e-9,\n",
    "                                            }\n",
    "\n",
    "outcome_bridge_kpv_algo_param_dict_default = {\n",
    "                                             \"algorithm_name\" : \"Kernel_Proxy_Variable\",\n",
    "                                             \"kernel_A\" : RBF(use_length_scale_heuristic = True, use_jit_call = True),\n",
    "                                             \"kernel_W\" : RBF(use_length_scale_heuristic = True, use_jit_call = True),\n",
    "                                             \"kernel_Z\" : RBF(use_length_scale_heuristic = True, use_jit_call = True),\n",
    "                                             \"kernel_V\" : RBF(use_length_scale_heuristic = True, use_jit_call = True), # Only required for CATE algorithm\n",
    "                                             \"kernel_X\" : RBF(use_length_scale_heuristic = True, use_jit_call = True),      \n",
    "                                             \"lambda1_\" : 1e-3,\n",
    "                                             \"lambda2_\" : 1e-3,\n",
    "                                             \"zeta\" : 1e-3, # Only required for ATT or CATE algorithm\n",
    "                                             \"nystrom_first_stage_m\": 500,\n",
    "                                             \"nystrom_second_stage_m\": 500,\n",
    "                                             \"stage1_perc\" : 0.5,\n",
    "                                             \"model_seed\": 0,\n",
    "                                             \"make_psd_eps\" : 1e-9,\n",
    "                                                }\n",
    "\n",
    "model_DR = DoublyRobustKernelProxyATE_Nystorm(\n",
    "                                      treatment_bridge_algo_param_dict = treatment_bridge_algo_param_dict_default,\n",
    "                                      outcome_bridge_algo_param_dict = outcome_bridge_kpv_algo_param_dict_default,\n",
    "                                      lambda_DR = 1e-5,\n",
    "                                     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bd9179d2-c79c-4591-9485-735d5d8c335f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Structured function test set MSE: 23.68571215241677\n",
      "Structured function test set MAE: 2.9631273200682684\n"
     ]
    }
   ],
   "source": [
    "model_DR.fit((A_transformed, W_transformed, Z_transformed), Y_transformed)\n",
    "do_A_size = do_A.shape[0]\n",
    "do_A_transformed = (A_transformer.transform(do_A)).reshape(do_A_size, -1)\n",
    "f_struct_pred_transformed = model_DR.predict(do_A_transformed)\n",
    "f_struct_pred = Y_transformer.inverse_transform(f_struct_pred_transformed.reshape(do_A_size, -1)).reshape(do_A_size, -1)\n",
    "\n",
    "structured_pred_mse = (np.mean((f_struct_pred.reshape(-1, 1) - EY_do_A.reshape(-1, 1)) ** 2))\n",
    "structured_pred_mae = (np.mean(np.abs(f_struct_pred.reshape(-1, 1) - EY_do_A.reshape(-1, 1))))\n",
    "print(\"Structured function test set MSE: {}\".format(structured_pred_mse))\n",
    "print(\"Structured function test set MAE: {}\".format(structured_pred_mae))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "45e3d8d2-991b-4cd2-a9dc-1fd2fac169b2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.001, 0.001, 0.001)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_DR.treatment_bridge_algo.lambda_, model_DR.treatment_bridge_algo.eta, model_DR.treatment_bridge_algo.lambda2_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "899c544b-7757-4c5d-88e2-d0e0416108dd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.001, 0.001)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_DR.outcome_bridge_algo.lambda1_, model_DR.outcome_bridge_algo.lambda2_ "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3b69381-7192-4220-9ab8-7747e36614b1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
