{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example of terminal call \n",
    "\n",
    "In this notebook, we briefly show an example terminal call that can be used for the repository. It will be for JSE on the toy dataset. \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "method = 'JSE' # Select the method - in this case, JSE\n",
    "dataset = 'Toy' # Select the dataset - in this case, Toy - other options include 'Waterbirds', 'celebA' or 'multiNLI'\n",
    "dataset_setting = 'default' # Select the dataset setting - in this case, default\n",
    "demean = 'True' # Select whether to demean the data - in this case, True\n",
    "pca = 'False' # Select whether to use PCA - in this case, False\n",
    "k_components = 20 # If pca, select the number of PCA components \n",
    "alpha = 0.05 # Select the alpha level for the hypothesis test of JSE\n",
    "batch_size = 128 # Select the batch size for training\n",
    "solver = 'SGD' # Select the solver for training - in this case, SGD\n",
    "lr = 0.01 # Select the learning rate for training\n",
    "weight_decay = 0.0 # Select the weight decay for training\n",
    "early_stopping = True # Select whether to use early stopping for training\n",
    "epochs = 50 # Select the (max) number of epochs for training\n",
    "per_step = 5 # Select the number of epochs for printing loss\n",
    "device_type = 'cpu' # Select the device type for training - in this case, cpu\n",
    "save_results = True # Select whether to save the results\n",
    "baseline_adjust = False # Select whether to use the heuristic for Delta, otherwise Delta = 0\n",
    "eval_balanced = True # Select whether to use the balanced average for tests\n",
    "use_standard_settings = True # Select whether to use the standard settings for ERM trained on transformed data\n",
    "sims = 1 # Select the number of simulations\n",
    "spurious_ratio = 0.0 # Select the spurious ratio - for the Toy dataset this is \\rho, for others this is p(y_{mt} = y | y_{sp} = y)\n",
    "run_seed = 0 # Select the run seed\n",
    "concept_first = True # Select whether to use the outer loop or inner loop for the concept\n",
    "remove_concept = True # Select whether to remove the concept from the data or to project onto main-task subspace \n",
    "\n",
    "\n",
    "# convert to strings for terminal call\n",
    "method_str = '--method ' + method\n",
    "dataset_str = '--dataset ' + dataset\n",
    "dataset_setting_str = '--dataset_setting ' + dataset_setting\n",
    "demean_str = '--demean ' + demean\n",
    "pca_str = '--pca ' + pca\n",
    "k_components_str = '--k_components ' + str(k_components)\n",
    "alpha_str = '--alpha ' + str(alpha)\n",
    "batch_size_str = '--batch_size ' + str(batch_size)\n",
    "solver_str = '--solver ' + solver\n",
    "lr_str = '--lr ' + str(lr)\n",
    "weight_decay_str = '--weight_decay ' + str(weight_decay)\n",
    "early_stopping_str = '--early_stopping ' + str(early_stopping)\n",
    "epochs_str = '--epochs ' + str(epochs)\n",
    "null_is_concept_str = '--null_is_concept ' + str(False)\n",
    "per_step_str = '--per_step ' + str(per_step)\n",
    "device_type_str = '--device_type ' + device_type\n",
    "save_results_str = '--save_results ' + str(save_results)\n",
    "baseline_adjust_str = '--baseline_adjust ' + str(baseline_adjust)\n",
    "eval_balanced_str = '--eval_balanced ' + str(eval_balanced)\n",
    "use_standard_settings_str = '--use_standard_settings ' + str(use_standard_settings)\n",
    "sims_str = '--sims ' + str(sims)\n",
    "spurious_ratio_str = '--spurious_ratio ' + str(spurious_ratio)\n",
    "run_seed_str = '--run_seed ' + str(run_seed)\n",
    "concept_first_str = '--concept_first ' + str(concept_first)\n",
    "remove_concept_str = '--remove_concept ' + str(remove_concept)\n",
    "\n",
    "\n",
    "\n",
    "terminal_call = 'generate_result_sim.py '+ method_str + ' ' + dataset_str + ' ' +  dataset_setting_str + ' ' + demean_str + ' ' + pca_str + ' ' + k_components_str + ' ' + alpha_str + ' ' + batch_size_str + ' ' + solver_str + ' ' + lr_str + ' ' + weight_decay_str + ' ' + early_stopping_str + ' ' + epochs_str +  ' ' + null_is_concept_str + ' ' + per_step_str + ' ' + device_type_str + ' ' + save_results_str  + ' ' + baseline_adjust_str + ' '  + eval_balanced_str + ' ' + use_standard_settings_str + ' ' + sims_str + ' ' + spurious_ratio_str + ' ' + run_seed_str + ' ' + concept_first_str + ' ' + remove_concept_str\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following is only necessary if you have not installed the JSE folder as a local library. If this is not done, the imports from JSE will not work for the 'generate_result_sim.py' file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -e ."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From here, first we set the directory to JSE folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd JSE"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we run the terminal call. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run {terminal_call}"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Afterwards, the results will appear in the 'JSE/results' folder"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "JSE-replicate-code",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
