{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ae9d5567-874b-416b-bac8-ce912122ae32",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from utils.kernel_utils import BinaryKernel, RBF\n",
    "from causal_models.causal_learning import KernelCATE\n",
    "from utils.experimental_data_functions import generate_synthetic_CATE_data\n",
    "from utils.ml_utils import data_transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2268a158-e8a1-489d-9fc2-12333f9f393c",
   "metadata": {},
   "outputs": [],
   "source": [
    "sigma = 0.5\n",
    "uniform_noise_upper_bound = 0.5,\n",
    "uniform_noise_lower_bound = -0.5,\n",
    "seed = np.random.randint(1000000)\n",
    "U, _, _, V, A, Y, covariate_v_test, do_A, EY_do_A_CATE = generate_synthetic_CATE_data(1000, \n",
    "                                                                                      sigma,\n",
    "                                                                                      uniform_noise_upper_bound,\n",
    "                                                                                      uniform_noise_lower_bound,\n",
    "                                                                                      seed)\n",
    "\n",
    "U, V, A, Y = jnp.array(U), jnp.array(V), jnp.array(A), jnp.array(Y)\n",
    "covariate_v_test, do_A, EY_do_A_CATE = jnp.array(covariate_v_test), jnp.array(do_A), jnp.array(EY_do_A_CATE)\n",
    "\n",
    "A_transformed, A_transformer = data_transform(A)\n",
    "U_transformed, U_transformer = data_transform(U)\n",
    "V_transformed, V_transformer = data_transform(V)\n",
    "Y_transformed, Y_transformer = data_transform(Y)\n",
    "\n",
    "data_size = A_transformed.shape[0]\n",
    "A_transformed = jnp.array(A_transformed).reshape(data_size, -1)\n",
    "U_transformed = jnp.array(U_transformed).reshape(data_size, -1)\n",
    "V_transformed = jnp.array(V_transformed).reshape(data_size, -1)\n",
    "Y_transformed = jnp.array(Y_transformed).reshape(data_size, -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b03e2871-120b-4887-af66-8fbe5be1c1e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "RBF_Kernel_X = RBF(use_length_scale_heuristic = True, use_jit_call = True)\n",
    "RBF_Kernel_A = BinaryKernel()\n",
    "# RBF_Kernel_A = RBF(length_scale = 0.08, use_median_length_scale_heuristic = False)\n",
    "RBF_Kernel_V = RBF(use_length_scale_heuristic = True, use_jit_call = True)\n",
    "\n",
    "lambda_ = 1e-3\n",
    "lambda2 = 1e-3\n",
    "optimize_regularization_parameters = True\n",
    "lambda_optimization_range = (1e-9, 1.0)\n",
    "regularization_grid_points = 150\n",
    "\n",
    "model = KernelCATE(\n",
    "                     kernel_A = RBF_Kernel_A,\n",
    "                     kernel_V = RBF_Kernel_V,\n",
    "                     kernel_X = RBF_Kernel_X,\n",
    "                     lambda_ = lambda_,\n",
    "                     lambda2 = lambda2,\n",
    "                     optimize_regularization_parameters = optimize_regularization_parameters,\n",
    "                     lambda_optimization_range = lambda_optimization_range,\n",
    "                     regularization_grid_points = regularization_grid_points\n",
    "                    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "04d37abd-b877-458d-a79b-13af89818c91",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Structured function test set MSE: 0.0001908795888432188\n",
      "Structured function test set MAE: 0.011762109776415633\n"
     ]
    }
   ],
   "source": [
    "model.fit((A_transformed, V_transformed, U_transformed), Y_transformed)\n",
    "\n",
    "do_A_size = do_A.shape[0]\n",
    "covariate_v_shape = covariate_v_test.shape[0]\n",
    "do_A_transformed = (A_transformer.transform(do_A)).reshape(do_A_size, -1)\n",
    "covariate_v_transformed = (V_transformer.transform(covariate_v_test)).reshape(covariate_v_shape, -1)\n",
    "\n",
    "f_struct_pred_transformed = model.predict(do_A_transformed, covariate_v_transformed)\n",
    "f_struct_pred = Y_transformer.inverse_transform(f_struct_pred_transformed.reshape(do_A_size, -1)).reshape(do_A_size, -1)\n",
    "\n",
    "structured_pred_mse = (np.mean((f_struct_pred.reshape(-1, 1) - EY_do_A_CATE.reshape(-1, 1)) ** 2))\n",
    "structured_pred_mae = (np.mean(np.abs(f_struct_pred.reshape(-1, 1) - EY_do_A_CATE.reshape(-1, 1))))\n",
    "print(\"Structured function test set MSE: {}\".format(structured_pred_mse))\n",
    "print(\"Structured function test set MAE: {}\".format(structured_pred_mae))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bb2c4626-20cc-461f-9c88-e77e376f5788",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-0.03755157],\n",
       "       [-0.12243415],\n",
       "       [ 0.02204725],\n",
       "       [ 0.2470249 ],\n",
       "       [ 0.47452248]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f_struct_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e002b827-e4f9-4f45-9ef2-bb17c410e179",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[-0.03136],\n",
       "       [-0.10368],\n",
       "       [ 0.     ],\n",
       "       [ 0.25088],\n",
       "       [ 0.46656]], dtype=float64)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "EY_do_A_CATE"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
