{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import utils\n",
    "from typing import Any\n",
    "import torch as t\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import itertools\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "# from functools import partial\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# parameters for top-k distribution\n",
    "rng = np.random.RandomState(1234)\n",
    "n = 10\n",
    "k = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.47143516 -1.19097569  1.43270697 -0.3126519  -0.72058873  0.88716294\n",
      "  0.85958841 -0.6365235   0.01569637 -2.24268495]\n"
     ]
    }
   ],
   "source": [
    "# generate random parameters for the distribution: these are the initial weights\n",
    "theta = rng.randn(n)\n",
    "print(theta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of possible states: 10\n",
      "[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]\n"
     ]
    }
   ],
   "source": [
    "# create all possible_states:\n",
    "combs = list(itertools.combinations(range(n), k))\n",
    "n_states = len(combs)\n",
    "assert n_states == np.math.factorial(n)/(np.math.factorial(k)*np.math.factorial(n-k))\n",
    "print('Number of possible states:', n_states)\n",
    "mat_x = np.zeros((len(combs), n))\n",
    "for i in range(n_states):\n",
    "    mat_x[i, combs[i]] = 1.\n",
    "\n",
    "print(mat_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# put all of this in pytorch\n",
    "theta_t = t.from_numpy(theta).float().requires_grad_(True)\n",
    "states_t = t.from_numpy(mat_x).float()\n",
    "\n",
    "def tow_t(_theta):\n",
    "    return states_t @ _theta\n",
    "\n",
    "def Z_t(_theta):\n",
    "    return t.log(t.sum(t.exp(tow_t(_theta))))\n",
    "\n",
    "def pmf_t(_theta):\n",
    "    return t.exp(tow_t(_theta) - Z_t(_theta))\n",
    "\n",
    "def sample_state_from_pdf(_theta):\n",
    "    _pmft = pmf_t(_theta)\n",
    "    indx_ch = rng.choice(list(range(n_states)), p=_pmft.detach().numpy())\n",
    "    return indx_ch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(1., dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "print(t.sum(pmf_t(theta)))  # so far so good"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#groundtruth distribution\n",
    "b_t = t.abs(t.from_numpy(rng.randn(n)).float())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1.1500, 0.9919, 0.9533, 2.0213, 0.3341, 0.0021, 0.4055, 0.2891, 1.3212,\n",
      "        1.5469])\n"
     ]
    }
   ],
   "source": [
    "print(b_t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.7361908\n"
     ]
    }
   ],
   "source": [
    "sorted_bt = np.sort(b_t.detach().numpy())\n",
    "min_value_of_exp = np.sum((sorted_bt[:5])**2) + np.sum((sorted_bt[5:] - 1)**2)\n",
    "print(min_value_of_exp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(10.8149)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def objective(index):\n",
    "    return t.sum((states_t[index] - b_t)**2)\n",
    "\n",
    "objective(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1.1500, 0.9919, 0.9533, 2.0213, 0.3341, 0.0021, 0.4055, 0.2891, 1.3212,\n",
       "        1.5469])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b_t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(pmf_t(theta_t))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# writing explicitly the expectation of this objective summing over\n",
    "# all possible states:\n",
    "def expectation_t(_theta):\n",
    "    _pmf = pmf_t(_theta)\n",
    "    _p_values = t.stack([_pmf[i] * objective(i) for i in range(n_states)])\n",
    "    return t.sum(_p_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(11.2864, grad_fn=<SumBackward0>)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "expectation_t(theta_t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#Ground truth gradient\n",
    "exact_gradient = t.autograd.grad(expectation_t(theta_t), theta_t)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Essentially we are now solving explicitly\n",
    "$\\min_{\\theta} \\mathbb{E}_{z\\sim p(z, \\theta)} b^\\intercal z$\n",
    "where $p(z, \\theta)$ is top-k distribution.\n",
    "\n",
    "With full optimization we simply write $\\mathbb{E}_{z\\sim p(z, \\theta)} b^\\intercal z= \\sum_{i=1}^{N} p(z_i, \\theta) b^\\intercal z_i $\n",
    "summing over all possible states, where $N=\\binom{n}{k}$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def return_grad(strategy, reinitialize=True):\n",
    "    global theta_t\n",
    "    if reinitialize:\n",
    "        theta_t = t.from_numpy(theta).float().requires_grad_(True)\n",
    "        \n",
    "    # redefine objective with given strategy\n",
    "    def objective_(_theta):\n",
    "        sample = strategy(_theta)\n",
    "#         print(\"Sample from obj_\", sample)\n",
    "#         import pdb; pdb.set_trace()\n",
    "#         return t.sum((sample - b_t)**2)\n",
    "        if len(sample.shape) == 2:\n",
    "            return ((sample - b_t)**2).mean(dim=0).sum()\n",
    "        else:\n",
    "            return t.sum((sample - b_t)**2)\n",
    "        \n",
    "    obj = objective_(theta_t)\n",
    "    obj.backward()\n",
    "    return theta_t.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "\n",
    "def sample_gumbel(shape, eps=1e-20):\n",
    "    U = torch.rand(shape)\n",
    "    return -Variable(torch.log(-torch.log(U + eps) + eps))\n",
    "\n",
    "def gumbel_softmax_sample(logits, temperature):\n",
    "#     import pdb; pdb.set_trace()\n",
    "    y = logits + sample_gumbel(logits.size())\n",
    "    return F.softmax(y / temperature, dim=-1)\n",
    "\n",
    "def gumbel_softmax(logits, temperature=1.0):\n",
    "    \"\"\"\n",
    "    input: [*, n_class]\n",
    "    return: [*, n_class] an one-hot vector\n",
    "    \"\"\"\n",
    "    y = gumbel_softmax_sample(logits, temperature)\n",
    "    shape = y.size()\n",
    "    _, ind = y.max(dim=-1)\n",
    "    y_hard = torch.zeros_like(y).view(-1, shape[-1])\n",
    "    y_hard.scatter_(1, ind.view(-1, 1), 1)\n",
    "    y_hard = y_hard.view(*shape)\n",
    "    return (y_hard - y).detach() + y\n",
    "\n",
    "grads = []\n",
    "for i in range(10000):\n",
    "    grads += [return_grad(gumbel_softmax)]\n",
    "STGS_grads = torch.stack(grads)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2149684/1349740020.py:13: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  y_perturbed = F.softmax(logits + sample_gumbel(logits.size()))\n"
     ]
    }
   ],
   "source": [
    "from __future__ import print_function\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "\n",
    "def SIMPLE(logits):\n",
    "    \"\"\"\n",
    "    input: [*, n_class]\n",
    "    return: [*, n_class] an one-hot vector\n",
    "    \"\"\"\n",
    "    y = F.softmax(logits, dim=-1)\n",
    "    y_perturbed = F.softmax(logits + sample_gumbel(logits.size()))\n",
    "    shape = y.size()\n",
    "    _, ind = y_perturbed.max(dim=-1)\n",
    "    y_hard = torch.zeros_like(y_perturbed).view(-1, shape[-1])\n",
    "    y_hard.scatter_(1, ind.view(-1, 1), 1)\n",
    "    y_hard = y_hard.view(*shape)\n",
    "    return (y_hard - y).detach() + y\n",
    "\n",
    "grads = []\n",
    "for i in range(10000):\n",
    "    grads += [return_grad(SIMPLE)]\n",
    "SIMPLE_grads = torch.stack(grads)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "exact_gradient = exact_gradient[0].expand(10000, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAATgAAACzCAYAAAAdfwDFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAhOUlEQVR4nO3de1RU5f4/8DcDDCBI6AkQgrxQgHrkeME8LiUVEEwuw00xwlBUKiuPaBn2KxHRVOJYAVp5SjDTo3lJFFELM4U64lE5YglkOEqQMShNgMQMDM/3D36zFyPMBdiIbD+vtVwL9uXZz8zo27337Of5GDHGGAghRIBEfd0BQgjpLRRwhBDBooAjhAgWBRwhRLAo4AghgmXS1x3QRy6XIycnBy4uLjA1Ne3r7hDy0GhubkZ5eTkCAwNhY2PT193plgc+4HJycpCcnNzX3SDkoRYdHd3XXeiWBz7gRowYAQB4++234e7u3se9IeThUVpaiuTkZO7fYH/0wAecWCwGALi7u8PT07OPe0PIw0f9b7A/MuhLBqlUisjISPj7+yMyMhI3btzosI1KpUJSUhJ8fX0xc+ZM7N+/X2N9bm4ugoKCEBgYiKCgINy+fZuXF0AIIdoYdAaXmJiIqKgoSCQSZGdnY82aNfjss880tjl69CgqKirw1VdfQS6XIyQkBJMnT4aTkxOuXLmCjIwM7Ny5E7a2tqivr+/X/ysQQvoHvWdwd+7cwdWrVxEYGAgACAwMxNWrV1FbW6uxXW5uLubMmQORSITBgwfD19cXJ06cAABkZWUhNjYWtra2AICBAwfCzMyM79dCCCEa9J7B3bp1C/b29jA2NgYAGBsbw87ODrdu3cLgwYM1tnN0dOR+d3BwwG+//QYAKC8vh5OTE5577jk0NjZi5syZeOmll2BkZKRxrLq6OtTV1Wksq6mp6f6rI3plZmZCKpXy3u7w4cOxcOFCXttsbW3F7du3IZfLoVKpeG37YWdubg4nJyfBPYp1X75kUKlUKCsrQ2ZmJpRKJRYvXgxHR0eEhIRobLdz505kZGR02kZZWVmHQCQ9V1hYCNn1G3jM1o63NqtqZJDJZPDw8OCtTQAQiUQYOHAg7OzsYGJiQn8feMIYwx9//IEffvgBra2t3PKysrI+7BU/9Aacg4MDqquroVKpYGxsDJVKBZlMBgcHhw7b/frrr9xf6vZndI6Ojpg1axbEYjHEYjF8fHxQXFzcIeBiYmIQGhqqsezKlStYvnw53NzcMGHChJ68VtKJ7OxsDFA0Y8VzMby1uWX3TljZ2fH+eZWUlGDEiBEQiWgADt8sLS1RV1eHkSNHcsuEMJOa3oD7y1/+gpEjRyInJwcSiQQ5OTkYOXKkxuUpAMyaNQv79++Hn58f5HI58vLysHv3bgBt9+3OnDkDiUSClpYWnDt3Dv7+/h2OZW1tDWtra41l6stcQgBohFtvXV4DvXOJ/SAT6tmwQZeoa9euRUJCArZt2wZra2ts3rwZALBkyRIsW7YMY8aMgUQiweXLl+Hn5wcAePnll+Hs7AwACAgIwA8//IDZs2dDJBJh6tSpiIiI6KWXRB4WUqkUpRcuwcnentd2K6urDd7W29sbYrEYZmZmUCgU8PT0RGJiIg4cOACFQoEFCxbw2jfSNQYFnIuLS4fn2gDgX//6F/ezsbExkpKSOt1fJBJh9erVWL16dTe7SUjnnOzteb28BtousbsiLS0Nrq6uUKlUeO655/D111/j2Wef5bVPpHvoZgYhPFEoFFAoFLC2tkZ6ejp3pVNWVoaoqCiEhoZi9uzZyMrK4vbZt28fnnnmGUgkEgQFBaG8vLyPes8PQwYFFBQUICwsDH/961+590ht69atCAgIQHBwMMLCwpCfn8+tS09Px+TJkyGRSCCRSLSeULX3wA/VIuRBt2zZMpiZmaGiogJTp07F1KlTUVRUxK1/7LHHkJWVBbFYjLt372LOnDnw8vKCi4sLUlJSkJOTAwcHByiVyn7/+IshgwKcnZ2xfv16nDx5EkqlUmOdh4cHYmNjYWFhgdLSUkRHR6OgoADm5uYAgJCQELzxxhsG94fO4AjpobS0NGRnZ+PcuXNQKBQaZ2gA0NTUhDfffBNBQUF49tlnIZPJUFpaCgD4+9//jtWrV2PXrl2orq6GhYVFH7wCfhg6KGDo0KEYNWoUTEw6nl95eXlx74GbmxsYY5DL5d3uEwUcITwxMzPD9OnT8f3332ss37JlC2xtbfHll1/iyJEj8PDwgEKhAABkZGRgxYoV+PPPP/H888/jzJkzfdF1nWpqalBZWanx594H8gHdgwK64/Dhw3j88ccxZMgQbtmxY8cQFBSE2NhYjbNkbegSlRCetLa24r///S+GDRumsby+vh5ubm4wMTHBTz/9hAsXLiAwMBAtLS3cs6MeHh6oqKhASUkJpk2b1jcvQIvly5d3WPbKK6/g1Vdf7bVjnj9/Hh988AF27NjBLZs3bx5efPFFmJqa4rvvvsPSpUuRm5uLQYMGaW2HAo70a5XV1V3+1tOQNt2dHzN4e/U9uObmZjz55JN4+eWXNe47vfTSS1i1ahWOHDmCxx9/HBMnTgTQFogJCQmor6+HkZERHBwcsHLlSl5fCx/ef/99jBkzRmPZvc+rAoYPCtCnqKgIr7/+OrZt26YxF516LDsATJkyBQ4ODrh27RqeeuoprW1RwJF+a/jw4b3SrrvzYwa3/c0333S6vP3ZzahRo5CTk9Ppdnv27Ol6B+8zW1tbODk56d3O0EEBuhQXFyM+Ph5paWkYPXq0xrrq6mrY//9nHktKSlBVVaX3c6KAI/3WwzTSoL8wZFDAhQsXsGLFCjQ0NIAxhmPHjmHDhg3w8vJCUlISmpqasGbNGq7NlJQUuLm5YcuWLfjxxx8hEolgamqKlJQUjbO6zlDAEUJ4Y8igAE9PT5w9e7bT/Q8ePKi17XufmTMEfYtKCBEsCjhCiGBRwBFCBIsCjhAiWPQlA+m3aD44og8FHOm3pFIpTuWcho2VDa/tyhvk8AnUv92iRYswc+ZMzJs3j1vGGIOPjw82b97MPdCrj0Qiwb59+7gB5YQ/FHCkX7OxssE0D29e2zxT3PnDu/cKDw9HVlaWRsAVFhbCxMTEoHBraWmBiYkJsrOzu91XohsFHCHd5Ovri6SkJPz888944oknAACHDh1CcHAwoqKi8Oeff0KhUGDu3LnczL4JCQmwtLTEjRs38Pvvv+PQoUNwc3PDpUuXYGlpic2bN+P8+fNobm7GoEGD8M477+Cxxx5DZWUlwsPDMW/ePJw5cwZ//vknNmzYAE9PTwDA6dOnkZ6ejpaWFohEImzatAnu7u64fPkyUlNTcffuXQBtw8qmT5/eF29Xn6CAI6SbxGIxgoKCcOjQIaxatQoNDQ3Iy8vD0aNHERcX1+n8b0DbWMvPP/8cAwYM6NDmkiVLuPnO9u/fj9TUVLz33nsAALlcjrFjxyI+Ph5HjhxBamoq9u7dC6lUirfeegu7d+/GsGHDoFQqoVQqUVdXh8TERGzfvh12dnaQyWSIiIhATk5Op2NJhYgCjpAeiIiIwOLFi7FixQocP34cEyZMgJmZGd58802u1KV6/jd1wM2aNavTcAOAs2fPYs+ePWhsbERLS4vGugEDBmDGjBkAgLFjx3JP9n///fd4+umnuVlM1NXrzpw5g8rKSixZsoRrw8jICDdv3uwweF6oKOAI6QF3d3fY2toiPz8fBw8exIIFC7j53zZt2gQTExPExsZy878B0BpuVVVV2LhxIw4cOABnZ2dcunQJr732GrdeLBZzP4tEIi4AtZX3Y4zBzc2Nq273MKLn4AjpofDwcKSnp+PGjRvw9vZGfX09hgwZojH/myEaGhpgamoKW1tbtLa2Yu/evQbtN3XqVJw9e5arf6BUKtHQ0IBx48bh5s2bOHfuHLdtcXGxIOqdGorO4Ei/Jm+QG/ytZ1fa7IqgoCCkpKQgMjISYrFY6/xv+ri5uWHWrFkICAiAo6MjJk6caFA4Dhs2DMnJyYiPj+fmYtu0aRPc3Nywbds2vPvuu3jnnXfQ3NwMZ2dnfPTRR4Ktg3ovCjjSbw0fPtyg59W627ahHnnkERQXF3O/65r/bdOmTR2WlZWVcT+/9dZbeOutt7jfly1bBgBwcnJCYWEht/ze3729veHt3fFxGQ8PD+zatcvg1yI0FHCk36KRBkQfugdHCBEsCjhCiGBRwJF+pbW1ta+7IEh8fbPa08r2KpUKSUlJ8PX1xcyZMzVmB9a1ThsKONJvWFpaoqqqCkql8qF61KG3McZw584dXgb7qyvbnzx5ElFRURq1FdTUle0XLVrUYd3Ro0dRUVGBr776Cvv27UN6ejoqKyv1rtOGvmQg/YaTkxNu376NmzdvdnjKn/SMubm5QZWzdFFXts/MzATQVtk+OTkZtbW1GpW1hg4dCgA4deoUlEqlRhu5ubmYM2cORCIRBg8eDF9fX5w4cQKLFy/WuU4bCjjSb4hEItjZ2cHOzq6vu/JQUVe2b8/a2rrDeFZdle0NLR1469YtODo6cr87ODjgt99+07tOGwo4QohOfVHZni8UcIQQne5nZXsHBwf8+uuv8PDwAKB51qZrnTb0JQMhRCd1Zfv2fzoLuPaV7QF0q7L9rFmzsH//frS2tqK2thZ5eXnw9/fXu04bOoMjhPCmp5XtJRIJLl++DD8/PwDAyy+/DGdnZwDQuU4bCjhCCG96Wtne2NgYSUlJXV6nDV2iEkIEiwKOECJYBgWcIcMvDBlGcf36dfztb3/rMDyDEEJ6g0EBZ8jwC33DKFQqFRITE+Hr68tf7wkhRAe9AacefhEY2DazYGBgIK5evYra2lqN7bQNo1Dbvn07pk+fzhXGIISQ3qY34HQNv7h3O23DKEpLS1FQUMDVhtSmrq4OlZWVGn9qamq6+poIIQTAfXhMpLm5GW+//TY2btzIhaQ2O3fuREZGRqfr1CXYCL9kMhlMlc2or6/nrc1mZTNkMhkuXrzIW5vk/ms/lXp/pTfgDB1+oW0YRU1NDSoqKhAXFweg7SyNMYaGhgYkJydrtBETE4PQ0FCNZVeuXMHy5cvh5uaGCRMm9OjFko6ys7PRoKjCwIEDeWvTVGyKQXZ29Hn1c0KYkkpvwLUffiGRSLQOv1APo/Dz84NcLkdeXh52794NR0dHjeIY6enpaGxs5Kp3t9fZDAX6ZgsghBBtDPoWde3atfj888/h7++Pzz//nHuaeMmSJbhy5QqAtmEUTk5O8PPzw9y5cw0aRkEIIb3JoHtwhgy/MHQYRX+YYoUQIgw0koEQIlgUcIQQwaKAI4QIFgUcIUSwKOAIIYJFAUcIESya0ZcQwhupVIqEhATI5XLY2Nhg8+bNHSbYUKlUWL9+PfLz82FkZIS4uDjMmTMHALBq1SqNIWJlZWXYunUrfHx8kJ6ejj179nBlI8ePH4/ExESd/aGAI4TwRj21mkQiQXZ2NtasWYPPPvtMY5v2U6vJ5XKEhIRg8uTJcHJyQkpKCrddaWkpYmJi4OXlxS0LCQnpdBSUNnSJSgjhBV9Tq6kdOHAAQUFBEIvF3e4TncERQnTiu7K9IRXqlUoljh49iqysLI3lx44dQ0FBAWxtbfHqq69i3LhxOvtOAUcI0akvKtvn5eXB0dERI0eO5JbNmzcPL774IkxNTfHdd99h6dKlyM3NxaBBg7S2QwFHCNGJ78r2hlSoP3jwIMLDwzWW2dracj9PmTIFDg4OuHbtGp566imtfad7cIQQnfiubK+vQv1vv/2Gixcvcvfy1Kqrq7mfS0pKUFVVheHDh+vsO53BEUJ4Y0hle30V6r/88kvMmDEDNjY2Gm1v2bIFP/74I0QiEUxNTZGSkqJxVtcZCjhCCG/4mFrtpZde6nR5d8qN0iUqIUSwKOAIIYJFAUcIESwKOEKIYFHAEUIEiwKOECJY9JgIuS8yMzMhlUp5b3f48OFYuHAh7+0SYaCAI/eFVCrFqZzTsLGy4a1NeYMcPoH6tyMPLwo4ct/YWNlgmoc3b+2dKf6Gt7aIMNE9OEKIYFHAEUIEiwKOECJYFHCEEMGigCOECBZ9i0pIN9Bzff0DBRwh3SCVSlF64RKc7O15a7Oy3Yy1hB8UcIR0k5O9PVY8F8Nbe1t27+StLdKG7sERQngjlUoRGRkJf39/REZG4saNGx22UalUSEpKgq+vL2bOnKkxA3B6ejomT54MiUQCiUSiMfOvrv20oTM4QghvelrZHtBevV7ffp2hMzhCiE7qws/t/9TV1XXYju/K9vfqzn50BkcI0cnQws98VbbXVr1e336doYAjhOhkaOFnPnSner0uBl2i9vTG4datWxEQEIDg4GCEhYUhPz+/W50lhNx/hhZ+bl/ZHoDeyvZqt27dwpAhQ7hjmZqaAtCsXq9vP20MCjj1jcOTJ08iKioKa9as6bBN+xuA+/btQ3p6OiorKwEAHh4eOHDgAI4cOYJ33nkH8fHxaGpqMuTQhJB+go/K9rqq1+vaTxu9l6jqG4eZmZkA2m4cJicno7a2VqPj2m4ALl68GF5eXtx2bm5uYIxBLpfrTV9CSP/S08r2uqrX69pPG70Bx9eNQ7XDhw/j8ccf7zTc6urqOnw7U1NTo6+LhJAHRE8r2+uqXq9rP23u65cM58+fxwcffIAdO3Z0un7nzp3IyMjodF1ZWRmMjIx6s3sPJZlMBlNlM+rr63lrs1nZDJlMhosXL2ocR6lsRn19A2/HUXZynPvlfr1vfamsrKyvu9BjegOu/Y1DY2NjvTcOPTw8AHQ8oysqKsLrr7+Obdu2YcSIEZ0eKyYmBqGhoRrLrly5guXLl8PNzQ0TJkzo8gskumVnZ6NBUYWBAwfy1qap2BSD7Ow0Pq/s7Gz8Iq7CwIFWvB1HLDaF3T3HuV/u1/vWlxhjfd2FHtP7JQMfNw6Li4sRHx+PtLQ0jB49WuuxrK2tO3xbo77+JoSQrjLoErWnNw6TkpLQ1NSk8e1rSkoK3Nzc+H49hBDCMSjgenrj8ODBg93sHiGEdB+NRSWECBYFHCFEsCjgCCGCRQFHCBEsCjhCiGBRwBFCBIsCjhAiWBRwhBDBooAjhAgWBRwhRLAo4AghgkVFZwghvJFKpUhISIBcLoeNjQ02b96MYcOGaWyjUqmwfv165Ofnw8jICHFxcZgzZw6Atvotubm5MDY2homJCeLj47kZwdPT07Fnzx7Y2dkBAMaPH4/ExESd/aGAI4TwpqeFnz08PBAbGwsLCwuUlpYiOjoaBQUFMDc3B6C9KLQ2dIlKCOEFH4Wfvby8YGFhAUCzfkt30RkcIUQndWX79qytrTuUDrwf9Vu0FYXWhgKOEKKToZXt+dRZ/ZbuFIWmgCPkAZaZmQmpVMp7u8OHD8fChQsN2tbQyva9Xb+lffmC9kWhn3rqKa19p4Aj5AEmlUpxKuc0bKxseGtT3iCHT6Dh26sr2+vTvn6LRCLRW7/Fz88PcrkceXl52L17NwDd9Vuqq6thb28PoGNRaG0o4Ah5wNlY2WCahzdv7Z0p/oa3tu7Vm/VbdBWF1oYCjhDCm96s36KrKLQ29JgIIUSwKOAIIYJFAUcIESwKOEKIYFHAEUIEiwKOECJYFHCEEMGigCOECBYFHCFEsCjgCCGCRQFHCBEsCjhCiGBRwBFCBIsCjhAiWBRwhBDBooAjhAgWBRwhRLAMCjipVIrIyEj4+/sjMjISN27c6LCNSqVCUlISfH19MXPmTI1ZPXWtI4QIR29mRXdyxKApy3tarVrXOkKIcPRmVnQnR/SewfFRrVrXOkKIMPR2VnQnR/SewfFRrdrQStZ1dXWoq6vTWFZVVQUAKC0t1ddVfPrpp3q36apFixZ1WPbJJ5/wfpzFixfz3qYhampqcPOGFP9vWxp/bf5ei6HmYly4cIFbJpPJUCOvwYkLubwdp6GxHjKZTOM498vD8L6p/81VVVVpVJcHeq+yPR850t4DVVVr586dyMjI6HRdcnLyfe5Nm2++6b0Sa+2dPn36vhxHm/Jfq3hvr7P37g/F77we5/Tp03363j0M79uqVas6LOvtyvZ80RtwfFSr1lfJWi0mJgahoaEayxoaGlBQUIDRo0fD1NS0e6+SENJlzc3N+PHHHzFlyhQMHDhQY11vVbbnI0c0MANER0ezw4cPM8YYO3z4MIuOju6wzcGDB1lsbCxTqVTszp07zMvLi1VUVOhdRwgRjt7Miu7kiEEB9/PPP7OIiAjm5+fHIiIiWHl5OWOMscWLF7Pi4mLGGGMtLS1szZo1zMfHh/n4+LC9e/dy++taRwgRjt7Miu7kiBFjjHX95JUQQh58NJKBECJYFHCEEMGigCOECBYFHCFEsCjgCCGC9UCNZLifRo4cCVdXV+73gIAAxMXF8dJ2SUkJZDIZpk2b1uO2PvzwQ+Tk5EAkEkEkEmHdunXYvn07Kisr0djYiNraWm6wcWJiIsaPH6+xf2ZmJvbt2wdTU1MYGRlh8uTJeO2113h5aHrcuHEoKioyePv09HQMGDCgw/C369evIzExEXV1dVAqlfD09NQ7cmXFihW4du0awsPDAQCRkZGwsLDo+ovoRZ19dqmpqVi1ahXGjBkDb29vDBkyBHv27OH2kUgkUKlUyMnJQWFhIZYuXQpnZ2coFAoEBATglVdeQWFhIXbs2IGPP/5Y43jz58+HTCaDubk5AGDo0KFIS+NvKFl/9NAGnLm5ObKzs3ul7ZKSEvzwww89DriioiJ8++23+PLLLyEWi1FbW4vm5mZs3boVALT+RVf797//jYKCAnzxxRewtraGUqlEVlYWFArFAzUqZMOGDYiJiYGvry8AoKysTOf2NTU1KCoq4oYaeXt7Izg4+IEKOG2f3b3u3r2LW7duwcHBAeXl5R3We3p64uOPP0ZjYyNCQkIwffp0ncdNTU3FmDFj+HoZ/R5dorZTX18Pf39/XL9+HUDbWcIXX3wBoO3sKCwsDAEBARr/KxYXF2PevHkIDg5GREQE6uvrkZaWhtzcXEgkEuTmdn+QdE1NDQYNGgSxWAwAGDx4MOzt7Q3e/6OPPsLatWu5YTVisRhxcXGwsrIC0HYGpnbixAkkJCQAABISEpCYmIj58+fDx8cH58+fx+rVq/HMM89w26ht2rQJoaGhiImJ4WaNqKiowKJFixAWFoaoqKhO/+G2J5PJNAZzu7m5AQAUCgVWr16NoKAghISE4Ny5cwCA2NhY3LlzBxKJBBkZGZDJZIiJicH8+fO51/Xuu+8iLCwMCxYsQHFxMfdaTp06BQCorKxEVFQUQkNDERoaikuXLgEAvv76ayxYsACMMchkMvj7+6Ompsbg91zN0M/umWee4f6O5OTkICAgoNP2BgwYgNGjR6OioqLLfXmo8fkUc3/i7u7OgoODuT/Hjh1jjDFWUFDA5s6dy3JyclhsbCy3/e+//84Ya3uaOjo6mpWUlDCFQsG8vb3Z5cuXGWOM1dfXs+bmZnbw4EGWlJTU4z42NDSw4OBg5ufnxxITE1lhYaHG+nPnzrG4uLhO962vr2eenp462x87diz38/Hjx9kbb7zBGGPsjTfeYMuXL2etra3s66+/ZuPGjWOlpaVMpVKx0NBQdvXqVcYYY66uriw7O5sxxlh6ejr3mp9//nkmlUoZY4z973//Y/Pnz2eMMZaWlsY++eSTDv04cOAAGz9+PFu0aBHLzMxkf/zxB2OMsU8//ZQlJCQwxtqekJ82bRprampiv/zyCwsICOD2nzFjBrtz5w73u6urK/v2228ZY4wtXbqULVy4kCmVSlZSUsKCg4MZY4w1NjaypqYmxhhjUqmUhYaGcvuvXLmS7dq1i8XFxbGjR4/qfA+10fbZRUdHc0/0z5gxg12/fp1FRkYyxhiTSCTs2rVr3Gtr//nW1tayGTNmsJ9++knr5x4dHc38/Py4v9ObNm3qVt+FhC5R7zFlyhScOHEC69at01h//PhxfPHFF2hpaUFNTQ3Ky8thZGQEW1tbbvCv+syIL5aWljh06BAuXLiAwsJCxMfHY+XKlQgLCzNofyMjI+7n/Px8pKamor6+HqmpqR3u1d1rxowZMDIygpubGx599FHurOqJJ55AVVUVRo4cCZFIhNmzZwNou3f0yiuv4O7duygqKsI//vEPri2lUqnzWOHh4Zg6dSry8/Nx6tQp7N27F0eOHMHFixcRHR0NAHBxcYGjoyOkUqne99nU1BRPP/00AMDV1RVisRimpqZwdXXlpt9qaWnBunXrUFpaCpFIpDHz7Ntvv43AwECMHTuWm9usq7R9dvd65JFHYG1tjWPHjsHFxYW7f6Z24cIFhISEQCQSYcmSJXjyySdRWFio9bh0iarpoQ04bVpbW1FeXg4zMzPI5XIMGTIEv/zyC3bs2IEDBw7gkUceQUJCAhQKBRhjGiHSG4yNjTFp0iRMmjQJrq6uOHz4sEEBZ2VlBQsLC/zyyy9wdnaGl5cXvLy88MILL3R6L0ihUGj8rr60MjIy4n4GAJFIhJaWlk6PaWRkBMYYrK2tu3x/097eHhEREYiIiEBgYCB++uknsG6OIlR/oaLur7r/IpEIKpUKAJCVlYVHH30U2dnZaG1t5f6TAoDq6mqIRCLcvn0bra2tEIm6dyens8+uM7Nnz8a6deuwcePGDuvU9+BI99A9uHtkZWXBxcUFW7ZswZtvvonm5mbcvXsXFhYWGDhwIG7fvo2zZ88CAEaMGAGZTIbi4mIAbVM7tbS0wNLSEnfv3u1xX65fv65xZlFSUqJ/eph24uLisHbtWm4SUcaYRpA9+uijKC8vR2trK/Ly8rrcv9bWVpw8eRJA2zTUEyZMgJWVFZycnHD8+HHumPomKz179iwXujU1NZDL5bC3t8fEiRNx9OhRAG1z/d+6dQsjRozosH933u/6+nrY2tpCJBIhOzubC76WlhasXr0a//znP+Hi4oLMzMwutavWlc/O19cXixYtwtSpU7t1LKLdQ3sG19TUBIlEwv3u5eWF8PBw7N+/H/v374eVlRUmTpyIDz/8EMuWLcOoUaMQEBAAZ2dn7vJOLBbjvffew/r169HU1ARzc3NkZmZi0qRJ2L59OyQSCV544QXuMq6rGhsbsX79etTV1cHY2BhDhw7FunXrDN4/KioKTU1NmDNnDsRiMSwtLTFu3DiMGjUKALBy5Uq88MILcHBwwJNPPonGxsYu9W/AgAG4du0awsLCYGVlhffffx8A8O6772Lt2rX48MMP0dLSgtmzZ8Pd3V1rO9999x02bNgAMzMzAMDrr78OW1tbREVFITExEUFBQTA2NsbGjRs1zibV5s6diyVLlsDW1ha7du0y+L159dVXceLECUyaNAkDBgwA0PbFjKenJzw9PeHu7o6IiAhMnz4dLi4uXXpvtH127S/d1aysrLr8iNJ//vMf7jIcAD744AMAwGuvvcZd5g4aNAhZWVldaldoaDYRQohg0SUqIUSwKOAIIYJFAUcIESwKOEKIYFHAEUIEiwKOECJYFHCEEMGigCOECNb/AeFJml5oytgFAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 288x192.24 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import math\n",
    "import seaborn as sns\n",
    "sns.set_theme(style=\"white\")\n",
    "\n",
    "sanae_colors = [sns.cubehelix_palette(as_cmap=False)[i] for i in range(6)]\n",
    "\n",
    "gradients = {'Exact': exact_gradient, 'ST Gumbel Softmax': STGS_grads, 'SIMPLE':SIMPLE_grads}\n",
    "x = ['Exact', 'ST Gumbel Softmax', 'SIMPLE']\n",
    "x_axis = np.arange(len(x))\n",
    "\n",
    "bias = [1.0 - F.cosine_similarity(exact_gradient.mean(axis=0), gradients[estimator].mean(axis=0), dim=0) for estimator in x]\n",
    "\n",
    "fig, ax1 = plt.subplots(figsize=(4, 2.67));\n",
    "lns1 = ax1.bar(x=[0, 3, 6], height=bias, alpha=0.75, edgecolor=\"k\", lw=1.5, color=sanae_colors[1], label='Bias')\n",
    "ax1.grid(axis=\"y\")\n",
    "plt.xticks([0.5, 3.5, 6.5], x, fontsize=10)\n",
    "\n",
    "\n",
    "ax2=plt.twinx()\n",
    "\n",
    "variance = []\n",
    "for estimator in x:\n",
    "    mu = gradients[estimator].mean(axis=0)\n",
    "    variance += [(1.0 - F.cosine_similarity(gradients[estimator], mu)).var()]\n",
    "\n",
    "lns3 = ax2.bar(x=[1, 4, 7], height=variance, alpha=0.75, edgecolor=\"k\", lw=1.5, color=sanae_colors[4], label='Variance')\n",
    "\n",
    "ax2.set_ylim([0.0, 0.14])\n",
    "plt.legend([lns1, lns3], ['Bias', 'Variance'],loc=1)\n",
    "\n",
    "plt.savefig('bias_variance_SIMPLE_STGS.pdf',bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
