{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import scipy as sp\n",
    "from scipy.optimize import brentq, minimize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def func(x,m):\n",
    "    fx = m*np.exp(x*(1-m)) + (1-m)*np.exp(-x*m)\n",
    "    gx = m*(1-m)*(np.exp((1-m)*x) - np.exp(-m*x))/fx - (2/x)*np.log(fx)\n",
    "    return gx\n",
    "\n",
    "sg_arr = []\n",
    "var_arr = []\n",
    "for i in np.arange(0,0.5,0.0005):\n",
    "    x_best = brentq(func, 0.0005, 100, args = i)\n",
    "    sg_arr += [(2/x_best**2)*np.log(i*np.exp((1-i)*x_best) + (1-i)*np.exp(-i*x_best))]\n",
    "    var_arr += [i*(1-i)]\n",
    "sg_arr += [0.25]\n",
    "var_arr += [0.25]\n",
    "sg_arr += sg_arr[:-1][::-1]\n",
    "var_arr += var_arr[:-1][::-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams[\"font.family\"] = \"sans-serif\"\n",
    "plt.figure(figsize = (8,6))\n",
    "plt.style.use('tableau-colorblind10')\n",
    "x = np.arange(0,1+5e-4,5e-4)\n",
    "plt.plot(x, sg_arr, label='Square of Improved Subgaussian Factor $\\sigma^2$', linewidth = 3.0)\n",
    "plt.plot(x, var_arr, label='Variance of Bernoulli', linewidth = 3.0)\n",
    "plt.plot(x, [0.25]*len(x), label='y = 1/4', linewidth = 3.0)\n",
    "plt.grid()\n",
    "plt.xlabel('m', fontsize=15)\n",
    "plt.ylabel('Value', fontsize=15)\n",
    "plt.xticks(fontsize=12)\n",
    "plt.yticks(fontsize=12)\n",
    "plt.title('Comparing Subgaussian factors and Bernoulli Variance', fontsize=15)\n",
    "plt.legend(fontsize=15)\n",
    "plt.savefig('Sg vs Var plot', bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "def get_rew_vec_clippedgaussian(avg, var = 0.01):\n",
    "    rew = np.zeros(avg.size)\n",
    "    for i in range(avg.size):\n",
    "        temp = avg[i] + np.sqrt(var)*np.random.randn()\n",
    "        rew[i] = max(0, min(temp, 1))\n",
    "    return rew\n",
    "\n",
    "def get_rew_vec_clippedunif(avg):\n",
    "    rew = np.zeros(avg.size)\n",
    "    bound = 0.1*np.ones(avg.size)\n",
    "    for i in range(avg.size):\n",
    "        while avg[i]+bound[i]>1 or avg[i]-bound[i]<0:\n",
    "            bound[i] /= 2.0\n",
    "        rew[i] = np.random.uniform(avg[i]-bound[i], avg[i]+bound[i])\n",
    "    return rew\n",
    "\n",
    "samples = [get_rew_vec_clippedunif(np.asarray([0.02])) for i in range(100000)]\n",
    "print(np.mean(samples))\n",
    "print(np.std(samples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
