{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "35588030-5964-46a4-8b32-9912bb8375f3",
   "metadata": {},
   "source": [
    "# Logistic Sampling"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "545806b6-5c21-4e2d-9e86-741f560ecd21",
   "metadata": {},
   "source": [
    "We present an example showing how logistic outcomes can be sampled from exactly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ab2ef343-7884-44f8-b0f4-6410d716e7b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import random\n",
    "import sys\n",
    "import os\n",
    "sys.path.append(\"../\") # go to parent dir\n",
    "\n",
    "import jax\n",
    "import jax.random as jr\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy\n",
    "import numpy as np\n",
    "from scipy.stats import rankdata\n",
    "from jax.scipy.special import expit\n",
    "import scipy.stats as ss\n",
    "import seaborn as sns\n",
    "from sklearn.model_selection import KFold\n",
    "\n",
    "jnp.set_printoptions(precision=2)\n",
    "\n",
    "# from data.create_sim_data import *\n",
    "import data.template_causl_simulations as causl_py\n",
    "from data.run_all_simulations import plot_simulation_results\n",
    "import data.hyperparam_and_bootstrapping as hb\n",
    "from frugal_flows.causal_flows import independent_continuous_marginal_flow, get_independent_quantiles, train_frugal_flow, train_copula_flow\n",
    "from frugal_flows.bijections import UnivariateNormalCDF\n",
    "from frugal_flows.benchmarking import FrugalFlowModel\n",
    "\n",
    "import rpy2.robjects as ro\n",
    "from rpy2.robjects.packages import importr\n",
    "from rpy2.robjects import pandas2ri\n",
    "from rpy2.robjects.packages import SignatureTranslatedAnonymousPackage\n",
    "import wandb\n",
    "\n",
    "# Activate automatic conversion of rpy2 objects to pandas objects\n",
    "pandas2ri.activate()\n",
    "\n",
    "# Import the R library causl\n",
    "try:\n",
    "    causl = importr('causl')\n",
    "except Exception as e:\n",
    "    package_names = ('causl')\n",
    "    utils.install_packages(StrVector(package_names))\n",
    "\n",
    "\n",
    "hyperparams_dict = {\n",
    "    'learning_rate': 5e-3,\n",
    "    'RQS_knots': 8,\n",
    "    'flow_layers': 5,\n",
    "    'nn_width': 30,\n",
    "    'nn_depth': 4,    \n",
    "    'max_patience': 100,\n",
    "    'max_epochs': 10000\n",
    "}\n",
    "\n",
    "jax.config.update(\"jax_enable_x64\", True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "06d04d48-ddd4-4632-9035-eef2294a0d9c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "R[write to console]: Inversion method selected: using pair-copula parameterization\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 8.86 ms, sys: 1.7 ms, total: 10.6 ms\n",
      "Wall time: 10.2 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "mixed_cont_rscript = f\"\"\"\n",
    "library(causl)\n",
    "forms <- list(list(Z1 ~ 1), X ~ Z1, Y ~ X, ~ 1)\n",
    "fams <- list(1, 5, 1, 1)\n",
    "pars <- list(Z1 = list(beta=0, phi=2),\n",
    "             X = list(beta=c(0,2)),\n",
    "             Y = list(beta=c(0,2), phi=1),\n",
    "             cop = list(beta=matrix(c(0.8), nrow=1)))\n",
    "\n",
    "\n",
    "\n",
    "set.seed(1234)\n",
    "n <- 1e3\n",
    "\n",
    "data_samples <- rfrugalParam(n, formulas = forms, family = fams, pars = pars)\n",
    "\"\"\"\n",
    "rcode_compiled = SignatureTranslatedAnonymousPackage(mixed_cont_rscript, \"powerpack\")\n",
    "df = rcode_compiled.data_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d15f6d3-612f-49b1-8b1a-45a78f61c2ea",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "111418a4-f5cd-42e6-b752-ac492d0609f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "logistic_flow = FrugalFlowModel(\n",
    "    Y=jnp.array(df['Y'].values)[:, None], \n",
    "    X=jnp.array(df['X'].values)[:, None],\n",
    "    Z_cont=jnp.array(df['Z1'].values)[:, None],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a8f597f4-bbef-4c00-887b-c1fa545216c6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|██▍                                                                                                 | 239/10000 [00:12<08:23, 19.38it/s, train=1.7414300345333242, val=1.8217514738657479 (Max patience reached)]\n",
      "  2%|██                                                                                                  | 209/10000 [00:10<07:56, 20.54it/s, train=1.2421812820501008, val=1.4462013794093111 (Max patience reached)]\n",
      "  1%|█▎                                                                                               | 130/10000 [00:06<08:04, 20.38it/s, train=-0.41174612436140146, val=0.06925319850852496 (Max patience reached)]\n"
     ]
    }
   ],
   "source": [
    "logistic_flow.train_benchmark_model(\n",
    "    training_seed=jr.PRNGKey(0), \n",
    "    marginal_hyperparam_dict=hyperparams_dict, \n",
    "    frugal_hyperparam_dict=hyperparams_dict, \n",
    "    prop_flow_hyperparam_dict=hyperparams_dict,\n",
    "    causal_model='gaussian', \n",
    "    causal_model_args={'ate': 0, 'const': 0, 'scale': 1}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5cfa8339-54f9-44a2-a63e-4bb83043c05b",
   "metadata": {},
   "outputs": [],
   "source": [
    "synthetic_samples = logistic_flow.generate_samples(\n",
    "    key=jr.PRNGKey(1),\n",
    "    sampling_size=(1000),\n",
    "    copula_param=0,\n",
    "    outcome_causal_model='logistic_regression',\n",
    "    outcome_causal_args={'ate': 2., 'const': -1.},\n",
    "    with_confounding=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2def3aa8-4034-4b2f-a736-6297a11e6e8d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.895027361408518\n"
     ]
    }
   ],
   "source": [
    "Y0, Y1 = synthetic_samples.groupby('X')['Y'].mean().values\n",
    "print(Y1/Y0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7126b4d6-5b9f-4a86-9f89-6d1083a03274",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(2.72, dtype=float64, weak_type=True)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "expit(1)/expit(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ab573923-07fe-423f-9226-5117a54d3a7a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Y</th>\n",
       "      <th>X</th>\n",
       "      <th>Z_1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.404187</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.907262</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.891973</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-1.298569</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-1.167730</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>995</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-1.115957</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>996</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-2.530417</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>997</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-0.671380</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>998</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.597466</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>999</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>-0.634764</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1000 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       Y    X       Z_1\n",
       "0    1.0  1.0  0.404187\n",
       "1    1.0  1.0  1.907262\n",
       "2    1.0  1.0  0.891973\n",
       "3    1.0  1.0 -1.298569\n",
       "4    0.0  0.0 -1.167730\n",
       "..   ...  ...       ...\n",
       "995  0.0  1.0 -1.115957\n",
       "996  0.0  0.0 -2.530417\n",
       "997  0.0  0.0 -0.671380\n",
       "998  1.0  0.0  0.597466\n",
       "999  0.0  0.0 -0.634764\n",
       "\n",
       "[1000 rows x 3 columns]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "synthetic_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "56938530-4d29-48e4-a010-164dd24b9308",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import rpy2.robjects as ro\n",
    "from rpy2.robjects import pandas2ri\n",
    "from rpy2.robjects.packages import importr\n",
    "\n",
    "# Activate the pandas2ri conversion\n",
    "pandas2ri.activate()\n",
    "\n",
    "# Import necessary R libraries\n",
    "base = importr('base')\n",
    "stats = importr('stats')\n",
    "survey = importr('survey')\n",
    "\n",
    "r_df = pandas2ri.py2rpy(synthetic_samples)\n",
    "ro.globalenv['dat'] = r_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5bef3bb1-969d-48cb-9f37-7535ff9b945e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the R code as a string\n",
    "r_code = \"\"\"\n",
    "library(survey)\n",
    "\n",
    "glmX <- glm(X ~ Z_1, family=binomial, data=dat)\n",
    "glmX_coefficients <- summary(glmX)$coefficients\n",
    "\n",
    "ps <- predict(glmX, type=\"response\")\n",
    "wt <- dat$X/ps + (1-dat$X)/(1-ps)\n",
    "glmY <- svyglm(Y ~ X, family=quasibinomial(), design = svydesign(~1, weights=wt, data=dat))\n",
    "glmY_coefficients <- summary(glmY)$coefficients\n",
    "\n",
    "glmY_OR <- glm(Y ~ X, family=binomial, data=dat)\n",
    "glmY_OR_coefficients <- summary(glmY_OR)$coefficients\n",
    "\n",
    "list(glmX_coefficients = glmX_coefficients, glmY_coefficients = glmY_coefficients, glmY_OR_coefficients = glmY_OR_coefficients)\n",
    "\"\"\"\n",
    "\n",
    "# Execute the R code\n",
    "result = ro.r(r_code)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6d7e7517-b629-4aa9-8a6b-ff27c36dd4a7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-1.59e+00,  1.16e-01, -1.38e+01,  3.15e-43],\n",
       "       [ 3.16e+00,  1.68e-01,  1.88e+01,  1.77e-78]])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result.rx2('glmY_OR_coefficients')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e819907-4a32-4e9d-888f-a2eb9a2888ee",
   "metadata": {},
   "source": [
    "First columns are the means, second columns are the std errors. True values are -1 and +2. Weighted GLM gets the right values!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "79db6fdd-2771-461d-bdaf-b509708e74f3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-0.88,  0.19],\n",
       "       [ 1.6 ,  0.29]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result.rx2('glmY_coefficients')[:, :2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "73adf0b2-f617-438d-96cd-510dd24b5bba",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-1.59,  0.12],\n",
       "       [ 3.16,  0.17]])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result.rx2('glmY_OR_coefficients')[:, :2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29b49cda-40f4-4df5-a0b9-6bd40ecba2d6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
