{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "def plot_categorical(dist: torch.distributions.Categorical):\n",
    "    x = np.arange(dist.probs.shape[0])\n",
    "    y = dist.probs.numpy()\n",
    "    plt.bar(x, y)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def gumbel_max(logits, gumbel=None):\n",
    "    if gumbel is None:\n",
    "        gumbel = -torch.empty_like(logits).exponential_().log()\n",
    "    return (logits + gumbel).argmax(), gumbel\n",
    "\n",
    "\n",
    "def gumbel_max_rejection_sampling(probs, observation, max_iterations: int=10000) -> tuple[int, torch.Tensor]:\n",
    "    out_ = None\n",
    "    steps = 0\n",
    "    while out_ != observation:\n",
    "        out_, g = gumbel_max(torch.log(probs))\n",
    "        print(out_, g)\n",
    "        steps += 1\n",
    "        if steps > max_iterations:\n",
    "            print(\"Max iterations reached\")\n",
    "            return None\n",
    "    return out_, g \n",
    "\n",
    "\n",
    "def plot_PA(P_A):\n",
    "    # plot P(A)\n",
    "    pa_values = [p.cpu() for p in list(P_A.values())]\n",
    "    plt.bar(range(len(P_A)), pa_values) \n",
    "    plt.xticks(range(len(P_A)), list(P_A.keys()))\n",
    "    print(pa_values)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "def estimate_arm_probs(control_probs, env):\n",
    "    P_A = {\n",
    "        k: 0.0 for k in env.arm_keys\n",
    "    }\n",
    "    for arm_index, arm in enumerate(env.arm_keys):\n",
    "        for c_index, c_prob in enumerate(control_probs):\n",
    "            # estimate p_arm_given_index (we actually have it in this case)\n",
    "            P_A[arm] += c_prob * env.arm_given_index[c_index].probs[arm_index]\n",
    "    return P_A\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Action Shift\n",
    "\n",
    "Each control is associated with an arm, but this relationship changes depending on the state. \n",
    "An analogy for this is that controls here are directions (up right down left), while abstractions are cardinal directions, or locations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAArEAAAIjCAYAAAAUdENlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABUEUlEQVR4nO3deXhTddr/8c9Jl3RfaWmBsqPsgiCI4AACsoiCjsrgOAg66ozihjMKXiMu8wiOgz64jaI8A+LPBXDfUBFEHBQVEFQEBhQospVCaUv3Nuf3R2wgXWjSpjlJ+35dVy7IyUly0wU+3L3P92uYpmkKAAAACCI2qwsAAAAAvEWIBQAAQNAhxAIAACDoEGIBAAAQdAixAAAACDqEWAAAAAQdQiwAAACCDiEWAAAAQYcQCwAAgKBDiAXgsalTp6p9+/ZWlwHIMAzdf//9VpcBwEKEWCDILF68WIZh1Hpbv3691SU2K3PmzNFbb71ldRl+V1hYqPvvv19r1qyxuhQAzVSo1QUAqJ8HH3xQHTp0qHa8c+fOjfaezz//vBwOR6O9fjCaM2eOLr/8ck2cONHqUvyqsLBQDzzwgCRp2LBhfn//oqIihYbyTxjQnPE3ABCkxo4dq/79+/v1PcPCwuo8p7y8XA6HQ+Hh4X6oCM1VRESE1SU0moKCAkVHR1tdBhDwGCcAmqg9e/bIMAzNmzdPzz33nDp16iS73a5zzjlH33zzjeu8efPmyTAM7d27t9przJo1S+Hh4crJyZFUfSb21PeYP3++6z1+/PFHSdLq1at1/vnnKzo6WgkJCZowYYK2bdvm9h7333+/DMPQrl27NHXqVCUkJCg+Pl7Tpk1TYWGh27mGYWj69Olavny5unfvrsjISA0aNEjff/+9JGnBggXq3LmzIiIiNGzYMO3Zs6fan+mrr77SmDFjFB8fr6ioKA0dOlTr1q2rV02GYaigoEAvvPCCa5xj6tSptX5O1qxZI8MwtHTpUt1zzz1KS0tTdHS0LrnkEu3bt8/t3M8//1xXXHGF2rZtK7vdroyMDN1xxx0qKipynbNo0SIZhqFvv/222nvNmTNHISEh2r9/v+vY8uXL1a9fP0VGRqpFixa6+uqr3R6XnF3Vmjqrp37u9+zZo5SUFEnSAw884PqzV86oHjp0SNOmTVObNm1kt9uVnp6uCRMm1Pj5qKrycxsREaGePXvqzTffrHEW+9T3e+2112QYhj777LNqr7dgwQIZhqEffvjBdWz79u26/PLLlZSUpIiICPXv31/vvPOO2/Mqx3bWrVunGTNmKCUlRdHR0br00kt15MiROv8c3333naZOnaqOHTsqIiJCaWlpuvbaa3X06FG38yq/1n788UddddVVSkxM1JAhQyRJ7du31/jx47VmzRr1799fkZGR6tWrl2uE44033lCvXr0UERGhfv361fh1ADRldGKBIJWbm6vs7Gy3Y4ZhKDk52e3Yyy+/rPz8fN14440yDEOPPPKILrvsMv38888KCwvTlVdeqbvuukvLli3TX//6V7fnLlu2TBdeeKESExNPW8uiRYtUXFysG264QXa7XUlJSfrkk080duxYdezYUffff7+Kior05JNPavDgwdq0aVO1UHLllVeqQ4cOmjt3rjZt2qSFCxcqNTVV//jHP9zO+/zzz/XOO+/o5ptvliTNnTtX48eP11133aV//etfuummm5STk6NHHnlE1157rVavXu167urVqzV27Fj169dP9913n2w2mxYtWqQLLrhAn3/+uQYMGOBVTS+++KL++Mc/asCAAbrhhhskSZ06dTrtx0qSHnroIRmGobvvvltZWVmaP3++Ro4cqc2bNysyMlKSM8wVFhbqz3/+s5KTk/X111/rySef1C+//KLly5dLki6//HLdfPPNeumll9S3b1+393jppZc0bNgwtW7dWpIzlE2bNk3nnHOO5s6dq8OHD+vxxx/XunXr9O233yohIaHOuiulpKTomWee0Z///GddeumluuyyyyRJvXv3liT99re/1datW3XLLbeoffv2ysrK0sqVK5WZmXnaCwPff/99TZo0Sb169dLcuXOVk5Oj6667zvVnqM1FF12kmJgYLVu2TEOHDnV7bOnSperRo4d69uwpSdq6dasGDx6s1q1ba+bMmYqOjtayZcs0ceJEvf7667r00kvdnn/LLbcoMTFR9913n/bs2aP58+dr+vTpWrp06WlrWrlypX7++WdNmzZNaWlp2rp1q5577jlt3bpV69evl2EYbudfccUV6tKli+bMmSPTNF3Hd+3apauuuko33nijrr76as2bN08XX3yxnn32Wd1zzz266aabJDm/D6688krt2LFDNhv9KTQTJoCgsmjRIlNSjTe73e46b/fu3aYkMzk52Tx27Jjr+Ntvv21KMt99913XsUGDBpn9+vVze5+vv/7alGQuWbLEdeyaa64x27VrV+094uLizKysLLfn9+nTx0xNTTWPHj3qOrZlyxbTZrOZU6ZMcR277777TEnmtdde6/b8Sy+91ExOTnY7Vvln3L17t+vYggULTElmWlqamZeX5zo+a9YsU5LrXIfDYXbp0sUcPXq06XA4XOcVFhaaHTp0MEeNGlWvmqKjo81rrrnG9MSnn35qSjJbt27tVuuyZctMSebjjz/uVldVc+fONQ3DMPfu3es6NnnyZLNVq1ZmRUWF69imTZtMSeaiRYtM0zTN0tJSMzU11ezZs6dZVFTkOu+9994zJZmzZ892HRs6dKg5dOjQau9d9XN/5MgRU5J53333uZ2Xk5NjSjL/+c9/1vnxqKpXr15mmzZtzPz8fNexNWvWmJLc3ts0zWrvPXnyZDM1NdUsLy93HTt48KBps9nMBx980HVsxIgRZq9evczi4mLXMYfDYZ533nlmly5dXMcqv89Gjhzp9vVyxx13mCEhIebx48dP+2ep6fP3yiuvmJLMtWvXuo5Vfq1Nnjy52vnt2rUzJZlffPGF69hHH31kSjIjIyPdvg4qvw8+/fTT09YFNCX8dw0IUk8//bRWrlzpdluxYkW18yZNmuTWST3//PMlST///LPbORs3btRPP/3kOrZ06VLZ7XZNmDChzlp++9vfun68LEkHDx7U5s2bNXXqVCUlJbmO9+7dW6NGjdIHH3xQ7TX+9Kc/ud0///zzdfToUeXl5bkdHzFihFs3b+DAga4aYmNjqx2v/HNu3rxZO3fu1FVXXaWjR48qOztb2dnZKigo0IgRI7R27dpqF615WpO3pkyZ4lbr5ZdfrvT0dLePS2VHVnLOSGZnZ+u8886TaZpuPzaeMmWKDhw4oE8//dR17KWXXlJkZKR++9vfSpI2bNigrKws3XTTTW6zpBdddJG6du2q999/v0F/nlNFRkYqPDxca9ascY2heOLAgQP6/vvvNWXKFMXExLiODx06VL169arz+ZMmTVJWVpbbagmvvfaaHA6HJk2aJEk6duyYVq9erSuvvFL5+fmur4GjR49q9OjR2rlzZ7XxihtuuMGta3r++eeroqKixvGbU536+SsuLlZ2drbOPfdcSdKmTZuqnV/1a61S9+7dNWjQINf9yq/rCy64QG3btq12/NTva6CpI8QCQWrAgAEaOXKk22348OHVzjv1HzpJrkB7asC44oorZLPZXD8iNU1Ty5cv19ixYxUXF1dnLVVXSaj8B/7MM8+sdm63bt1c4dHbOms6Lz4+XpKUkZFR4/HK5+/cuVOSdM011yglJcXttnDhQpWUlCg3N7deNXmrS5cubvcNw1Dnzp3dZkYzMzNd/wmIiYlRSkqK60flp9Y5atQopaen66WXXpIkORwOvfLKK5owYYIrKJ/u89G1a9c6A5k37Ha7/vGPf2jFihVq2bKlfvOb3+iRRx7RoUOHTvu8yhpqWl3DkxU3KuecT/0x/9KlS9WnTx+dccYZkpw/mjdNU/fee2+1r4H77rtPkpSVleX2uvX9Gjh27Jhuu+02tWzZUpGRkUpJSXF9n1T9OpOqfw/V9v6efr0DzQEzsUATFxISUuNx85S5u1atWun888/XsmXLdM8992j9+vXKzMysNo9am1O7To1Z5+nOq+v5lV3Wf/7zn+rTp0+N557aAfSmJl+rqKjQqFGjdOzYMd19993q2rWroqOjtX//fk2dOtWtYxwSEqKrrrpKzz//vP71r39p3bp1OnDggK6++up6vbdhGDX++SoqKjx+jdtvv10XX3yx3nrrLX300Ue69957NXfuXK1evbra7K6v2O12TZw4UW+++ab+9a9/6fDhw1q3bp3mzJnjOqfy4/aXv/xFo0ePrvF1qgbm+n4NXHnllfriiy/017/+VX369FFMTIwcDofGjBlT4zJ1tX0P1ffrHWgOCLEAJDl/HHvTTTdpx44dWrp0qaKionTxxRfX67XatWsnSdqxY0e1x7Zv364WLVr4fQmhyguu4uLiNHLkSJ+9btULdDxR2RWuZJqmdu3a5bow6vvvv9d///tfvfDCC5oyZYrrvJUrV9b4elOmTNGjjz6qd999VytWrFBKSopbSDv183HBBRe4PXfHjh2uxyVnp7GmH0lX7dbW9efu1KmT7rzzTt15553auXOn+vTpo0cffVT/7//9vxrPr6xh165d1R6r6VhNJk2apBdeeEGrVq3Stm3bZJqma5RAkjp27CjJuVScL78GqsrJydGqVav0wAMPaPbs2a7jVT/vABqGcQIAkpwzpSEhIXrllVe0fPlyjR8/vt5BMz09XX369NELL7yg48ePu47/8MMP+vjjjzVu3DgfVe25fv36qVOnTpo3b55OnDhR7XFPlk2qSXR0tNuf0RNLlixRfn6+6/5rr72mgwcPauzYsZJOdtlO7aqZpqnHH3+8xtfr3bu3evfurYULF+r111/X7373O7eNAPr376/U1FQ9++yzKikpcR1fsWKFtm3bposuush1rFOnTtq+fbvbx2PLli3VliGLioqSpGp/9sLCQhUXF7sd69Spk2JjY93eu6pWrVqpZ8+eWrJkidvn57PPPnMtoVaXkSNHKikpSUuXLtXSpUs1YMAAtx/Tp6amatiwYVqwYIEOHjxY7fn1/RqoqqbPnyTNnz/fJ68PwIlOLBCkVqxYoe3bt1c7ft5557k6Tt5ITU3V8OHD9dhjjyk/P9+tg1Uf//znPzV27FgNGjRI1113nWuJrfj4eEv2vLfZbFq4cKHGjh2rHj16aNq0aWrdurX279+vTz/9VHFxcXr33Xe9ft1+/frpk08+0WOPPaZWrVqpQ4cOrotsapOUlKQhQ4Zo2rRpOnz4sObPn6/OnTvr+uuvl+ScU+3UqZP+8pe/aP/+/YqLi9Prr79+2nnHKVOm6C9/+YskVRslCAsL0z/+8Q9NmzZNQ4cO1eTJk11LbLVv31533HGH69xrr71Wjz32mEaPHq3rrrtOWVlZevbZZ9WjRw+3C9oiIyPVvXt3LV26VGeccYaSkpLUs2dPlZeXa8SIEbryyivVvXt3hYaG6s0339Thw4f1u9/97rQflzlz5mjChAkaPHiwpk2bppycHD311FPq2bNnjf/xqCosLEyXXXaZXn31VRUUFGjevHnVznn66ac1ZMgQ9erVS9dff706duyow4cP68svv9Qvv/yiLVu21Pk+dYmLi3PNApeVlal169b6+OOPtXv37ga/NoCT6MQCQWr27Nn6wx/+UO22du3aer/mpEmTlJ+fr9jY2AZ3S0eOHKkPP/xQycnJmj17tubNm6dzzz1X69atq/UilsY2bNgwffnll+rfv7+eeuop3XLLLVq8eLHS0tLcgpw3HnvsMfXr109/+9vfNHnyZD3zzDN1Pueee+7RRRddpLlz5+rxxx/XiBEjtGrVKld3MywsTO+++6769OmjuXPn6oEHHlCXLl20ZMmSWl/z97//vUJCQnTGGWdUW+9Wcm5WsHTpUpWWluruu+/WggULdOmll+o///mP2xqx3bp105IlS5Sbm6sZM2bonXfe0Ysvvqizzz672msuXLhQrVu31h133KHJkyfrtddeU0ZGhiZPnqw1a9Zo1qxZmjVrlvLy8rRs2TLXagm1ufjii/XKK6+otLRUM2fO1BtvvKHFixfrzDPP9HiHrkmTJrkC75VXXlnt8e7du2vDhg266KKLtHjxYt1888169tlnZbPZ3H7031Avv/yyRo8eraefflqzZs1SWFhYjauHAKg/w2QKHAD8Ys2aNRo+fLiWL1+uyy+/3KevnZ2drfT0dM2ePVv33nuvT1/ban369FFKSkqtM8EAmic6sQDQBCxevFgVFRX6wx/+YHUp9VZWVqby8nK3Y2vWrNGWLVtq3AoXQPPGTCwABLHVq1frxx9/1EMPPaSJEyeedlvXQLd//36NHDlSV199tVq1aqXt27fr2WefVVpaWq2bAQBovgixABDEHnzwQX3xxRcaPHiwnnzySavLaZDExET169dPCxcu1JEjRxQdHa2LLrpIDz/8sJKTk60uD0CAYSYWAAAAQYeZWAAAAAQdQiwAAACCTrOaiXU4HDpw4IBiY2PrtVUkAAAAGpdpmsrPz1erVq1ks9Xeb21WIfbAgQPKyMiwugwAAADUYd++fWrTpk2tjzerEBsbGyvJ+UGJi4uzuBoAAABUlZeXp4yMDFduq02zCrGVIwRxcXGEWAAAgABW1+gnF3YBAAAg6BBiAQAAEHQIsQAAAAg6hFgAAAAEHUIsAAAAgg4hFgAAAEGHEAsAAICgQ4gFAABA0CHEAgAAIOgQYgEAABB0CLEAAAAIOoRYAAAABB1CLAAAAIIOIRYAAABBhxALAACAoEOIBQAAQNAhxAIAACDoEGIDQHFZhdUlAAAABBVCbAA4VlBqdQkAAABBhRAbAEyrCwAAAAgyhNgAYJrEWAAAAG8QYgMAGRYAAMA7hNgAQIgFAADwDiE2ADhIsQAAAF4hxAYAIiwAAIB3CLEBgE4sAACAdwixAYAMCwAA4B1CbABgiS0AAADvEGIDgIMMCwAA4BVCbAAwubQLAADAK4TYAOBwWF0BAABAcCHEBgBWJwAAAPAOITYAkGEBAAC8Q4gNAHRiAQAAvEOIDQCEWAAAAO8QYgMAIRYAAMA7hNgAwDqxAAAA3iHEBgAHKRYAAMArhNgAQIYFAADwTtCE2GeeeUa9e/dWXFyc4uLiNGjQIK1YscLqsnyCmVgAAADvBE2IbdOmjR5++GFt3LhRGzZs0AUXXKAJEyZo69atVpfWYBW0YgEAALwSanUBnrr44ovd7j/00EN65plntH79evXo0cOiqnyDTiwAAIB3gibEnqqiokLLly9XQUGBBg0aVOt5JSUlKikpcd3Py8vzR3leI8QCAAB4J2jGCSTp+++/V0xMjOx2u/70pz/pzTffVPfu3Ws9f+7cuYqPj3fdMjIy/Fit5yocVlcAAAAQXAzTDJ42YGlpqTIzM5Wbm6vXXntNCxcu1GeffVZrkK2pE5uRkaHc3FzFxcX5q+w6ffjDIY3pmWZ1GQAAAJbLy8tTfHx8nXktqMYJwsPD1blzZ0lSv3799M033+jxxx/XggULajzfbrfLbrf7s8R6YZwAAADAO0E1TlCVw+Fw67QGK1YnAAAA8E7QdGJnzZqlsWPHqm3btsrPz9fLL7+sNWvW6KOPPrK6tAajEwsAAOCdoAmxWVlZmjJlig4ePKj4+Hj17t1bH330kUaNGmV1aQ1WXkGIBQAA8EbQhNj/+7//s7qERlNhmjJNU4ZhWF0KAABAUAjqmdimwjRNMVEAAADgOUJsAHCYzMUCAAB4gxAbABymKRYoAAAA8BwhNgDQiQUAAPAOITYAMBMLAADgHUJsAHA4TDqxAAAAXiDEBgDGCQAAALxDiA0AXNgFAADgHUJsADBN50gBAAAAPEOIDQAVJjOxAAAA3iDEBgDGCQAAALxDiA0ArE4AAADgHUJsAHCYUgWtWAAAAI8RYi1m/joPSycWAADAc4RYi5lm5eoEVlcCAAAQPAixFqv4tQNbQScWAADAY4RYi1XOwjJOAAAA4DlCrMUqwyubHQAAAHiOEGuxyk4s4wQAAACeI8RarPKCrvIKQiwAAICnCLEWq+zAMhMLAADgOUKsxVzjBMzEAgAAeIwQazEHnVgAAACvEWItVv5rB5aZWAAAAM8RYi3mYHUCAAAArxFiLVbZiWXbWQAAAM8RYi1W8Wt6LSfFAgAAeIwQa7EKR+WvjBMAAAB4ihBrsXJXJ5YQCwAA4ClCrMVYJxYAAMB7hFiLuZbYIsQCAAB4jBBrsZOdWC7sAgAA8BQh1mKVmxyw2QEAAIDnCLEWq7ywi5lYAAAAzxFiLVY5C1tGiAUAAPAYIdZilWMEFRXMxAIAAHiKEGuxCtaJBQAA8Boh1mJlFSyxBQAA4C1CrMVcO3YxTgAAAOAxQqzFKjuxZRWmTJNuLAAAgCcIsRY7dX3YMtaKBQAA8Agh1mLlp+zUVc6uXQAAAB4hxFqsjE4sAACA1wixFis75YIuLu4CAADwDCHWYqcGVzqxAAAAniHEWsg0zSrjBHRiAQAAPEGItVDVzishFgCAurEkJSRCrKWqhlZ27QIAAPAMIdZCVUNsaTmdWAAAAE8QYi1UWiXEMk4AAADgGUKsharOxNKJBQAA8Awh1kJl5VU7sczEAgBQF1P8ewlCrKUYJwAAwHsOk38vQYi1VNXxgaqhFgAAVMcSW5AIsZaqGlqZiQUAoG4VZoXVJSAAEGItVHUmlhALAEDdyhxlVpeAAECItVC1TizjBEBg4UeWQEAqLi+2ugQEAEKshap2XrmwCwCAuuWV5lldAgIAIdZCVUNsSRkhFgCA03GYDh0tOmp1GQgAhFgLMU4AAIB38krydKz4mNVlIAAQYi1UUsOFXSwbAgQQvh+BgHO48LAOFx7m30sET4idO3euzjnnHMXGxio1NVUTJ07Ujh07rC6rQaqG2NqOAbAK/0gCgean4z+poKxAhwoOWV0KLBY0Ifazzz7TzTffrPXr12vlypUqKyvThRdeqIKCAqtLq7ealtQixAIBhE4PEFCyi7K1N2+vJGlj1ka6sc1cqNUFeOrDDz90u7948WKlpqZq48aN+s1vfmNRVQ1TU4hlrVgAAKorc5RpdeZqmb/+hCQzL1Pbj21Xt+RuFlcGqwRNiK0qNzdXkpSUlFTrOSUlJSopKXHdz8sLnCU5TNNUSXn1HUdqOgbAIuzPDgSE0opSfbjnQ2UXZbsdX/PLGoXZwtQ5sbNFlcFKQTNOcCqHw6Hbb79dgwcPVs+ePWs9b+7cuYqPj3fdMjIy/Fjl6ZVWOGr8SSWdWCCAEGIByx0uOKzXd76uX/J/qfaYaZr6eO/H+mL/F+zi1QwFZSf25ptv1g8//KD//Oc/pz1v1qxZmjFjhut+Xl5ewATZ2mZfmYkFAgghFrBMXmmeNh7aqO3HtrtGCGqz+chm/ZT7k85JO0dnJJ4hmxGUPTp4KehC7PTp0/Xee+9p7dq1atOmzWnPtdvtstvtfqrMO7VtbFBcxjgBEDBMvh8BfzJNUwcKDmhr9lb9lPuTVxdu5Zfma3Xman1z6Bv1SO6hrkldFRUW1YjVwmpBE2JN09Qtt9yiN998U2vWrFGHDh2sLqlBapt9ZZwACCCOcqsrAJo8h+lQVmGWfj7+s3Yd36UTZSca9Hr5pflaf3C9vjr0ldrEtFHnhM5qF9eOQNsEBU2Ivfnmm/Xyyy/r7bffVmxsrA4dcq4PFx8fr8jISIur815tYwPFhFggcFQwYwf4mmmayi/L14ETB7Qvf5/25e9TcXlxo7xP5etLUkpUitrGtlXrmNZqGd1SYbYwn78n/CtoQuwzzzwjSRo2bJjb8UWLFmnq1Kn+L6iBahsnKGGcAAgcFaVWVwAEvXJHubKLspVVmKXDhYd18MTBBndb6+NI4REdKTyijYc3ymbYlBqVqrToNLWMaqnUqFTFhMXIMAy/14X6C5oQ29QWNC6uZZyAC7uAAFLm++4Q0JSVVpTqaNFRZRdlO2/F2TpadFSOALtI0mE6dKjgkNuuX1GhUUqJSlFyZLJaRLRQi8gWirfHE2wDWNCE2KaGC7uAIFDq/24REAzKKsqUU5KjY8XHdKz4mHKKnb/PL823urR6Kywv1N68va4dwSQp1BaqxIhEJUUkKSkiSYn2RCVFJik2LJZwGwAIsRaprRPLTCwQQEqC9x9koKEcpkP5pfnKLclVTkmO89fiHB0vOa6CsuDd8t0b5Y5y1xjCqUJtoYoPj1d8RLwS7YlKsCco3h6vBHuCIkIjLKq2+SHEWqS22VdmYoEAUnzc6gqARlXuKFd+ab7ySvOUW5LrvJU6f80rzWtyo3y+Uu4o19HiozpafLTaY/YQu+Lt8c5beLzr93HhcYoMjaSD60OEWIuw2QEQ4ExTKjhS93lAADNNU0XlRcorzXPeSvJcgTWvNE+FZYV1biQA75RUlCirMEtZhVnVHgu1hSouPE5x4XGuYBsbHqs4u/NXVkzwDiHWIkWlta8TW+EwFWLjf2qApcpLpMJjzjBL5wQBrKyi7GRILc1zdlZLTt4vZ73jgFHuKHfNEdckOizaFXIrw23l/eiwaLq4VRBiLXK6C7iKyyoUbedTA1iqONe5TmxZoRQebXU1aMZM01RheaHySvKUW5rrFlBzS3JVVF5kdYnwkYKyAhWUFehgwcFqj9kMm2LDY10d3LjwOMXZ4xQfHq84e1yz7OKSlCxyugu4CLFAACj8ddat8BghFo2uMqgeLzl+cjb1lBlVuqlwmA7X10RNKru4lReZnXprqgGXpGQB0zRrXWJLYoUCICCc+HX9yBOHpYQMa2tBk2Gapk6UnVB2UbZyinOUU5Kj48XHlVOSo1I210AD1NbFNWQoNjxWCREJSrInOX+NSFJyRLLCQoI73BJiLVBS7pDjNFd81jYvC8CPcvf/+usvUpv+1taCoGSapvJK83So4JCyCrOUXZSto8VHCavwK1Oma/wkU5mu44YMxdnjlByZrJTIFLWMaqmWUS2DKtgSYi1Q14YGbHgAWKysWMo74Px9zh7J4ZBsNktLQnAoLi/W3ry92p23W4dOHFJheaHVJQE1MmW6xhN+Pv6zJMkwDLWIbKHWMa3VMb6jWka1DOiLyQixFig+zSiB83FCLGCpozulym0yy4qk3Ewpsb2lJSGw5Zbkav3B9fo592fWVkXQMk3TtbnD5qzNig6LVr+W/dQ9ubtsRuD9R54Qa4G6O7HMxAKWOvhd9fuEWNTCNE29sfMNVglAk1NQVqC1v6xVuaNcfVL7WF1ONYEXq5uBojpCbF2PA2hEeQel45nux45sdy65BdTAMAwNSB+g8JBwq0sBfC49Ol1dErtYXUaN6MRagBALBCjTlHZ/Vv24o0La8x+p60X+rwlBoUdyD52RcIZ2Hd+ln3J/0sETB1XmKLO6LKBe4u3xyojNULekbkqJSrG6nFoRYi1QXMfqA3U9DqCRHNkuHdtd82MHv5PSerPcFmoVFhKmbsnd1C25myocFcoqzNKBggM6UnhEWYVZOlF2wuoSgWpCjBAlRyYrNSpVqVGpah3TWrHhsVaX5RFCrAXoxAIBqDhX+u9Hpz9n+3tSv2lSWIR/akLQCrGFKD0mXekx6a5jhWWFyi7Kdm07eqz4mHKKc+jYwm9iw2OVGJHoWic2KSJJSRFJCrGFWF1avRBiLVBXSGV1AsDPKsqkH95wrkRwOkXHpW3vSD0vZ8kteC0qLEptw9qqbVxb17HKzQ8qd+rKKc5x/T6/NF+mWOkA3gkPCVe8PV4J9gTX7l2J9kTF2+Ob3Nw2IdYCdW1mUFzmkGmaAb02G9BkVJQ7A2z+Ic/OP/qTtOMD53ws36NoIMNw7qYUGx6rjFj3UZUKR4XyS/OVW5qrvJI85ZY61/TMK81TXkmeKkwaHs1VVGiU4uxxig+PV5w9TnHhcYq3xysuPE6RoZHNJj8QYi1QV6fVYZoqKXcoIiw42/tA0Cgvlba+KR372bvnHfpekimdOU4K0h/DIfCF2EKUEJGghIiEao+ZpqnC8kLlleS5dmM69fcFZQX+Lxg+E2oLVWx4rOLCnQG1MqhW3oJpV63GRIi1gCczr0WlFYRYoDEV50k/vO55B7aqQz9IpQVS9wlSWKRvawPqYBiGosOiFR0WrXSlV3u83FGu/NJ8V7h1/b40T7kluczhWsyQ8/NXGU5jw2NdndTm1k1tCEKsn5mmqaLSujczKCyrUKIf6gGapZw90o/vOENoQxzbLW18QeoxUYpN80VlgE+E2kKVGJGoxIjq/5KYpqniimLXiMKpv9LF9Z0QI8TtR/6n/hoTHqNQGxGsofgI+llJuUMOD7YkrGtuFkA9VJRLe9ZK+752rgnrC0U50qYlUvvzpYyBXPCFgGcYhiJDIxUZGqmW0S2rPV7mKHPO3pbk6XjJcdeFZsdLjrMrWRU2w6a48LiTF1FFxCs+3HlRVXRYNN3URkaI9TNPwykrFAA+djxT2vGhVHjU96/tqJB+XiNl75DOGCvFVg8GQLAIs4WpRWQLtYhsUe2xkooS1woKx4qP6XjxceWU5CivJK9Jr6QQHhKuRHuiq7udYE9Qoj1RcfY42Qz+42oVQqyfeboGbCGdWMA3ivOcAfPw1sZ/r7yD0sZFUquzpQ7nMyuLJsceYldadJrSot3HZ8ocZcopztGx4mM6WnRU2UXZyi7KVklFiUWV1o8hQ/H2eKVEpZxcRzUySbFhsXRVAxAh1s88DadseAA0UFmxtO8r6ZevnWME/mKa0v6N0uEfpHaDpdZnS1xJjCYuzBbm2vGpkmmayivNU3ZRtrIKs3S48LCyCrNU7vDj92MdYsNjlRad5qw9MlUtIltw5X8QIcT6madjAkWlgfNNDgSV8lLpwCYp80tnkLWsjhLpp9XSL99I7c6T0s9iOS40K4bh7GrG2+PVKaGTJOfat0eLj+rAiQP65cQvOnjioF9XSkiwJ6hNbBu1jmmttOg0RYdF++294XuEWD+jEws0krJiZ3j95RuptNDqak4qyXduZ5v5pZRxrpTem84smq0QW4irY9sntY8qHBU6XHhYu3N3a3fubuWV5vn0/WyGTa1jWqtjfEe1jWur2PBYn74+rEWI9TNmYgEfKznhDK4HvnV2PwNVcZ6082Np73+kNudIrfoyM4tmL8QWolYxrdQqppXOa3WejhQd0fZj2/XfnP+qtKK03q/bMqqluiZ1VefEzrKH2H1YMQIJIdbPPB0TYIktoA4nspzh9fBW5+oAwaK0UPr5M2nvF84Rg9b9pKgkq6sCLGcYhqtLOyh9kLYd26Zvs771at3aDvEd1De1b7ULz9A0EWL9zONxgtIKmabJ1ZDAqUxTOvqTM7zm7LG6moapKJN+2eC8CCy5s7M7m9BW4nseUFhImHqn9Fa3pG7alLVJm7I2yTzN2s4J9gQNyximVjGt/FglrEaI9TNPxwnKHaZKKxyyh3IhCKCyYunQd9L+Tc7NBZoS05SydzpvMSlS6/5Syx7MzQJyhtmB6QPVNratVuxZoeLy6hdrdknsomEZwxRm43umuSHE+pk3YwLFpYRYNHMnjjgv1jr0vbNz2dSdOCLtWCH9/Klz1KBVXymSDaiB9Jh0XdLpEr2580231Qy6JHbRyLYj+allM0WI9SPTNL0KsYVl5YoX/7NEM+NwSMd+cv6oPdhHBuqrrFjK/Mq5PW5yZ6lNfymhHaMGaNZaRLbQ+W3O1+rM1ZKkuPA4DcsYRoBtxgixflRa4VC5w/Nt+VihAM2Ka2Rgo1R03OpqAsOpowbRLZxhtmVPRg3QbJ2ZeKa+zfpWOcU56p/WnxGCZo4Q60ferjjACgVoFoqOS/s3SAe3ODcqQM0KsqUdHzpXNmh9tnNVg3AWakfzYhiGzkw8UxsOb3BtoIDmixDrR952VtnwAE3aiSPODQCytkmmw+pqgkdZkbRnnXPcIK2X1PZcKTLB6qoAv2kd01r78vfRhQUh1p+8DaWME6BJOnFE2vO5dGSH1ZUEN0e5c4OHg1uktJ7OrW25CAzNQFJkkpIiWFsZhFi/8n6cwLONEYCgUJwr7V7r3JzgNOs9wkumQzr4nXToB+dqBu0HM2aAJi3MFqYWkS2sLgMBgBDrR4wToFlyVEj7vpL2rpMq+I9ZozEdzoviDv8gdRwqtTqb1QzQZCXYE6wuAQGAEOtHhV52VhknQNAryJZ+fNu5RSz8o7xE+u/HUtZ2qdt4KSLe6ooAn4sKi7K6BAQAm9UFNCfFXnZWWZ0AQS17p7RxMQHWKsczpQ2LpOP7rK4E8Dku6oJEiPUrr8cJSitOu1c0ELCO/Ff64Y3msctWICsrkr57Vcr9xepKAJ8KtfGDZBBi/crbEFvuMFVawdJDCDLFudK2d1g2K1BUlEtb33SOGQBNhM0gvoAQ61fejhNIjBQgCP3yDR3YQFNywrkUF9BEEGIhEWL9xjTNegVSVihA0Mk/bHUFqEn+IasrAHzGECtvgBDrN6UVDpU7vJ9vZYUCBB17jNUVoCZ8XgA0MYRYP6nvWADjBAg6ab2trgBVGTapZS+rqwAAnyLE+kl9xwIYJ0DQSerg3DkKgaP9ECkmxeoqAMCnCLF+QicWzUqXC6W0nlZXAUlqO1Bqd57VVQCAz7HQmp/Ud7aVTiyCks0mdR0vRbWQdq9luS0rhIQ6/zORfpbVlQA+Z7ClMkSI9Zv6LK8l0YlFEDMMqd0gKbG9tON96cQRqytqPhIypDPHSVFJVlcCAI2GEOsnzMSi2YpLl/pdKx3YJO35j3MXKTSOiDipw1CpZQ/nfyIAoAkjxPoJM7Fo1mw2qU1/qWVPad9XbIjga2GRUttBUuuzpRD2lAfQPBBi/YROLCApLELqOFRqc460f4P0ywa2Q22I8GgpY4BzNYhQu9XVAIBfEWL9pL4zsaXlDlU4TIXY+NEgmpDwKKnDb6SMgdKBzc7ObEm+1VUFj6gkZ3ht2ct5ARcANEP87ecnDRkLKCqrUIydTxWaoFC7cwmoNv2lI9udYTbvoNVVBa7EdlKbAVJyJ2ZeATR7JCM/KSqr/xJDxYRYNHW2EOfFSKndpbz9zjGDIztYmkuSbKHOj02b/lJMqtXVAEDAIBn5gcNhqqS8AZ1YLu5Cc2EYUnwb5604Tzq4WTrwrVRaaHVl/hcR77xQK623c/wCAOCm3iG2tLRUu3fvVqdOnRQaShY+nZJyh0yzIc8nxKIZiohzzs22Pc85arB/Q/MYNUhsJ7XuLyV3dq7qAACokdd/QxYWFuq6665TVFSUevTooczMTEnSLbfcoocfftjnBTYF9b2oq1JRKT9SRTMWEurcwrbfVOnsKVLL7pLRxMKdLVRq1Uc6549Sn6uklDMIsABQB6//lpw1a5a2bNmiNWvWKCIiwnV85MiRWrp0qU+LayoaukxWMZ1YwCm+tdR9gnTun527gYVF1P2cQGaPcS45Nuhm6cyxUkyK1RUBQNDwOsS+9dZbeuqppzRkyBC3vYt79Oihn376yafFVbV27VpdfPHFatWqlQzD0FtvvdWo7+crDe3ENvT5QJMTESd1HCade7PU5UIpMsHqirwT3ULqNl469yap3XnMvAJAPXg9zHrkyBGlpla/QragoMAt1DaGgoICnXXWWbr22mt12WWXNep7+VJxA1YmkLiwC6hVaLjUpp9zsf8j26XML6QTR6yuqnbxbZyhNakjS2QBQAN5HWL79++v999/X7fccoskuYLrwoULNWjQIN9WV8XYsWM1duxYj88vKSlRScnJ3YDy8vIao6w6NXQcoKScmVjgtGw256xsajcpe6e09z9S/mGrqzopIUNqP0RKaEd4BQAf8TrEzpkzR2PHjtWPP/6o8vJyPf744/rxxx/1xRdf6LPPPmuMGutt7ty5euCBB6wug3ECwF8Mw3lRVIsuznVmd6+VCo9aV09smnPmNbED4RUAfMzrmdghQ4Zo8+bNKi8vV69evfTxxx8rNTVVX375pfr169cYNdbbrFmzlJub67rt27fPkjpKGjhOUEwnFvCOYUipXaVzrpPOuFAKi/Tv+0fESd0vca6owOgAADSKei3w2qlTJz3//PO+rsXn7Ha77Ha71WU0eJ3XEjqxQP3YQqTW/Zw7ge1e69w4oSGLNtfFsDm30W03WAoJa7z3AQB4H2I/+OADhYSEaPTo0W7HP/roIzkcDq9mVpuLhs60MhMLNFBYpHTGaOf2rdvflwqP+f49YtOkruNZJgsA/MTrcYKZM2eqoqJ6Z9A0Tc2cOdMnRTU1DZ1pLS13qMLRiN0joLmIbyP1v9a5sYCvGIZzxYGzpxBgAcCPvO7E7ty5U927d692vGvXrtq1a5dPiqrNiRMn3N5j9+7d2rx5s5KSktS2bdtGfe+G8EUntaS8QlHhbO8LNFhImHNjgfgMaccKyVFe/9cKtTs3X0ju5Lv6AAAe8boTGx8fr59//rna8V27dik6OtonRdVmw4YN6tu3r/r27StJmjFjhvr27avZs2c36vs2VEMv7JKc3VgAPpTWU+ozuf4XfUXES2dfQ4AFAIt4HWInTJig22+/3W13rl27dunOO+/UJZdc4tPiqho2bJhM06x2W7x4caO+b0M19MIu52sQYgGfi28j9fm99ztmRSZKfX8vRSc3Tl0AgDp5HWIfeeQRRUdHq2vXrurQoYM6dOigbt26KTk5WfPmzWuMGoOaw2GqrKLh86x0YoFGEpMi9Z7k+WoC4dHSWb9zdmIBAJbxesgyPj5eX3zxhVauXKktW7YoMjJSvXv31m9+85vGqC/olVb4Jnz6opsLoBaxac7Z1u9fO/15thCp52+lyAS/lAUAqF29rhQyDEMXXnihLrzwQl/X0+T4agyAcQKgkbXoIrXpL/2yofZzOvxGim/tv5oAALWqV4gtKCjQZ599pszMTJWWlro9duutt/qksKbCVx1UxgkAP+gw1LldbUl+9cdiUqU2A/xfEwCgRl6H2G+//Vbjxo1TYWGhCgoKlJSUpOzsbEVFRSk1NZUQW4WvwichFvCD0HBnt3X7+9Uf63SBZPP6MgIAQCPx+m/kO+64QxdffLFycnIUGRmp9evXa+/everXrx8XdtXAFxd1Sb6brQVQh5Y9pIg492Nx6VJie0vKAQDUzOsQu3nzZt15552y2WwKCQlRSUmJMjIy9Mgjj+iee+5pjBqDmq86qGWEWMA/bCFSWm/3Y636OnfmAgAEDK9DbFhYmGy//kgtNTVVmZmZkpyrFuzbt8+31TUBvgqfjBMAftSyx8nf20KkFmdaVwsAoEZez8T27dtX33zzjbp06aKhQ4dq9uzZys7O1osvvqiePXs2Ro1BzVdjAKU+GksA4IGoJOeGBkU5UlxrKSzC6ooAAFV43YmdM2eO0tPTJUkPPfSQEhMT9ec//1lHjhzRc8895/MCg12Zjzqo5YwTAP6VkPHrr22trQMAUCOvO7H9+/d3/T41NVUffvihTwtqanx1YRczsYCfxaZJB7+TYtOtrgQAUIN6rRMrSVlZWdqxY4ckqWvXrkpJSfFZUU1JmcNXF3YxTgD4VfSvf6dFJ1tbBwCgRl6PE+Tn5+sPf/iDWrduraFDh2ro0KFq1aqVrr76auXm5jZGjUGtnE4sEJwiEiTDJtnjra4EAFADr0PsH//4R3311Vd67733dPz4cR0/flzvvfeeNmzYoBtvvLExagxqvppl9VUYBuAhe6wUEc8GBwAQoLweJ3jvvff00UcfaciQIa5jo0eP1vPPP68xY8b4tLimoMzho06sj8YSAHjIMKToFlZXAQCohdcthuTkZMXHV//xWnx8vBITE31SVFPiq05sBZ1YwP+ikqyuAABQC69D7N/+9jfNmDFDhw4dch07dOiQ/vrXv+ree+/1aXFNQbmPOrHlDlOmSZAF/Ip5WAAIWF6PEzzzzDPatWuX2rZtq7ZtnesnZmZmym6368iRI1qwYIHr3E2bNvmu0iBV4aMQW/laoSFsfQn4TXi01RUAAGrhdYidOHFiI5TRdPmqE1v5WqEhPns5AHVhpy4ACFheh9j77ruvMeposip8uDSWL7u6ADwQSogFgEBV780OJKm4uFhLly5VQUGBRo0apS5duviqribDl9djVTATC/hXSLjVFQAAauFxiJ0xY4bKysr05JNPSpJKS0t17rnn6scff1RUVJTuuusuffzxxzrvvPMardhg5PDlTCwrFAD+ZWN+BwAClcerE3z88ccaNWqU6/5LL72kzMxM7dy5Uzk5Obriiiv00EMPNUqRwcyX3VM6sYCf2Rr0wyoAQCPyOMRmZmaqe/furvsff/yxLr/8crVr106GYei2227Tt99+2yhFBjNfzrE6CLGAfxns1gUAgcrjv6FtNpvbOqXr16/Xueee67qfkJCgnJwc31bXBPhybVc27QL8zGCcAAAClcchtlu3bnr33XclSVu3blVmZqaGDx/uenzv3r1q2bKl7ysMcr5cUIBOLOBnBusyA0Cg8njg66677tLvfvc7vf/++9q6davGjRunDh06uB7/4IMPNGDAgEYpMpj5MngSYgF/I8QCQKDyuBN76aWX6oMPPlDv3r11xx13aOnSpW6PR0VF6aabbvJ5gcHMNE35MneSYQE/oxMLAAHLq0tvR4wYoREjRtT4GJsgVOfr0EmIBfyNEAsAgYpLbxuRr3/8b4oUCwAAIBFiG5WvIyedWMDPGCcAgIBFiA0iZFgAAAAnQiwAAACCTr1CbHl5uT755BMtWLBA+fn5kqQDBw7oxIkTPi0OACzFOAEABCyvNwbfu3evxowZo8zMTJWUlGjUqFGKjY3VP/7xD5WUlOjZZ59tjDoBAAAAF687sbfddpv69++vnJwcRUZGuo5feumlWrVqlU+LAwAAAGridSf2888/1xdffKHw8HC34+3bt9f+/ft9Vhiq4webAAAATl53Yh0OhyoqKqod/+WXXxQbG+uTopoKX4dOxvMAAACcvA6xF154oebPn++6bxiGTpw4ofvuu0/jxo3zZW1Bz/Bx6jToxQIAAEiqxzjBo48+qtGjR6t79+4qLi7WVVddpZ07d6pFixZ65ZVXGqPGoEUnFgAAoHF4HWLbtGmjLVu26NVXX9V3332nEydO6LrrrtPvf/97twu9QOgEAABoLF6HWEkKDQ3V1Vdf7etamhzDMGQzDDl8tF+szUYqBgAAkDwMse+8847HL3jJJZfUu5imyGZIDh/tFxtCaxcAAECShyF24sSJHr2YYRg1rlzQnNl8mGJpxAIAADh5FGIdDkdj19Fk2XzYPfX1agcAAADByusltuCdEB9+hENoxQIAAEiqZ4hdtWqVxo8fr06dOqlTp04aP368PvnkE1/X1iT4shPLTCwAAICT1yH2X//6l8aMGaPY2Fjddtttuu222xQXF6dx48bp6aefbowag5ovu6c2+uYAAACS6rHE1pw5c/S///u/mj59uuvYrbfeqsGDB2vOnDm6+eabfVpgsPNliA0lxQIAAEiqRyf2+PHjGjNmTLXjF154oXJzc31SVFPiyxDLTCwAAICT1yH2kksu0Ztvvlnt+Ntvv63x48f7pKimJNSnnVhCLAAAgFSPcYLu3bvroYce0po1azRo0CBJ0vr167Vu3TrdeeedeuKJJ1zn3nrrrb6rNEj5agTAZhjs2AUAAPArwzS92xO1Q4cOnr2wYejnn3+uV1GNJS8vT/Hx8crNzVVcXJxf3vPtzfv185GCBr9OeKhNNw/v7IOKAAAAApenec3rTuzu3bsbVFhz46tOLKMEAAAAJ3G5eyMLDfFN+Az15a4JAAAAQc7rTqxpmnrttdf06aefKisrq9qWtG+88YbPimsKwnwUYn31OgAAAE2B1yH29ttv14IFCzR8+HC1bNlSBrtInZavxgnC6MQCAAC4eB1iX3zxRb3xxhsaN25cY9TT5PgqfDITCwAAcJLXCSs+Pl4dO3ZsjFqaJF+NAYSH0okFAACo5HUyuv/++/XAAw+oqKioMeppcnzViWWcAAAA4CSvk9GVV16pnJwcpaamqlevXjr77LPdbo3t6aefVvv27RUREaGBAwfq66+/bvT3bAhfdVAJsQAAACd5PRN7zTXXaOPGjbr66qv9fmHX0qVLNWPGDD377LMaOHCg5s+fr9GjR2vHjh1KTU31Wx3e8F0nlplYAACASl7v2BUdHa2PPvpIQ4YMaayaajVw4ECdc845euqppyRJDodDGRkZuuWWWzRz5sw6n2/Fjl37jhXqtY2/NPh1BnZM0nmdWvigIgAAgMDlaV7zuk2YkZHhtwB4qtLSUm3cuFEjR450HbPZbBo5cqS+/PLLGp9TUlKivLw8t5u/+aoTa+fCLgAAABevk9Gjjz6qu+66S3v27GmEcmqXnZ2tiooKtWzZ0u14y5YtdejQoRqfM3fuXMXHx7tuGRkZ/ijVja9mYsNDQnzyOgAAAE2B1zOxV199tQoLC9WpUydFRUUpLCzM7fFjx475rLiGmjVrlmbMmOG6n5eX5/cg67MQSycWAADAxesQO3/+/EYoo24tWrRQSEiIDh8+7Hb88OHDSktLq/E5drtddrvdH+XVyldjAIRYAACAk+q1OoEVwsPD1a9fP61atUoTJ06U5Lywa9WqVZo+fbolNXki1GbIZhhyeHf9XDXMxAIAAJzkdYg9VXFxsUpLS92ONeZFXzNmzNA111yj/v37a8CAAZo/f74KCgo0bdq0RnvPhjIMQ/Ywm4pKKxr0OnRiAQAATvI6xBYUFOjuu+/WsmXLdPTo0WqPV1Q0LKydzqRJk3TkyBHNnj1bhw4dUp8+ffThhx9Wu9gr0NhDGx5i6cQCAACc5HUyuuuuu7R69Wo988wzstvtWrhwoR544AG1atVKS5YsaYwa3UyfPl179+5VSUmJvvrqKw0cOLDR37Oh7KENX1nAF68BAADQVHjdiX333Xe1ZMkSDRs2TNOmTdP555+vzp07q127dnrppZf0+9//vjHqDGoN7aLaDIMduwAAAE7hdbo6duyYOnbsKMk5/1q5pNaQIUO0du1a31bXRNjDGhZiI8Jsft3eFwAAINB5na46duyo3bt3S5K6du2qZcuWSXJ2aBMSEnxaXFMR0cBRAOZhAQAA3HmdjqZNm6YtW7ZIkmbOnKmnn35aERERuuOOO/TXv/7V5wU2BQ3vxDIPCwAAcCqvZ2LvuOMO1+9Hjhypbdu2adOmTercubN69+7t0+KaioaGUEIsAACAuwatEytJ7du3V/v27X1QStPFOAEAAIBveZyOvvzyS7333ntux5YsWaIOHTooNTVVN9xwg0pKSnxeYFMQwTgBAACAT3mcrh588EFt3brVdf/777/Xddddp5EjR2rmzJl69913NXfu3EYpMtgxTgAAAOBbHofYzZs3a8SIEa77r776qgYOHKjnn39eM2bM0BNPPOFaqQDufLHEFgAAAE7yOB3l5OS4be/62WefaezYsa7755xzjvbt2+fb6pqISDqxAAAAPuVxiG3ZsqVrfdjS0lJt2rRJ5557ruvx/Px8hYWF+b7CJqChIbShIRgAAKCp8TjEjhs3TjNnztTnn3+uWbNmKSoqSueff77r8e+++06dOnVqlCKDXViIrUHbxtKJBQAAcOfxElt///vfddlll2no0KGKiYnRCy+8oPDwcNfj//73v3XhhRc2SpFNQURYiMoqyuv1XDqxAAAA7jwOsS1atNDatWuVm5urmJgYhYS4B6vly5crJibG5wU2FRFhIcovrl+IjQjnwi4AAIBTeb3ZQXx8fI3Hk5KSGlxMU1bfbmqIzVB4CCEWAADgVKQjP4kMr1+IjQwLkWHUf54WAACgKSLE+kl9O7ER9Qy/AAAATRkh1k/qu8IAF3UBAABUR4j1k6gGjBMAAADAHSHWT+o9E8vKBAAAANWQkPykvh3VyDCvF5AAAABo8gixflL/TizjBAAAAFURYv2kvp3Y+s7SAgAANGWEWD9xrvdav+cBAADAHSHWT2w2Q/ZQ7wMp4wQAAADVEWL9qD6jAXRiAQAAqiPE+pG3gdQwCLEAAAA1IcT6kbejARFhIbLZ6jFICwAA0MQRYv3I23ECurAAAAA1I8T6kbehlIu6AAAAakaI9SNvQylrxAIAANSMEOtHUeHebSFLiAUAAKgZIdaPvB0niGAmFgAAoEaEWD/yfpzAu84tAABAc0GI9SNvxwMYJwAAAKgZIdaPvF6dgHECAACAGhFi/chmM7yac2WJLQAAgJoRYv3MmxEBxgkAAABqRoj1M0+7q4YhRYQSYgEAAGpCiPUzT+dcI8JCZLMZjVwNAABAcCLE+pmnIwJc1AUAAFA7QqyfeTpOwEVdAAAAtSPE+pmnHVYu6gIAAKgdIdbPPN2Fi3ECAACA2hFi/czTcMo4AQAAQO0IsX4WEe7Zh5xOLAAAQO0IsX7m8TgBnVgAAIBaEWL9zONxAjqxAAAAtSLE+lmIzVB4aN0fdjqxAAAAtSPEWsCT5bPoxAIAANSOEGsBTwIqIRYAAKB2hFgL1DUqEB5qU2gInxoAAIDakJQsEFFHl7WuxwEAAJo7QqwF6hoVYJQAAADg9AixFqhrnCDSww0RAAAAmivSkgUiQunEAgAANAQh1gJ1dVqZiQUAADg9QqwFuLALAACgYQixFqgrpDJOAAAAcHqEWAvUFVLpxAIAAJweIdYCdGIBAAAaJmhC7EMPPaTzzjtPUVFRSkhIsLqcBgmxGQoPrf1DHxEWNJ8WAAAASwRNWiotLdUVV1yhP//5z1aX4hP204RYO51YAACA0wq1ugBPPfDAA5KkxYsXW1uIj0SEhSi/uLyWx4Lm/xYAAACWCJoQWx8lJSUqKSlx3c/Ly7OwGne1zb3aDEPhIYRYAACA02nSaWnu3LmKj4933TIyMqwuycVeS7c1IswmwzD8XA0AAEBwsTTEzpw5U4ZhnPa2ffv2er/+rFmzlJub67rt27fPh9U3TG1bz7K8FgAAQN0sHSe48847NXXq1NOe07Fjx3q/vt1ul91ur/fzG1NtYfV0F3wBAADAydIQm5KSopSUFCtLsEzt4wR0YgEAAOoSNBd2ZWZm6tixY8rMzFRFRYU2b94sSercubNiYmKsLa4eahsnoBMLAABQt6AJsbNnz9YLL7zgut+3b19J0qeffqphw4ZZVFX91daJre04AAAATgqaxLR48WKZplntFowBVjrNhV21HAcAAMBJQRNimxo6sQAAAPVHYrJIbbOvdjqxAAAAdSLEWqS2sMqFXQAAAHUjMVkknE4sAABAvRFiLRJiMxQWUn17WWZiAQAA6kZislBNXdfwED4lAAAAdSExWaimkQI6sQAAAHUjMVmophBLJxYAAKBuJCYLVV2JIMRmKJQQCwAAUCcSk4WqdmJrW7EAAAAA7khNFqo6OsAoAQAAgGdITRaiEwsAAFA/pCYLEWIBAADqh9RkoarjA2w5CwAA4BlSk4WqdWKZiQUAAPAIqclCYVVCa9X7AAAAqBmpyUJVO7FhjBMAAAB4hNRkIZbYAgAAqB9Sk4Wqr05gWFQJAABAcCHEWijU5h5amYkFAADwDKnJQlVnYAmxAAAAniE1WajqDCwhFgAAwDOkJgtVX2KLmVgAAABPEGItFGIzZDNOBlc6sQAAAJ4hNVks7JQVCULpxAIAAHiEEGuxMNvJTwHrxAIAAHiG1GSxU+dgQwmxAAAAHiE1WezU4Fp13VgAAADUjBBrsVM7sVzYBQAA4BlSk8VCf52JtRmGQujEAgAAeIQQa7HKFQlYmQAAAMBzhFiLVXZi2egAAADAc4RYi7k6sTY+FQAAAJ4iOVksjHECAAAArxFiLRbyaweWTiwAAIDnSE4WC7NVjhPQiQUAAPAUIdZilctqsbwWAACA5wixFmOJLQAAAO8RYi1WORNLJxYAAMBzhFiLhTITCwAA4DVCrMVOzsTyqQAAAPAUyclidGIBAAC8R4i1mI3VCQAAALxGiLUYnVgAAADvEWItZjOc4dVGiAUAAPAYIdZilevDMk4AAADgOUKsxUIMQiwAAIC3CLEWcy2xZRBiAQAAPEWItVgIqxMAAAB4jRBrscoLumx0YgEAADxGiLUYM7EAAADeI8Ra7OQ4gcWFAAAABBGik8Vc68QyTgAAAOAxQqzFuLALAADAe4RYi1VmVzqxAAAAniPEWswwDNkMg21nAQAAvECIDQA2g80OAAAAvEGIDQA2myEasQAAAJ4jxAYAxgkAAAC8Q4gNADaDC7sAAAC8QYgNACE2g5lYAAAALwRFiN2zZ4+uu+46dejQQZGRkerUqZPuu+8+lZaWWl2aTxiGISMoPhMAAACBIdTqAjyxfft2ORwOLViwQJ07d9YPP/yg66+/XgUFBZo3b57V5TUY4wQAAADeCYoQO2bMGI0ZM8Z1v2PHjtqxY4eeeeaZJhJiWZ0AAADAG0ERYmuSm5urpKSk055TUlKikpIS1/28vLzGLqte6MQCAAB4JygnMXft2qUnn3xSN95442nPmzt3ruLj4123jIwMP1XoHZvNEBkWAADAc5aG2JkzZzovajrNbfv27W7P2b9/v8aMGaMrrrhC119//Wlff9asWcrNzXXd9u3b15h/nHozZNCJBQAA8IKl4wR33nmnpk6detpzOnbs6Pr9gQMHNHz4cJ133nl67rnn6nx9u90uu93e0DIbHeMEAAAA3rE0xKakpCglJcWjc/fv36/hw4erX79+WrRokWy2oJyEqBEXdgEAAHgnKC7s2r9/v4YNG6Z27dpp3rx5OnLkiOuxtLQ0CyvzDedMLCkWAADAU0ERYleuXKldu3Zp165datOmjdtjpmlaVJXvhDSdpjIAAIBfBEV8mjp1qkzTrPHWFDAPCwAA4J2gCLFNXWgTmu8FAADwB9JTAGCcAAAAwDvEpwDAOAEAAIB3CLEBIIT1tQAAALxCiA0AdGIBAAC8Q4gNADY6sQAAAF4hxAaAEDqxAAAAXiHEBgAasQAAAN4hxAYAxgkAAAC8Q4gNAFzYBQAA4B1CbACgEQsAAOAdQmwAYJwAAADAO4TYAMA4AQAAgHcIsQGARiwAAIB3CLEBgE4sAACAdwixAYAMCwAA4B1CbACgEwsAAOAdQmwAIMQCAAB4hxAbAMiwAAAA3iHEBgBWJwAAAPAOITYAGLRiAQAAvEKIDQBEWAAAAO8QYgMAF3YBAAB4hxAbAMiwAAAA3iHEBgA6sQAAAN4hxAIAACDoEGIDAI1YAAAA7xBiAwDjBAAAAN4hxAYAeyifBgAAAG+QngJAcozd6hIAAACCCiEWAAAAQYcQCwAAgKBDiAUAAEDQIcQCAAAg6BBiAQAAEHQIsQAAAAg6hFgAAAAEHUIsAAAAgg4hFgAAAEGHEAsAAICgQ4gFAABA0CHEAgAAIOgQYgEAABB0CLEAAAAIOoRYAAAABB1CLAAAAIIOIRYAAABBhxALAACAoEOIBQAAQNAJtboAfzJNU5KUl5dncSUAAACoSWVOq8xttWlWITY/P1+SlJGRYXElAAAAOJ38/HzFx8fX+rhh1hVzmxCHw6EDBw4oNjZWhmFYXY5LXl6eMjIytG/fPsXFxVldDgDxfQkEKr43mz7TNJWfn69WrVrJZqt98rVZdWJtNpvatGljdRm1iouL4xsSCDB8XwKBie/Npu10HdhKXNgFAACAoEOIBQAAQNAhxAYAu92u++67T3a73epSAPyK70sgMPG9iUrN6sIuAAAANA10YgEAABB0CLEAAAAIOoRYAAAABB1CLACcYtiwYbr99tutLgMAUAdCLAAAAIIOIRYAAABBhxBroQ8//FBDhgxRQkKCkpOTNX78eP30009WlwU0e+Xl5Zo+fbri4+PVokUL3XvvvWI1QsB6DodDjzzyiDp37iy73a62bdvqoYcesrosWIQQa6GCggLNmDFDGzZs0KpVq2Sz2XTppZfK4XBYXRrQrL3wwgsKDQ3V119/rccff1yPPfaYFi5caHVZQLM3a9YsPfzww7r33nv1448/6uWXX1bLli2tLgsWYbODAJKdna2UlBR9//336tmzp9XlAM3SsGHDlJWVpa1bt8owDEnSzJkz9c477+jHH3+0uDqg+crPz1dKSoqeeuop/fGPf7S6HAQAOrEW2rlzpyZPnqyOHTsqLi5O7du3lyRlZmZaWxjQzJ177rmuACtJgwYN0s6dO1VRUWFhVUDztm3bNpWUlGjEiBFWl4IAEWp1Ac3ZxRdfrHbt2un5559Xq1at5HA41LNnT5WWllpdGgAAASUyMtLqEhBg6MRa5OjRo9qxY4f+9re/acSIEerWrZtycnKsLguApK+++srt/vr169WlSxeFhIRYVBGALl26KDIyUqtWrbK6FAQIOrEWSUxMVHJysp577jmlp6crMzNTM2fOtLosAHKO9MyYMUM33nijNm3apCeffFKPPvqo1WUBzVpERITuvvtu3XXXXQoPD9fgwYN15MgRbd26Vdddd53V5cEChFiL2Gw2vfrqq7r11lvVs2dPnXnmmXriiSc0bNgwq0sDmr0pU6aoqKhIAwYMUEhIiG677TbdcMMNVpcFNHv33nuvQkNDNXv2bB04cEDp6en605/+ZHVZsAirEwAAACDoMBMLAACAoEOIBQAAQNAhxAIAACDoEGIBAAAQdAixAAAACDqEWAAAAAQdQiwAAACCDiEWAAAAQYcQCwAeMgxDb731ls9fd9iwYbr99ttd99u3b6/58+f7/H1qei8ACFaEWADN2tSpU2UYhgzDUFhYmFq2bKlRo0bp3//+txwOh9u5Bw8e1NixYz16XW8C7xtvvKG///3v3pZ+WmvWrJFhGDp+/Hijv1dVe/bskWEY2rx5s9v9yltsbKx69Oihm2++WTt37mzUWgA0XYRYAM3emDFjdPDgQe3Zs0crVqzQ8OHDddttt2n8+PEqLy93nZeWlia73e6z9y0tLZUkJSUlKTY21mevezr+fK+qPvnkEx08eFBbtmzRnDlztG3bNp111llatWqVJfUACG6EWADNnt1uV1pamlq3bq2zzz5b99xzj95++22tWLFCixcvdp13ane1tLRU06dPV3p6uiIiItSuXTvNnTtXknMcQJIuvfRSGYbhun///ferT58+WrhwoTp06KCIiAhJNf+IPz8/X5MnT1Z0dLRat26tp59+2vVY1U6nJB0/flyGYWjNmjXas2ePhg8fLklKTEyUYRiaOnVqje+Vk5OjKVOmKDExUVFRURo7dqxbd3Tx4sVKSEjQRx99pG7duikmJsYV+r2VnJystLQ0dezYURMmTNAnn3yigQMH6rrrrlNFRYXXrwegeSPEAkANLrjgAp111ll64403anz8iSee0DvvvKNly5Zpx44deumll1xh9ZtvvpEkLVq0SAcPHnTdl6Rdu3bp9ddf1xtvvOEWQqv65z//qbPOOkvffvutZs6cqdtuu00rV670qPaMjAy9/vrrkqQdO3bo4MGDevzxx2s8d+rUqdqwYYPeeecdffnllzJNU+PGjVNZWZnrnMLCQs2bN08vvvii1q5dq8zMTP3lL3/xqJbTsdlsuu2227R3715t3Lixwa8HoHkJtboAAAhUXbt21XfffVfjY5mZmerSpYuGDBkiwzDUrl0712MpKSmSpISEBKWlpbk9r7S0VEuWLHGdU5vBgwdr5syZkqQzzjhD69at0//+7/9q1KhRddYdEhKipKQkSVJqaqoSEhJqPG/nzp165513tG7dOp133nmSpJdeekkZGRl66623dMUVV0iSysrK9Oyzz6pTp06SpOnTp+vBBx+ssw5PdO3aVZKzuzxgwACfvCaA5oFOLADUwjRNGYZR42NTp07V5s2bdeaZZ+rWW2/Vxx9/7NFrtmvXrs4AK0mDBg2qdn/btm0evYentm3bptDQUA0cONB1LDk5WWeeeabbe0VFRbkCrCSlp6crKyvLJzWYpilJtX6cAaA2hFgAqMW2bdvUoUOHGh87++yztXv3bv39739XUVGRrrzySl1++eV1vmZ0dHSD67LZnH91VwZASW4//ve1sLAwt/uGYbi9d0NUhuXaPs4AUBtCLADUYPXq1fr+++/129/+ttZz4uLiNGnSJD3//PNaunSpXn/9dR07dkySM/g15GKl9evXV7vfrVs3SSfHFU69uKrqfG14eLgknbaGbt26qby8XF999ZXr2NGjR7Vjxw5179693rV7yuFw6IknnlCHDh3Ut2/fRn8/AE0LM7EAmr2SkhIdOnRIFRUVOnz4sD788EPNnTtX48eP15QpU2p8zmOPPab09HT17dtXNptNy5cvV1pammv+tH379lq1apUGDx4su92uxMREr2pat26dHnnkEU2cOFErV67U8uXL9f7770uSIiMjde655+rhhx9Whw4dlJWVpb/97W9uz2/Xrp0Mw9B7772ncePGKTIyUjExMW7ndOnSRRMmTND111+vBQsWKDY2VjNnzlTr1q01YcIEr+r1xNGjR3Xo0CEVFhbqhx9+0Pz58/X111/r/fffV0hIiM/fD0DTRicWQLP34YcfKj09Xe3bt9eYMWP06aef6oknntDbb79da7iKjY3VI488ov79++ucc87Rnj179MEHH7h+1P/oo49q5cqVysjIqFeX8c4779SGDRvUt29f/c///I8ee+wxjR492vX4v//9b5WXl6tfv366/fbb9T//8z9uz2/durUeeOABzZw5Uy1bttT06dNrfJ9FixapX79+Gj9+vAYNGiTTNPXBBx9UGyHwhZEjRyo9PV29evXSzJkz1a1bN3333Xeu5cAAwBuG6avBJgAAAMBP6MQCAAAg6BBiAQAAEHQIsQAAAAg6hFgAAAAEHUIsAAAAgg4hFgAAAEGHEAsAAICgQ4gFAABA0CHEAgAAIOgQYgEAABB0CLEAAAAIOv8fk5kB/PgD8uMAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 800x600 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from abstract_cf.bandit import BanditEnv\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "env = BanditEnv(3, device=device)\n",
    "env.display_arms()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.0900, 0.2447, 0.6652])\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGdCAYAAAA44ojeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAgO0lEQVR4nO3df2xV9f3H8Vdb6L0yaJFVbqHe2QEKItpia7tiFJZdrRkxkmxZZc52N1g3hQV3Nyedrh2yeFERa2Zn1a2ygIbO38tgJXgFjVLtbCFDRBwoFNR7S4f2Qlla13u+f/j1skqLPf11P719PpKb0MPnnPu+Jze3z5ze2yZYlmUJAADAIImxHgAAAODLCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxhkT6wH6IhKJ6KOPPtKECROUkJAQ63EAAEAfWJal48ePa+rUqUpMtHdNZEQEykcffSS32x3rMQAAQD8cPnxY5557rq19RkSgTJgwQdLnDzAlJSXG0wAAgL4Ih8Nyu93R7+N2jIhA+eLHOikpKQQKAAAjTH/ensGbZAEAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYJwxsR4AABBbmSs2xXoExNjB1QtjPcJpuIICAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOP0K1CqqqqUmZkpp9Op/Px8NTQ0nHH9p59+qqVLl2rKlClyOBy64IILtHnz5n4NDAAA4t8YuzvU1tbK5/Opurpa+fn5qqysVGFhofbt26fJkyeftr6zs1NXXXWVJk+erGeeeUYZGRk6dOiQJk6cOBjzAwCAOGQ7UNauXavS0lJ5vV5JUnV1tTZt2qSamhqtWLHitPU1NTU6duyYduzYobFjx0qSMjMzBzY1AACIa7Z+xNPZ2anGxkZ5PJ5TB0hMlMfjUX19fY/7/PWvf1VBQYGWLl0ql8ulOXPm6J577lFXV1ev99PR0aFwONztBgAARg9bgdLa2qquri65XK5u210ul4LBYI/7vP/++3rmmWfU1dWlzZs36ze/+Y0eeOAB/e53v+v1fvx+v1JTU6M3t9ttZ0wAADDCDfmneCKRiCZPnqzHHntMOTk5Kioq0p133qnq6upe9ykrK1NbW1v0dvjw4aEeEwAAGMTWe1DS0tKUlJSkUCjUbXsoFFJ6enqP+0yZMkVjx45VUlJSdNuFF16oYDCozs5OJScnn7aPw+GQw+GwMxoAAIgjtq6gJCcnKycnR4FAILotEokoEAiooKCgx30uv/xy7d+/X5FIJLrtvffe05QpU3qMEwAAANs/4vH5fHr88cf15z//WXv37tUtt9yi9vb26Kd6iouLVVZWFl1/yy236NixY1q+fLnee+89bdq0Sffcc4+WLl06eI8CAADEFdsfMy4qKtLRo0dVXl6uYDCo7Oxs1dXVRd8429zcrMTEU93jdru1ZcsW/fznP9cll1yijIwMLV++XHfcccfgPQoAABBXEizLsmI9xFcJh8NKTU1VW1ubUlJSYj0OAMSVzBWbYj0CYuzg6oVDctyBfP/mb/EAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDj9CpSqqiplZmbK6XQqPz9fDQ0Nva5dt26dEhISut2cTme/BwYAAPHPdqDU1tbK5/OpoqJCTU1NysrKUmFhoVpaWnrdJyUlRR9//HH0dujQoQENDQAA4pvtQFm7dq1KS0vl9Xo1e/ZsVVdXa9y4caqpqel1n4SEBKWnp0dvLpdrQEMDAID4ZitQOjs71djYKI/Hc+oAiYnyeDyqr6/vdb8TJ07ovPPOk9vt1nXXXac9e/ac8X46OjoUDoe73QAAwOhhK1BaW1vV1dV12hUQl8ulYDDY4z4zZ85UTU2NXnzxRW3YsEGRSETz5s3TkSNHer0fv9+v1NTU6M3tdtsZEwAAjHBD/imegoICFRcXKzs7W/Pnz9dzzz2nc845R48++miv+5SVlamtrS16O3z48FCPCQAADDLGzuK0tDQlJSUpFAp12x4KhZSent6nY4wdO1Zz587V/v37e13jcDjkcDjsjAYAAOKIrSsoycnJysnJUSAQiG6LRCIKBAIqKCjo0zG6urq0e/duTZkyxd6kAABg1LB1BUWSfD6fSkpKlJubq7y8PFVWVqq9vV1er1eSVFxcrIyMDPn9fknS3XffrW9961uaMWOGPv30U91///06dOiQbrrppsF9JAAAIG7YDpSioiIdPXpU5eXlCgaDys7OVl1dXfSNs83NzUpMPHVh5pNPPlFpaamCwaDOPvts5eTkaMeOHZo9e/bgPQoAABBXEizLsmI9xFcJh8NKTU1VW1ubUlJSYj0OAMSVzBWbYj0CYuzg6oVDctyBfP/mb/EAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOP0K1CqqqqUmZkpp9Op/Px8NTQ09Gm/jRs3KiEhQYsWLerP3QIAgFHCdqDU1tbK5/OpoqJCTU1NysrKUmFhoVpaWs6438GDB/XLX/5SV1xxRb+HBQAAo4PtQFm7dq1KS0vl9Xo1e/ZsVVdXa9y4caqpqel1n66uLt1www1auXKlpk2bNqCBAQBA/LMVKJ2dnWpsbJTH4zl1gMREeTwe1dfX97rf3XffrcmTJ2vJkiX9nxQAAIwaY+wsbm1tVVdXl1wuV7ftLpdL7777bo/7vPbaa/rTn/6kXbt29fl+Ojo61NHREf06HA7bGRMAAIxwQ/opnuPHj+vGG2/U448/rrS0tD7v5/f7lZqaGr253e4hnBIAAJjG1hWUtLQ0JSUlKRQKddseCoWUnp5+2voDBw7o4MGDuvbaa6PbIpHI53c8Zoz27dun6dOnn7ZfWVmZfD5f9OtwOEykAAAwitgKlOTkZOXk5CgQCEQ/KhyJRBQIBLRs2bLT1s+aNUu7d+/utu2uu+7S8ePH9dBDD/UaHQ6HQw6Hw85oAAAgjtgKFEny+XwqKSlRbm6u8vLyVFlZqfb2dnm9XklScXGxMjIy5Pf75XQ6NWfOnG77T5w4UZJO2w4AAPAF24FSVFSko0ePqry8XMFgUNnZ2aqrq4u+cba5uVmJifyCWgAA0H8JlmVZsR7iq4TDYaWmpqqtrU0pKSmxHgcA4krmik2xHgExdnD1wiE57kC+f3OpAwAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADG6VegVFVVKTMzU06nU/n5+WpoaOh17XPPPafc3FxNnDhRX/va15Sdna3169f3e2AAABD/bAdKbW2tfD6fKioq1NTUpKysLBUWFqqlpaXH9ZMmTdKdd96p+vp6/fOf/5TX65XX69WWLVsGPDwAAIhPCZZlWXZ2yM/P12WXXaaHH35YkhSJROR2u/Wzn/1MK1as6NMxLr30Ui1cuFCrVq3q0/pwOKzU1FS1tbUpJSXFzrgAgK+QuWJTrEdAjB1cvXBIjjuQ79+2rqB0dnaqsbFRHo/n1AESE+XxeFRfX/+V+1uWpUAgoH379unKK6+0NSgAABg9xthZ3Nraqq6uLrlcrm7bXS6X3n333V73a2trU0ZGhjo6OpSUlKQ//OEPuuqqq3pd39HRoY6OjujX4XDYzpgAAGCEsxUo/TVhwgTt2rVLJ06cUCAQkM/n07Rp07RgwYIe1/v9fq1cuXI4RgMAAAayFShpaWlKSkpSKBTqtj0UCik9Pb3X/RITEzVjxgxJUnZ2tvbu3Su/399roJSVlcnn80W/DofDcrvddkYFAAAjmK33oCQnJysnJ0eBQCC6LRKJKBAIqKCgoM/HiUQi3X6E82UOh0MpKSndbgAAYPSw/SMen8+nkpIS5ebmKi8vT5WVlWpvb5fX65UkFRcXKyMjQ36/X9LnP67Jzc3V9OnT1dHRoc2bN2v9+vV65JFHBveRAACAuGE7UIqKinT06FGVl5crGAwqOztbdXV10TfONjc3KzHx1IWZ9vZ23XrrrTpy5IjOOusszZo1Sxs2bFBRUdHgPQoAABBXbP8elFjg96AAwNDh96BgxP8eFAAAgOFAoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4Y2I9ADDaZa7YFOsREGMHVy+M9QiAcbiCAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADj9CtQqqqqlJmZKafTqfz8fDU0NPS69vHHH9cVV1yhs88+W2effbY8Hs8Z1wMAANgOlNraWvl8PlVUVKipqUlZWVkqLCxUS0tLj+u3b9+uxYsXa9u2baqvr5fb7dbVV1+tDz/8cMDDAwCA+GQ7UNauXavS0lJ5vV7Nnj1b1dXVGjdunGpqanpc/+STT+rWW29Vdna2Zs2apT/+8Y+KRCIKBAIDHh4AAMQnW4HS2dmpxsZGeTyeUwdITJTH41F9fX2fjnHy5El99tlnmjRpUq9rOjo6FA6Hu90AAMDoYStQWltb1dXVJZfL1W27y+VSMBjs0zHuuOMOTZ06tVvkfJnf71dqamr05na77YwJAABGuGH9FM/q1au1ceNGPf/883I6nb2uKysrU1tbW/R2+PDhYZwSAADE2hg7i9PS0pSUlKRQKNRteygUUnp6+hn3XbNmjVavXq2XXnpJl1xyyRnXOhwOORwOO6MBAIA4YusKSnJysnJycrq9wfWLN7wWFBT0ut99992nVatWqa6uTrm5uf2fFgAAjAq2rqBIks/nU0lJiXJzc5WXl6fKykq1t7fL6/VKkoqLi5WRkSG/3y9Juvfee1VeXq6nnnpKmZmZ0feqjB8/XuPHjx/EhwIAAOKF7UApKirS0aNHVV5ermAwqOzsbNXV1UXfONvc3KzExFMXZh555BF1dnbq+9//frfjVFRU6Le//e3ApgcAAHHJdqBI0rJly7Rs2bIe/2/79u3dvj548GB/7gIAAIxi/C0eAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABinX4FSVVWlzMxMOZ1O5efnq6Ghode1e/bs0fe+9z1lZmYqISFBlZWV/Z0VAACMErYDpba2Vj6fTxUVFWpqalJWVpYKCwvV0tLS4/qTJ09q2rRpWr16tdLT0wc8MAAAiH+2A2Xt2rUqLS2V1+vV7NmzVV1drXHjxqmmpqbH9Zdddpnuv/9+XX/99XI4HAMeGAAAxD9bgdLZ2anGxkZ5PJ5TB0hMlMfjUX19/aAN1dHRoXA43O0GAABGD1uB0traqq6uLrlcrm7bXS6XgsHgoA3l9/uVmpoavbnd7kE7NgAAMJ+Rn+IpKytTW1tb9Hb48OFYjwQAAIbRGDuL09LSlJSUpFAo1G17KBQa1DfAOhwO3q8CAMAoZusKSnJysnJychQIBKLbIpGIAoGACgoKBn04AAAwOtm6giJJPp9PJSUlys3NVV5eniorK9Xe3i6v1ytJKi4uVkZGhvx+v6TP31j7zjvvRP/94YcfateuXRo/frxmzJgxiA8FAADEC9uBUlRUpKNHj6q8vFzBYFDZ2dmqq6uLvnG2ublZiYmnLsx89NFHmjt3bvTrNWvWaM2aNZo/f762b98+8EcAAADiju1AkaRly5Zp2bJlPf7fl6MjMzNTlmX1526GReaKTbEeATF2cPXCWI8AAPgSIz/FAwAARjcCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYp1+BUlVVpczMTDmdTuXn56uhoeGM659++mnNmjVLTqdTF198sTZv3tyvYQEAwOhgO1Bqa2vl8/lUUVGhpqYmZWVlqbCwUC0tLT2u37FjhxYvXqwlS5Zo586dWrRokRYtWqS33357wMMDAID4ZDtQ1q5dq9LSUnm9Xs2ePVvV1dUaN26campqelz/0EMP6ZprrtHtt9+uCy+8UKtWrdKll16qhx9+eMDDAwCA+DTGzuLOzk41NjaqrKwsui0xMVEej0f19fU97lNfXy+fz9dtW2FhoV544YVe76ejo0MdHR3Rr9va2iRJ4XDYzrh9Euk4OejHxMgyFM8rO3gOgucgYm2onoNfHNeyLNv72gqU1tZWdXV1yeVyddvucrn07rvv9rhPMBjscX0wGOz1fvx+v1auXHnadrfbbWdcoE9SK2M9AUY7noOItaF+Dh4/flypqam29rEVKMOlrKys21WXSCSiY8eO6etf/7oSEhKi28PhsNxutw4fPqyUlJRYjDricQ4HhvM3cJzDgeH8DRzncGDOdP4sy9Lx48c1depU28e1FShpaWlKSkpSKBTqtj0UCik9Pb3HfdLT022tlySHwyGHw9Ft28SJE3tdn5KSwpNqgDiHA8P5GzjO4cBw/gaOczgwvZ0/u1dOvmDrTbLJycnKyclRIBCIbotEIgoEAiooKOhxn4KCgm7rJWnr1q29rgcAALD9Ix6fz6eSkhLl5uYqLy9PlZWVam9vl9frlSQVFxcrIyNDfr9fkrR8+XLNnz9fDzzwgBYuXKiNGzfqrbfe0mOPPTa4jwQAAMQN24FSVFSko0ePqry8XMFgUNnZ2aqrq4u+Eba5uVmJiacuzMybN09PPfWU7rrrLv3617/W+eefrxdeeEFz5swZ8PAOh0MVFRWn/TgIfcc5HBjO38BxDgeG8zdwnMOBGarzl2D157M/AAAAQ4i/xQMAAIxDoAAAAOMQKAAAwDgECgAAMM6IC5Rjx47phhtuUEpKiiZOnKglS5boxIkTZ9xnwYIFSkhI6Hb76U9/OkwTx15VVZUyMzPldDqVn5+vhoaGM65/+umnNWvWLDmdTl188cXavHnzME1qJjvnb926dac915xO5zBOa5ZXX31V1157raZOnaqEhIQz/g2uL2zfvl2XXnqpHA6HZsyYoXXr1g35nCazew63b99+2nMwISHhjH9eJJ75/X5ddtllmjBhgiZPnqxFixZp3759X7kfr4Of68/5G6zXwREXKDfccIP27NmjrVu36m9/+5teffVV3XzzzV+5X2lpqT7++OPo7b777huGaWOvtrZWPp9PFRUVampqUlZWlgoLC9XS0tLj+h07dmjx4sVasmSJdu7cqUWLFmnRokV6++23h3lyM9g9f9Lnv03xf59rhw4dGsaJzdLe3q6srCxVVVX1af0HH3yghQsX6tvf/rZ27dql2267TTfddJO2bNkyxJOay+45/MK+ffu6PQ8nT548RBOa7ZVXXtHSpUv1xhtvaOvWrfrss8909dVXq729vdd9eB08pT/nTxqk10FrBHnnnXcsSdY//vGP6La///3vVkJCgvXhhx/2ut/8+fOt5cuXD8OE5snLy7OWLl0a/bqrq8uaOnWq5ff7e1z/gx/8wFq4cGG3bfn5+dZPfvKTIZ3TVHbP3xNPPGGlpqYO03QjiyTr+eefP+OaX/3qV9ZFF13UbVtRUZFVWFg4hJONHH05h9u2bbMkWZ988smwzDTStLS0WJKsV155pdc1vA72ri/nb7BeB0fUFZT6+npNnDhRubm50W0ej0eJiYl68803z7jvk08+qbS0NM2ZM0dlZWU6eTL+/7x4Z2enGhsb5fF4otsSExPl8XhUX1/f4z719fXd1ktSYWFhr+vjWX/OnySdOHFC5513ntxut6677jrt2bNnOMaNCzz/Bk92dramTJmiq666Sq+//nqsxzFGW1ubJGnSpEm9ruF52Lu+nD9pcF4HR1SgBIPB0y5TjhkzRpMmTTrjz1d/+MMfasOGDdq2bZvKysq0fv16/ehHPxrqcWOutbVVXV1d0d/y+wWXy9Xr+QoGg7bWx7P+nL+ZM2eqpqZGL774ojZs2KBIJKJ58+bpyJEjwzHyiNfb8y8cDus///lPjKYaWaZMmaLq6mo9++yzevbZZ+V2u7VgwQI1NTXFerSYi0Qiuu2223T55Zef8beZ8zrYs76ev8F6HbT9q+6HwooVK3Tvvfeecc3evXv7ffz/fY/KxRdfrClTpug73/mODhw4oOnTp/f7uMCXFRQUdPtDmPPmzdOFF16oRx99VKtWrYrhZBgtZs6cqZkzZ0a/njdvng4cOKAHH3xQ69evj+Fksbd06VK9/fbbeu2112I9yojU1/M3WK+DRgTKL37xC/34xz8+45pp06YpPT39tDcn/ve//9WxY8eUnp7e5/vLz8+XJO3fvz+uAyUtLU1JSUkKhULdtodCoV7PV3p6uq318aw/5+/Lxo4dq7lz52r//v1DMWLc6e35l5KSorPOOitGU418eXl5o/6b8rJly6IfrDj33HPPuJbXwdPZOX9f1t/XQSN+xHPOOedo1qxZZ7wlJyeroKBAn376qRobG6P7vvzyy4pEItHo6Itdu3ZJ+vxSaDxLTk5WTk6OAoFAdFskElEgEOhWt/+roKCg23pJ2rp1a6/r41l/zt+XdXV1affu3XH/XBssPP+Gxq5du0btc9CyLC1btkzPP/+8Xn75ZX3zm9/8yn14Hp7Sn/P3Zf1+HRzw22yH2TXXXGPNnTvXevPNN63XXnvNOv/8863FixdH///IkSPWzJkzrTfffNOyLMvav3+/dffdd1tvvfWW9cEHH1gvvviiNW3aNOvKK6+M1UMYVhs3brQcDoe1bt0665133rFuvvlma+LEiVYwGLQsy7JuvPFGa8WKFdH1r7/+ujVmzBhrzZo11t69e62Kigpr7Nix1u7du2P1EGLK7vlbuXKltWXLFuvAgQNWY2Ojdf3111tOp9Pas2dPrB5CTB0/ftzauXOntXPnTkuStXbtWmvnzp3WoUOHLMuyrBUrVlg33nhjdP37779vjRs3zrr99tutvXv3WlVVVVZSUpJVV1cXq4cQc3bP4YMPPmi98MIL1r/+9S9r9+7d1vLly63ExETrpZdeitVDiKlbbrnFSk1NtbZv3259/PHH0dvJkyeja3gd7F1/zt9gvQ6OuED597//bS1evNgaP368lZKSYnm9Xuv48ePR///ggw8sSda2bdssy7Ks5uZm68orr7QmTZpkORwOa8aMGdbtt99utbW1xegRDL/f//731je+8Q0rOTnZysvLs954443o/82fP98qKSnptv4vf/mLdcEFF1jJycnWRRddZG3atGmYJzaLnfN32223Rde6XC7ru9/9rtXU1BSDqc3wxUdev3z74pyVlJRY8+fPP22f7OxsKzk52Zo2bZr1xBNPDPvcJrF7Du+9915r+vTpltPptCZNmmQtWLDAevnll2MzvAF6OneSuj2veB3sXX/O32C9Dib8/wAAAADGMOI9KAAAAP+LQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGCc/wMwjjc/C4EZVwAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2 tensor([1.7214])\n",
      "arm c\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# observing an interaction with the environment \n",
    "logits = torch.tensor([1, 2, 3], dtype=torch.float32).to(device)\n",
    "control_probs = torch.softmax(logits, dim=0)\n",
    "\n",
    "print(control_probs)\n",
    "plt.bar(range(len(control_probs)), control_probs.cpu().numpy())\n",
    "plt.show()\n",
    "# control = torch.multinomial(control_probs, 1).item()\n",
    "control = 2\n",
    "\n",
    "outcome = env.pull(control)\n",
    "print(control, outcome)\n",
    "arm = env.sample_arm(control) \n",
    "print('arm', arm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0, device='cuda:0') tensor([ 2.6304,  0.5015, -0.1748], device='cuda:0')\n",
      "tensor(2, device='cuda:0') tensor([-0.2699, -0.0621,  2.2259], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "# inferring the noise distribution on the action probabilities \n",
    "\n",
    "control_, control_g = gumbel_max_rejection_sampling(\n",
    "    control_probs, control\n",
    ")\n",
    "# g is our posterior noise term, a sample from P(g | a, o)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor(0.0900), tensor(0.2447), tensor(0.6652)]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAajUlEQVR4nO3df2xd913/8Vfs1HZDa6drWjsNZlbXsjYKizd78bIxWmkeQSoTRSCFCpHI6gJsDRSugMasTegKc9jaKB0NC+2ImDaqRkxjTGqVCSz6B6ohWkJF6a/9QKmzFTsJ3ezMRTb4+vvHvnNlGre5adLPbD8e0pHqk885932lW/mpc8/1XTYzMzMTAIBC6koPAAAsbWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKWl56gLNRrVbz4osv5tJLL82yZctKjwMAnIWZmZmcPn06V111Verq5r/+sSBi5MUXX0x7e3vpMQCAc3D8+PH8+I//+Lz/viBi5NJLL03ygyfT3NxceBoA4GyMj4+nvb199vf4fBZEjPzwrZnm5mYxAgALzOvdYuEGVgCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUctLDwBAWR07Hi09AoUd231T0cd3ZQQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQ1DnFyL59+9LR0ZGmpqb09PTk8OHDr7n+e9/7Xm677basXr06jY2N+cmf/Mk89thj5zQwALC4LK/1gIMHD6ZSqWT//v3p6enJ3r17s2nTpjz//PO58sorX7V+amoqH/zgB3PllVfmi1/8YtasWZMXXnghK1euPB/zAwALXM0xsmfPnmzbti19fX1Jkv379+fRRx/NgQMHsmPHjletP3DgQF566aU88cQTueiii5IkHR0db2xqAGDRqOltmqmpqRw5ciS9vb2vnKCuLr29vRkaGjrjMV/5yleycePG3HbbbWltbc26devyiU98ItPT0/M+zuTkZMbHx+dsAMDiVFOMnDp1KtPT02ltbZ2zv7W1NSMjI2c85j/+4z/yxS9+MdPT03nsscdy11135b777ssf//Efz/s4AwMDaWlpmd3a29trGRMAWEAu+KdpqtVqrrzyyjz44IPp6urK5s2b87GPfSz79++f95j+/v6MjY3NbsePH7/QYwIAhdR0z8iqVatSX1+f0dHROftHR0fT1tZ2xmNWr16diy66KPX19bP7rr/++oyMjGRqaioNDQ2vOqaxsTGNjY21jAYALFA1XRlpaGhIV1dXBgcHZ/dVq9UMDg5m48aNZzzmfe97X775zW+mWq3O7vv617+e1atXnzFEAIClpea3aSqVSh566KF87nOfy7PPPpuPfOQjmZiYmP10zZYtW9Lf3z+7/iMf+Uheeuml3H777fn617+eRx99NJ/4xCdy2223nb9nAQAsWDV/tHfz5s05efJkdu7cmZGRkXR2dubQoUOzN7UODw+nru6Vxmlvb89Xv/rV/O7v/m7e8Y53ZM2aNbn99ttzxx13nL9nAQAsWMtmZmZmSg/xesbHx9PS0pKxsbE0NzeXHgdgUenY8WjpESjs2O6bLsh5z/b3t++mAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUecUI/v27UtHR0eamprS09OTw4cPz7v2r/7qr7Js2bI5W1NT0zkPDAAsLjXHyMGDB1OpVLJr164cPXo069evz6ZNm3LixIl5j2lubs5//ud/zm4vvPDCGxoaAFg8ao6RPXv2ZNu2benr68vatWuzf//+rFixIgcOHJj3mGXLlqWtrW12a21tfUNDAwCLR00xMjU1lSNHjqS3t/eVE9TVpbe3N0NDQ/Me9/3vfz9vfetb097enl/4hV/I008//ZqPMzk5mfHx8TkbALA41RQjp06dyvT09KuubLS2tmZkZOSMx7z97W/PgQMH8nd/93f5whe+kGq1mve+97359re/Pe/jDAwMpKWlZXZrb2+vZUwAYAG54J+m2bhxY7Zs2ZLOzs7ccMMN+dKXvpQrrrgif/EXfzHvMf39/RkbG5vdjh8/fqHHBAAKWV7L4lWrVqW+vj6jo6Nz9o+Ojqatre2sznHRRRflne98Z775zW/Ou6axsTGNjY21jAYALFA1XRlpaGhIV1dXBgcHZ/dVq9UMDg5m48aNZ3WO6enpPPXUU1m9enVtkwIAi1JNV0aSpFKpZOvWrenu7s6GDRuyd+/eTExMpK+vL0myZcuWrFmzJgMDA0mSj3/843nPe96Ta665Jt/73vfyqU99Ki+88EI+/OEPn99nAgAsSDXHyObNm3Py5Mns3LkzIyMj6ezszKFDh2Zvah0eHk5d3SsXXL773e9m27ZtGRkZyWWXXZaurq488cQTWbt27fl7FgDAgrVsZmZmpvQQr2d8fDwtLS0ZGxtLc3Nz6XEAFpWOHY+WHoHCju2+6YKc92x/f/tuGgCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKCoc4qRffv2paOjI01NTenp6cnhw4fP6rhHHnkky5Yty80333wuDwsALEI1x8jBgwdTqVSya9euHD16NOvXr8+mTZty4sSJ1zzu2LFj+b3f+728//3vP+dhAYDFp+YY2bNnT7Zt25a+vr6sXbs2+/fvz4oVK3LgwIF5j5mens6v/uqv5u67787VV1/9hgYGABaXmmJkamoqR44cSW9v7ysnqKtLb29vhoaG5j3u4x//eK688srceuut5z4pALAoLa9l8alTpzI9PZ3W1tY5+1tbW/Pcc8+d8Zh/+qd/yl/+5V/mySefPOvHmZyczOTk5OzP4+PjtYwJACwgF/TTNKdPn86v/dqv5aGHHsqqVavO+riBgYG0tLTMbu3t7RdwSgCgpJqujKxatSr19fUZHR2ds390dDRtbW2vWv+tb30rx44dy4c+9KHZfdVq9QcPvHx5nn/++bztbW971XH9/f2pVCqzP4+PjwsSAFikaoqRhoaGdHV1ZXBwcPbjudVqNYODg9m+ffur1l933XV56qmn5uy78847c/r06dx///3zBkZjY2MaGxtrGQ0AWKBqipEkqVQq2bp1a7q7u7Nhw4bs3bs3ExMT6evrS5Js2bIla9asycDAQJqamrJu3bo5x69cuTJJXrUfAFiaao6RzZs35+TJk9m5c2dGRkbS2dmZQ4cOzd7UOjw8nLo6f9gVADg7y2ZmZmZKD/F6xsfH09LSkrGxsTQ3N5ceB2BR6djxaOkRKOzY7psuyHnP9ve3SxgAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEACjqnGJk37596ejoSFNTU3p6enL48OF5137pS19Kd3d3Vq5cmR/7sR9LZ2dnPv/5z5/zwADA4lJzjBw8eDCVSiW7du3K0aNHs379+mzatCknTpw44/q3vOUt+djHPpahoaH827/9W/r6+tLX15evfvWrb3h4AGDhWzYzMzNTywE9PT1597vfnQceeCBJUq1W097ent/6rd/Kjh07zuoc73rXu3LTTTflnnvuOav14+PjaWlpydjYWJqbm2sZF4DX0bHj0dIjUNix3TddkPOe7e/vmq6MTE1N5ciRI+nt7X3lBHV16e3tzdDQ0OsePzMzk8HBwTz//PP5mZ/5mVoeGgBYpJbXsvjUqVOZnp5Oa2vrnP2tra157rnn5j1ubGwsa9asyeTkZOrr6/Pnf/7n+eAHPzjv+snJyUxOTs7+PD4+XsuYAMACUlOMnKtLL700Tz75ZL7//e9ncHAwlUolV199dW688cYzrh8YGMjdd9/9ZowGABRWU4ysWrUq9fX1GR0dnbN/dHQ0bW1t8x5XV1eXa665JknS2dmZZ599NgMDA/PGSH9/fyqVyuzP4+PjaW9vr2VUAGCBqOmekYaGhnR1dWVwcHB2X7VazeDgYDZu3HjW56lWq3Pehvm/Ghsb09zcPGcDABanmt+mqVQq2bp1a7q7u7Nhw4bs3bs3ExMT6evrS5Js2bIla9asycDAQJIfvOXS3d2dt73tbZmcnMxjjz2Wz3/+8/nMZz5zfp8JALAg1RwjmzdvzsmTJ7Nz586MjIyks7Mzhw4dmr2pdXh4OHV1r1xwmZiYyEc/+tF8+9vfzsUXX5zrrrsuX/jCF7J58+bz9ywAgAWr5r8zUoK/MwJw4fg7IyyovzMCAHC+iREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRy0sPAEtdx45HS49AYcd231R6BCjKlREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABR1TjGyb9++dHR0pKmpKT09PTl8+PC8ax966KG8//3vz2WXXZbLLrssvb29r7keAFhaao6RgwcPplKpZNeuXTl69GjWr1+fTZs25cSJE2dc//jjj+eWW27JP/7jP2ZoaCjt7e352Z/92XznO995w8MDAAtfzTGyZ8+ebNu2LX19fVm7dm3279+fFStW5MCBA2dc/9d//df56Ec/ms7Ozlx33XX57Gc/m2q1msHBwTc8PACw8NUUI1NTUzly5Eh6e3tfOUFdXXp7ezM0NHRW53j55ZfzP//zP3nLW94y75rJycmMj4/P2QCAxammGDl16lSmp6fT2to6Z39ra2tGRkbO6hx33HFHrrrqqjlB838NDAykpaVldmtvb69lTABgAXlTP02ze/fuPPLII/nbv/3bNDU1zbuuv78/Y2Njs9vx48ffxCkBgDfT8loWr1q1KvX19RkdHZ2zf3R0NG1tba957L333pvdu3fnH/7hH/KOd7zjNdc2NjamsbGxltEAgAWqpisjDQ0N6erqmnPz6Q9vRt24ceO8x33yk5/MPffck0OHDqW7u/vcpwUAFp2arowkSaVSydatW9Pd3Z0NGzZk7969mZiYSF9fX5Jky5YtWbNmTQYGBpIkf/qnf5qdO3fm4YcfTkdHx+y9JZdcckkuueSS8/hUAICFqOYY2bx5c06ePJmdO3dmZGQknZ2dOXTo0OxNrcPDw6mre+WCy2c+85lMTU3ll3/5l+ecZ9euXfmjP/qjNzY9ALDg1RwjSbJ9+/Zs3779jP/2+OOPz/n52LFj5/IQAMAS4btpAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQ1DnFyL59+9LR0ZGmpqb09PTk8OHD8659+umn80u/9Evp6OjIsmXLsnfv3nOdFQBYhGqOkYMHD6ZSqWTXrl05evRo1q9fn02bNuXEiRNnXP/yyy/n6quvzu7du9PW1vaGBwYAFpeaY2TPnj3Ztm1b+vr6snbt2uzfvz8rVqzIgQMHzrj+3e9+dz71qU/lV37lV9LY2PiGBwYAFpeaYmRqaipHjhxJb2/vKyeoq0tvb2+GhobO21CTk5MZHx+fswEAi1NNMXLq1KlMT0+ntbV1zv7W1taMjIyct6EGBgbS0tIyu7W3t5+3cwMAP1p+JD9N09/fn7Gxsdnt+PHjpUcCAC6Q5bUsXrVqVerr6zM6Ojpn/+jo6Hm9ObWxsdH9JQCwRNR0ZaShoSFdXV0ZHByc3VetVjM4OJiNGzee9+EAgMWvpisjSVKpVLJ169Z0d3dnw4YN2bt3byYmJtLX15ck2bJlS9asWZOBgYEkP7jp9Zlnnpn97+985zt58sknc8kll+Saa645j08FAFiIao6RzZs35+TJk9m5c2dGRkbS2dmZQ4cOzd7UOjw8nLq6Vy64vPjii3nnO985+/O9996be++9NzfccEMef/zxN/4MAIAFreYYSZLt27dn+/btZ/y3/xsYHR0dmZmZOZeHeVN07Hi09AgUdmz3TaVHAFjSfiQ/TQMALB1iBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEACjqnGJk37596ejoSFNTU3p6enL48OHXXP83f/M3ue6669LU1JSf+qmfymOPPXZOwwIAi0/NMXLw4MFUKpXs2rUrR48ezfr167Np06acOHHijOufeOKJ3HLLLbn11lvzr//6r7n55ptz880359///d/f8PAAwMJXc4zs2bMn27ZtS19fX9auXZv9+/dnxYoVOXDgwBnX33///fm5n/u5/P7v/36uv/763HPPPXnXu96VBx544A0PDwAsfMtrWTw1NZUjR46kv79/dl9dXV16e3szNDR0xmOGhoZSqVTm7Nu0aVO+/OUvz/s4k5OTmZycnP15bGwsSTI+Pl7LuGelOvnyeT8nC8uFeF3VwmsQr0FKu1CvwR+ed2Zm5jXX1RQjp06dyvT0dFpbW+fsb21tzXPPPXfGY0ZGRs64fmRkZN7HGRgYyN133/2q/e3t7bWMC2elZW/pCVjqvAYp7UK/Bk+fPp2WlpZ5/72mGHmz9Pf3z7maUq1W89JLL+Xyyy/PsmXLCk62+IyPj6e9vT3Hjx9Pc3Nz6XFYgrwGKc1r8MKZmZnJ6dOnc9VVV73muppiZNWqVamvr8/o6Oic/aOjo2lrazvjMW1tbTWtT5LGxsY0NjbO2bdy5cpaRqVGzc3N/iekKK9BSvMavDBe64rID9V0A2tDQ0O6uroyODg4u69arWZwcDAbN2484zEbN26csz5J/v7v/37e9QDA0lLz2zSVSiVbt25Nd3d3NmzYkL1792ZiYiJ9fX1Jki1btmTNmjUZGBhIktx+++254YYbct999+Wmm27KI488kq997Wt58MEHz+8zAQAWpJpjZPPmzTl58mR27tyZkZGRdHZ25tChQ7M3qQ4PD6eu7pULLu9973vz8MMP584778wf/uEf5tprr82Xv/zlrFu37vw9C85ZY2Njdu3a9aq3xeDN4jVIaV6D5S2beb3P2wAAXEC+mwYAKEqMAABFiREAoCgxAhRx44035nd+53dKjwH8CBAjAEBRYgQAKEqMLFGHDh3KT//0T2flypW5/PLL8/M///P51re+VXoslpj//d//zfbt29PS0pJVq1blrrvuet1v94TzqVqt5pOf/GSuueaaNDY25id+4ifyJ3/yJ6XHWnLEyBI1MTGRSqWSr33taxkcHExdXV1+8Rd/MdVqtfRoLCGf+9znsnz58hw+fDj3339/9uzZk89+9rOlx2IJ6e/vz+7du3PXXXflmWeeycMPP/yqb5rnwvNHz0iSnDp1KldccUWeeuopfx2XN8WNN96YEydO5Omnn579Nu4dO3bkK1/5Sp555pnC07EUnD59OldccUUeeOCBfPjDHy49zpLmysgS9Y1vfCO33HJLrr766jQ3N6ejoyPJD/6cP7xZ3vOe98yGSPKDL9b8xje+kenp6YJTsVQ8++yzmZyczAc+8IHSoyx5NX83DYvDhz70obz1rW/NQw89lKuuuirVajXr1q3L1NRU6dEA3hQXX3xx6RH4/1wZWYL+67/+K88//3zuvPPOfOADH8j111+f7373u6XHYgn6l3/5lzk///M//3Ouvfba1NfXF5qIpeTaa6/NxRdfnMHBwdKjLHmujCxBl112WS6//PI8+OCDWb16dYaHh7Njx47SY7EEDQ8Pp1Kp5Dd+4zdy9OjR/Nmf/Vnuu+++0mOxRDQ1NeWOO+7IH/zBH6ShoSHve9/7cvLkyTz99NO59dZbS4+3pIiRJaiuri6PPPJIfvu3fzvr1q3L29/+9nz605/OjTfeWHo0lpgtW7bkv//7v7Nhw4bU19fn9ttvz6//+q+XHosl5K677sry5cuzc+fOvPjii1m9enV+8zd/s/RYS45P0wAARblnBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAU9f8A5CVuX60XFPsAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# estimating P(A)    \n",
    "P_A = estimate_arm_probs(control_probs, env)\n",
    "\n",
    "plot_PA(P_A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# interventions here (if only reordering the amrs) will only shift the probabilities of the arms\n",
    "from copy import deepcopy\n",
    "\n",
    "env_ = deepcopy(env)\n",
    "env_.intervene(\n",
    "    {\n",
    "        0: torch.distributions.Categorical(torch.tensor([0, 0, 1], dtype=torch.float32)),\n",
    "        1: torch.distributions.Categorical(torch.tensor([0, 1, 0], dtype=torch.float32)),\n",
    "        2: torch.distributions.Categorical(torch.tensor([1, 0, 0], dtype=torch.float32)),\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor(0.0900), tensor(0.2447), tensor(0.6652)]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAajUlEQVR4nO3df2xd913/8Vfs1HZDa6drWjsNZlbXsjYKizd78bIxWmkeQSoTRSCFCpHI6gJsDRSugMasTegKc9jaKB0NC+2ImDaqRkxjTGqVCSz6B6ohWkJF6a/9QKmzFTsJ3ezMRTb4+vvHvnNlGre5adLPbD8e0pHqk885932lW/mpc8/1XTYzMzMTAIBC6koPAAAsbWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKWl56gLNRrVbz4osv5tJLL82yZctKjwMAnIWZmZmcPn06V111Verq5r/+sSBi5MUXX0x7e3vpMQCAc3D8+PH8+I//+Lz/viBi5NJLL03ygyfT3NxceBoA4GyMj4+nvb199vf4fBZEjPzwrZnm5mYxAgALzOvdYuEGVgCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUctLDwBAWR07Hi09AoUd231T0cd3ZQQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQ1DnFyL59+9LR0ZGmpqb09PTk8OHDr7n+e9/7Xm677basXr06jY2N+cmf/Mk89thj5zQwALC4LK/1gIMHD6ZSqWT//v3p6enJ3r17s2nTpjz//PO58sorX7V+amoqH/zgB3PllVfmi1/8YtasWZMXXnghK1euPB/zAwALXM0xsmfPnmzbti19fX1Jkv379+fRRx/NgQMHsmPHjletP3DgQF566aU88cQTueiii5IkHR0db2xqAGDRqOltmqmpqRw5ciS9vb2vnKCuLr29vRkaGjrjMV/5yleycePG3HbbbWltbc26devyiU98ItPT0/M+zuTkZMbHx+dsAMDiVFOMnDp1KtPT02ltbZ2zv7W1NSMjI2c85j/+4z/yxS9+MdPT03nsscdy11135b777ssf//Efz/s4AwMDaWlpmd3a29trGRMAWEAu+KdpqtVqrrzyyjz44IPp6urK5s2b87GPfSz79++f95j+/v6MjY3NbsePH7/QYwIAhdR0z8iqVatSX1+f0dHROftHR0fT1tZ2xmNWr16diy66KPX19bP7rr/++oyMjGRqaioNDQ2vOqaxsTGNjY21jAYALFA1XRlpaGhIV1dXBgcHZ/dVq9UMDg5m48aNZzzmfe97X775zW+mWq3O7vv617+e1atXnzFEAIClpea3aSqVSh566KF87nOfy7PPPpuPfOQjmZiYmP10zZYtW9Lf3z+7/iMf+Uheeuml3H777fn617+eRx99NJ/4xCdy2223nb9nAQAsWDV/tHfz5s05efJkdu7cmZGRkXR2dubQoUOzN7UODw+nru6Vxmlvb89Xv/rV/O7v/m7e8Y53ZM2aNbn99ttzxx13nL9nAQAsWMtmZmZmSg/xesbHx9PS0pKxsbE0NzeXHgdgUenY8WjpESjs2O6bLsh5z/b3t++mAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUecUI/v27UtHR0eamprS09OTw4cPz7v2r/7qr7Js2bI5W1NT0zkPDAAsLjXHyMGDB1OpVLJr164cPXo069evz6ZNm3LixIl5j2lubs5//ud/zm4vvPDCGxoaAFg8ao6RPXv2ZNu2benr68vatWuzf//+rFixIgcOHJj3mGXLlqWtrW12a21tfUNDAwCLR00xMjU1lSNHjqS3t/eVE9TVpbe3N0NDQ/Me9/3vfz9vfetb097enl/4hV/I008//ZqPMzk5mfHx8TkbALA41RQjp06dyvT09KuubLS2tmZkZOSMx7z97W/PgQMH8nd/93f5whe+kGq1mve+97359re/Pe/jDAwMpKWlZXZrb2+vZUwAYAG54J+m2bhxY7Zs2ZLOzs7ccMMN+dKXvpQrrrgif/EXfzHvMf39/RkbG5vdjh8/fqHHBAAKWV7L4lWrVqW+vj6jo6Nz9o+Ojqatre2sznHRRRflne98Z775zW/Ou6axsTGNjY21jAYALFA1XRlpaGhIV1dXBgcHZ/dVq9UMDg5m48aNZ3WO6enpPPXUU1m9enVtkwIAi1JNV0aSpFKpZOvWrenu7s6GDRuyd+/eTExMpK+vL0myZcuWrFmzJgMDA0mSj3/843nPe96Ta665Jt/73vfyqU99Ki+88EI+/OEPn99nAgAsSDXHyObNm3Py5Mns3LkzIyMj6ezszKFDh2Zvah0eHk5d3SsXXL773e9m27ZtGRkZyWWXXZaurq488cQTWbt27fl7FgDAgrVsZmZmpvQQr2d8fDwtLS0ZGxtLc3Nz6XEAFpWOHY+WHoHCju2+6YKc92x/f/tuGgCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKCoc4qRffv2paOjI01NTenp6cnhw4fP6rhHHnkky5Yty80333wuDwsALEI1x8jBgwdTqVSya9euHD16NOvXr8+mTZty4sSJ1zzu2LFj+b3f+728//3vP+dhAYDFp+YY2bNnT7Zt25a+vr6sXbs2+/fvz4oVK3LgwIF5j5mens6v/uqv5u67787VV1/9hgYGABaXmmJkamoqR44cSW9v7ysnqKtLb29vhoaG5j3u4x//eK688srceuut5z4pALAoLa9l8alTpzI9PZ3W1tY5+1tbW/Pcc8+d8Zh/+qd/yl/+5V/mySefPOvHmZyczOTk5OzP4+PjtYwJACwgF/TTNKdPn86v/dqv5aGHHsqqVavO+riBgYG0tLTMbu3t7RdwSgCgpJqujKxatSr19fUZHR2ds390dDRtbW2vWv+tb30rx44dy4c+9KHZfdVq9QcPvHx5nn/++bztbW971XH9/f2pVCqzP4+PjwsSAFikaoqRhoaGdHV1ZXBwcPbjudVqNYODg9m+ffur1l933XV56qmn5uy78847c/r06dx///3zBkZjY2MaGxtrGQ0AWKBqipEkqVQq2bp1a7q7u7Nhw4bs3bs3ExMT6evrS5Js2bIla9asycDAQJqamrJu3bo5x69cuTJJXrUfAFiaao6RzZs35+TJk9m5c2dGRkbS2dmZQ4cOzd7UOjw8nLo6f9gVADg7y2ZmZmZKD/F6xsfH09LSkrGxsTQ3N5ceB2BR6djxaOkRKOzY7psuyHnP9ve3SxgAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEACjqnGJk37596ejoSFNTU3p6enL48OF5137pS19Kd3d3Vq5cmR/7sR9LZ2dnPv/5z5/zwADA4lJzjBw8eDCVSiW7du3K0aNHs379+mzatCknTpw44/q3vOUt+djHPpahoaH827/9W/r6+tLX15evfvWrb3h4AGDhWzYzMzNTywE9PT1597vfnQceeCBJUq1W097ent/6rd/Kjh07zuoc73rXu3LTTTflnnvuOav14+PjaWlpydjYWJqbm2sZF4DX0bHj0dIjUNix3TddkPOe7e/vmq6MTE1N5ciRI+nt7X3lBHV16e3tzdDQ0OsePzMzk8HBwTz//PP5mZ/5mVoeGgBYpJbXsvjUqVOZnp5Oa2vrnP2tra157rnn5j1ubGwsa9asyeTkZOrr6/Pnf/7n+eAHPzjv+snJyUxOTs7+PD4+XsuYAMACUlOMnKtLL700Tz75ZL7//e9ncHAwlUolV199dW688cYzrh8YGMjdd9/9ZowGABRWU4ysWrUq9fX1GR0dnbN/dHQ0bW1t8x5XV1eXa665JknS2dmZZ599NgMDA/PGSH9/fyqVyuzP4+PjaW9vr2VUAGCBqOmekYaGhnR1dWVwcHB2X7VazeDgYDZu3HjW56lWq3Pehvm/Ghsb09zcPGcDABanmt+mqVQq2bp1a7q7u7Nhw4bs3bs3ExMT6evrS5Js2bIla9asycDAQJIfvOXS3d2dt73tbZmcnMxjjz2Wz3/+8/nMZz5zfp8JALAg1RwjmzdvzsmTJ7Nz586MjIyks7Mzhw4dmr2pdXh4OHV1r1xwmZiYyEc/+tF8+9vfzsUXX5zrrrsuX/jCF7J58+bz9ywAgAWr5r8zUoK/MwJw4fg7IyyovzMCAHC+iREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRy0sPAEtdx45HS49AYcd231R6BCjKlREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABR1TjGyb9++dHR0pKmpKT09PTl8+PC8ax966KG8//3vz2WXXZbLLrssvb29r7keAFhaao6RgwcPplKpZNeuXTl69GjWr1+fTZs25cSJE2dc//jjj+eWW27JP/7jP2ZoaCjt7e352Z/92XznO995w8MDAAtfzTGyZ8+ebNu2LX19fVm7dm3279+fFStW5MCBA2dc/9d//df56Ec/ms7Ozlx33XX57Gc/m2q1msHBwTc8PACw8NUUI1NTUzly5Eh6e3tfOUFdXXp7ezM0NHRW53j55ZfzP//zP3nLW94y75rJycmMj4/P2QCAxammGDl16lSmp6fT2to6Z39ra2tGRkbO6hx33HFHrrrqqjlB838NDAykpaVldmtvb69lTABgAXlTP02ze/fuPPLII/nbv/3bNDU1zbuuv78/Y2Njs9vx48ffxCkBgDfT8loWr1q1KvX19RkdHZ2zf3R0NG1tba957L333pvdu3fnH/7hH/KOd7zjNdc2NjamsbGxltEAgAWqpisjDQ0N6erqmnPz6Q9vRt24ceO8x33yk5/MPffck0OHDqW7u/vcpwUAFp2arowkSaVSydatW9Pd3Z0NGzZk7969mZiYSF9fX5Jky5YtWbNmTQYGBpIkf/qnf5qdO3fm4YcfTkdHx+y9JZdcckkuueSS8/hUAICFqOYY2bx5c06ePJmdO3dmZGQknZ2dOXTo0OxNrcPDw6mre+WCy2c+85lMTU3ll3/5l+ecZ9euXfmjP/qjNzY9ALDg1RwjSbJ9+/Zs3779jP/2+OOPz/n52LFj5/IQAMAS4btpAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQ1DnFyL59+9LR0ZGmpqb09PTk8OHD8659+umn80u/9Evp6OjIsmXLsnfv3nOdFQBYhGqOkYMHD6ZSqWTXrl05evRo1q9fn02bNuXEiRNnXP/yyy/n6quvzu7du9PW1vaGBwYAFpeaY2TPnj3Ztm1b+vr6snbt2uzfvz8rVqzIgQMHzrj+3e9+dz71qU/lV37lV9LY2PiGBwYAFpeaYmRqaipHjhxJb2/vKyeoq0tvb2+GhobO21CTk5MZHx+fswEAi1NNMXLq1KlMT0+ntbV1zv7W1taMjIyct6EGBgbS0tIyu7W3t5+3cwMAP1p+JD9N09/fn7Gxsdnt+PHjpUcCAC6Q5bUsXrVqVerr6zM6Ojpn/+jo6Hm9ObWxsdH9JQCwRNR0ZaShoSFdXV0ZHByc3VetVjM4OJiNGzee9+EAgMWvpisjSVKpVLJ169Z0d3dnw4YN2bt3byYmJtLX15ck2bJlS9asWZOBgYEkP7jp9Zlnnpn97+985zt58sknc8kll+Saa645j08FAFiIao6RzZs35+TJk9m5c2dGRkbS2dmZQ4cOzd7UOjw8nLq6Vy64vPjii3nnO985+/O9996be++9NzfccEMef/zxN/4MAIAFreYYSZLt27dn+/btZ/y3/xsYHR0dmZmZOZeHeVN07Hi09AgUdmz3TaVHAFjSfiQ/TQMALB1iBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEACjqnGJk37596ejoSFNTU3p6enL48OHXXP83f/M3ue6669LU1JSf+qmfymOPPXZOwwIAi0/NMXLw4MFUKpXs2rUrR48ezfr167Np06acOHHijOufeOKJ3HLLLbn11lvzr//6r7n55ptz880359///d/f8PAAwMJXc4zs2bMn27ZtS19fX9auXZv9+/dnxYoVOXDgwBnX33///fm5n/u5/P7v/36uv/763HPPPXnXu96VBx544A0PDwAsfMtrWTw1NZUjR46kv79/dl9dXV16e3szNDR0xmOGhoZSqVTm7Nu0aVO+/OUvz/s4k5OTmZycnP15bGwsSTI+Pl7LuGelOvnyeT8nC8uFeF3VwmsQr0FKu1CvwR+ed2Zm5jXX1RQjp06dyvT0dFpbW+fsb21tzXPPPXfGY0ZGRs64fmRkZN7HGRgYyN133/2q/e3t7bWMC2elZW/pCVjqvAYp7UK/Bk+fPp2WlpZ5/72mGHmz9Pf3z7maUq1W89JLL+Xyyy/PsmXLCk62+IyPj6e9vT3Hjx9Pc3Nz6XFYgrwGKc1r8MKZmZnJ6dOnc9VVV73muppiZNWqVamvr8/o6Oic/aOjo2lrazvjMW1tbTWtT5LGxsY0NjbO2bdy5cpaRqVGzc3N/iekKK9BSvMavDBe64rID9V0A2tDQ0O6uroyODg4u69arWZwcDAbN2484zEbN26csz5J/v7v/37e9QDA0lLz2zSVSiVbt25Nd3d3NmzYkL1792ZiYiJ9fX1Jki1btmTNmjUZGBhIktx+++254YYbct999+Wmm27KI488kq997Wt58MEHz+8zAQAWpJpjZPPmzTl58mR27tyZkZGRdHZ25tChQ7M3qQ4PD6eu7pULLu9973vz8MMP584778wf/uEf5tprr82Xv/zlrFu37vw9C85ZY2Njdu3a9aq3xeDN4jVIaV6D5S2beb3P2wAAXEC+mwYAKEqMAABFiREAoCgxAhRx44035nd+53dKjwH8CBAjAEBRYgQAKEqMLFGHDh3KT//0T2flypW5/PLL8/M///P51re+VXoslpj//d//zfbt29PS0pJVq1blrrvuet1v94TzqVqt5pOf/GSuueaaNDY25id+4ifyJ3/yJ6XHWnLEyBI1MTGRSqWSr33taxkcHExdXV1+8Rd/MdVqtfRoLCGf+9znsnz58hw+fDj3339/9uzZk89+9rOlx2IJ6e/vz+7du3PXXXflmWeeycMPP/yqb5rnwvNHz0iSnDp1KldccUWeeuopfx2XN8WNN96YEydO5Omnn579Nu4dO3bkK1/5Sp555pnC07EUnD59OldccUUeeOCBfPjDHy49zpLmysgS9Y1vfCO33HJLrr766jQ3N6ejoyPJD/6cP7xZ3vOe98yGSPKDL9b8xje+kenp6YJTsVQ8++yzmZyczAc+8IHSoyx5NX83DYvDhz70obz1rW/NQw89lKuuuirVajXr1q3L1NRU6dEA3hQXX3xx6RH4/1wZWYL+67/+K88//3zuvPPOfOADH8j111+f7373u6XHYgn6l3/5lzk///M//3Ouvfba1NfXF5qIpeTaa6/NxRdfnMHBwdKjLHmujCxBl112WS6//PI8+OCDWb16dYaHh7Njx47SY7EEDQ8Pp1Kp5Dd+4zdy9OjR/Nmf/Vnuu+++0mOxRDQ1NeWOO+7IH/zBH6ShoSHve9/7cvLkyTz99NO59dZbS4+3pIiRJaiuri6PPPJIfvu3fzvr1q3L29/+9nz605/OjTfeWHo0lpgtW7bkv//7v7Nhw4bU19fn9ttvz6//+q+XHosl5K677sry5cuzc+fOvPjii1m9enV+8zd/s/RYS45P0wAARblnBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAU9f8A5CVuX60XFPsAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# step 1\n",
    "P_A = estimate_arm_probs(control_probs, env)\n",
    "plot_PA(P_A)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "tensor(1) tensor([ 0.9892,  0.6118, -1.1675])\n",
      "tensor(2) tensor([ 0.2595,  0.1317, -0.7224])\n"
     ]
    }
   ],
   "source": [
    "# step 2 \n",
    "\n",
    "# abduction on noise of the composition function P_arms\n",
    "arm_id = env_.arm_keys.index(arm)\n",
    "print(arm_id)\n",
    "_, arm_g = gumbel_max_rejection_sampling(\n",
    "    torch.tensor(list(P_A.values())), arm_id, max_iterations=100000\n",
    ") \n",
    "# noise that gets us abstraction 'c' given the control '2' and the original enironment (x)\n",
    "# this is the noise for the 'combined mechanism' here "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor(0.6652), tensor(0.2447), tensor(0.0900)]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAalElEQVR4nO3df2xd913/8Vfs1HZDa6erWycNZlbXsjYKi7dk8bwxWmkeQSoTRSCZCuHI6gJsDRSugMasdegKc9jaKB0NC+2ImDaqRkxjTGqVCSz6B6ohWkJF6a/9QGmyFTsJ3ewsRTb4+vvHvnNlGre5adLPbD8e0pHqk885932lW/mpc8/1XTYzMzMTAIBC6koPAAAsbWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKWl56gLNRrVbz4osv5tJLL82yZctKjwMAnIWZmZmcOnUqV111Verq5r/+sSBi5MUXX0x7e3vpMQCAc3Ds2LH8+I//+Lz/viBi5NJLL03ygyfT3NxceBoA4GxMTEykvb199vf4fBZEjPzwrZnm5mYxAgALzOvdYuEGVgCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUctLD1Bax/ZHS49AYUd23lR6BIAlzZURAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFHnFCN79uxJR0dHmpqa0tXVlYMHD77m+u9973u57bbbsnr16jQ2NuYnf/In89hjj53TwADA4rK81gP279+fSqWSvXv3pqurK7t3787mzZvz/PPP58orr3zV+qmpqXzwgx/MlVdemS9+8YtZs2ZNXnjhhaxcufJ8zA8ALHA1x8iuXbuydevW9Pf3J0n27t2bRx99NPv27cv27dtftX7fvn156aWX8sQTT+Siiy5KknR0dLyxqQGARaOmt2mmpqZy6NCh9PT0vHKCurr09PRkZGTkjMd85StfSXd3d2677ba0tbVl3bp1+cQnPpHp6el5H2dycjITExNzNgBgcaopRk6ePJnp6em0tbXN2d/W1pbR0dEzHvMf//Ef+eIXv5jp6ek89thjueuuu3Lfffflj//4j+d9nKGhobS0tMxu7e3ttYwJACwgF/zTNNVqNVdeeWUefPDBbNiwIb29vfnYxz6WvXv3znvMwMBAxsfHZ7djx45d6DEBgEJqumektbU19fX1GRsbm7N/bGwsq1atOuMxq1evzkUXXZT6+vrZfddff31GR0czNTWVhoaGVx3T2NiYxsbGWkYDABaomq6MNDQ0ZMOGDRkeHp7dV61WMzw8nO7u7jMe8773vS/f/OY3U61WZ/d9/etfz+rVq88YIgDA0lLz2zSVSiUPPfRQPve5z+XZZ5/NRz7ykZw+fXr20zV9fX0ZGBiYXf+Rj3wkL730Um6//fZ8/etfz6OPPppPfOITue22287fswAAFqyaP9rb29ubEydOZHBwMKOjo+ns7MyBAwdmb2o9evRo6upeaZz29vZ89atfze/+7u/mHe94R9asWZPbb789d9xxx/l7FgDAgrVsZmZmpvQQr2diYiItLS0ZHx9Pc3PzeT13x/ZHz+v5WHiO7Lyp9AgAi9LZ/v723TQAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEACjqnGJkz5496ejoSFNTU7q6unLw4MF51/7VX/1Vli1bNmdramo654EBgMWl5hjZv39/KpVKduzYkcOHD2f9+vXZvHlzjh8/Pu8xzc3N+c///M/Z7YUXXnhDQwMAi0fNMbJr165s3bo1/f39Wbt2bfbu3ZsVK1Zk37598x6zbNmyrFq1anZra2t7Q0MDAItHTTEyNTWVQ4cOpaen55UT1NWlp6cnIyMj8x73/e9/P29961vT3t6eX/iFX8jTTz/9mo8zOTmZiYmJORsAsDjVFCMnT57M9PT0q65stLW1ZXR09IzHvP3tb8++ffvyd3/3d/nCF76QarWa9773vfn2t7897+MMDQ2lpaVldmtvb69lTABgAbngn6bp7u5OX19fOjs7c8MNN+RLX/pSrrjiivzFX/zFvMcMDAxkfHx8djt27NiFHhMAKGR5LYtbW1tTX1+fsbGxOfvHxsayatWqszrHRRddlHe+85355je/Oe+axsbGNDY21jIaALBA1XRlpKGhIRs2bMjw8PDsvmq1muHh4XR3d5/VOaanp/PUU09l9erVtU0KACxKNV0ZSZJKpZItW7Zk48aN2bRpU3bv3p3Tp0+nv78/SdLX15c1a9ZkaGgoSfLxj38873nPe3LNNdfke9/7Xj71qU/lhRdeyIc//OHz+0wAgAWp5hjp7e3NiRMnMjg4mNHR0XR2dubAgQOzN7UePXo0dXWvXHD57ne/m61bt2Z0dDSXXXZZNmzYkCeeeCJr1649f88CAFiwls3MzMyUHuL1TExMpKWlJePj42lubj6v5+7Y/uh5PR8Lz5GdN5UeAWBROtvf376bBgAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEACjqnGJkz5496ejoSFNTU7q6unLw4MGzOu6RRx7JsmXLcvPNN5/LwwIAi1DNMbJ///5UKpXs2LEjhw8fzvr167N58+YcP378NY87cuRIfu/3fi/vf//7z3lYAGDxqTlGdu3ala1bt6a/vz9r167N3r17s2LFiuzbt2/eY6anp/Orv/qrufvuu3P11Ve/oYEBgMWlphiZmprKoUOH0tPT88oJ6urS09OTkZGReY/7+Mc/niuvvDK33nrruU8KACxKy2tZfPLkyUxPT6etrW3O/ra2tjz33HNnPOaf/umf8pd/+Zd58sknz/pxJicnMzk5OfvzxMRELWMCAAvIBf00zalTp/Jrv/Zreeihh9La2nrWxw0NDaWlpWV2a29vv4BTAgAl1XRlpLW1NfX19RkbG5uzf2xsLKtWrXrV+m9961s5cuRIPvShD83uq1arP3jg5cvz/PPP521ve9urjhsYGEilUpn9eWJiQpAAwCJVU4w0NDRkw4YNGR4env14brVazfDwcLZt2/aq9dddd12eeuqpOfvuvPPOnDp1Kvfff/+8gdHY2JjGxsZaRgMAFqiaYiRJKpVKtmzZko0bN2bTpk3ZvXt3Tp8+nf7+/iRJX19f1qxZk6GhoTQ1NWXdunVzjl+5cmWSvGo/ALA01Rwjvb29OXHiRAYHBzM6OprOzs4cOHBg9qbWo0ePpq7OH3YFAM7OspmZmZnSQ7yeiYmJtLS0ZHx8PM3Nzef13B3bHz2v52PhObLzptIjACxKZ/v72yUMAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUdU4xsmfPnnR0dKSpqSldXV05ePDgvGu/9KUvZePGjVm5cmV+7Md+LJ2dnfn85z9/zgMDAItLzTGyf//+VCqV7NixI4cPH8769euzefPmHD9+/Izr3/KWt+RjH/tYRkZG8m//9m/p7+9Pf39/vvrVr77h4QGAhW/ZzMzMTC0HdHV15d3vfnceeOCBJEm1Wk17e3t+67d+K9u3bz+rc7zrXe/KTTfdlHvuuees1k9MTKSlpSXj4+Npbm6uZdzX1bH90fN6PhaeIztvKj0CwKJ0tr+/a7oyMjU1lUOHDqWnp+eVE9TVpaenJyMjI697/MzMTIaHh/P888/nZ37mZ2p5aABgkVpey+KTJ09meno6bW1tc/a3tbXlueeem/e48fHxrFmzJpOTk6mvr8+f//mf54Mf/OC86ycnJzM5OTn788TERC1jAgALSE0xcq4uvfTSPPnkk/n+97+f4eHhVCqVXH311bnxxhvPuH5oaCh33333mzEaAFBYTTHS2tqa+vr6jI2Nzdk/NjaWVatWzXtcXV1drrnmmiRJZ2dnnn322QwNDc0bIwMDA6lUKrM/T0xMpL29vZZRAYAFoqZ7RhoaGrJhw4YMDw/P7qtWqxkeHk53d/dZn6darc55G+b/amxsTHNz85wNAFican6bplKpZMuWLdm4cWM2bdqU3bt35/Tp0+nv70+S9PX1Zc2aNRkaGkryg7dcNm7cmLe97W2ZnJzMY489ls9//vP5zGc+c36fCQCwINUcI729vTlx4kQGBwczOjqazs7OHDhwYPam1qNHj6au7pULLqdPn85HP/rRfPvb387FF1+c6667Ll/4whfS29t7/p4FALBg1fx3Rkrwd0a4kPydEYAL44L8nREAgPNNjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIpaXnoAWOo6tj9aegQKO7LzptIjQFGujAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoKhzipE9e/ako6MjTU1N6erqysGDB+dd+9BDD+X9739/Lrvsslx22WXp6el5zfUAwNJSc4zs378/lUolO3bsyOHDh7N+/fps3rw5x48fP+P6xx9/PLfcckv+8R//MSMjI2lvb8/P/uzP5jvf+c4bHh4AWPhqjpFdu3Zl69at6e/vz9q1a7N3796sWLEi+/btO+P6v/7rv85HP/rRdHZ25rrrrstnP/vZVKvVDA8Pv+HhAYCFr6YYmZqayqFDh9LT0/PKCerq0tPTk5GRkbM6x8svv5z/+Z//yVve8pZ510xOTmZiYmLOBgAsTjXFyMmTJzM9PZ22trY5+9va2jI6OnpW57jjjjty1VVXzQma/2toaCgtLS2zW3t7ey1jAgALyJv6aZqdO3fmkUceyd/+7d+mqalp3nUDAwMZHx+f3Y4dO/YmTgkAvJmW17K4tbU19fX1GRsbm7N/bGwsq1ates1j77333uzcuTP/8A//kHe84x2vubaxsTGNjY21jAYALFA1XRlpaGjIhg0b5tx8+sObUbu7u+c97pOf/GTuueeeHDhwIBs3bjz3aQGARaemKyNJUqlUsmXLlmzcuDGbNm3K7t27c/r06fT39ydJ+vr6smbNmgwNDSVJ/vRP/zSDg4N5+OGH09HRMXtvySWXXJJLLrnkPD4VAGAhqjlGent7c+LEiQwODmZ0dDSdnZ05cODA7E2tR48eTV3dKxdcPvOZz2Rqaiq//Mu/POc8O3bsyB/90R+9sekBgAWv5hhJkm3btmXbtm1n/LfHH398zs9Hjhw5l4cAAJYI300DABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICizilG9uzZk46OjjQ1NaWrqysHDx6cd+3TTz+dX/qlX0pHR0eWLVuW3bt3n+usAMAiVHOM7N+/P5VKJTt27Mjhw4ezfv36bN68OcePHz/j+pdffjlXX311du7cmVWrVr3hgQGAxaXmGNm1a1e2bt2a/v7+rF27Nnv37s2KFSuyb9++M65/97vfnU996lP5lV/5lTQ2Nr7hgQGAxaWmGJmamsqhQ4fS09Pzygnq6tLT05ORkZHzNtTk5GQmJibmbADA4lRTjJw8eTLT09Npa2ubs7+trS2jo6PnbaihoaG0tLTMbu3t7eft3ADAj5YfyU/TDAwMZHx8fHY7duxY6ZEAgAtkeS2LW1tbU19fn7GxsTn7x8bGzuvNqY2Nje4vAYAloqYrIw0NDdmwYUOGh4dn91Wr1QwPD6e7u/u8DwcALH41XRlJkkqlki1btmTjxo3ZtGlTdu/endOnT6e/vz9J0tfXlzVr1mRoaCjJD256feaZZ2b/+zvf+U6efPLJXHLJJbnmmmvO41MBABaimmOkt7c3J06cyODgYEZHR9PZ2ZkDBw7M3tR69OjR1NW9csHlxRdfzDvf+c7Zn++9997ce++9ueGGG/L444+/8WcAACxoNcdIkmzbti3btm0747/938Do6OjIzMzMuTwMAG+Cju2Plh6Bwo7svKno4/9IfpoGAFg6xAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQ1DnFyJ49e9LR0ZGmpqZ0dXXl4MGDr7n+b/7mb3LdddelqakpP/VTP5XHHnvsnIYFABafmmNk//79qVQq2bFjRw4fPpz169dn8+bNOX78+BnXP/HEE7nlllty66235l//9V9z88035+abb86///u/v+HhAYCFr+YY2bVrV7Zu3Zr+/v6sXbs2e/fuzYoVK7Jv374zrr///vvzcz/3c/n93//9XH/99bnnnnvyrne9Kw888MAbHh4AWPiW17J4amoqhw4dysDAwOy+urq69PT0ZGRk5IzHjIyMpFKpzNm3efPmfPnLX573cSYnJzM5OTn78/j4eJJkYmKilnHPSnXy5fN+ThaWC/G6qoXXIF6DlHahXoM/PO/MzMxrrqspRk6ePJnp6em0tbXN2d/W1pbnnnvujMeMjo6ecf3o6Oi8jzM0NJS77777Vfvb29trGRfOSsvu0hOw1HkNUtqFfg2eOnUqLS0t8/57TTHyZhkYGJhzNaVareall17K5ZdfnmXLlhWcbPGZmJhIe3t7jh07lubm5tLjsAR5DVKa1+CFMzMzk1OnTuWqq656zXU1xUhra2vq6+szNjY2Z//Y2FhWrVp1xmNWrVpV0/okaWxsTGNj45x9K1eurGVUatTc3Ox/QoryGqQ0r8EL47WuiPxQTTewNjQ0ZMOGDRkeHp7dV61WMzw8nO7u7jMe093dPWd9kvz93//9vOsBgKWl5rdpKpVKtmzZko0bN2bTpk3ZvXt3Tp8+nf7+/iRJX19f1qxZk6GhoSTJ7bffnhtuuCH33XdfbrrppjzyyCP52te+lgcffPD8PhMAYEGqOUZ6e3tz4sSJDA4OZnR0NJ2dnTlw4MDsTapHjx5NXd0rF1ze+9735uGHH86dd96ZP/zDP8y1116bL3/5y1m3bt35exacs8bGxuzYseNVb4vBm8VrkNK8BstbNvN6n7cBALiAfDcNAFCUGAEAihIjAEBRYgQo4sYbb8zv/M7vlB4D+BEgRgCAosQIAFCUGFmiDhw4kJ/+6Z/OypUrc/nll+fnf/7n861vfav0WCwx//u//5tt27alpaUlra2tueuuu1732z3hfKpWq/nkJz+Za665Jo2NjfmJn/iJ/Mmf/EnpsZYcMbJEnT59OpVKJV/72tcyPDycurq6/OIv/mKq1Wrp0VhCPve5z2X58uU5ePBg7r///uzatSuf/exnS4/FEjIwMJCdO3fmrrvuyjPPPJOHH374Vd80z4Xnj56RJDl58mSuuOKKPPXUU/46Lm+KG2+8McePH8/TTz89+23c27dvz1e+8pU888wzhadjKTh16lSuuOKKPPDAA/nwhz9cepwlzZWRJeob3/hGbrnlllx99dVpbm5OR0dHkh/8OX94s7znPe+ZDZHkB1+s+Y1vfCPT09MFp2KpePbZZzM5OZkPfOADpUdZ8mr+bhoWhw996EN561vfmoceeihXXXVVqtVq1q1bl6mpqdKjAbwpLr744tIj8P+5MrIE/dd//Veef/753HnnnfnABz6Q66+/Pt/97ndLj8US9C//8i9zfv7nf/7nXHvttamvry80EUvJtddem4svvjjDw8OlR1nyXBlZgi677LJcfvnlefDBB7N69eocPXo027dvLz0WS9DRo0dTqVTyG7/xGzl8+HD+7M/+LPfdd1/psVgimpqacscdd+QP/uAP0tDQkPe97305ceJEnn766dx6662lx1tSxMgSVFdXl0ceeSS//du/nXXr1uXtb397Pv3pT+fGG28sPRpLTF9fX/77v/87mzZtSn19fW6//fb8+q//eumxWELuuuuuLF++PIODg3nxxRezevXq/OZv/mbpsZYcn6YBAIpyzwgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKOr/AWCZbl9m4kKQAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "cf_logits = deepcopy(logits)    # simplified case, the policy outputs the same logits \n",
    "cf_control_probs = torch.softmax(cf_logits, dim=0)\n",
    "\n",
    "# in this counterfactual state, the distribution over abstractions (arms) is different \n",
    "# because the abstraction function maps the controls differently\n",
    "P_A_cf = estimate_arm_probs(cf_control_probs, env_)\n",
    "plot_PA(P_A_cf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'a'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# compute the counterfactual value of the arm using the original noise term and P_A_cf\n",
    "\n",
    "cf_arm_probs = torch.tensor(list(P_A_cf.values()))\n",
    "# we can now sample from this using the original noise term and gumbel_max trick\n",
    "cf_arm_id, _ = gumbel_max(torch.log(cf_arm_probs), arm_g)\n",
    "cf_arm = env_.arm_keys[cf_arm_id]\n",
    "cf_arm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('c', 'a')"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# these are now different \n",
    "arm, cf_arm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: tensor(0., device='cuda:0'),\n",
       " 1: tensor(0., device='cuda:0'),\n",
       " 2: tensor(1., device='cuda:0')}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# now we want the coutnerfactual probability of controls, given the coutnerfactual arm\n",
    "cf_control_posterior = {\n",
    "    i: cf_control_probs[i] * env_.arm_given_index[i].probs[cf_arm_id] / P_A_cf[env_.arm_keys[cf_arm_id]]\n",
    "    for i in range(len(logits))\n",
    "}\n",
    "\n",
    "cf_control_posterior"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Control Equivalence Class\n",
    "Here controls are partitioned by the abstraction function. \n",
    "We should observe that conditional on observation of a member of some class being observed, the posterior of the other members increases, even when their prior was low. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from abstract_cf.bandit import BanditEnv\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "env = BanditEnv(4, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.1749, 0.1749, 0.1749, 0.4754], device='cuda:0')\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAcs0lEQVR4nO3df2zdVf348Ve72VYc7ZiTlo1+qGO6icAKK2uKkaFWalzUJRrnJGxWxF9gIFWkU7OKxHToxBlZAFEkQckmRjARHGJhGLQy6LYwBi6C4CbYbhNtR9FO2/f3D78UC+vo7dqd3e7xSG7i3j3ve8/J8aZP3n33tiDLsiwAABIpTD0BAODoJkYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACCpyaknMBIDAwPx7LPPxrHHHhsFBQWppwMAjECWZbFv376YMWNGFBYOf/0jL2Lk2WefjcrKytTTAABGYdeuXXHiiScO+/W8iJFjjz02Iv67mNLS0sSzAQBGoqenJyorKwe/jw8nL2LkxR/NlJaWihEAyDOvdouFG1gBgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAElNTj0BAIiIqGq+M/UUjlpPr1qU9PVdGQEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASGpUMbJ27dqoqqqKkpKSqK2tjU2bNo3ovHXr1kVBQUEsXrx4NC8LAExAOcfI+vXro6mpKVpaWmLz5s0xb968aGhoiN27dx/0vKeffjq+8IUvxNvf/vZRTxYAmHhyjpFrrrkmLrroomhsbIxTTjklrr/++jjmmGPipptuGvac/v7+OP/88+PKK6+MWbNmHdKEAYCJJacY2b9/f3R0dER9ff1LT1BYGPX19dHe3j7seV/72tfi+OOPjwsvvHBEr9PX1xc9PT1DHgDAxJRTjOzduzf6+/ujvLx8yPHy8vLo7Ow84DkPPPBA/OAHP4gbb7xxxK/T2toaZWVlg4/KyspcpgkA5JFx/W2affv2xQUXXBA33nhjTJ8+fcTnrVixIrq7uwcfu3btGsdZAgApTc5l8PTp02PSpEnR1dU15HhXV1dUVFS8YvyTTz4ZTz/9dLzvfe8bPDYwMPDfF548OXbs2BEnn3zyK84rLi6O4uLiXKYGAOSpnK6MFBUVxfz586OtrW3w2MDAQLS1tUVdXd0rxs+dOze2bdsWW7duHXy8//3vj3e84x2xdetWP34BAHK7MhIR0dTUFMuXL4+amppYsGBBrFmzJnp7e6OxsTEiIpYtWxYzZ86M1tbWKCkpiVNPPXXI+VOnTo2IeMVxAODolHOMLFmyJPbs2RMrV66Mzs7OqK6ujg0bNgze1Lpz584oLPTBrgDAyBRkWZalnsSr6enpibKysuju7o7S0tLU0wFgHFQ135l6Cketp1ctGpfnHen3b5cwAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAIKlRxcjatWujqqoqSkpKora2NjZt2jTs2J/97GdRU1MTU6dOjde97nVRXV0dt9xyy6gnDABMLDnHyPr166OpqSlaWlpi8+bNMW/evGhoaIjdu3cfcPy0adPiy1/+crS3t8cjjzwSjY2N0djYGHffffchTx4AyH8FWZZluZxQW1sbZ511Vlx77bURETEwMBCVlZXxuc99Lpqbm0f0HGeeeWYsWrQorrrqqhGN7+npibKysuju7o7S0tJcpgtAnqhqvjP1FI5aT69aNC7PO9Lv3zldGdm/f390dHREfX39S09QWBj19fXR3t7+qudnWRZtbW2xY8eOOOecc4Yd19fXFz09PUMeAMDElFOM7N27N/r7+6O8vHzI8fLy8ujs7Bz2vO7u7pgyZUoUFRXFokWL4rvf/W68+93vHnZ8a2trlJWVDT4qKytzmSYAkEcOy2/THHvssbF169Z46KGH4utf/3o0NTXFxo0bhx2/YsWK6O7uHnzs2rXrcEwTAEhgci6Dp0+fHpMmTYqurq4hx7u6uqKiomLY8woLC2P27NkREVFdXR2PP/54tLa2xrnnnnvA8cXFxVFcXJzL1ACAPJXTlZGioqKYP39+tLW1DR4bGBiItra2qKurG/HzDAwMRF9fXy4vDQBMUDldGYmIaGpqiuXLl0dNTU0sWLAg1qxZE729vdHY2BgREcuWLYuZM2dGa2trRPz3/o+ampo4+eSTo6+vL+6666645ZZb4rrrrhvblQAAeSnnGFmyZEns2bMnVq5cGZ2dnVFdXR0bNmwYvKl1586dUVj40gWX3t7e+OxnPxt/+ctf4rWvfW3MnTs3fvSjH8WSJUvGbhUAQN7K+XNGUvA5IwATn88ZSSevPmcEAGCsiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFKjipG1a9dGVVVVlJSURG1tbWzatGnYsTfeeGO8/e1vj+OOOy6OO+64qK+vP+h4AODoknOMrF+/PpqamqKlpSU2b94c8+bNi4aGhti9e/cBx2/cuDGWLl0a9913X7S3t0dlZWWcd9558cwzzxzy5AGA/FeQZVmWywm1tbVx1llnxbXXXhsREQMDA1FZWRmf+9znorm5+VXP7+/vj+OOOy6uvfbaWLZs2Yhes6enJ8rKyqK7uztKS0tzmS4AeaKq+c7UUzhqPb1q0bg870i/f+d0ZWT//v3R0dER9fX1Lz1BYWHU19dHe3v7iJ7jhRdeiH//+98xbdq0Ycf09fVFT0/PkAcAMDHlFCN79+6N/v7+KC8vH3K8vLw8Ojs7R/QcV1xxRcyYMWNI0Lxca2trlJWVDT4qKytzmSYAkEcO62/TrFq1KtatWxe33357lJSUDDtuxYoV0d3dPfjYtWvXYZwlAHA4Tc5l8PTp02PSpEnR1dU15HhXV1dUVFQc9NzVq1fHqlWr4te//nWcfvrpBx1bXFwcxcXFuUwNAMhTOV0ZKSoqivnz50dbW9vgsYGBgWhra4u6urphz/vGN74RV111VWzYsCFqampGP1sAYMLJ6cpIRERTU1MsX748ampqYsGCBbFmzZro7e2NxsbGiIhYtmxZzJw5M1pbWyMi4uqrr46VK1fGrbfeGlVVVYP3lkyZMiWmTJkyhksBAPJRzjGyZMmS2LNnT6xcuTI6Ozujuro6NmzYMHhT686dO6Ow8KULLtddd13s378/PvShDw15npaWlvjqV796aLMHAPJezp8zkoLPGQGY+HzOSDp59TkjAABjTYwAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAklfNf7Z1o/GGmdMbrDzO9yN6mY28nrvHeW45OrowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgqVHFyNq1a6OqqipKSkqitrY2Nm3aNOzY7du3xwc/+MGoqqqKgoKCWLNmzWjnCgBMQDnHyPr166OpqSlaWlpi8+bNMW/evGhoaIjdu3cfcPwLL7wQs2bNilWrVkVFRcUhTxgAmFhyjpFrrrkmLrroomhsbIxTTjklrr/++jjmmGPipptuOuD4s846K775zW/GRz7ykSguLj7kCQMAE0tOMbJ///7o6OiI+vr6l56gsDDq6+ujvb19zCbV19cXPT09Qx4AwMSUU4zs3bs3+vv7o7y8fMjx8vLy6OzsHLNJtba2RllZ2eCjsrJyzJ4bADiyHJG/TbNixYro7u4efOzatSv1lACAcTI5l8HTp0+PSZMmRVdX15DjXV1dY3pzanFxsftLAOAokdOVkaKiopg/f360tbUNHhsYGIi2traoq6sb88kBABNfTldGIiKamppi+fLlUVNTEwsWLIg1a9ZEb29vNDY2RkTEsmXLYubMmdHa2hoR/73p9bHHHhv8388880xs3bo1pkyZErNnzx7DpQAA+SjnGFmyZEns2bMnVq5cGZ2dnVFdXR0bNmwYvKl1586dUVj40gWXZ599Ns4444zBf69evTpWr14dCxcujI0bNx76CgCAvJZzjEREXHLJJXHJJZcc8GsvD4yqqqrIsmw0LwMAHAWOyN+mAQCOHmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJDUqGJk7dq1UVVVFSUlJVFbWxubNm066Pjbbrst5s6dGyUlJXHaaafFXXfdNarJAgATT84xsn79+mhqaoqWlpbYvHlzzJs3LxoaGmL37t0HHP+73/0uli5dGhdeeGFs2bIlFi9eHIsXL45HH330kCcPAOS/nGPkmmuuiYsuuigaGxvjlFNOieuvvz6OOeaYuOmmmw44/jvf+U685z3vicsvvzze8pa3xFVXXRVnnnlmXHvttYc8eQAg/03OZfD+/fujo6MjVqxYMXissLAw6uvro729/YDntLe3R1NT05BjDQ0Ncccddwz7On19fdHX1zf47+7u7oiI6OnpyWW6IzLQ98KYPycjMx77+b/sbTr2duIaz721r+mM176++LxZlh10XE4xsnfv3ujv74/y8vIhx8vLy+MPf/jDAc/p7Ow84PjOzs5hX6e1tTWuvPLKVxyvrKzMZboc4crWpJ4B48XeTlz2dmIa733dt29flJWVDfv1nGLkcFmxYsWQqykDAwPx3HPPxetf//ooKCgY9ryenp6orKyMXbt2RWlp6eGYalJH03qtdeI6mtZrrRPX0bTeXNaaZVns27cvZsyYcdBxOcXI9OnTY9KkSdHV1TXkeFdXV1RUVBzwnIqKipzGR0QUFxdHcXHxkGNTp04d8TxLS0sn/P8Z/tfRtF5rnbiOpvVa68R1NK13pGs92BWRF+V0A2tRUVHMnz8/2traBo8NDAxEW1tb1NXVHfCcurq6IeMjIu65555hxwMAR5ecf0zT1NQUy5cvj5qamliwYEGsWbMment7o7GxMSIili1bFjNnzozW1taIiLj00ktj4cKF8a1vfSsWLVoU69ati4cffji+973vje1KAIC8lHOMLFmyJPbs2RMrV66Mzs7OqK6ujg0bNgzepLpz584oLHzpgsvZZ58dt956a3zlK1+JL33pS/GmN70p7rjjjjj11FPHbhX/X3FxcbS0tLziRzwT1dG0XmuduI6m9VrrxHU0rXc81lqQvdrv2wAAjCN/mwYASEqMAABJiREAICkxAgAklfcx8txzz8X5558fpaWlMXXq1Ljwwgvj+eefP+g55557bhQUFAx5fPrTnz5MM87N2rVro6qqKkpKSqK2tjY2bdp00PG33XZbzJ07N0pKSuK0006Lu+666zDN9NDlstabb775FXtYUlJyGGc7er/5zW/ife97X8yYMSMKCgoO+neaXrRx48Y488wzo7i4OGbPnh0333zzuM9zLOS61o0bN75iXwsKCg765yOOFK2trXHWWWfFscceG8cff3wsXrw4duzY8arn5eN7djRrzef37HXXXRenn3764Id81dXVxS9/+cuDnpOP+xqR+1rHal/zPkbOP//82L59e9xzzz3xi1/8In7zm9/EJz/5yVc976KLLoq//vWvg49vfOMbh2G2uVm/fn00NTVFS0tLbN68OebNmxcNDQ2xe/fuA47/3e9+F0uXLo0LL7wwtmzZEosXL47FixfHo48+ephnnrtc1xrx30//+989/POf/3wYZzx6vb29MW/evFi7du2Ixj/11FOxaNGieMc73hFbt26Nyy67LD7xiU/E3XffPc4zPXS5rvVFO3bsGLK3xx9//DjNcOzcf//9cfHFF8fvf//7uOeee+Lf//53nHfeedHb2zvsOfn6nh3NWiPy9z174oknxqpVq6KjoyMefvjheOc73xkf+MAHYvv27Qccn6/7GpH7WiPGaF+zPPbYY49lEZE99NBDg8d++ctfZgUFBdkzzzwz7HkLFy7MLr300sMww0OzYMGC7OKLLx78d39/fzZjxoystbX1gOM//OEPZ4sWLRpyrLa2NvvUpz41rvMcC7mu9Yc//GFWVlZ2mGY3fiIiu/322w865otf/GL21re+dcixJUuWZA0NDeM4s7E3krXed999WURkf//73w/LnMbT7t27s4jI7r///mHH5PN79n+NZK0T5T37ouOOOy77/ve/f8CvTZR9fdHB1jpW+5rXV0ba29tj6tSpUVNTM3isvr4+CgsL48EHHzzouT/+8Y9j+vTpceqpp8aKFSvihReOrD9dvX///ujo6Ij6+vrBY4WFhVFfXx/t7e0HPKe9vX3I+IiIhoaGYccfKUaz1oiI559/Pk466aSorKx81XLPZ/m6r4eiuro6TjjhhHj3u98dv/3tb1NPZ1S6u7sjImLatGnDjpkoezuStUZMjPdsf39/rFu3Lnp7e4f9syYTZV9HstaIsdnXI/Kv9o5UZ2fnKy7fTp48OaZNm3bQnzF/9KMfjZNOOilmzJgRjzzySFxxxRWxY8eO+NnPfjbeUx6xvXv3Rn9//+An276ovLw8/vCHPxzwnM7OzgOOP9J/3j6atc6ZMyduuummOP3006O7uztWr14dZ599dmzfvj1OPPHEwzHtw2a4fe3p6Yl//vOf8drXvjbRzMbeCSecENdff33U1NREX19ffP/7349zzz03HnzwwTjzzDNTT2/EBgYG4rLLLou3ve1tB/206Xx9z/6vka4139+z27Zti7q6uvjXv/4VU6ZMidtvvz1OOeWUA47N933NZa1jta9HZIw0NzfH1VdffdAxjz/++Kif/3/vKTnttNPihBNOiHe9613x5JNPxsknnzzq5+XwqaurG1LqZ599drzlLW+JG264Ia666qqEM+NQzJkzJ+bMmTP477PPPjuefPLJ+Pa3vx233HJLwpnl5uKLL45HH300HnjggdRTGXcjXWu+v2fnzJkTW7duje7u7vjpT38ay5cvj/vvv3/Yb9L5LJe1jtW+HpEx8vnPfz4+9rGPHXTMrFmzoqKi4hU3OP7nP/+J5557LioqKkb8erW1tRER8cQTTxwxMTJ9+vSYNGlSdHV1DTne1dU17NoqKipyGn+kGM1aX+41r3lNnHHGGfHEE0+MxxSTGm5fS0tLJ9RVkeEsWLAgr76pX3LJJYM307/afxnm63v2Rbms9eXy7T1bVFQUs2fPjoiI+fPnx0MPPRTf+c534oYbbnjF2Hzf11zW+nKj3dcj8p6RN7zhDTF37tyDPoqKiqKuri7+8Y9/REdHx+C59957bwwMDAwGxkhs3bo1Iv57ifhIUVRUFPPnz4+2trbBYwMDA9HW1jbsz+7q6uqGjI+IuOeeew76s74jwWjW+nL9/f2xbdu2I2oPx0q+7utY2bp1a17sa5Zlcckll8Ttt98e9957b7zxjW981XPydW9Hs9aXy/f37MDAQPT19R3wa/m6r8M52FpfbtT7esi3wCb2nve8JzvjjDOyBx98MHvggQeyN73pTdnSpUsHv/6Xv/wlmzNnTvbggw9mWZZlTzzxRPa1r30te/jhh7Onnnoq+/nPf57NmjUrO+ecc1ItYVjr1q3LiouLs5tvvjl77LHHsk9+8pPZ1KlTs87OzizLsuyCCy7ImpubB8f/9re/zSZPnpytXr06e/zxx7OWlpbsNa95TbZt27ZUSxixXNd65ZVXZnfffXf25JNPZh0dHdlHPvKRrKSkJNu+fXuqJYzYvn37si1btmRbtmzJIiK75pprsi1btmR//vOfsyzLsubm5uyCCy4YHP+nP/0pO+aYY7LLL788e/zxx7O1a9dmkyZNyjZs2JBqCSOW61q//e1vZ3fccUf2xz/+Mdu2bVt26aWXZoWFhdmvf/3rVEsYsc985jNZWVlZtnHjxuyvf/3r4OOFF14YHDNR3rOjWWs+v2ebm5uz+++/P3vqqaeyRx55JGtubs4KCgqyX/3qV1mWTZx9zbLc1zpW+5r3MfK3v/0tW7p0aTZlypSstLQ0a2xszPbt2zf49aeeeiqLiOy+++7LsizLdu7cmZ1zzjnZtGnTsuLi4mz27NnZ5ZdfnnV3dydawcF997vfzf7v//4vKyoqyhYsWJD9/ve/H/zawoULs+XLlw8Z/5Of/CR785vfnBUVFWVvfetbszvvvPMwz3j0clnrZZddNji2vLw8e+9735tt3rw5waxz9+Kvr7788eL6li9fni1cuPAV51RXV2dFRUXZrFmzsh/+8IeHfd6jketar7766uzkk0/OSkpKsmnTpmXnnntudu+996aZfI4OtM6IGLJXE+U9O5q15vN79uMf/3h20kknZUVFRdkb3vCG7F3vetfgN+csmzj7mmW5r3Ws9rUgy7Ist2spAABj54i8ZwQAOHqIEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKT+HyIx5VGeGgiSAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor(0.1749), tensor(0.1749), tensor(0.1749), tensor(0.4754)]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWkklEQVR4nO3dfWxddf3A8c/a2Rbc2o0VOjaqDQ/hIdNVVlanIjNUZzIxGDWTGLs0EzQGxFSRVaAVCenkYSmyBcKUxEDIFhNFE0l9aJyJsTpZXURAJCSkk6XdJqQdxbTa298fxPKrbLC7dXxo93olN2Fn33Pu5+aw9J3Tc++dMzExMREAAElKsgcAAE5uYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASDU3e4CjUSgUYt++fTF//vyYM2dO9jgAwFGYmJiIQ4cOxZIlS6Kk5MjXP2ZEjOzbty9qa2uzxwAAjsHevXvjrLPOOuLfz4gYmT9/fkS8+mIqKyuTpwEAjsbw8HDU1tZO/hw/khkRI//91UxlZaUYAYAZ5s1usXADKwCQSowAAKnECACQSowAAKnECACQSowAAKnECACQSowAAKnECACQSowAAKnECACQSowAAKnECACQSowAAKnmZg8AABERdRt/nj3CSev5TWtTn9+VEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAg1THFyNatW6Ouri4qKiqisbExdu3adVT7bd++PebMmRNXXnnlsTwtADALFR0jO3bsiNbW1ujo6Ii+vr5Yvnx5rFmzJvbv3/+G+z3//PPxjW98Iy699NJjHhYAmH2KjpHNmzfH1VdfHS0tLXHRRRfF/fffH6eeemo8+OCDR9xnfHw8Pv/5z8ett94aZ5999nENDADMLkXFyNjYWOzevTuamppeO0BJSTQ1NUVvb+8R9/vOd74TZ5xxRmzYsOGonmd0dDSGh4enPACA2amoGDl48GCMj49HTU3NlO01NTUxMDBw2H1+97vfxQ9+8IPYtm3bUT9PZ2dnVFVVTT5qa2uLGRMAmEFO6LtpDh06FF/4whdi27ZtUV1dfdT7tbW1xdDQ0ORj7969J3BKACDT3GIWV1dXR2lpaQwODk7ZPjg4GIsXL37d+ueeey6ef/75uOKKKya3FQqFV5947tx45pln4pxzznndfuXl5VFeXl7MaADADFXUlZGysrJYsWJF9PT0TG4rFArR09MTq1atet36Cy64IJ544onYs2fP5OOTn/xkfOQjH4k9e/b49QsAUNyVkYiI1tbWWL9+fTQ0NMTKlSujq6srRkZGoqWlJSIimpubY+nSpdHZ2RkVFRWxbNmyKfsvWLAgIuJ12wGAk1PRMbJu3bo4cOBAtLe3x8DAQNTX10d3d/fkTa39/f1RUuKDXQGAozNnYmJiInuINzM8PBxVVVUxNDQUlZWV2eMAcALUbfx59ggnrec3rT0hxz3an98uYQAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqY4pRrZu3Rp1dXVRUVERjY2NsWvXriOu/fGPfxwNDQ2xYMGCeOc73xn19fXx0EMPHfPAAMDsUnSM7NixI1pbW6OjoyP6+vpi+fLlsWbNmti/f/9h15922mlx0003RW9vb/zlL3+JlpaWaGlpiV/84hfHPTwAMPPNmZiYmChmh8bGxrjkkktiy5YtERFRKBSitrY2rrvuuti4ceNRHePiiy+OtWvXxm233XZU64eHh6OqqiqGhoaisrKymHEBmCHqNv48e4ST1vOb1p6Q4x7tz++iroyMjY3F7t27o6mp6bUDlJREU1NT9Pb2vun+ExMT0dPTE88880x8+MMfPuK60dHRGB4envIAAGanomLk4MGDMT4+HjU1NVO219TUxMDAwBH3Gxoainnz5kVZWVmsXbs27r333vjoRz96xPWdnZ1RVVU1+aitrS1mTABgBnlL3k0zf/782LNnT/zpT3+K22+/PVpbW2Pnzp1HXN/W1hZDQ0OTj717974VYwIACeYWs7i6ujpKS0tjcHBwyvbBwcFYvHjxEfcrKSmJc889NyIi6uvr4+mnn47Ozs5YvXr1YdeXl5dHeXl5MaMBADNUUVdGysrKYsWKFdHT0zO5rVAoRE9PT6xateqoj1MoFGJ0dLSYpwYAZqmiroxERLS2tsb69eujoaEhVq5cGV1dXTEyMhItLS0REdHc3BxLly6Nzs7OiHj1/o+GhoY455xzYnR0NB577LF46KGH4r777pveVwIAzEhFx8i6deviwIED0d7eHgMDA1FfXx/d3d2TN7X29/dHSclrF1xGRkbiK1/5SvzjH/+IU045JS644IJ4+OGHY926ddP3KgCAGavozxnJ4HNGAGY/nzOSZ0Z9zggAwHQTIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQ6phjZunVr1NXVRUVFRTQ2NsauXbuOuHbbtm1x6aWXxsKFC2PhwoXR1NT0husBgJNL0TGyY8eOaG1tjY6Ojujr64vly5fHmjVrYv/+/Yddv3PnzrjqqqviN7/5TfT29kZtbW187GMfixdeeOG4hwcAZr45ExMTE8Xs0NjYGJdcckls2bIlIiIKhULU1tbGddddFxs3bnzT/cfHx2PhwoWxZcuWaG5uPqrnHB4ejqqqqhgaGorKyspixgVghqjb+PPsEU5az29ae0KOe7Q/v4u6MjI2Nha7d++Opqam1w5QUhJNTU3R29t7VMd45ZVX4t///necdtppR1wzOjoaw8PDUx4AwOxUVIwcPHgwxsfHo6amZsr2mpqaGBgYOKpj3HjjjbFkyZIpQfO/Ojs7o6qqavJRW1tbzJgAwAzylr6bZtOmTbF9+/b4yU9+EhUVFUdc19bWFkNDQ5OPvXv3voVTAgBvpbnFLK6uro7S0tIYHBycsn1wcDAWL178hvveddddsWnTpvj1r38d733ve99wbXl5eZSXlxczGgAwQxV1ZaSsrCxWrFgRPT09k9sKhUL09PTEqlWrjrjfHXfcEbfddlt0d3dHQ0PDsU8LAMw6RV0ZiYhobW2N9evXR0NDQ6xcuTK6urpiZGQkWlpaIiKiubk5li5dGp2dnRER8d3vfjfa29vjkUceibq6usl7S+bNmxfz5s2bxpcCAMxERcfIunXr4sCBA9He3h4DAwNRX18f3d3dkze19vf3R0nJaxdc7rvvvhgbG4vPfOYzU47T0dER3/72t49vegBgxiv6c0Yy+JwRgNnP54zkmVGfMwIAMN3ECACQSowAAKnECACQSowAAKnECACQSowAAKnECACQSowAAKnECACQSowAAKnECACQquhv7Z1tfDFTnhP1xUz/5dzmcW5nrxN9bjk5uTICAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAqmOKka1bt0ZdXV1UVFREY2Nj7Nq164hrn3zyyfj0pz8ddXV1MWfOnOjq6jrWWQGAWajoGNmxY0e0trZGR0dH9PX1xfLly2PNmjWxf//+w65/5ZVX4uyzz45NmzbF4sWLj3tgAGB2KTpGNm/eHFdffXW0tLTERRddFPfff3+ceuqp8eCDDx52/SWXXBJ33nlnfO5zn4vy8vLjHhgAmF2KipGxsbHYvXt3NDU1vXaAkpJoamqK3t7eaRtqdHQ0hoeHpzwAgNmpqBg5ePBgjI+PR01NzZTtNTU1MTAwMG1DdXZ2RlVV1eSjtrZ22o4NALy9vC3fTdPW1hZDQ0OTj71792aPBACcIHOLWVxdXR2lpaUxODg4Zfvg4OC03pxaXl7u/hIAOEkUdWWkrKwsVqxYET09PZPbCoVC9PT0xKpVq6Z9OABg9ivqykhERGtra6xfvz4aGhpi5cqV0dXVFSMjI9HS0hIREc3NzbF06dLo7OyMiFdven3qqacm//uFF16IPXv2xLx58+Lcc8+dxpcCAMxERcfIunXr4sCBA9He3h4DAwNRX18f3d3dkze19vf3R0nJaxdc9u3bF+973/sm/3zXXXfFXXfdFZdddlns3Lnz+F8BADCjFR0jERHXXnttXHvttYf9u/8NjLq6upiYmDiWpwEATgJvy3fTAAAnDzECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAqmOKka1bt0ZdXV1UVFREY2Nj7Nq16w3X/+hHP4oLLrggKioq4j3veU889thjxzQsADD7FB0jO3bsiNbW1ujo6Ii+vr5Yvnx5rFmzJvbv33/Y9b///e/jqquuig0bNsSf//znuPLKK+PKK6+Mv/71r8c9PAAw8xUdI5s3b46rr746Wlpa4qKLLor7778/Tj311HjwwQcPu/6ee+6Jj3/843HDDTfEhRdeGLfddltcfPHFsWXLluMeHgCY+eYWs3hsbCx2794dbW1tk9tKSkqiqakpent7D7tPb29vtLa2Ttm2Zs2aePTRR4/4PKOjozE6Ojr556GhoYiIGB4eLmbco1IYfWXaj8nRORHn8/9zbvM4t7PXiTy3zmueE3Ve/3vciYmJN1xXVIwcPHgwxsfHo6amZsr2mpqa+Nvf/nbYfQYGBg67fmBg4IjP09nZGbfeeuvrttfW1hYzLm9zVV3ZE3CiOLezl3M7O53o83ro0KGoqqo64t8XFSNvlba2tilXUwqFQrz44ouxaNGimDNnTuJkby/Dw8NRW1sbe/fujcrKyuxxmCbO6+zl3M5ezu3hTUxMxKFDh2LJkiVvuK6oGKmuro7S0tIYHBycsn1wcDAWL1582H0WL15c1PqIiPLy8igvL5+ybcGCBcWMelKprKz0P/8s5LzOXs7t7OXcvt4bXRH5r6JuYC0rK4sVK1ZET0/P5LZCoRA9PT2xatWqw+6zatWqKesjIn71q18dcT0AcHIp+tc0ra2tsX79+mhoaIiVK1dGV1dXjIyMREtLS0RENDc3x9KlS6OzszMiIq6//vq47LLL4u677461a9fG9u3b4/HHH48HHnhgel8JADAjFR0j69atiwMHDkR7e3sMDAxEfX19dHd3T96k2t/fHyUlr11w+cAHPhCPPPJI3HzzzfGtb30rzjvvvHj00Udj2bJl0/cqTlLl5eXR0dHxul9pMbM5r7OXczt7ObfHZ87Em73fBgDgBPLdNABAKjECAKQSIwBAKjECbxOrV6+Or33ta9ljAMfJv+XiiREAIJUYAQBSiZEZqLu7Oz70oQ/FggULYtGiRfGJT3winnvuueyxmAb/+c9/4tprr42qqqqorq6OW2655U2/7ZKZoVAoxB133BHnnntulJeXx7ve9a64/fbbs8fiOI2MjERzc3PMmzcvzjzzzLj77ruzR5qRxMgMNDIyEq2trfH4449HT09PlJSUxKc+9akoFArZo3GcfvjDH8bcuXNj165dcc8998TmzZvj+9//fvZYTIO2trbYtGlT3HLLLfHUU0/FI4888rpvNGfmueGGG+K3v/1t/PSnP41f/vKXsXPnzujr68sea8bxoWezwMGDB+P000+PJ554wifbzmCrV6+O/fv3x5NPPjn57dQbN26Mn/3sZ/HUU08lT8fxOHToUJx++umxZcuW+OIXv5g9DtPk5ZdfjkWLFsXDDz8cn/3sZyMi4sUXX4yzzjorrrnmmujq6sodcAZxZWQGevbZZ+Oqq66Ks88+OyorK6Ouri4iXv0ofma297///ZMhEvHqF00+++yzMT4+njgVx+vpp5+O0dHRuPzyy7NHYRo999xzMTY2Fo2NjZPbTjvttDj//PMTp5qZiv5uGvJdccUV8e53vzu2bdsWS5YsiUKhEMuWLYuxsbHs0YDDOOWUU7JHgLc1V0ZmmH/+85/xzDPPxM033xyXX355XHjhhfHSSy9lj8U0+eMf/zjlz3/4wx/ivPPOi9LS0qSJmA7nnXdenHLKKdHT05M9CtPonHPOiXe84x1T/t2+9NJL8fe//z1xqpnJlZEZZuHChbFo0aJ44IEH4swzz4z+/v7YuHFj9lhMk/7+/mhtbY0vfelL0dfXF/fee6+782eBioqKuPHGG+Ob3/xmlJWVxQc/+ME4cOBAPPnkk7Fhw4bs8ThG8+bNiw0bNsQNN9wQixYtijPOOCNuuummKd9cz9ERIzNMSUlJbN++Pb761a/GsmXL4vzzz4/vfe97sXr16uzRmAbNzc3xr3/9K1auXBmlpaVx/fXXxzXXXJM9FtPglltuiblz50Z7e3vs27cvzjzzzPjyl7+cPRbH6c4774yXX345rrjiipg/f358/etfj6GhoeyxZhzvpgEAUrmWBACkEiMAQCoxAgCkEiMAQCoxAgCkEiMAQCoxAgCkEiMAQCoxAgCkEiMAQCoxAgCkEiMAQKr/A59/PjUAoeRPAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3 tensor([-0.0030], device='cuda:0')\n",
      "arm d\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# observing an interaction with the environment \n",
    "logits = torch.tensor([1, 1, 1, 2], dtype=torch.float32).to(device)\n",
    "control_probs = torch.softmax(logits, dim=0)\n",
    "\n",
    "print(control_probs)\n",
    "plt.bar(range(len(control_probs)), control_probs.cpu().numpy())\n",
    "plt.show()\n",
    "# control = torch.multinomial(control_probs, 1).item()\n",
    "control = 3\n",
    "# estimating P(A)    \n",
    "P_A = estimate_arm_probs(control_probs, env)\n",
    "\n",
    "plot_PA(P_A)\n",
    "outcome = env.pull(control)\n",
    "print(control, outcome)\n",
    "arm = env.sample_arm(control) \n",
    "print('arm', arm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor(0.1749), tensor(0.1749), tensor(0.1749), tensor(0.4754)]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWkklEQVR4nO3dfWxddf3A8c/a2Rbc2o0VOjaqDQ/hIdNVVlanIjNUZzIxGDWTGLs0EzQGxFSRVaAVCenkYSmyBcKUxEDIFhNFE0l9aJyJsTpZXURAJCSkk6XdJqQdxbTa298fxPKrbLC7dXxo93olN2Fn33Pu5+aw9J3Tc++dMzExMREAAElKsgcAAE5uYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASDU3e4CjUSgUYt++fTF//vyYM2dO9jgAwFGYmJiIQ4cOxZIlS6Kk5MjXP2ZEjOzbty9qa2uzxwAAjsHevXvjrLPOOuLfz4gYmT9/fkS8+mIqKyuTpwEAjsbw8HDU1tZO/hw/khkRI//91UxlZaUYAYAZ5s1usXADKwCQSowAAKnECACQSowAAKnECACQSowAAKnECACQSowAAKnECACQSowAAKnECACQSowAAKnECACQSowAAKnmZg8AABERdRt/nj3CSev5TWtTn9+VEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAg1THFyNatW6Ouri4qKiqisbExdu3adVT7bd++PebMmRNXXnnlsTwtADALFR0jO3bsiNbW1ujo6Ii+vr5Yvnx5rFmzJvbv3/+G+z3//PPxjW98Iy699NJjHhYAmH2KjpHNmzfH1VdfHS0tLXHRRRfF/fffH6eeemo8+OCDR9xnfHw8Pv/5z8ett94aZ5999nENDADMLkXFyNjYWOzevTuamppeO0BJSTQ1NUVvb+8R9/vOd74TZ5xxRmzYsOGonmd0dDSGh4enPACA2amoGDl48GCMj49HTU3NlO01NTUxMDBw2H1+97vfxQ9+8IPYtm3bUT9PZ2dnVFVVTT5qa2uLGRMAmEFO6LtpDh06FF/4whdi27ZtUV1dfdT7tbW1xdDQ0ORj7969J3BKACDT3GIWV1dXR2lpaQwODk7ZPjg4GIsXL37d+ueeey6ef/75uOKKKya3FQqFV5947tx45pln4pxzznndfuXl5VFeXl7MaADADFXUlZGysrJYsWJF9PT0TG4rFArR09MTq1atet36Cy64IJ544onYs2fP5OOTn/xkfOQjH4k9e/b49QsAUNyVkYiI1tbWWL9+fTQ0NMTKlSujq6srRkZGoqWlJSIimpubY+nSpdHZ2RkVFRWxbNmyKfsvWLAgIuJ12wGAk1PRMbJu3bo4cOBAtLe3x8DAQNTX10d3d/fkTa39/f1RUuKDXQGAozNnYmJiInuINzM8PBxVVVUxNDQUlZWV2eMAcALUbfx59ggnrec3rT0hxz3an98uYQAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqcQIAJBKjAAAqY4pRrZu3Rp1dXVRUVERjY2NsWvXriOu/fGPfxwNDQ2xYMGCeOc73xn19fXx0EMPHfPAAMDsUnSM7NixI1pbW6OjoyP6+vpi+fLlsWbNmti/f/9h15922mlx0003RW9vb/zlL3+JlpaWaGlpiV/84hfHPTwAMPPNmZiYmChmh8bGxrjkkktiy5YtERFRKBSitrY2rrvuuti4ceNRHePiiy+OtWvXxm233XZU64eHh6OqqiqGhoaisrKymHEBmCHqNv48e4ST1vOb1p6Q4x7tz++iroyMjY3F7t27o6mp6bUDlJREU1NT9Pb2vun+ExMT0dPTE88880x8+MMfPuK60dHRGB4envIAAGanomLk4MGDMT4+HjU1NVO219TUxMDAwBH3Gxoainnz5kVZWVmsXbs27r333vjoRz96xPWdnZ1RVVU1+aitrS1mTABgBnlL3k0zf/782LNnT/zpT3+K22+/PVpbW2Pnzp1HXN/W1hZDQ0OTj717974VYwIACeYWs7i6ujpKS0tjcHBwyvbBwcFYvHjxEfcrKSmJc889NyIi6uvr4+mnn47Ozs5YvXr1YdeXl5dHeXl5MaMBADNUUVdGysrKYsWKFdHT0zO5rVAoRE9PT6xateqoj1MoFGJ0dLSYpwYAZqmiroxERLS2tsb69eujoaEhVq5cGV1dXTEyMhItLS0REdHc3BxLly6Nzs7OiHj1/o+GhoY455xzYnR0NB577LF46KGH4r777pveVwIAzEhFx8i6deviwIED0d7eHgMDA1FfXx/d3d2TN7X29/dHSclrF1xGRkbiK1/5SvzjH/+IU045JS644IJ4+OGHY926ddP3KgCAGavozxnJ4HNGAGY/nzOSZ0Z9zggAwHQTIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQ6phjZunVr1NXVRUVFRTQ2NsauXbuOuHbbtm1x6aWXxsKFC2PhwoXR1NT0husBgJNL0TGyY8eOaG1tjY6Ojujr64vly5fHmjVrYv/+/Yddv3PnzrjqqqviN7/5TfT29kZtbW187GMfixdeeOG4hwcAZr45ExMTE8Xs0NjYGJdcckls2bIlIiIKhULU1tbGddddFxs3bnzT/cfHx2PhwoWxZcuWaG5uPqrnHB4ejqqqqhgaGorKyspixgVghqjb+PPsEU5az29ae0KOe7Q/v4u6MjI2Nha7d++Opqam1w5QUhJNTU3R29t7VMd45ZVX4t///necdtppR1wzOjoaw8PDUx4AwOxUVIwcPHgwxsfHo6amZsr2mpqaGBgYOKpj3HjjjbFkyZIpQfO/Ojs7o6qqavJRW1tbzJgAwAzylr6bZtOmTbF9+/b4yU9+EhUVFUdc19bWFkNDQ5OPvXv3voVTAgBvpbnFLK6uro7S0tIYHBycsn1wcDAWL178hvveddddsWnTpvj1r38d733ve99wbXl5eZSXlxczGgAwQxV1ZaSsrCxWrFgRPT09k9sKhUL09PTEqlWrjrjfHXfcEbfddlt0d3dHQ0PDsU8LAMw6RV0ZiYhobW2N9evXR0NDQ6xcuTK6urpiZGQkWlpaIiKiubk5li5dGp2dnRER8d3vfjfa29vjkUceibq6usl7S+bNmxfz5s2bxpcCAMxERcfIunXr4sCBA9He3h4DAwNRX18f3d3dkze19vf3R0nJaxdc7rvvvhgbG4vPfOYzU47T0dER3/72t49vegBgxiv6c0Yy+JwRgNnP54zkmVGfMwIAMN3ECACQSowAAKnECACQSowAAKnECACQSowAAKnECACQSowAAKnECACQSowAAKnECACQquhv7Z1tfDFTnhP1xUz/5dzmcW5nrxN9bjk5uTICAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAqmOKka1bt0ZdXV1UVFREY2Nj7Nq164hrn3zyyfj0pz8ddXV1MWfOnOjq6jrWWQGAWajoGNmxY0e0trZGR0dH9PX1xfLly2PNmjWxf//+w65/5ZVX4uyzz45NmzbF4sWLj3tgAGB2KTpGNm/eHFdffXW0tLTERRddFPfff3+ceuqp8eCDDx52/SWXXBJ33nlnfO5zn4vy8vLjHhgAmF2KipGxsbHYvXt3NDU1vXaAkpJoamqK3t7eaRtqdHQ0hoeHpzwAgNmpqBg5ePBgjI+PR01NzZTtNTU1MTAwMG1DdXZ2RlVV1eSjtrZ22o4NALy9vC3fTdPW1hZDQ0OTj71792aPBACcIHOLWVxdXR2lpaUxODg4Zfvg4OC03pxaXl7u/hIAOEkUdWWkrKwsVqxYET09PZPbCoVC9PT0xKpVq6Z9OABg9ivqykhERGtra6xfvz4aGhpi5cqV0dXVFSMjI9HS0hIREc3NzbF06dLo7OyMiFdven3qqacm//uFF16IPXv2xLx58+Lcc8+dxpcCAMxERcfIunXr4sCBA9He3h4DAwNRX18f3d3dkze19vf3R0nJaxdc9u3bF+973/sm/3zXXXfFXXfdFZdddlns3Lnz+F8BADCjFR0jERHXXnttXHvttYf9u/8NjLq6upiYmDiWpwEATgJvy3fTAAAnDzECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAqmOKka1bt0ZdXV1UVFREY2Nj7Nq16w3X/+hHP4oLLrggKioq4j3veU889thjxzQsADD7FB0jO3bsiNbW1ujo6Ii+vr5Yvnx5rFmzJvbv33/Y9b///e/jqquuig0bNsSf//znuPLKK+PKK6+Mv/71r8c9PAAw8xUdI5s3b46rr746Wlpa4qKLLor7778/Tj311HjwwQcPu/6ee+6Jj3/843HDDTfEhRdeGLfddltcfPHFsWXLluMeHgCY+eYWs3hsbCx2794dbW1tk9tKSkqiqakpent7D7tPb29vtLa2Ttm2Zs2aePTRR4/4PKOjozE6Ojr556GhoYiIGB4eLmbco1IYfWXaj8nRORHn8/9zbvM4t7PXiTy3zmueE3Ve/3vciYmJN1xXVIwcPHgwxsfHo6amZsr2mpqa+Nvf/nbYfQYGBg67fmBg4IjP09nZGbfeeuvrttfW1hYzLm9zVV3ZE3CiOLezl3M7O53o83ro0KGoqqo64t8XFSNvlba2tilXUwqFQrz44ouxaNGimDNnTuJkby/Dw8NRW1sbe/fujcrKyuxxmCbO6+zl3M5ezu3hTUxMxKFDh2LJkiVvuK6oGKmuro7S0tIYHBycsn1wcDAWL1582H0WL15c1PqIiPLy8igvL5+ybcGCBcWMelKprKz0P/8s5LzOXs7t7OXcvt4bXRH5r6JuYC0rK4sVK1ZET0/P5LZCoRA9PT2xatWqw+6zatWqKesjIn71q18dcT0AcHIp+tc0ra2tsX79+mhoaIiVK1dGV1dXjIyMREtLS0RENDc3x9KlS6OzszMiIq6//vq47LLL4u677461a9fG9u3b4/HHH48HHnhgel8JADAjFR0j69atiwMHDkR7e3sMDAxEfX19dHd3T96k2t/fHyUlr11w+cAHPhCPPPJI3HzzzfGtb30rzjvvvHj00Udj2bJl0/cqTlLl5eXR0dHxul9pMbM5r7OXczt7ObfHZ87Em73fBgDgBPLdNABAKjECAKQSIwBAKjECbxOrV6+Or33ta9ljAMfJv+XiiREAIJUYAQBSiZEZqLu7Oz70oQ/FggULYtGiRfGJT3winnvuueyxmAb/+c9/4tprr42qqqqorq6OW2655U2/7ZKZoVAoxB133BHnnntulJeXx7ve9a64/fbbs8fiOI2MjERzc3PMmzcvzjzzzLj77ruzR5qRxMgMNDIyEq2trfH4449HT09PlJSUxKc+9akoFArZo3GcfvjDH8bcuXNj165dcc8998TmzZvj+9//fvZYTIO2trbYtGlT3HLLLfHUU0/FI4888rpvNGfmueGGG+K3v/1t/PSnP41f/vKXsXPnzujr68sea8bxoWezwMGDB+P000+PJ554wifbzmCrV6+O/fv3x5NPPjn57dQbN26Mn/3sZ/HUU08lT8fxOHToUJx++umxZcuW+OIXv5g9DtPk5ZdfjkWLFsXDDz8cn/3sZyMi4sUXX4yzzjorrrnmmujq6sodcAZxZWQGevbZZ+Oqq66Ks88+OyorK6Ouri4iXv0ofma297///ZMhEvHqF00+++yzMT4+njgVx+vpp5+O0dHRuPzyy7NHYRo999xzMTY2Fo2NjZPbTjvttDj//PMTp5qZiv5uGvJdccUV8e53vzu2bdsWS5YsiUKhEMuWLYuxsbHs0YDDOOWUU7JHgLc1V0ZmmH/+85/xzDPPxM033xyXX355XHjhhfHSSy9lj8U0+eMf/zjlz3/4wx/ivPPOi9LS0qSJmA7nnXdenHLKKdHT05M9CtPonHPOiXe84x1T/t2+9NJL8fe//z1xqpnJlZEZZuHChbFo0aJ44IEH4swzz4z+/v7YuHFj9lhMk/7+/mhtbY0vfelL0dfXF/fee6+782eBioqKuPHGG+Ob3/xmlJWVxQc/+ME4cOBAPPnkk7Fhw4bs8ThG8+bNiw0bNsQNN9wQixYtijPOOCNuuummKd9cz9ERIzNMSUlJbN++Pb761a/GsmXL4vzzz4/vfe97sXr16uzRmAbNzc3xr3/9K1auXBmlpaVx/fXXxzXXXJM9FtPglltuiblz50Z7e3vs27cvzjzzzPjyl7+cPRbH6c4774yXX345rrjiipg/f358/etfj6GhoeyxZhzvpgEAUrmWBACkEiMAQCoxAgCkEiMAQCoxAgCkEiMAQCoxAgCkEiMAQCoxAgCkEiMAQCoxAgCkEiMAQKr/A59/PjUAoeRPAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# estimating P(A)    \n",
    "P_A = estimate_arm_probs(control_probs, env)\n",
    "\n",
    "plot_PA(P_A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# interventions here (if only reordering the amrs) will only shift the probabilities of the arms\n",
    "from copy import deepcopy\n",
    "\n",
    "env_ = deepcopy(env)\n",
    "env_.intervene(\n",
    "    {\n",
    "        0: torch.distributions.Categorical(torch.tensor([1, 0, 0, 0], dtype=torch.float32)),\n",
    "        1: torch.distributions.Categorical(torch.tensor([1, 0, 0, 0], dtype=torch.float32)),\n",
    "        2: torch.distributions.Categorical(torch.tensor([0, 1, 0, 0], dtype=torch.float32)),\n",
    "        3: torch.distributions.Categorical(torch.tensor([0, 1, 0, 0], dtype=torch.float32)),\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n",
      "tensor(3) tensor([0.2666, 1.1745, 1.8901, 1.1171])\n",
      "[tensor(0.3498), tensor(0.6502), tensor(0.), tensor(0.)]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAbVElEQVR4nO3df2xd913H4Xfs1HZDaqdNGifNzKz+oG0UGo+EZN4YDZpHkEqhE6BQIRxZXQZbIwKG0piuCdvYHNY2ytiihWZETCtVI6axTWqVARZBmmoIS6gobWm3ojTZip2EdnbqIhts88c0F9O4zU2cfWvneaQj1cffc+7n6rb1S8fn+s4ZHx8fDwBAIVWlBwAALm5iBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAippbeoCzMTY2lhdffDGXXXZZ5syZU3ocAOAsjI+P5/Tp07nqqqtSVTX19Y8ZESMvvvhimpqaSo8BAJyD48eP521ve9uU358RMXLZZZcl+f6Tqa+vLzwNAHA2BgcH09TUNPFzfCozIkZ+8KuZ+vp6MQIAM8yb3WLhBlYAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFzSw8AF0rz1kdLj3DROrrjltIjADOIKyMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICizilGdu/enebm5tTV1WXt2rU5dOjQG67/3ve+lzvvvDNLly5NbW1tfuzHfiyPPfbYOQ0MAMwucys9YP/+/ens7MyePXuydu3a7Nq1K+vXr8+zzz6bxYsXv279yMhI3ve+92Xx4sX50pe+lGXLluWFF17IggULpmN+AGCGqzhGdu7cmU2bNqWjoyNJsmfPnjz66KPZt29ftm7d+rr1+/bty0svvZTHH388l1xySZKkubn5/KYGAGaNin5NMzIyksOHD6etre21E1RVpa2tLb29vWc85mtf+1paW1tz5513prGxMStWrMgnP/nJjI6OTvk4w8PDGRwcnLQBALNTRTFy6tSpjI6OprGxcdL+xsbG9PX1nfGYf//3f8+XvvSljI6O5rHHHsu9996bBx54IH/0R3805eN0d3enoaFhYmtqaqpkTABgBrng76YZGxvL4sWL8+CDD2bVqlXZsGFD7rnnnuzZs2fKY7q6ujIwMDCxHT9+/EKPCQAUUtE9I4sWLUp1dXX6+/sn7e/v78+SJUvOeMzSpUtzySWXpLq6emLfjTfemL6+voyMjKSmpuZ1x9TW1qa2traS0QCAGaqiKyM1NTVZtWpVenp6JvaNjY2lp6cnra2tZzzm3e9+d7797W9nbGxsYt9zzz2XpUuXnjFEAICLS8W/puns7MzevXvzhS98Ic8880w+9KEPZWhoaOLdNe3t7enq6ppY/6EPfSgvvfRStmzZkueeey6PPvpoPvnJT+bOO++cvmcBAMxYFb+1d8OGDTl58mS2bduWvr6+tLS05MCBAxM3tR47dixVVa81TlNTU77+9a/nd37nd3LTTTdl2bJl2bJlS+6+++7pexYAwIw1Z3x8fLz0EG9mcHAwDQ0NGRgYSH19felxmCGatz5aeoSL1tEdt5QeAXgLONuf3z6bBgAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKCoc4qR3bt3p7m5OXV1dVm7dm0OHTo05do///M/z5w5cyZtdXV15zwwADC7VBwj+/fvT2dnZ7Zv354jR45k5cqVWb9+fU6cODHlMfX19fmP//iPie2FF144r6EBgNmj4hjZuXNnNm3alI6Ojixfvjx79uzJvHnzsm/fvimPmTNnTpYsWTKxNTY2ntfQAMDsUVGMjIyM5PDhw2lra3vtBFVVaWtrS29v75THvfLKK3n729+epqam/OIv/mKeeuqpc58YAJhVKoqRU6dOZXR09HVXNhobG9PX13fGY66//vrs27cvX/3qV/PQQw9lbGws73rXu/Kd73xnyscZHh7O4ODgpA0AmJ0u+LtpWltb097enpaWltx888358pe/nCuvvDJ/+qd/OuUx3d3daWhomNiampou9JgAQCEVxciiRYtSXV2d/v7+Sfv7+/uzZMmSszrHJZdckne84x359re/PeWarq6uDAwMTGzHjx+vZEwAYAapKEZqamqyatWq9PT0TOwbGxtLT09PWltbz+oco6OjefLJJ7N06dIp19TW1qa+vn7SBgDMTnMrPaCzszMbN27M6tWrs2bNmuzatStDQ0Pp6OhIkrS3t2fZsmXp7u5OknzsYx/LO9/5zlx77bX53ve+l/vuuy8vvPBCPvCBD0zvMwEAZqSKY2TDhg05efJktm3blr6+vrS0tOTAgQMTN7UeO3YsVVWvXXB5+eWXs2nTpvT19eXyyy/PqlWr8vjjj2f58uXT9ywAgBlrzvj4+HjpId7M4OBgGhoaMjAw4Fc2nLXmrY+WHuGidXTHLaVHAN4Czvbnt8+mAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARZ1TjOzevTvNzc2pq6vL2rVrc+jQobM67pFHHsmcOXNy2223ncvDAgCzUMUxsn///nR2dmb79u05cuRIVq5cmfXr1+fEiRNveNzRo0fze7/3e3nPe95zzsMCALNPxTGyc+fObNq0KR0dHVm+fHn27NmTefPmZd++fVMeMzo6ml/7tV/LRz/60Vx99dXnNTAAMLtUFCMjIyM5fPhw2traXjtBVVXa2trS29s75XEf+9jHsnjx4txxxx1n9TjDw8MZHByctAEAs1NFMXLq1KmMjo6msbFx0v7Gxsb09fWd8ZhvfOMb+bM/+7Ps3bv3rB+nu7s7DQ0NE1tTU1MlYwIAM8gFfTfN6dOn8+u//uvZu3dvFi1adNbHdXV1ZWBgYGI7fvz4BZwSAChpbiWLFy1alOrq6vT390/a39/fnyVLlrxu/fPPP5+jR4/m1ltvndg3Njb2/QeeOzfPPvtsrrnmmtcdV1tbm9ra2kpGAwBmqIqujNTU1GTVqlXp6emZ2Dc2Npaenp60tra+bv0NN9yQJ598Mk888cTE9gu/8Av5mZ/5mTzxxBN+/QIAVHZlJEk6OzuzcePGrF69OmvWrMmuXbsyNDSUjo6OJEl7e3uWLVuW7u7u1NXVZcWKFZOOX7BgQZK8bj8AcHGqOEY2bNiQkydPZtu2benr60tLS0sOHDgwcVPrsWPHUlXlD7sCAGdnzvj4+HjpId7M4OBgGhoaMjAwkPr6+tLjMEM0b3209AgXraM7bik9AvAWcLY/v13CAACKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARc0tPUBpzVsfLT3CRevojltKjwDAW4ArIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLOKUZ2796d5ubm1NXVZe3atTl06NCUa7/85S9n9erVWbBgQX7kR34kLS0t+eIXv3jOAwMAs0vFMbJ///50dnZm+/btOXLkSFauXJn169fnxIkTZ1x/xRVX5J577klvb2/+5V/+JR0dHeno6MjXv/718x4eAJj5Ko6RnTt3ZtOmTeno6Mjy5cuzZ8+ezJs3L/v27Tvj+nXr1uX9739/brzxxlxzzTXZsmVLbrrppnzjG9847+EBgJmvohgZGRnJ4cOH09bW9toJqqrS1taW3t7eNz1+fHw8PT09efbZZ/PTP/3TU64bHh7O4ODgpA0AmJ0qipFTp05ldHQ0jY2Nk/Y3Njamr69vyuMGBgYyf/781NTU5JZbbslnPvOZvO9975tyfXd3dxoaGia2pqamSsYEAGaQH8q7aS677LI88cQT+ad/+qd84hOfSGdnZw4ePDjl+q6urgwMDExsx48f/2GMCQAUMLeSxYsWLUp1dXX6+/sn7e/v78+SJUumPK6qqirXXnttkqSlpSXPPPNMuru7s27dujOur62tTW1tbSWjAQAzVEVXRmpqarJq1ar09PRM7BsbG0tPT09aW1vP+jxjY2MZHh6u5KEBgFmqoisjSdLZ2ZmNGzdm9erVWbNmTXbt2pWhoaF0dHQkSdrb27Ns2bJ0d3cn+f79H6tXr84111yT4eHhPPbYY/niF7+Yz33uc9P7TACAGaniGNmwYUNOnjyZbdu2pa+vLy0tLTlw4MDETa3Hjh1LVdVrF1yGhoby4Q9/ON/5zndy6aWX5oYbbshDDz2UDRs2TN+zAABmrDnj4+PjpYd4M4ODg2loaMjAwEDq6+un9dzNWx+d1vNx9o7uuOWCnt9rW86Ffm2BmeFsf377bBoAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAos4pRnbv3p3m5ubU1dVl7dq1OXTo0JRr9+7dm/e85z25/PLLc/nll6etre0N1wMAF5eKY2T//v3p7OzM9u3bc+TIkaxcuTLr16/PiRMnzrj+4MGDuf322/N3f/d36e3tTVNTU372Z3823/3ud897eABg5qs4Rnbu3JlNmzalo6Mjy5cvz549ezJv3rzs27fvjOv/4i/+Ih/+8IfT0tKSG264IZ///OczNjaWnp6e8x4eAJj5KoqRkZGRHD58OG1tba+doKoqbW1t6e3tPatzvPrqq/nv//7vXHHFFVOuGR4ezuDg4KQNAJidKoqRU6dOZXR0NI2NjZP2NzY2pq+v76zOcffdd+eqq66aFDT/X3d3dxoaGia2pqamSsYEAGaQH+q7aXbs2JFHHnkkf/VXf5W6urop13V1dWVgYGBiO378+A9xSgDgh2luJYsXLVqU6urq9Pf3T9rf39+fJUuWvOGx999/f3bs2JG//du/zU033fSGa2tra1NbW1vJaADADFXRlZGampqsWrVq0s2nP7gZtbW1dcrjPvWpT+XjH/94Dhw4kNWrV5/7tADArFPRlZEk6ezszMaNG7N69eqsWbMmu3btytDQUDo6OpIk7e3tWbZsWbq7u5Mkf/zHf5xt27bl4YcfTnNz88S9JfPnz8/8+fOn8akAADNRxTGyYcOGnDx5Mtu2bUtfX19aWlpy4MCBiZtajx07lqqq1y64fO5zn8vIyEh++Zd/edJ5tm/fnj/8wz88v+kBgBmv4hhJks2bN2fz5s1n/N7BgwcnfX306NFzeQgA4CLhs2kAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAijqnGNm9e3eam5tTV1eXtWvX5tChQ1Oufeqpp/JLv/RLaW5uzpw5c7Jr165znRUAmIUqjpH9+/ens7Mz27dvz5EjR7Jy5cqsX78+J06cOOP6V199NVdffXV27NiRJUuWnPfAAMDsUnGM7Ny5M5s2bUpHR0eWL1+ePXv2ZN68edm3b98Z1//kT/5k7rvvvvzqr/5qamtrz3tgAGB2qShGRkZGcvjw4bS1tb12gqqqtLW1pbe3d9qGGh4ezuDg4KQNAJidKoqRU6dOZXR0NI2NjZP2NzY2pq+vb9qG6u7uTkNDw8TW1NQ0becGAN5a3pLvpunq6srAwMDEdvz48dIjAQAXyNxKFi9atCjV1dXp7++ftL+/v39ab06tra11fwkAXCQqujJSU1OTVatWpaenZ2Lf2NhYenp60traOu3DAQCzX0VXRpKks7MzGzduzOrVq7NmzZrs2rUrQ0ND6ejoSJK0t7dn2bJl6e7uTvL9m16ffvrpiX/+7ne/myeeeCLz58/PtddeO41PBQCYiSqOkQ0bNuTkyZPZtm1b+vr60tLSkgMHDkzc1Hrs2LFUVb12weXFF1/MO97xjomv77///tx///25+eabc/DgwfN/BgDAjFZxjCTJ5s2bs3nz5jN+7/8HRnNzc8bHx8/lYQCAi8Bb8t00AMDFQ4wAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKCoc4qR3bt3p7m5OXV1dVm7dm0OHTr0huv/8i//MjfccEPq6ury4z/+43nsscfOaVgAYPapOEb279+fzs7ObN++PUeOHMnKlSuzfv36nDhx4ozrH3/88dx+++2544478s///M+57bbbctttt+Vf//Vfz3t4AGDmqzhGdu7cmU2bNqWjoyPLly/Pnj17Mm/evOzbt++M6z/96U/n537u53LXXXflxhtvzMc//vH8xE/8RD772c+e9/AAwMw3t5LFIyMjOXz4cLq6uib2VVVVpa2tLb29vWc8pre3N52dnZP2rV+/Pl/5ylemfJzh4eEMDw9PfD0wMJAkGRwcrGTcszI2/Oq0n5OzcyFez//La1vOhX5tgZnhB/8vGB8ff8N1FcXIqVOnMjo6msbGxkn7Gxsb82//9m9nPKavr++M6/v6+qZ8nO7u7nz0ox993f6mpqZKxuUtrmFX6Qm4ULy2wP91+vTpNDQ0TPn9imLkh6Wrq2vS1ZSxsbG89NJLWbhwYebMmVNwsreWwcHBNDU15fjx46mvry89DtPE6zp7eW1nL6/tmY2Pj+f06dO56qqr3nBdRTGyaNGiVFdXp7+/f9L+/v7+LFmy5IzHLFmypKL1SVJbW5va2tpJ+xYsWFDJqBeV+vp6//LPQl7X2ctrO3t5bV/vja6I/EBFN7DW1NRk1apV6enpmdg3NjaWnp6etLa2nvGY1tbWSeuT5G/+5m+mXA8AXFwq/jVNZ2dnNm7cmNWrV2fNmjXZtWtXhoaG0tHRkSRpb2/PsmXL0t3dnSTZsmVLbr755jzwwAO55ZZb8sgjj+Sb3/xmHnzwwel9JgDAjFRxjGzYsCEnT57Mtm3b0tfXl5aWlhw4cGDiJtVjx46lquq1Cy7vete78vDDD+cjH/lI/uAP/iDXXXddvvKVr2TFihXT9ywuUrW1tdm+ffvrfqXFzOZ1nb28trOX1/b8zBl/s/fbAABcQD6bBgAoSowAAEWJEQCgKDECbxHr1q3Lb//2b5ceAzhP/luunBgBAIoSIwBAUWJkBjpw4EB+6qd+KgsWLMjChQvz8z//83n++edLj8U0+J//+Z9s3rw5DQ0NWbRoUe699943/bRLZoaxsbF86lOfyrXXXpva2tr86I/+aD7xiU+UHovzNDQ0lPb29syfPz9Lly7NAw88UHqkGUmMzEBDQ0Pp7OzMN7/5zfT09KSqqirvf//7MzY2Vno0ztMXvvCFzJ07N4cOHcqnP/3p7Ny5M5///OdLj8U06Orqyo4dO3Lvvffm6aefzsMPP/y6TzRn5rnrrrvy93//9/nqV7+av/7rv87Bgwdz5MiR0mPNOP7o2Sxw6tSpXHnllXnyySf9ZdsZbN26dTlx4kSeeuqpiU+n3rp1a772ta/l6aefLjwd5+P06dO58sor89nPfjYf+MAHSo/DNHnllVeycOHCPPTQQ/mVX/mVJMlLL72Ut73tbfngBz+YXbt2lR1wBnFlZAb61re+ldtvvz1XX3116uvr09zcnOT7f4qfme2d73znRIgk3/+gyW9961sZHR0tOBXn65lnnsnw8HDe+973lh6FafT8889nZGQka9eundh3xRVX5Prrry841cxU8WfTUN6tt96at7/97dm7d2+uuuqqjI2NZcWKFRkZGSk9GnAGl156aekR4C3NlZEZ5j//8z/z7LPP5iMf+Uje+9735sYbb8zLL79ceiymyT/+4z9O+vof/uEfct1116W6urrQREyH6667Lpdeeml6enpKj8I0uuaaa3LJJZdM+u/25ZdfznPPPVdwqpnJlZEZ5vLLL8/ChQvz4IMPZunSpTl27Fi2bt1aeiymybFjx9LZ2Znf+I3fyJEjR/KZz3zG3fmzQF1dXe6+++78/u//fmpqavLud787J0+ezFNPPZU77rij9Hico/nz5+eOO+7IXXfdlYULF2bx4sW55557Jn1yPWdHjMwwVVVVeeSRR/Jbv/VbWbFiRa6//vr8yZ/8SdatW1d6NKZBe3t7/uu//itr1qxJdXV1tmzZkg9+8IOlx2Ia3HvvvZk7d262bduWF198MUuXLs1v/uZvlh6L83TffffllVdeya233prLLrssv/u7v5uBgYHSY8043k0DABTlWhIAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKOp/AZfdtxeXPbBUAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# abduction on noise of the composition function P_arms\n",
    "arm_id = env_.arm_keys.index(arm)\n",
    "print(arm_id)\n",
    "_, arm_g = gumbel_max_rejection_sampling(\n",
    "    torch.tensor(list(P_A.values())), arm_id, max_iterations=100000\n",
    ") \n",
    "# noise that gets us abstraction 'c' given the control '2' and the original enironment (x)\n",
    "# this is the noise for the 'combined mechanism' here \n",
    "\n",
    "cf_logits = deepcopy(logits)    # simplified case, the policy outputs the same logits \n",
    "cf_control_probs = torch.softmax(cf_logits, dim=0)\n",
    "\n",
    "# in this counterfactual state, the distribution over abstractions (arms) is different \n",
    "# because the abstraction function maps the controls differently\n",
    "P_A_cf = estimate_arm_probs(cf_control_probs, env_)\n",
    "plot_PA(P_A_cf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'b'"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# compute the counterfactual value of the arm using the original noise term and P_A_cf\n",
    "cf_arm_probs = torch.tensor(list(P_A_cf.values()))\n",
    "# we can now sample from this using the original noise term and gumbel_max trick\n",
    "cf_arm_id, _ = gumbel_max(torch.log(cf_arm_probs), arm_g)\n",
    "cf_arm = env_.arm_keys[cf_arm_id]\n",
    "cf_arm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor(0.), tensor(0.), tensor(0.2689), tensor(0.7311)]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAdOElEQVR4nO3df2xddf3H8VdvR28pW+8YdbdbuXqFKaOBtdCutRB+mQvVLMgSNYWorTejf8BG0BsNq2groNzpoCm6hsJcxYDLGggD42YRrw5CqCm0LjKEEdTR8uPetkHuHZfYknvv9w/iXfpdO3r6Y29u93wkJ7Fnn3Pu++YS+8y5597mZTKZjAAAAIy4rAcAAACnNmIEAACYIkYAAIApYgQAAJgiRgAAgCliBAAAmCJGAACAKWIEAACYWmI9wEyk02m9/fbbWrZsmfLy8qzHAQAAM5DJZHT06FGtXr1aLtf01z9yIkbefvtt+Xw+6zEAAMAsDA8P6+yzz57233MiRpYtWybpoydTXFxsPA0AAJiJRCIhn8+X/T0+nZyIkf+9NVNcXEyMAACQYz7uFgtuYAUAAKaIEQAAYIoYAQAApogRAABgihgBAACmiBEAAGCKGAEAAKaIEQAAYIoYAQAApogRAABgihgBAACmiBEAAGCKGAEAAKaIEQAAYGqJ9QAAAEiSf+s+6xFOWUe2bTB9fK6MAAAAU8QIAAAwRYwAAABTxAgAADBFjAAAAFPECAAAMEWMAAAAU8QIAAAwRYwAAABTxAgAADA1qxjp7OyU3+9XYWGhamtr1d/fP+3aK6+8Unl5ecdtGzbYfvUsAAD4ZHAcIz09PQqFQmpra9Pg4KAqKipUX1+vkZGRKdc//vjjeuedd7LboUOHlJ+fr69//etzHh4AAOQ+xzHS3t6u5uZmBYNBlZeXq6urS0VFReru7p5y/YoVK1RaWprdnn76aRUVFREjAABAksMYmZiY0MDAgAKBwLETuFwKBALq6+ub0Tl27dql66+/Xmeccca0a8bHx5VIJCZtAABgcXIUI2NjY0qlUvJ6vZP2e71eRaPRjz2+v79fhw4d0o033njCdeFwWB6PJ7v5fD4nYwIAgBxyUj9Ns2vXLl144YWqqak54bqWlhbF4/HsNjw8fJImBAAAJ9sSJ4tLSkqUn5+vWCw2aX8sFlNpaekJj00mk9qzZ4/uvPPOj30ct9stt9vtZDQAAJCjHF0ZKSgoUFVVlSKRSHZfOp1WJBJRXV3dCY999NFHNT4+rm9+85uzmxQAACxKjq6MSFIoFFJTU5Oqq6tVU1Ojjo4OJZNJBYNBSVJjY6PKysoUDocnHbdr1y5t3LhRZ5111vxMDgAAFgXHMdLQ0KDR0VG1trYqGo2qsrJSvb292Ztah4aG5HJNvuBy+PBhPffcc/rjH/84P1MDAIBFIy+TyWSsh/g4iURCHo9H8XhcxcXF1uMAABaAf+s+6xFOWUe2Lcy3os/09zd/mwYAAJgiRgAAgCliBAAAmCJGAACAKWIEAACYIkYAAIApYgQAAJgiRgAAgCliBAAAmCJGAACAKWIEAACYIkYAAIApYgQAAJgiRgAAgCliBAAAmCJGAACAKWIEAACYIkYAAIApYgQAAJgiRgAAgCliBAAAmCJGAACAKWIEAACYIkYAAIApYgQAAJgiRgAAgCliBAAAmCJGAACAKWIEAACYIkYAAIApYgQAAJgiRgAAgCliBAAAmCJGAACAKWIEAACYIkYAAIApYgQAAJiaVYx0dnbK7/ersLBQtbW16u/vP+H69957T5s3b9aqVavkdrv1+c9/Xvv375/VwAAAYHFZ4vSAnp4ehUIhdXV1qba2Vh0dHaqvr9fhw4e1cuXK49ZPTEzo6quv1sqVK/XYY4+prKxMb7zxhpYvXz4f8wMAgBznOEba29vV3NysYDAoSerq6tK+ffvU3d2trVu3Hre+u7tb7777rp5//nmddtppkiS/3z+3qQEAwKLh6G2aiYkJDQwMKBAIHDuBy6VAIKC+vr4pj/nd736nuro6bd68WV6vVxdccIHuvvtupVKpaR9nfHxciURi0gYAABYnRzEyNjamVColr9c7ab/X61U0Gp3ymH/961967LHHlEqltH//fv3oRz/Svffeq5/85CfTPk44HJbH48luPp/PyZgAACCHLPinadLptFauXKkHH3xQVVVVamho0O23366urq5pj2lpaVE8Hs9uw8PDCz0mAAAw4uiekZKSEuXn5ysWi03aH4vFVFpaOuUxq1at0mmnnab8/PzsvvPPP1/RaFQTExMqKCg47hi32y232+1kNAAAkKMcXRkpKChQVVWVIpFIdl86nVYkElFdXd2Ux1x66aV6/fXXlU6ns/tee+01rVq1asoQAQAApxbHb9OEQiHt3LlTv/nNb/TKK6/opptuUjKZzH66prGxUS0tLdn1N910k959913deuuteu2117Rv3z7dfffd2rx58/w9CwAAkLMcf7S3oaFBo6Ojam1tVTQaVWVlpXp7e7M3tQ4NDcnlOtY4Pp9PTz31lL773e9q3bp1Kisr06233qrbbrtt/p4FAADIWXmZTCZjPcTHSSQS8ng8isfjKi4uth4HALAA/Fv3WY9wyjqybcOCnHemv7/52zQAAMAUMQIAAEwRIwAAwBQxAgAATBEjAADAFDECAABMESMAAMAUMQIAAEwRIwAAwBQxAgAATBEjAADAFDECAABMESMAAMAUMQIAAEwRIwAAwBQxAgAATBEjAADAFDECAABMESMAAMAUMQIAAEwRIwAAwBQxAgAATBEjAADAFDECAABMESMAAMAUMQIAAEwRIwAAwBQxAgAATBEjAADAFDECAABMESMAAMAUMQIAAEwRIwAAwBQxAgAATBEjAADAFDECAABMESMAAMDUrGKks7NTfr9fhYWFqq2tVX9//7RrH3roIeXl5U3aCgsLZz0wAABYXBzHSE9Pj0KhkNra2jQ4OKiKigrV19drZGRk2mOKi4v1zjvvZLc33nhjTkMDAIDFw3GMtLe3q7m5WcFgUOXl5erq6lJRUZG6u7unPSYvL0+lpaXZzev1zmloAACweDiKkYmJCQ0MDCgQCBw7gculQCCgvr6+aY97//339ZnPfEY+n0/XXXedXn755dlPDAAAFhVHMTI2NqZUKnXclQ2v16toNDrlMeedd566u7v15JNP6pFHHlE6ndYll1yiN998c9rHGR8fVyKRmLQBAIDFacE/TVNXV6fGxkZVVlbqiiuu0OOPP65PfepTeuCBB6Y9JhwOy+PxZDefz7fQYwIAACOOYqSkpET5+fmKxWKT9sdiMZWWls7oHKeddpouuugivf7669OuaWlpUTwez27Dw8NOxgQAADnEUYwUFBSoqqpKkUgkuy+dTisSiaiurm5G50ilUnrppZe0atWqade43W4VFxdP2gAAwOK0xOkBoVBITU1Nqq6uVk1NjTo6OpRMJhUMBiVJjY2NKisrUzgcliTdeeed+sIXvqA1a9bovffe0/bt2/XGG2/oxhtvnN9nAgAAcpLjGGloaNDo6KhaW1sVjUZVWVmp3t7e7E2tQ0NDcrmOXXD5z3/+o+bmZkWjUZ155pmqqqrS888/r/Ly8vl7FgAAIGflZTKZjPUQHyeRSMjj8Sgej/OWDQAsUv6t+6xHOGUd2bZhQc4709/f/G0aAABgihgBAACmiBEAAGCKGAEAAKaIEQAAYIoYAQAApogRAABgihgBAACmiBEAAGCKGAEAAKaIEQAAYIoYAQAApogRAABgihgBAACmiBEAAGCKGAEAAKaIEQAAYIoYAQAApogRAABgihgBAACmiBEAAGCKGAEAAKaIEQAAYIoYAQAApogRAABgihgBAACmiBEAAGCKGAEAAKaIEQAAYIoYAQAApogRAABgihgBAACmiBEAAGCKGAEAAKaIEQAAYIoYAQAApogRAABgihgBAACmZhUjnZ2d8vv9KiwsVG1trfr7+2d03J49e5SXl6eNGzfO5mEBAMAi5DhGenp6FAqF1NbWpsHBQVVUVKi+vl4jIyMnPO7IkSP63ve+p8suu2zWwwIAgMXHcYy0t7erublZwWBQ5eXl6urqUlFRkbq7u6c9JpVK6Rvf+IbuuOMOnXPOOXMaGAAALC6OYmRiYkIDAwMKBALHTuByKRAIqK+vb9rj7rzzTq1cuVKbNm2a0eOMj48rkUhM2gAAwOLkKEbGxsaUSqXk9Xon7fd6vYpGo1Me89xzz2nXrl3auXPnjB8nHA7L4/FkN5/P52RMAACQQxb00zRHjx7Vt771Le3cuVMlJSUzPq6lpUXxeDy7DQ8PL+CUAADA0hIni0tKSpSfn69YLDZpfywWU2lp6XHr//nPf+rIkSO69tprs/vS6fRHD7xkiQ4fPqxzzz33uOPcbrfcbreT0QAAQI5ydGWkoKBAVVVVikQi2X3pdFqRSER1dXXHrV+7dq1eeuklHTx4MLt95Stf0VVXXaWDBw/y9gsAAHB2ZUSSQqGQmpqaVF1drZqaGnV0dCiZTCoYDEqSGhsbVVZWpnA4rMLCQl1wwQWTjl++fLkkHbcfAACcmhzHSENDg0ZHR9Xa2qpoNKrKykr19vZmb2odGhqSy8UXuwIAgJnJy2QyGeshPk4ikZDH41E8HldxcbH1OACABeDfus96hFPWkW0bFuS8M/39zSUMAABgihgBAACmiBEAAGCKGAEAAKaIEQAAYIoYAQAApogRAABgihgBAACmiBEAAGCKGAEAAKaIEQAAYIoYAQAApogRAABgihgBAACmiBEAAGCKGAEAAKaIEQAAYIoYAQAApogRAABgihgBAACmiBEAAGCKGAEAAKaIEQAAYIoYAQAApogRAABgihgBAACmiBEAAGCKGAEAAKaIEQAAYIoYAQAApogRAABgihgBAACmiBEAAGCKGAEAAKaIEQAAYIoYAQAApogRAABgalYx0tnZKb/fr8LCQtXW1qq/v3/atY8//riqq6u1fPlynXHGGaqsrNTDDz8864EBAMDi4jhGenp6FAqF1NbWpsHBQVVUVKi+vl4jIyNTrl+xYoVuv/129fX16e9//7uCwaCCwaCeeuqpOQ8PAAByX14mk8k4OaC2tlbr16/Xjh07JEnpdFo+n0+33HKLtm7dOqNzXHzxxdqwYYPuuuuuGa1PJBLyeDyKx+MqLi52Mi4AIEf4t+6zHuGUdWTbhgU570x/fzu6MjIxMaGBgQEFAoFjJ3C5FAgE1NfX97HHZzIZRSIRHT58WJdffvm068bHx5VIJCZtAABgcXIUI2NjY0qlUvJ6vZP2e71eRaPRaY+Lx+NaunSpCgoKtGHDBv3yl7/U1VdfPe36cDgsj8eT3Xw+n5MxAQBADjkpn6ZZtmyZDh48qBdeeEE//elPFQqFdODAgWnXt7S0KB6PZ7fh4eGTMSYAADCwxMnikpIS5efnKxaLTdofi8VUWlo67XEul0tr1qyRJFVWVuqVV15ROBzWlVdeOeV6t9stt9vtZDQAAJCjHF0ZKSgoUFVVlSKRSHZfOp1WJBJRXV3djM+TTqc1Pj7u5KEBAMAi5ejKiCSFQiE1NTWpurpaNTU16ujoUDKZVDAYlCQ1NjaqrKxM4XBY0kf3f1RXV+vcc8/V+Pi49u/fr4cfflj333///D4TAACQkxzHSENDg0ZHR9Xa2qpoNKrKykr19vZmb2odGhqSy3XsgksymdTNN9+sN998U6effrrWrl2rRx55RA0NDfP3LAAAQM5y/D0jFvieEQBY/PieETs59T0jAAAA840YAQAApogRAABgihgBAACmiBEAAGCKGAEAAKaIEQAAYIoYAQAAphx/AysAWOPLsews1Jdj4dTGlREAAGCKGAEAAKaIEQAAYIoYAQAApogRAABgihgBAACmiBEAAGCKGAEAAKaIEQAAYIoYAQAApogRAABgihgBAACmiBEAAGCKGAEAAKaIEQAAYIoYAQAApogRAABgihgBAACmiBEAAGCKGAEAAKaIEQAAYIoYAQAApogRAABgihgBAACmiBEAAGCKGAEAAKaIEQAAYIoYAQAApmYVI52dnfL7/SosLFRtba36+/unXbtz505ddtllOvPMM3XmmWcqEAiccD0AADi1OI6Rnp4ehUIhtbW1aXBwUBUVFaqvr9fIyMiU6w8cOKAbbrhBf/nLX9TX1yefz6drrrlGb7311pyHBwAAuc9xjLS3t6u5uVnBYFDl5eXq6upSUVGRuru7p1z/29/+VjfffLMqKyu1du1a/epXv1I6nVYkEpnz8AAAIPc5ipGJiQkNDAwoEAgcO4HLpUAgoL6+vhmd44MPPtCHH36oFStWTLtmfHxciURi0gYAABYnRzEyNjamVColr9c7ab/X61U0Gp3ROW677TatXr16UtD8f+FwWB6PJ7v5fD4nYwIAgBxyUj9Ns23bNu3Zs0d79+5VYWHhtOtaWloUj8ez2/Dw8EmcEgAAnExLnCwuKSlRfn6+YrHYpP2xWEylpaUnPPaee+7Rtm3b9Kc//Unr1q074Vq32y232+1kNAAAkKMcXRkpKChQVVXVpJtP/3czal1d3bTH/fznP9ddd92l3t5eVVdXz35aAACw6Di6MiJJoVBITU1Nqq6uVk1NjTo6OpRMJhUMBiVJjY2NKisrUzgcliT97Gc/U2trq3bv3i2/35+9t2Tp0qVaunTpPD4VAACQixzHSENDg0ZHR9Xa2qpoNKrKykr19vZmb2odGhqSy3Xsgsv999+viYkJfe1rX5t0nra2Nv34xz+e2/QAACDnOY4RSdqyZYu2bNky5b8dOHBg0s9HjhyZzUMAAIBTBH+bBgAAmCJGAACAKWIEAACYIkYAAIApYgQAAJgiRgAAgCliBAAAmCJGAACAKWIEAACYIkYAAIApYgQAAJgiRgAAgCliBAAAmCJGAACAKWIEAACYIkYAAIApYgQAAJgiRgAAgCliBAAAmCJGAACAKWIEAACYIkYAAIApYgQAAJgiRgAAgCliBAAAmCJGAACAKWIEAACYIkYAAIApYgQAAJgiRgAAgCliBAAAmCJGAACAKWIEAACYIkYAAIApYgQAAJgiRgAAgCliBAAAmJpVjHR2dsrv96uwsFC1tbXq7++fdu3LL7+sr371q/L7/crLy1NHR8dsZwUAAIuQ4xjp6elRKBRSW1ubBgcHVVFRofr6eo2MjEy5/oMPPtA555yjbdu2qbS0dM4DAwCAxcVxjLS3t6u5uVnBYFDl5eXq6upSUVGRuru7p1y/fv16bd++Xddff73cbvecBwYAAIuLoxiZmJjQwMCAAoHAsRO4XAoEAurr65u3ocbHx5VIJCZtAABgcXIUI2NjY0qlUvJ6vZP2e71eRaPReRsqHA7L4/FkN5/PN2/nBgAAnyyfyE/TtLS0KB6PZ7fh4WHrkQAAwAJZ4mRxSUmJ8vPzFYvFJu2PxWLzenOq2+3m/hIAAE4Rjq6MFBQUqKqqSpFIJLsvnU4rEomorq5u3ocDAACLn6MrI5IUCoXU1NSk6upq1dTUqKOjQ8lkUsFgUJLU2NiosrIyhcNhSR/d9PqPf/wj+7/feustHTx4UEuXLtWaNWvm8akAAIBc5DhGGhoaNDo6qtbWVkWjUVVWVqq3tzd7U+vQ0JBcrmMXXN5++21ddNFF2Z/vuece3XPPPbriiit04MCBuT8DAACQ0xzHiCRt2bJFW7ZsmfLf/n9g+P1+ZTKZ2TwMAAA4BXwiP00DAABOHcQIAAAwRYwAAABTxAgAADBFjAAAAFPECAAAMEWMAAAAU8QIAAAwRYwAAABTxAgAADBFjAAAAFPECAAAMEWMAAAAU8QIAAAwRYwAAABTxAgAADBFjAAAAFPECAAAMEWMAAAAU8QIAAAwRYwAAABTxAgAADBFjAAAAFPECAAAMEWMAAAAU8QIAAAwRYwAAABTxAgAADBFjAAAAFPECAAAMEWMAAAAU8QIAAAwRYwAAABTxAgAADBFjAAAAFPECAAAMEWMAAAAU7OKkc7OTvn9fhUWFqq2tlb9/f0nXP/oo49q7dq1Kiws1IUXXqj9+/fPalgAALD4OI6Rnp4ehUIhtbW1aXBwUBUVFaqvr9fIyMiU659//nndcMMN2rRpk/72t79p48aN2rhxow4dOjTn4QEAQO5zHCPt7e1qbm5WMBhUeXm5urq6VFRUpO7u7inX33ffffrSl76k73//+zr//PN111136eKLL9aOHTvmPDwAAMh9S5wsnpiY0MDAgFpaWrL7XC6XAoGA+vr6pjymr69PoVBo0r76+no98cQT0z7O+Pi4xsfHsz/H43FJUiKRcDIugEUqPf6B9QinrIX8/2FeVzsL9br+77yZTOaE6xzFyNjYmFKplLxe76T9Xq9Xr7766pTHRKPRKddHo9FpHyccDuuOO+44br/P53MyLgBgnnk6rCfAQljo1/Xo0aPyeDzT/rujGDlZWlpaJl1NSafTevfdd3XWWWcpLy/PcLJPlkQiIZ/Pp+HhYRUXF1uPg3nC67p48douXry2U8tkMjp69KhWr159wnWOYqSkpET5+fmKxWKT9sdiMZWWlk55TGlpqaP1kuR2u+V2uyftW758uZNRTynFxcX8x78I8bouXry2ixev7fFOdEXkfxzdwFpQUKCqqipFIpHsvnQ6rUgkorq6uimPqaurm7Rekp5++ulp1wMAgFOL47dpQqGQmpqaVF1drZqaGnV0dCiZTCoYDEqSGhsbVVZWpnA4LEm69dZbdcUVV+jee+/Vhg0btGfPHr344ot68MEH5/eZAACAnOQ4RhoaGjQ6OqrW1lZFo1FVVlaqt7c3e5Pq0NCQXK5jF1wuueQS7d69Wz/84Q/1gx/8QJ/73Of0xBNP6IILLpi/Z3GKcrvdamtrO+4tLeQ2XtfFi9d28eK1nZu8zMd93gYAAGAB8bdpAACAKWIEAACYIkYAAIApYgQAAJgiRnJUZ2en/H6/CgsLVVtbq/7+fuuRMA+effZZXXvttVq9erXy8vJO+DeckDvC4bDWr1+vZcuWaeXKldq4caMOHz5sPRbm6P7779e6deuyX3RWV1enP/zhD9Zj5SRiJAf19PQoFAqpra1Ng4ODqqioUH19vUZGRqxHwxwlk0lVVFSos7PTehTMo2eeeUabN2/WX//6Vz399NP68MMPdc011yiZTFqPhjk4++yztW3bNg0MDOjFF1/UF7/4RV133XV6+eWXrUfLOXy0NwfV1tZq/fr12rFjh6SPvgXX5/Pplltu0datW42nw3zJy8vT3r17tXHjRutRMM9GR0e1cuVKPfPMM7r88sutx8E8WrFihbZv365NmzZZj5JTuDKSYyYmJjQwMKBAIJDd53K5FAgE1NfXZzgZgJmKx+OSPvrFhcUhlUppz549SiaT/LmTWfhE/tVeTG9sbEypVCr7jbf/4/V69eqrrxpNBWCm0um0vvOd7+jSSy/lm6gXgZdeekl1dXX673//q6VLl2rv3r0qLy+3HivnECMAcBJt3rxZhw4d0nPPPWc9CubBeeedp4MHDyoej+uxxx5TU1OTnnnmGYLEIWIkx5SUlCg/P1+xWGzS/lgsptLSUqOpAMzEli1b9Pvf/17PPvuszj77bOtxMA8KCgq0Zs0aSVJVVZVeeOEF3XfffXrggQeMJ8st3DOSYwoKClRVVaVIJJLdl06nFYlEeJ8S+ITKZDLasmWL9u7dqz//+c/67Gc/az0SFkg6ndb4+Lj1GDmHKyM5KBQKqampSdXV1aqpqVFHR4eSyaSCwaD1aJij999/X6+//nr253//+986ePCgVqxYoU9/+tOGk2EuNm/erN27d+vJJ5/UsmXLFI1GJUkej0enn3668XSYrZaWFn35y1/Wpz/9aR09elS7d+/WgQMH9NRTT1mPlnP4aG+O2rFjh7Zv365oNKrKykr94he/UG1trfVYmKMDBw7oqquuOm5/U1OTHnrooZM/EOZFXl7elPt//etf69vf/vbJHQbzZtOmTYpEInrnnXfk8Xi0bt063Xbbbbr66qutR8s5xAgAADDFPSMAAMAUMQIAAEwRIwAAwBQxAgAATBEjAADAFDECAABMESMAAMAUMQIAAEwRIwAAwBQxAgAATBEjAADAFDECAABM/R8gAZOlnNlHTAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# now we want the coutnerfactual probability of controls, given the coutnerfactual arm\n",
    "cf_control_posterior = {\n",
    "    i: cf_control_probs[i] * env_.arm_given_index[i].probs[cf_arm_id] / P_A_cf[env_.arm_keys[cf_arm_id]]\n",
    "    for i in range(len(logits))\n",
    "}\n",
    "\n",
    "# plot the counterfactual control posterior\n",
    "plot_PA(\n",
    "    cf_control_posterior\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A, B, C, D\n",
    "      C, D, E, F\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
