{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "from datetime import datetime\n",
    "from tqdm import tqdm\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from data.linear.binary import BinaryLinearData\n",
    "from data.linear.continuous import GaussianLinearData\n",
    "\n",
    "from algorithm.general import GConfounderTest, PearsonConfounderTest\n",
    "from algorithm.baseline import GEnvironmentTest, PearsonEnvironmentTest\n",
    "\n",
    "from experiment.utils import run\n",
    "from experiment.plot import set_mpl_default_settings, marker_dict, name_dict\n",
    "\n",
    "set_mpl_default_settings()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiment: Comparing our procedure to the $Y \\perp E \\mid T$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load = False \n",
    "\n",
    "# Data and algorithm\n",
    "simulate_class = BinaryLinearData \n",
    "test_method_list = [GConfounderTest(), GEnvironmentTest()]\n",
    "\n",
    "# Experiment parameters\n",
    "nbr_env = [500]\n",
    "nbr_samples = [100]\n",
    "repetitions = 50\n",
    "sign_level = 0.05\n",
    "\n",
    "# Fixed dataset\n",
    "conf_strength = [0,10]\n",
    "dist_param_list = []\n",
    "for k in list(np.linspace(0,0.25,10)):\n",
    "     dist_param_list.append( {'X': {'a': 0.0, 'b': 1.0},\n",
    "                              'Y': {'a': 0.0, 'b': k},\n",
    "                              'T': {'a': 0.0, 'b': 1.0}\n",
    "                              })\n",
    "\n",
    "# Get timestamp for experiment\n",
    "now = datetime.now()\n",
    "timestamp = now.strftime(\"%m%d%H%M\")\n",
    "print('Timestamp:', timestamp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nbr_proc=4\n",
    "def experiment_vary_dist(test_method, path):\n",
    "\n",
    "    df_list = []\n",
    "    for d in dist_param_list:\n",
    "        args = [d, nbr_env, nbr_samples, conf_strength, simulate_class, test_method, repetitions, sign_level]\n",
    "        res = run(args, save_during_run=path)\n",
    "        df_list.append(pd.concat(res))\n",
    "    \n",
    "    return pd.concat(df_list)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not load:\n",
    "    \n",
    "    experiment_results = {}\n",
    "    for alg in tqdm(test_method_list):\n",
    "        alg_name = type(alg).__name__\n",
    "        res = experiment_vary_dist(alg, f'results/comparison_{alg_name}_{timestamp}.csv') \n",
    "        experiment_results[alg_name] = res\n",
    "\n",
    "else:\n",
    "    \n",
    "    # Load data\n",
    "    timestamp_str = \"11151428\"\n",
    "    test_method_list = ['PearsonConfounderTest', 'PearsonEnvironmentTest']\n",
    "    timestamp = int(timestamp_str)\n",
    "\n",
    "    experiment_results = {}\n",
    "    for alg in test_method_list:\n",
    "        path = f'results/comparison_{alg}_{timestamp_str}.csv'\n",
    "        df = pd.read_csv(path)\n",
    "        experiment_results[alg] = df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_experiment(confounder_strength : float):\n",
    "    '''\n",
    "    Plot curves for each fixed number of environments\n",
    "    '''\n",
    "\n",
    "    for alg in experiment_results:\n",
    "        tmp_df = experiment_results[alg]\n",
    "        tmp_df = tmp_df[(np.abs(tmp_df.confounder_strength - confounder_strength) < 1e-3)]\n",
    "        \n",
    "        p = tmp_df.reject_rate\n",
    "        plt.plot(tmp_df.Y_b, p, label=name_dict[alg], marker=marker_dict[alg])\n",
    "        std = np.sqrt(p*(1-p)/repetitions)\n",
    "        plt.fill_between(tmp_df.Y_b, p-std, p+std, alpha=0.5)\n",
    "\n",
    "    if confounder_strength == 0:\n",
    "        plt.ylabel('False detection rate')\n",
    "    else:\n",
    "        plt.ylabel('Detection rate')\n",
    "    plt.xlabel('$\\\\sigma_{\\\\theta_Y}^2$')\n",
    "    plt.ylim([-.1,1.1])\n",
    "    \n",
    "    \n",
    "    if confounder_strength == 0:\n",
    "        plt.axhline(y=0.05, color='black', linestyle='--')\n",
    "        plt.legend()\n",
    "\n",
    "    path = f'results/figures/comparison_cs{confounder_strength}_{timestamp}.pdf'\n",
    "    plt.savefig(path, format='pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for c in conf_strength:\n",
    "    plt.figure()\n",
    "    plot_experiment(c)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.6 ('CI2')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "1119f8550b2138f5de574d3adbc9d9c628b005f552be6d04a225ae36781ad7c3"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
