{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "# General settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of jobs.\n",
    "global_n_jobs = 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Entropy estimator ###\n",
    "method = 'KL' # 'KL' 'KDE'\n",
    "\n",
    "# KDE:\n",
    "bandwidth_algorithm = 'loo_ml' # 'loo_ml', 'loo_lsq'\n",
    "kernel = 'gaussian'\n",
    "\n",
    "# KL:\n",
    "k_neighbours = 5\n",
    "\n",
    "# Entropy estimator settings.\n",
    "entropy_estimator_params = \\\n",
    "{\n",
    "    'method': method,\n",
    "    'functional_params': {'n_jobs': global_n_jobs}\n",
    "}\n",
    "\n",
    "# Additional, estimator-specific settings.\n",
    "if method == 'KDE':\n",
    "    entropy_estimator_params['functional_params']['bandwidth_algorithm'] = bandwidth_algorithm\n",
    "    entropy_estimator_params['functional_params']['kernel'] = kernel\n",
    "elif method == 'KL':\n",
    "    entropy_estimator_params['functional_params']['k_neighbours'] = k_neighbours"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_method_name = method # Estimator full name.\n",
    "\n",
    "if method == 'KDE':\n",
    "    full_method_name += '_' + kernel + '_' + bandwidth_algorithm\n",
    "elif method == 'KL':\n",
    "    full_method_name += '_' + str(k_neighbours)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_estimated_MI(MI, estimated_MI, name):\n",
    "    \"\"\"\n",
    "    Save results to a file.\n",
    "    \"\"\"\n",
    "\n",
    "    file_path = experiments_path + name + '/'\n",
    "    file_name = full_method_name + '__' + str(n_samples) + '_' + str(X_dimension) + '_' + str(Y_dimension) + \"__\" + \\\n",
    "        datetime.now().strftime(\"%d-%b-%Y_%H:%M:%S\") + '.csv'\n",
    "    os.makedirs(file_path, exist_ok=True)\n",
    "    np.savetxt(file_path + file_name, np.column_stack([MI, np.asarray(estimated_MI)]), delimiter=' ')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
