{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "from datetime import datetime \n",
    "from tqdm import tqdm\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "from data.linear.binary import BinaryLinearData\n",
    "\n",
    "from algorithm.permutation_based import PermutationBasedTest\n",
    "\n",
    "from experiment.utils import run\n",
    "from experiment.plot import plot_vary_sample_env, set_mpl_default_settings\n",
    "\n",
    "set_mpl_default_settings()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiment: Vary number of samples and environments with permutation based procedure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load = True \n",
    "\n",
    "# Data and algorithm\n",
    "simulate_class = BinaryLinearData\n",
    "test_method_list = [PermutationBasedTest()]\n",
    "\n",
    "# Experiment parameters\n",
    "nbr_env = [25, 100, 200]\n",
    "nbr_samples = [2, 10, 25, 50, 100]\n",
    "repetitions = 50\n",
    "sign_level = 0.05\n",
    "\n",
    "# Fixed dataset\n",
    "conf_strength = [5]\n",
    "dist_param = {'X': {'a': 0.0, 'b': 1},\n",
    "              'Y': {'a': 0.0, 'b': 1},\n",
    "              'T': {'a': 0.0, 'b': 1}\n",
    "                   }\n",
    "\n",
    "# Get timestamp for experiment\n",
    "now = datetime.now()\n",
    "timestamp = now.strftime(\"%m%d%H%M\")\n",
    "print('Timestamp:', timestamp)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not load:\n",
    "    \n",
    "    experiment_results = {}\n",
    "    for alg in tqdm(test_method_list):\n",
    "        \n",
    "        args = [dist_param, nbr_env, nbr_samples, conf_strength, simulate_class, alg, repetitions, sign_level]\n",
    "        alg_name = type(alg).__name__\n",
    "        \n",
    "        res = run(args, save_during_run=f'results/vary_sample_env_{alg_name}_{timestamp}.csv')\n",
    "        experiment_results[alg_name] = pd.concat(res)\n",
    "\n",
    "else:\n",
    "    \n",
    "    # Load data\n",
    "    timestamp_str = \"11151107\"\n",
    "    test_method_list = ['PermutationBasedTest']\n",
    "    timestamp = int(timestamp_str)\n",
    "\n",
    "    experiment_results = {}\n",
    "    for alg in test_method_list:\n",
    "        path = f'results/vary_sample_env_{alg}_{timestamp_str}.csv'\n",
    "        df = pd.read_csv(path)\n",
    "        df.sort_values('nbr_samples', inplace=True)\n",
    "        experiment_results[alg] = df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_experiment(exp_dict : dict, conf_strength : float) -> dict:\n",
    "    '''\n",
    "    Select exerpeimental results with a specific confounding strength\n",
    "    '''\n",
    "    new_exp_dict = {}\n",
    "    for alg in exp_dict:\n",
    "        df = exp_dict[alg]\n",
    "        new_exp_dict[alg] = df[(np.abs(df.confounder_strength - conf_strength) < 1e-3)]\n",
    "    return new_exp_dict\n",
    "\n",
    "def plot_experiment(confounder_strength : float):\n",
    "    '''\n",
    "    Plot curves for each fixed number of environments\n",
    "    '''\n",
    "\n",
    "    for e in nbr_env:\n",
    "\n",
    "        tmp_res = filter_experiment(experiment_results, confounder_strength)\n",
    "        plot_vary_sample_env(tmp_res, x_axis='nbr_samples', fix_val=e,label=f'K={e}',iter=repetitions)\n",
    "\n",
    "    plt.ylabel('Detection rate')\n",
    "    plt.legend()\n",
    "\n",
    "    path = f'results/figures/vary_sample_env_cs{confounder_strength}_{timestamp}.pdf'\n",
    "    plt.savefig(path, format='pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for c in conf_strength:\n",
    "    plt.figure()\n",
    "    plot_experiment(c)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.6 ('CI2')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6 (main, Oct 24 2022, 11:04:34) [Clang 12.0.0 ]"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "1119f8550b2138f5de574d3adbc9d9c628b005f552be6d04a225ae36781ad7c3"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
