{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Finding interesting case studies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "from abstract_cf.bios.utils import load_dataset, load_learned_abstraction\n",
    "from transformers.modeling_outputs import SequenceClassifierOutput\n",
    "from abstract_cf.bios.profession_classifier import ProfessionClassifier\n",
    "import plotly.express as px\n",
    "import torch.nn.functional as F\n",
    "import torch\n",
    "\n",
    "device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "datasets, id_to_label = load_dataset()\n",
    "\n",
    "profession_classifier = ProfessionClassifier(\n",
    "    id_to_label=id_to_label,\n",
    "    device=device,\n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "import ravfogel_lm_counterfactuals\n",
    "from ravfogel_lm_counterfactuals.utils import get_counterfactual_model, load_model\n",
    "\n",
    "factual_model_name = \"openai-community/gpt2-xl\"\n",
    "intervention_type = \"mimic_gender_gpt2_instruct\"\n",
    "\n",
    "factual_model = load_model(factual_model_name)\n",
    "counterfactual_model = get_counterfactual_model(intervention_type)\n",
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "    factual_model_name, \n",
    "    model_max_length=512, \n",
    "    padding_side=\"right\", \n",
    "    use_fast=False,\n",
    "    trust_remote_code=True\n",
    ")\n",
    "\n",
    "if tokenizer.pad_token is None:\n",
    "    tokenizer.pad_token = tokenizer.eos_token\n",
    "factual_model.config.pad_token_id = tokenizer.pad_token_id\n",
    "counterfactual_model.config.pad_token_id = tokenizer.pad_token_id\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We want to find examples where the distribution of professions changes a lot from the factual to the interventional continuation. \n",
    "# We sample across biographies, estimate the respective distributions using our learned abstraction, \n",
    "# and sort them by KL divergence. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>index</th>\n",
       "      <th>text</th>\n",
       "      <th>kl</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>41</th>\n",
       "      <td>37988</td>\n",
       "      <td>Jessica received her undergraduate degree in P...</td>\n",
       "      <td>0.545379</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>43</th>\n",
       "      <td>5726</td>\n",
       "      <td>His interest in other kinds of arts emerged si...</td>\n",
       "      <td>0.505530</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>21809</td>\n",
       "      <td>She practices in Aurora , Colorado and has the...</td>\n",
       "      <td>0.446309</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>54</th>\n",
       "      <td>6396</td>\n",
       "      <td>She received her B.A. in film production from ...</td>\n",
       "      <td>0.416268</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>51</th>\n",
       "      <td>29583</td>\n",
       "      <td>After a brief stint in California , John retur...</td>\n",
       "      <td>0.402550</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    index                                               text        kl\n",
       "41  37988  Jessica received her undergraduate degree in P...  0.545379\n",
       "43   5726  His interest in other kinds of arts emerged si...  0.505530\n",
       "3   21809  She practices in Aurora , Colorado and has the...  0.446309\n",
       "54   6396  She received her B.A. in film production from ...  0.416268\n",
       "51  29583  After a brief stint in California , John retur...  0.402550"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd \n",
    "\n",
    "kl_results = pd.read_csv('../kl_results.csv').sort_values('kl', ascending=False)\n",
    "kl_results.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "biography = datasets['dev'].iloc[37988].hard_text\n",
    "inputs = tokenizer(\n",
    "    biography, \n",
    "    return_tensors='pt', \n",
    "    truncation=True, \n",
    "    max_length=8\n",
    ").to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>index</th>\n",
       "      <th>text</th>\n",
       "      <th>kl</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>51</th>\n",
       "      <td>29583</td>\n",
       "      <td>After a brief stint in California , John retur...</td>\n",
       "      <td>0.40255</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    index                                               text       kl\n",
       "51  29583  After a brief stint in California , John retur...  0.40255"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kl_results.iloc[4:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Jessica received her undergraduate degree in Pre Med and shortly after she and her husband decided to become missionaries with the International Mission Board . During her time as a missionary , Jessica went to France and learned French and lived on Reunion Island near Madagascar .'"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "biography"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "from abstract_cf.bios.utils import sample_from_model\n",
    "from abstract_cf.bios.profession_classifier import analyze_career_distribution\n",
    "\n",
    "\n",
    "def get_abstraction_kl_divergence(\n",
    "    inputs: torch.Tensor,\n",
    "    factual_model,\n",
    "    counterfactual_model,\n",
    "    tokenizer,\n",
    "    profession_classifier,\n",
    "    n_samples_per_generation: int = 50,\n",
    "    max_length: int = 100,\n",
    "    return_distributions: bool = False\n",
    "):\n",
    "    factual_text, factual_token_ids = sample_from_model(\n",
    "        factual_model, tokenizer, inputs, n_samples=n_samples_per_generation, max_length=max_length\n",
    "    )\n",
    "    cf_text, cf_token_ids = sample_from_model(\n",
    "        counterfactual_model, tokenizer, inputs, n_samples=n_samples_per_generation, max_length=max_length\n",
    "    )\n",
    "    factual_df = analyze_career_distribution(factual_text, profession_classifier)\n",
    "    counterfactual_df = analyze_career_distribution(cf_text, profession_classifier)\n",
    "\n",
    "    # compute the kl divergence between the two distributions \n",
    "    factual_probs = torch.tensor(factual_df.probability.values)\n",
    "    counterfactual_probs = torch.tensor(counterfactual_df.probability.values)\n",
    "\n",
    "    kl = (factual_probs * (factual_probs / counterfactual_probs).log()).sum()\n",
    "    if return_distributions:\n",
    "        return kl, factual_df, counterfactual_df\n",
    "    return kl "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "kl, factual_df, counterfactual_df = get_abstraction_kl_divergence(\n",
    "    inputs, \n",
    "    factual_model,\n",
    "    counterfactual_model,\n",
    "    tokenizer,\n",
    "    profession_classifier,\n",
    "    n_samples_per_generation=50,\n",
    "    max_length=100,\n",
    "    return_distributions=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.3805)"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.plotly.v1+json": {
       "config": {
        "plotlyServerURL": "XXXX"
       },
       "data": [
        {
         "alignmentgroup": "True",
         "hovertemplate": "profession=%{x}<br>probability=%{y}<extra></extra>",
         "legendgroup": "",
         "marker": {
          "color": "#636efa",
          "pattern": {
           "shape": ""
          }
         },
         "name": "",
         "offsetgroup": "",
         "orientation": "v",
         "showlegend": false,
         "textposition": "auto",
         "type": "bar",
         "x": [
          "teacher",
          "professor",
          "attorney",
          "surgeon",
          "photographer",
          "painter",
          "psychologist",
          "filmmaker",
          "physician",
          "interior_designer",
          "architect",
          "dietitian",
          "pastor",
          "rapper",
          "journalist",
          "dentist",
          "accountant",
          "poet",
          "model",
          "nurse",
          "chiropractor",
          "yoga_teacher",
          "software_engineer",
          "paralegal",
          "composer",
          "personal_trainer",
          "comedian",
          "dj"
         ],
         "xaxis": "x",
         "y": [
          0.016674518585205078,
          0.16868005692958832,
          0.002878899220377207,
          0.018809927627444267,
          0.0006990677793510258,
          0.00024508454953320324,
          0.03087838925421238,
          0.00023743911879137158,
          0.41323986649513245,
          0.0001135040947701782,
          0.0005183322937227786,
          0.014568183571100235,
          0.0006414071540348232,
          0.00008833643369143829,
          0.0016930830897763371,
          0.0005571285146288574,
          0.0011224771151319146,
          0.0004906946560367942,
          0.001537545002065599,
          0.2951102554798126,
          0.0013754138490185142,
          0.0007874607108533382,
          0.0008095065713860095,
          0.0008358677732758224,
          0.00009223704546457157,
          0.027078883722424507,
          0.00014283224300015718,
          0.0000935160496737808
         ],
         "yaxis": "y"
        }
       ],
       "layout": {
        "barmode": "relative",
        "legend": {
         "tracegroupgap": 0
        },
        "template": {
         "data": {
          "bar": [
           {
            "error_x": {
             "color": "#2a3f5f"
            },
            "error_y": {
             "color": "#2a3f5f"
            },
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "bar"
           }
          ],
          "barpolar": [
           {
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "barpolar"
           }
          ],
          "carpet": [
           {
            "aaxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "baxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "type": "carpet"
           }
          ],
          "choropleth": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "choropleth"
           }
          ],
          "contour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "contour"
           }
          ],
          "contourcarpet": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "contourcarpet"
           }
          ],
          "heatmap": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "heatmap"
           }
          ],
          "heatmapgl": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "heatmapgl"
           }
          ],
          "histogram": [
           {
            "marker": {
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "histogram"
           }
          ],
          "histogram2d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2d"
           }
          ],
          "histogram2dcontour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2dcontour"
           }
          ],
          "mesh3d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "mesh3d"
           }
          ],
          "parcoords": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "parcoords"
           }
          ],
          "pie": [
           {
            "automargin": true,
            "type": "pie"
           }
          ],
          "scatter": [
           {
            "fillpattern": {
             "fillmode": "overlay",
             "size": 10,
             "solidity": 0.2
            },
            "type": "scatter"
           }
          ],
          "scatter3d": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatter3d"
           }
          ],
          "scattercarpet": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattercarpet"
           }
          ],
          "scattergeo": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergeo"
           }
          ],
          "scattergl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergl"
           }
          ],
          "scattermapbox": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattermapbox"
           }
          ],
          "scatterpolar": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolar"
           }
          ],
          "scatterpolargl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolargl"
           }
          ],
          "scatterternary": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterternary"
           }
          ],
          "surface": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "surface"
           }
          ],
          "table": [
           {
            "cells": {
             "fill": {
              "color": "#EBF0F8"
             },
             "line": {
              "color": "white"
             }
            },
            "header": {
             "fill": {
              "color": "#C8D4E3"
             },
             "line": {
              "color": "white"
             }
            },
            "type": "table"
           }
          ]
         },
         "layout": {
          "annotationdefaults": {
           "arrowcolor": "#2a3f5f",
           "arrowhead": 0,
           "arrowwidth": 1
          },
          "autotypenumbers": "strict",
          "coloraxis": {
           "colorbar": {
            "outlinewidth": 0,
            "ticks": ""
           }
          },
          "colorscale": {
           "diverging": [
            [
             0,
             "#8e0152"
            ],
            [
             0.1,
             "#c51b7d"
            ],
            [
             0.2,
             "#de77ae"
            ],
            [
             0.3,
             "#f1b6da"
            ],
            [
             0.4,
             "#fde0ef"
            ],
            [
             0.5,
             "#f7f7f7"
            ],
            [
             0.6,
             "#e6f5d0"
            ],
            [
             0.7,
             "#b8e186"
            ],
            [
             0.8,
             "#7fbc41"
            ],
            [
             0.9,
             "#4d9221"
            ],
            [
             1,
             "#276419"
            ]
           ],
           "sequential": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ],
           "sequentialminus": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ]
          },
          "colorway": [
           "#636efa",
           "#EF553B",
           "#00cc96",
           "#ab63fa",
           "#FFA15A",
           "#19d3f3",
           "#FF6692",
           "#B6E880",
           "#FF97FF",
           "#FECB52"
          ],
          "font": {
           "color": "#2a3f5f"
          },
          "geo": {
           "bgcolor": "white",
           "lakecolor": "white",
           "landcolor": "#E5ECF6",
           "showlakes": true,
           "showland": true,
           "subunitcolor": "white"
          },
          "hoverlabel": {
           "align": "left"
          },
          "hovermode": "closest",
          "mapbox": {
           "style": "light"
          },
          "paper_bgcolor": "white",
          "plot_bgcolor": "#E5ECF6",
          "polar": {
           "angularaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "radialaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "scene": {
           "xaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "yaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "zaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           }
          },
          "shapedefaults": {
           "line": {
            "color": "#2a3f5f"
           }
          },
          "ternary": {
           "aaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "baxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "caxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "title": {
           "x": 0.05
          },
          "xaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          },
          "yaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          }
         }
        },
        "title": {
         "text": "Factual distribution"
        },
        "xaxis": {
         "anchor": "y",
         "domain": [
          0,
          1
         ],
         "title": {
          "text": "profession"
         }
        },
        "yaxis": {
         "anchor": "x",
         "domain": [
          0,
          1
         ],
         "title": {
          "text": "probability"
         }
        }
       }
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "factual_df\n",
    "px.bar(factual_df, x='profession', y='probability', title='Factual distribution')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.plotly.v1+json": {
       "config": {
        "plotlyServerURL": "XXXX"
       },
       "data": [
        {
         "alignmentgroup": "True",
         "hovertemplate": "profession=%{x}<br>probability=%{y}<extra></extra>",
         "legendgroup": "",
         "marker": {
          "color": "#636efa",
          "pattern": {
           "shape": ""
          }
         },
         "name": "",
         "offsetgroup": "",
         "orientation": "v",
         "showlegend": false,
         "textposition": "auto",
         "type": "bar",
         "x": [
          "teacher",
          "professor",
          "attorney",
          "surgeon",
          "photographer",
          "painter",
          "psychologist",
          "filmmaker",
          "physician",
          "interior_designer",
          "architect",
          "dietitian",
          "pastor",
          "rapper",
          "journalist",
          "dentist",
          "accountant",
          "poet",
          "model",
          "nurse",
          "chiropractor",
          "yoga_teacher",
          "software_engineer",
          "paralegal",
          "composer",
          "personal_trainer",
          "comedian",
          "dj"
         ],
         "xaxis": "x",
         "y": [
          0.05144008249044418,
          0.06804841011762619,
          0.024368515238165855,
          0.0065572066232562065,
          0.0016494595911353827,
          0.0003638146445155144,
          0.020314287394285202,
          0.0005450155586004257,
          0.16647404432296753,
          0.00024705962277948856,
          0.0004215115332044661,
          0.01821943372488022,
          0.0010925292735919356,
          0.0001673226070124656,
          0.022281045094132423,
          0.016558554023504257,
          0.0024825481232255697,
          0.0010389570379629731,
          0.0033500248100608587,
          0.5713786482810974,
          0.015792207792401314,
          0.0008336161845363677,
          0.0007498269551433623,
          0.0013589647132903337,
          0.00021850863413419574,
          0.003414415055885911,
          0.00038946542190387845,
          0.000244500843109563
         ],
         "yaxis": "y"
        }
       ],
       "layout": {
        "barmode": "relative",
        "legend": {
         "tracegroupgap": 0
        },
        "template": {
         "data": {
          "bar": [
           {
            "error_x": {
             "color": "#2a3f5f"
            },
            "error_y": {
             "color": "#2a3f5f"
            },
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "bar"
           }
          ],
          "barpolar": [
           {
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "barpolar"
           }
          ],
          "carpet": [
           {
            "aaxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "baxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "type": "carpet"
           }
          ],
          "choropleth": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "choropleth"
           }
          ],
          "contour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "contour"
           }
          ],
          "contourcarpet": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "contourcarpet"
           }
          ],
          "heatmap": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "heatmap"
           }
          ],
          "heatmapgl": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "heatmapgl"
           }
          ],
          "histogram": [
           {
            "marker": {
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "histogram"
           }
          ],
          "histogram2d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2d"
           }
          ],
          "histogram2dcontour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2dcontour"
           }
          ],
          "mesh3d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "mesh3d"
           }
          ],
          "parcoords": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "parcoords"
           }
          ],
          "pie": [
           {
            "automargin": true,
            "type": "pie"
           }
          ],
          "scatter": [
           {
            "fillpattern": {
             "fillmode": "overlay",
             "size": 10,
             "solidity": 0.2
            },
            "type": "scatter"
           }
          ],
          "scatter3d": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatter3d"
           }
          ],
          "scattercarpet": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattercarpet"
           }
          ],
          "scattergeo": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergeo"
           }
          ],
          "scattergl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergl"
           }
          ],
          "scattermapbox": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattermapbox"
           }
          ],
          "scatterpolar": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolar"
           }
          ],
          "scatterpolargl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolargl"
           }
          ],
          "scatterternary": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterternary"
           }
          ],
          "surface": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "surface"
           }
          ],
          "table": [
           {
            "cells": {
             "fill": {
              "color": "#EBF0F8"
             },
             "line": {
              "color": "white"
             }
            },
            "header": {
             "fill": {
              "color": "#C8D4E3"
             },
             "line": {
              "color": "white"
             }
            },
            "type": "table"
           }
          ]
         },
         "layout": {
          "annotationdefaults": {
           "arrowcolor": "#2a3f5f",
           "arrowhead": 0,
           "arrowwidth": 1
          },
          "autotypenumbers": "strict",
          "coloraxis": {
           "colorbar": {
            "outlinewidth": 0,
            "ticks": ""
           }
          },
          "colorscale": {
           "diverging": [
            [
             0,
             "#8e0152"
            ],
            [
             0.1,
             "#c51b7d"
            ],
            [
             0.2,
             "#de77ae"
            ],
            [
             0.3,
             "#f1b6da"
            ],
            [
             0.4,
             "#fde0ef"
            ],
            [
             0.5,
             "#f7f7f7"
            ],
            [
             0.6,
             "#e6f5d0"
            ],
            [
             0.7,
             "#b8e186"
            ],
            [
             0.8,
             "#7fbc41"
            ],
            [
             0.9,
             "#4d9221"
            ],
            [
             1,
             "#276419"
            ]
           ],
           "sequential": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ],
           "sequentialminus": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ]
          },
          "colorway": [
           "#636efa",
           "#EF553B",
           "#00cc96",
           "#ab63fa",
           "#FFA15A",
           "#19d3f3",
           "#FF6692",
           "#B6E880",
           "#FF97FF",
           "#FECB52"
          ],
          "font": {
           "color": "#2a3f5f"
          },
          "geo": {
           "bgcolor": "white",
           "lakecolor": "white",
           "landcolor": "#E5ECF6",
           "showlakes": true,
           "showland": true,
           "subunitcolor": "white"
          },
          "hoverlabel": {
           "align": "left"
          },
          "hovermode": "closest",
          "mapbox": {
           "style": "light"
          },
          "paper_bgcolor": "white",
          "plot_bgcolor": "#E5ECF6",
          "polar": {
           "angularaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "radialaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "scene": {
           "xaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "yaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "zaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           }
          },
          "shapedefaults": {
           "line": {
            "color": "#2a3f5f"
           }
          },
          "ternary": {
           "aaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "baxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "caxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "title": {
           "x": 0.05
          },
          "xaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          },
          "yaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          }
         }
        },
        "title": {
         "text": "Factual distribution"
        },
        "xaxis": {
         "anchor": "y",
         "domain": [
          0,
          1
         ],
         "title": {
          "text": "profession"
         }
        },
        "yaxis": {
         "anchor": "x",
         "domain": [
          0,
          1
         ],
         "title": {
          "text": "probability"
         }
        }
       }
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "counterfactual_df\n",
    "px.bar(counterfactual_df, x='profession', y='probability', title='Factual distribution')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# log prob experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "EOS_TOKEN = tokenizer.eos_token_id\n",
    "\n",
    "def get_sample_log_prob(generated_tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:\n",
    "    log_prob = 0\n",
    "    for token, score in zip(generated_tokens, scores):\n",
    "        if token == EOS_TOKEN:\n",
    "            break\n",
    "        log_probs = F.log_softmax(score, dim=-1)\n",
    "        log_prob += log_probs[token]\n",
    "    return log_prob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'i' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[25], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m input_length \u001b[38;5;241m=\u001b[39m inputs\u001b[38;5;241m.\u001b[39minput_ids[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m----> 2\u001b[0m generated_tokens \u001b[38;5;241m=\u001b[39m cf_token_ids[\u001b[43mi\u001b[49m][input_length:]\n\u001b[1;32m      4\u001b[0m get_sample_log_prob(generated_tokens, cf_scores[i])\n",
      "\u001b[0;31mNameError\u001b[0m: name 'i' is not defined"
     ]
    }
   ],
   "source": [
    "input_length = inputs.input_ids[0].shape[0]\n",
    "generated_tokens = cf_token_ids[i][input_length:]\n",
    "\n",
    "get_sample_log_prob(generated_tokens, cf_scores[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_abstraction_log_probs = F.log_softmax(profession_classifier.predict(cf_text).logits, dim=-1)\n",
    "# gamma(Y' | A', x')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:03<00:00, 30.44it/s]\n"
     ]
    }
   ],
   "source": [
    "# this method works for lower temperatures, because the samples are close to the argmax of the distribution \n",
    "# of course, this limits greatly the usefulness of the method (becuase there is no need for it in the first place at temp=0)\n",
    "# similarly, it might kinda work for smaller topk ?  even worse at 50 than at 200\n",
    "\n",
    "input_length = inputs.input_ids[0].shape[0]\n",
    "\n",
    "n_actions = n_samples\n",
    "P_cf_A = [[] for i in range(n_actions)]\n",
    "policy_probs = [get_sample_log_prob(cf_token_ids[i][input_length:], cf_scores[i]) for i in range(n_actions)]\n",
    "\n",
    "\n",
    "for action_id in tqdm.trange(n_actions):\n",
    "    for abstraction_value_index_cf in abstraction_value_index_cf_samples:\n",
    "        generated_tokens = cf_token_ids[action_id][input_length:]\n",
    "        pi_A_given_x = policy_probs[action_id]\n",
    "        gamma_Y_given_A_x = all_abstraction_log_probs[action_id, abstraction_value_index_cf]\n",
    "        # gamma_Y_given_x_ = torch.logsumexp(\n",
    "        #     torch.tensor([all_abstraction_log_probs[a_i, abstraction_value_index_cf] for a_i in range(n_actions)]), dim=0\n",
    "        # ) - torch.log(torch.tensor(n_actions))\n",
    "        gamma_Y_given_x = np.log(abstraction_probs_cf[abstraction_value_index_cf])\n",
    "\n",
    "        # check that these two are close, if not print them\n",
    "        # if abs(gamma_Y_given_x - gamma_Y_given_x_) > 1e-5:\n",
    "            # print(gamma_Y_given_x, gamma_Y_given_x_, gamma_Y_given_x - gamma_Y_given_x_)\n",
    "\n",
    "\n",
    "        # NOTE: this is extremely slow for some reason\n",
    "        # if pi_A_given_x == -np.inf:\n",
    "        #     raise ValueError('pi_A_given_x is -inf')\n",
    "        # if gamma_Y_given_A_x == -np.inf:\n",
    "        #     raise ValueError('gamma_Y_given_A_x is -inf')\n",
    "        # if gamma_Y_given_x == -np.inf:\n",
    "        #     raise ValueError('gamma_Y_given_x is -inf')\n",
    "\n",
    "        P_cf_A[action_id].append(pi_A_given_x + gamma_Y_given_A_x - gamma_Y_given_x)\n",
    "\n",
    "P_cf_A = torch.exp(torch.logsumexp(torch.tensor(P_cf_A), dim=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.1091)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "P_cf_A.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x66d8be990>]"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGdCAYAAAAxCSikAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAujUlEQVR4nO3df3AU933/8df9QBLYlqhhkAwIRBwSwGAECGQRT0gmmoiG1paTUsxQo6oMHrcQQ9QSG8yPyZe6oolhIIZGQ+dLMpmYQJnahLiUhMjGjb/IyEhgl9jGTuIYCjkJ6kGHsQ3obr9/oF3dj907nRC7h3g+Zm6w9j53+tx6zL383vfnsz7DMAwBAABkMb/XEwAAAEiHwAIAALIegQUAAGQ9AgsAAMh6BBYAAJD1CCwAACDrEVgAAEDWI7AAAICsF/R6An0hGo3q7NmzuuOOO+Tz+byeDgAA6AHDMHTx4kUNHz5cfn/qGkq/CCxnz55VcXGx19MAAAC9cPr0aY0cOTLlmH4RWO644w5J1z5wfn6+x7MBAAA9EQ6HVVxcbH2Pp9IvAot5GSg/P5/AAgDATaYn7Rw03QIAgKxHYAEAAFmPwAIAALIegQUAAGQ9AgsAAMh6BBYAAJD1CCwAACDrEVgAAEDWI7AAAICsR2ABAABZj8ACAACyHoEFAABkPQILACCt0x9+rIZXfqeLn171eiq4RfWLuzUDAG6sfzn0W/20+bQKBg7Q/BmjvJ4ObkFUWAAAaYU/7ZQkfdT1J+A2AgsAIK1IxJAkdUYNj2eCWxWBBQCQVsS4FlSiBoEF3iCwAADSinRVVjojBBZ4g8ACAEjLDCwRKizwCIEFAJCWFViiUY9nglsVgQUAkFZ3YPF4IrhlEVgAAGlRYYHXCCwAgLTM3hUqLPAKgQUAkFYnFRZ4jMACAEgryioheIzAAgBIq7vCQmCBNwgsAIC0ogQWeIzAAgBIi6ZbeI3AAgBIi2XN8BqBBQCQVvfW/B5PBLcsAgsAIC0qLPAagQUAkFaEplt4jMACAEiLZc3wGoEFAJBW1CCwwFsEFgBAWp1d65k7CSzwCIEFAJCWmVOibM0PjxBYAABpdXatDupkXTM80qvAsm3bNpWUlCgvL0/l5eVqbm52HPub3/xG3/jGN1RSUiKfz6fNmzdf93sCANxlrmamwgKvZBxYdu/erbq6Oq1bt06tra2aPHmyqqqq1N7ebjv+448/1mc+8xlt2LBBRUVFffKeAAB3WRUWeljgkYwDy6ZNm7R48WLV1tZqwoQJamho0KBBg7Rjxw7b8dOnT9f3vvc9Pfzww8rNze2T9wQAuMcwjO4eFgILPJJRYLly5YpaWlpUWVnZ/QZ+vyorK9XU1NSrCdyI9wQA9J3YpcxUWOCVYCaDz58/r0gkosLCwrjjhYWFeuedd3o1gd685+XLl3X58mXr53A43KvfDQBILxLTt8I+LPDKTblKqL6+XgUFBdajuLjY6ykBQL8Ve/sgAgu8klFgGTp0qAKBgNra2uKOt7W1OTbU3oj3XLlypTo6OqzH6dOne/W7AQDpdcYklgirhOCRjAJLTk6Opk2bpsbGRutYNBpVY2OjKioqejWB3rxnbm6u8vPz4x4AgBuDCguyQUY9LJJUV1enmpoalZWVacaMGdq8ebMuXbqk2tpaSdLChQs1YsQI1dfXS7rWVPvWW29Z/3zmzBkdP35ct99+uz772c/26D0BAN6Jq7AQWOCRjAPLvHnzdO7cOa1du1ahUEilpaU6cOCA1TR76tQp+f3dhZuzZ89qypQp1s/PPPOMnnnmGc2aNUuHDh3q0XsCALwTexmIZc3wis8wbv4LkuFwWAUFBero6ODyEAD0sT92fKKK+pckScPuyFXzU5VpXgH0TCbf3zflKiEAgHtiLwOxNT+8QmABAKTExnHIBgQWAEBKsYGFplt4hcACAEiJwIJsQGABAKTE1vzIBgQWAEBKnRECC7xHYAEApBS7Moit+eEVAgsAIKXYlUGGweZx8AaBBQCQUmJAocoCLxBYAAApJe69Qh8LvEBgAQCklFRhIbDAAwQWAEBKiZeA2O0WXiCwAABSSgwoNN3CCwQWAEBKiQGFCgu8QGABAKSUVGFhlRA8QGABAKRE0y2yAYEFAJASy5qRDQgsAICUEi8BEVjgBQILACCl2JsfSjTdwhsEFgBASon7sNB0Cy8QWAAAKSVeAkqsuABuILAAAFJKDCxUWOAFAgsAIKWkCgs9LPAAgQUAkFJiYGGVELxAYAEApERgQTYgsAAAUkpcJURggRcILACAlKiwIBsQWAAAKSUFFlYJwQMEFgBASskVlqhHM8GtjMACAEgpObB4NBHc0ggsAICUkptuSSxwH4EFAJASFRZkAwILACCl5J1uSSxwH4EFAJAS9xJCNiCwAABS4pIQsgGBBQCQUuLNDmm6hRcILACAlKJUWJAFCCwAgJSosCAbEFgAACklNtlyLyF4gcACAEgpscKS+DPgBgILACClxB4WljXDCwQWAEBKiRvFUWGBFwgsAICUElcFJVZcADcQWAAAKZmXgHKC174yqLDACwQWAEBKZkDJDVz7yqDCAi8QWAAAKZkBhQoLvERgAQCkZDbdmoElwioheIDAAgBIyVwkZAWWCIEF7iOwAABSsiosASos8A6BBQCQkllQsSos9LDAAwQWAEBKkcQeFgILPNCrwLJt2zaVlJQoLy9P5eXlam5uTjl+z549GjdunPLy8jRp0iTt378/7vmPPvpIS5cu1ciRIzVw4EBNmDBBDQ0NvZkaAKCPmRvHmZeE2JofXsg4sOzevVt1dXVat26dWltbNXnyZFVVVam9vd12/OHDhzV//nwtWrRIx44dU3V1taqrq3XixAlrTF1dnQ4cOKCf/OQnevvtt7V8+XItXbpU+/bt6/0nAwD0icQKSydNt/BAxoFl06ZNWrx4sWpra61KyKBBg7Rjxw7b8Vu2bNHs2bO1YsUKjR8/XuvXr9fUqVO1detWa8zhw4dVU1OjL33pSyopKdGjjz6qyZMnp63cAABuPPMSUC7LmuGhjALLlStX1NLSosrKyu438PtVWVmppqYm29c0NTXFjZekqqqquPEzZ87Uvn37dObMGRmGoZdfflnvvvuuvvrVr9q+5+XLlxUOh+MeAIAbwwwsAwL0sMA7GQWW8+fPKxKJqLCwMO54YWGhQqGQ7WtCoVDa8c8++6wmTJigkSNHKicnR7Nnz9a2bdv0xS9+0fY96+vrVVBQYD2Ki4sz+RgAgAxEEu4lRGCBF7JildCzzz6r1157Tfv27VNLS4s2btyoJUuW6Fe/+pXt+JUrV6qjo8N6nD592uUZA8Ctw9woLocKCzwUzGTw0KFDFQgE1NbWFne8ra1NRUVFtq8pKipKOf6TTz7RqlWr9MILL2jOnDmSpHvvvVfHjx/XM888k3Q5SZJyc3OVm5ubydQBAL1EhQXZIKMKS05OjqZNm6bGxkbrWDQaVWNjoyoqKmxfU1FRETdekg4ePGiNv3r1qq5evSq/P34qgUBAUXM/aACAZyKJW/MTWOCBjCos0rUlyDU1NSorK9OMGTO0efNmXbp0SbW1tZKkhQsXasSIEaqvr5ckLVu2TLNmzdLGjRs1Z84c7dq1S0ePHtX27dslSfn5+Zo1a5ZWrFihgQMHavTo0XrllVf04x//WJs2berDjwoA6I2kjeNYJQQPZBxY5s2bp3Pnzmnt2rUKhUIqLS3VgQMHrMbaU6dOxVVLZs6cqZ07d2r16tVatWqVxo4dq71792rixInWmF27dmnlypVasGCBPvzwQ40ePVpPP/20HnvssT74iACA62Eta6aHBR7yGcbNH5XD4bAKCgrU0dGh/Px8r6cDAP3KPWsP6NKViP7hq5/TM798VzPvHqKdi+/zelroBzL5/s6KVUIAgOyV2HTbSYUFHiCwAABSMi8BWfcSIrDAAwQWAEBKVmAJBiRRYYE3CCwAAEeGYcjMJ+YlIe7WDC8QWAAAjmJXBHG3ZniJwAIAcBR7+cfqYaHCAg8QWAAAjmLDSS473cJDBBYAgKPYCssANo6DhwgsAABHUZseFrbmhxcILAAAR5003SJLEFgAAI7MCovfJwX9vmvHqLDAAwQWAIAjs8IS8Pvk9/nijgFuIrAAABxFYgJLMNBVYSGwwAMEFgCAI/PyT8BHhQXeIrAAABzFXhKyelgILPAAgQUA4CgaE1gCfios8A6BBQDgqLvC4rcCC/uwwAsEFgCAo+6mW3UHFios8ACBBQDgyAwnwdgKS9SQQZUFLiOwAAAcmZd//P5rK4VMFFngNgILAMBRXIUl4Es6DriFwAIAcBSJ2Zo/tsJCYIHbCCwAAEd2PSwSK4XgPgILAMCRVWGJ2YdFkiLcsRkuI7AAABx1V1h88ZeEqLDAZQQWAICj2AqL3++TmVnoYYHbCCwAAEedMRUWqbvxlsACtxFYAACOYu/WLInt+eEZAgsAwFHs3Zpj/6TpFm4jsAAAHEWdAgsVFriMwAIAcBTbdCvF3gAx6tmccGsisAAAHEUSmm6DVmDxbEq4RRFYAACOrJsfdjXdmn92UmGBywgsAABHicuazT/JK3AbgQUA4Cix6dbsZaHCArcRWAAAjhKXNVsVFlYJwWUEFgCAI8cKC/uwwGUEFgCAI6cKC/uwwG0EFgCAo8St+f3cSwgeIbAAAByZl34CgcSN4wgscBeBBQDgKJJQYQkSWOARAgsAwJG5BX9i0y2BBW4jsAAAHJlb8LOsGV4jsAAAHJkVFjOodG/NT2CBuwgsAABHZoXFvBQUDHBJCN4gsAAAHJmXfhIrLAQWuI3AAgBwZN4zyM8qIXiMwAIAcGReEjKDCvuwwCsEFgCAI7Pp1p8YWFglBJcRWAAAjhKXNVNhgVcILAAAR4nLmgN+f9dxAgvc1avAsm3bNpWUlCgvL0/l5eVqbm5OOX7Pnj0aN26c8vLyNGnSJO3fvz9pzNtvv60HHnhABQUFuu222zR9+nSdOnWqN9MDAPSRrlsJWU23XauaCSxwXcaBZffu3aqrq9O6devU2tqqyZMnq6qqSu3t7bbjDx8+rPnz52vRokU6duyYqqurVV1drRMnTlhjfve73+n+++/XuHHjdOjQIb355ptas2aN8vLyev/JAADXzaqwBKiwwFsZB5ZNmzZp8eLFqq2t1YQJE9TQ0KBBgwZpx44dtuO3bNmi2bNna8WKFRo/frzWr1+vqVOnauvWrdaYp556Sl/72tf03e9+V1OmTNHdd9+tBx54QMOGDev9JwMAXDczmFgVlq5vDXa6hdsyCixXrlxRS0uLKisru9/A71dlZaWamppsX9PU1BQ3XpKqqqqs8dFoVP/xH/+hz33uc6qqqtKwYcNUXl6uvXv3Os7j8uXLCofDcQ8AQN8zA0tiD0uUwAKXZRRYzp8/r0gkosLCwrjjhYWFCoVCtq8JhUIpx7e3t+ujjz7Shg0bNHv2bP3yl7/UQw89pK9//et65ZVXbN+zvr5eBQUF1qO4uDiTjwEA6CGrwuKnwgJveb5KKNp1ffTBBx/Ut771LZWWlurJJ5/Un/3Zn6mhocH2NStXrlRHR4f1OH36tJtTBoBbRmdihcXH3ZrhjWAmg4cOHapAIKC2tra4421tbSoqKrJ9TVFRUcrxQ4cOVTAY1IQJE+LGjB8/Xq+++qrte+bm5io3NzeTqQMAesEMJoGES0JUWOC2jCosOTk5mjZtmhobG61j0WhUjY2NqqiosH1NRUVF3HhJOnjwoDU+JydH06dP18mTJ+PGvPvuuxo9enQm0wMA9LHOSGJguXacHha4LaMKiyTV1dWppqZGZWVlmjFjhjZv3qxLly6ptrZWkrRw4UKNGDFC9fX1kqRly5Zp1qxZ2rhxo+bMmaNdu3bp6NGj2r59u/WeK1as0Lx58/TFL35RX/7yl3XgwAH9/Oc/16FDh/rmUwIAesWqsPiosMBbGQeWefPm6dy5c1q7dq1CoZBKS0t14MABq7H21KlT8vu7CzczZ87Uzp07tXr1aq1atUpjx47V3r17NXHiRGvMQw89pIaGBtXX1+vxxx/X5z//ef37v/+77r///j74iACA3jKDSWKFhX1Y4DafYdz8nVPhcFgFBQXq6OhQfn6+19MBgH7jwa2v6o3/6dD/rSnTV8YXatPBd/X9xve0sGK0/s+DE9O/AZBCJt/fnq8SAgBkr0hi023XpSEuCcFtBBYAgKPEpltzi36abuE2AgsAwFHismY/FRZ4hMACAHBkNd12BRVzAzkqLHAbgQUA4CiasErI3KKfCgvcRmABADhKXNZsVlgiN/8CU9xkCCwAAEdOFZZIhMACdxFYAACOqLAgWxBYAACOkm5+2NV8y063cBuBBQDgyKywBBMvCRFY4DICCwDAkRlM/AnLmgkscBuBBQDgKGJVWK59XVBhgVcILAAAR1aFpevbggoLvEJgAQA4Sqqw+FglBG8QWAAAjsxgklhhYadbuI3AAgCwFY0aMgspZoUlwL2E4BECCwDAVuxlH3P/lQA9LPAIgQUAYCs2lAQCBBZ4i8ACALAVF1gSKyw03cJlBBYAgK24S0J+KizwFoEFAGAr9o7MBBZ4jcACALAVW2Hpyinc/BCeIbAAAGyZoSTg98nHKiF4jMACALBlBZausCJ1BxY2joPbCCwAAFuxFRaTudNtlFVCcBmBBQBgyy6wmHdr7oxEPZkTbl0EFgCArU6bwGJeHuKKENxGYAEA2DIv+8QFFquHhQoL3EVgAQDY6ow4BxbyCtxGYAEA2LIqLL7kplsqLHAbgQUAYMuuh8Xv7+5hMVgpBBcRWAAAtlIta459HnADgQUAYMu8JBS0qbBI3LEZ7iKwAABsmU23focKC20scBOBBQBgy7bCEtOAS+Mt3ERgAQDYMptu/TarhCQqLHAXgQUAYCvaFViCgeR9WCQqLHAXgQUAYMuuwuLz+WRmFppu4SYCCwDAlrlsOfYykNRdZWFZM9xEYAEA2DIDiZ/AgixAYAEA2IrYrBKSurfqJ7DATQQWAICtSFdTbSAhsPipsMADBBYAgK1I1yKg2KZbqbviQmCBmwgsAABbZoXFsemWVUJwEYEFAGDLqrA4BBZz637ADQQWAIAtxwpL1yWiKBUWuIjAAgCw5bisuWvn2056WOAiAgsAwJZ5xcexwkJggYsILAAAW9ayZp9DDwuBBS4isAAAbJlNt4n7sJg/U2GBm3oVWLZt26aSkhLl5eWpvLxczc3NKcfv2bNH48aNU15eniZNmqT9+/c7jn3sscfk8/m0efPm3kwNANBHnDaOC/ivfXVQYYGbMg4su3fvVl1dndatW6fW1lZNnjxZVVVVam9vtx1/+PBhzZ8/X4sWLdKxY8dUXV2t6upqnThxImnsCy+8oNdee03Dhw/P/JMAAPqUc4Wl63lWCcFFGQeWTZs2afHixaqtrdWECRPU0NCgQYMGaceOHbbjt2zZotmzZ2vFihUaP3681q9fr6lTp2rr1q1x486cOaNvfvObeu655zRgwIDefRoAQJ9JV2HhkhDclFFguXLlilpaWlRZWdn9Bn6/Kisr1dTUZPuapqamuPGSVFVVFTc+Go3qkUce0YoVK3TPPfekncfly5cVDofjHgCAvmVWUJICS9ePXBKCmzIKLOfPn1ckElFhYWHc8cLCQoVCIdvXhEKhtOP/+Z//WcFgUI8//niP5lFfX6+CggLrUVxcnMnHAAD0gBlIElcJBamwwAOerxJqaWnRli1b9KMf/Ui+hP8onKxcuVIdHR3W4/Tp0zd4lgBw6zEDiblRnKkrr1BhgasyCixDhw5VIBBQW1tb3PG2tjYVFRXZvqaoqCjl+F//+tdqb2/XqFGjFAwGFQwG9cEHH+jv//7vVVJSYvueubm5ys/Pj3sAAPpW2goLTbdwUUaBJScnR9OmTVNjY6N1LBqNqrGxURUVFbavqaioiBsvSQcPHrTGP/LII3rzzTd1/Phx6zF8+HCtWLFCv/jFLzL9PACAPmJWWBJ3uvVz80N4IJjpC+rq6lRTU6OysjLNmDFDmzdv1qVLl1RbWytJWrhwoUaMGKH6+npJ0rJlyzRr1ixt3LhRc+bM0a5du3T06FFt375dkjRkyBANGTIk7ncMGDBARUVF+vznP3+9nw8A0EudTvcS6vqRZc1wU8aBZd68eTp37pzWrl2rUCik0tJSHThwwGqsPXXqlPz+7sLNzJkztXPnTq1evVqrVq3S2LFjtXfvXk2cOLHvPgUAoM+Zl3ySt+a/9nd8hB4WuCjjwCJJS5cu1dKlS22fO3ToUNKxuXPnau7cuT1+/z/84Q+9mRYAoA+Zl3wSm26tjeMILHCR56uEAADZKeJQYQlSYYEHCCwAAFtmIEncOM7saSGwwE0EFgCALafAEiSwwAMEFgCALbPpNmlZc9clIlYJwU0EFgCALbPpNnFZMxUWeIHAAgCw5VhhIbDAAwQWAIAta+O4pFVCvrjnATcQWAAAtswKSjBpH5ZrP3O3ZriJwAIAsBVxqLCYgYWmW7iJwAIAsGVVWPzxXxUBeljgAQILAMBW9z4s8ccJLPACgQUAYMvamj+xwuIjsMB9BBYAgC0qLMgmBBYAgK3uwGLfw8KyZriJwAIAsGUFFodVQixrhpsILAAAW043P6TCAi8QWAAAthwDS1fFJco+LHARgQUAYKt7lVD8cSos8AKBBQBgK13TLT0scBOBBQBgK13TbWc06vqccOsisAAAbKVruo2QV+AiAgsAwFb6wEJigXsILAAAW91Nt/arhCK0sMBFBBYAgK1IxD6wBANUWOA+AgsAwJZZYQkmBBY/Nz+EBwgsAABb5j4r/sQKCzc/hAcILAAAW+Y+K0kVFgILPEBgAQDYsiosPocKC3kFLiKwAACSxO5i61xhoekW7iGwAACSxN4nKLGHxVrWTF6BiwgsAIAksXdiTqywBKmwwAMEFgBAktgKS+I+LDTdwgsEFgBAkkiKwMKyZniBwAIASBIXWHwOFRaDwAL3EFgAAEnMwOLzpdg4jnXNcBGBBQCQxGy6TayuSDFb81NhgYsILACAJE7b8kuxNz8ksMA9BBYAQBKnbfml2H1YCCxwD4EFAJDErLDYXRIyVw11EljgIgILACCJWT0JBJwDS5TAAhcRWAAASSJUWJBlCCwAgCRWYLHrYTErLKwSgosILACAJD0JLFRY4CYCCwAgibnHim1g6bpMZBiSQZUFLiGwAACSmHditgssQb8/ZhyBBe4gsAAAkkSu5RXbwBKTV7gsBNcQWAAASTrNCkuKVUISjbdwD4EFAJAkmqLCEnuMCgvcQmABACTpTNHDElt1YfM4uIXAAgBIYl7qsb2XEBUWeIDAAgBI0hlxvluzz+eTeZgKC9zSq8Cybds2lZSUKC8vT+Xl5Wpubk45fs+ePRo3bpzy8vI0adIk7d+/33ru6tWreuKJJzRp0iTddtttGj58uBYuXKizZ8/2ZmoAgD6QqsJy7fi1rw8qLHBLxoFl9+7dqqur07p169Ta2qrJkyerqqpK7e3ttuMPHz6s+fPna9GiRTp27Jiqq6tVXV2tEydOSJI+/vhjtba2as2aNWptbdXzzz+vkydP6oEHHri+TwYA6DVzWbPfZpWQ1L20mX1Y4BafkeE2heXl5Zo+fbq2bt0qSYpGoyouLtY3v/lNPfnkk0nj582bp0uXLunFF1+0jt13330qLS1VQ0OD7e94/fXXNWPGDH3wwQcaNWpU2jmFw2EVFBSoo6ND+fn5mXwcAICNnx0/o2W7jmvm3UO0c/F9Sc9PXPcLfXS5U4f+4UsqGXqbBzNEf5DJ93dGFZYrV66opaVFlZWV3W/g96uyslJNTU22r2lqaoobL0lVVVWO4yWpo6NDPp9PgwcPtn3+8uXLCofDcQ8AQN+JptiaX5LVwxJhHxa4JKPAcv78eUUiERUWFsYdLywsVCgUsn1NKBTKaPynn36qJ554QvPnz3dMW/X19SooKLAexcXFmXwMAEAaZtOtU2AJBq59fXBJCG7JqlVCV69e1V/+5V/KMAz94Ac/cBy3cuVKdXR0WI/Tp0+7OEsA6P/SNd2avS0EFrglmMngoUOHKhAIqK2tLe54W1ubioqKbF9TVFTUo/FmWPnggw/00ksvpbyWlZubq9zc3EymDgDIgLn6x6np1gwyBBa4JaMKS05OjqZNm6bGxkbrWDQaVWNjoyoqKmxfU1FRETdekg4ePBg33gwr7733nn71q19pyJAhmUwLANDHzP1VggH7wBIgsMBlGVVYJKmurk41NTUqKyvTjBkztHnzZl26dEm1tbWSpIULF2rEiBGqr6+XJC1btkyzZs3Sxo0bNWfOHO3atUtHjx7V9u3bJV0LK3/xF3+h1tZWvfjii4pEIlZ/y5133qmcnJy++qwAgB5KV2ExAwv7sMAtGQeWefPm6dy5c1q7dq1CoZBKS0t14MABq7H21KlT8sfce3zmzJnauXOnVq9erVWrVmns2LHau3evJk6cKEk6c+aM9u3bJ0kqLS2N+10vv/yyvvSlL/XyowEAesusnDj1sJiBhbs1wy0ZBxZJWrp0qZYuXWr73KFDh5KOzZ07V3PnzrUdX1JSogy3ggEA3GBmYLHbml+KWdZMhQUuyapVQgCA7BDp4db8BBa4hcACAEgSSbMPi5+mW7iMwAIASBJJs9Mty5rhNgILACCJGUQCjjc/JLDAXQQWAEASK7D47b8mgixrhssILACAJN2XhOyfNysvLGuGWwgsAIAk3U239l8TbBwHtxFYAABJ0lZYzI3jCCxwCYEFAJAkXQ8LFRa4jcACAEiSbpUQFRa4jcACAEjSXWGxf54KC9xGYAEAJEl7Sair8hJhlRBcQmABACRJW2EJdAWWSNStKeEWR2ABACTpXiWUrsLi2pRwiyOwAACSdFpNt/bPd99LiAoL3EFgAQAkMVf/BByuCXXfS8i1KeEWR2ABACTpTLes2UeFBe4isAAAkpgVFvPSTyKr6Za8ApcQWAAAScwKi98psLCsGS4jsAAAkph3YXassNB0C5cRWAAASTojaSosNN3CZQQWAECSSJoKC8ua4TYCCwAgidl063dYJcSyZriNwAIASNKZZpUQFRa4jcACAEgStbbmd6iwsEoILiOwAACSmE23ToGlu8JCYIE7CCwAgCRpKywEFriMwAIASGJtzZ+mwtJJYIFLCCwAgCTRNIHFPB4lsMAlBBYAQJLONMuaA1RY4DICCwAgSSTdzQ/NCgurhOASAgsAIEkkzSUhs/JiriYCbjQCCwAgSSTNKqEgFRa4jMACAEiStsJCDwtcRmABACRJF1jYOA5uI7AAAJJYgSXNKiEuCcEtBBYAQJJ0FRZrWTNNt3AJgQUAkCRd061ZeaHCArcQWAAASXq6DwtNt3ALgQUAEMcwDCuw+NmaH1mCwAIAiBObQaiwIFsQWAAAcWKXKqersLCsGW4hsAAA4sSGkHQVFgIL3EJgAQDEicSs/HG8W3PX8QirhOASAgsAIE4kkr7CEgxQYYG7CCwAgDixVZN0d2smsMAtBBYAQJzOaFSS5PNJPodLQkH/ta8PAgvcQmABAMTpyiuOl4MkqSuvEFjgGgILACCOWWFxariVWCUE9xFYAABxelJhMZ9jlRDc0qvAsm3bNpWUlCgvL0/l5eVqbm5OOX7Pnj0aN26c8vLyNGnSJO3fvz/uecMwtHbtWt11110aOHCgKisr9d577/VmagCA62RVWFJdEjKbbrlbM1yScWDZvXu36urqtG7dOrW2tmry5MmqqqpSe3u77fjDhw9r/vz5WrRokY4dO6bq6mpVV1frxIkT1pjvfve7+v73v6+GhgYdOXJEt912m6qqqvTpp5/2/pMBAHrFvANz6gpLV9MtFRa4JOPAsmnTJi1evFi1tbWaMGGCGhoaNGjQIO3YscN2/JYtWzR79mytWLFC48eP1/r16zV16lRt3bpV0rXqyubNm7V69Wo9+OCDuvfee/XjH/9YZ8+e1d69e6/rwwEAMmfeH8hpSbPU3XTLvYTglmAmg69cuaKWlhatXLnSOub3+1VZWammpibb1zQ1Namuri7uWFVVlRVG3n//fYVCIVVWVlrPFxQUqLy8XE1NTXr44YeT3vPy5cu6fPmy9XM4HM7kY/RYZySqp/e/fUPeG7gZmP/zbBiGEr+WfOpe8pqiNxN9yPz3ETUMGUZ3JcTnu3aJxt9H/z4+vHRFUurAYlZYrkai+s7Pf3N9vxA3haDfp6fmTPDu92cy+Pz584pEIiosLIw7XlhYqHfeecf2NaFQyHZ8KBSynjePOY1JVF9fr+985zuZTL1Xoob0w//3hxv+ewAgG+XnDXB8blBuQAG/T5Gowd+Tt4icoP/mCSzZYuXKlXFVm3A4rOLi4j7/PX6ftOTLd/f5+wI3E5981zYQk7r/172r4mIYUnLtBTeS3+ezqlsx/zq6/n0YVtXlevnk01fvKXR8Pj9vgH6wYKre+J8LffL7kP0Cfm8XFmcUWIYOHapAIKC2tra4421tbSoqKrJ9TVFRUcrx5p9tbW2666674saUlpbavmdubq5yc3MzmXqvBAN+ragad8N/DwDcjL56T5G+eo/93/1AX8soLuXk5GjatGlqbGy0jkWjUTU2NqqiosL2NRUVFXHjJengwYPW+DFjxqioqChuTDgc1pEjRxzfEwAA3FoyviRUV1enmpoalZWVacaMGdq8ebMuXbqk2tpaSdLChQs1YsQI1dfXS5KWLVumWbNmaePGjZozZ4527dqlo0ePavv27ZKulTWXL1+uf/zHf9TYsWM1ZswYrVmzRsOHD1d1dXXffVIAAHDTyjiwzJs3T+fOndPatWsVCoVUWlqqAwcOWE2zp06dkj/mOtfMmTO1c+dOrV69WqtWrdLYsWO1d+9eTZw40Rrz7W9/W5cuXdKjjz6qCxcu6P7779eBAweUl5fXBx8RAADc7HyGcfPv+hMOh1VQUKCOjg7l5+d7PR0AANADmXx/cy8hAACQ9QgsAAAg6xFYAABA1iOwAACArEdgAQAAWY/AAgAAsh6BBQAAZD0CCwAAyHoEFgAAkPUy3po/G5mb9YbDYY9nAgAAesr83u7Jpvv9IrBcvHhRklRcXOzxTAAAQKYuXryogoKClGP6xb2EotGozp49qzvuuEM+n69P3zscDqu4uFinT5/mPkU3GOfaPZxr93Cu3cO5dk9fnWvDMHTx4kUNHz487sbJdvpFhcXv92vkyJE39Hfk5+fzH4BLONfu4Vy7h3PtHs61e/riXKerrJhougUAAFmPwAIAALIegSWN3NxcrVu3Trm5uV5Ppd/jXLuHc+0ezrV7ONfu8eJc94umWwAA0L9RYQEAAFmPwAIAALIegQUAAGQ9AgsAAMh6BJY0tm3bppKSEuXl5am8vFzNzc1eT+mmVl9fr+nTp+uOO+7QsGHDVF1drZMnT8aN+fTTT7VkyRINGTJEt99+u77xjW+ora3Noxn3Hxs2bJDP59Py5cutY5zrvnPmzBn91V/9lYYMGaKBAwdq0qRJOnr0qPW8YRhau3at7rrrLg0cOFCVlZV67733PJzxzSsSiWjNmjUaM2aMBg4cqLvvvlvr16+Pux8N57t3/uu//kt//ud/ruHDh8vn82nv3r1xz/fkvH744YdasGCB8vPzNXjwYC1atEgfffTR9U/OgKNdu3YZOTk5xo4dO4zf/OY3xuLFi43BgwcbbW1tXk/tplVVVWX88Ic/NE6cOGEcP37c+NrXvmaMGjXK+Oijj6wxjz32mFFcXGw0NjYaR48eNe677z5j5syZHs765tfc3GyUlJQY9957r7Fs2TLrOOe6b3z44YfG6NGjjb/+6782jhw5Yvz+9783fvGLXxi//e1vrTEbNmwwCgoKjL179xpvvPGG8cADDxhjxowxPvnkEw9nfnN6+umnjSFDhhgvvvii8f777xt79uwxbr/9dmPLli3WGM537+zfv9946qmnjOeff96QZLzwwgtxz/fkvM6ePduYPHmy8dprrxm//vWvjc9+9rPG/Pnzr3tuBJYUZsyYYSxZssT6ORKJGMOHDzfq6+s9nFX/0t7ebkgyXnnlFcMwDOPChQvGgAEDjD179lhj3n77bUOS0dTU5NU0b2oXL140xo4daxw8eNCYNWuWFVg4133niSeeMO6//37H56PRqFFUVGR873vfs45duHDByM3NNX7605+6McV+Zc6cOcbf/M3fxB37+te/bixYsMAwDM53X0kMLD05r2+99ZYhyXj99detMf/5n/9p+Hw+48yZM9c1Hy4JObhy5YpaWlpUWVlpHfP7/aqsrFRTU5OHM+tfOjo6JEl33nmnJKmlpUVXr16NO+/jxo3TqFGjOO+9tGTJEs2ZMyfunEqc6760b98+lZWVae7cuRo2bJimTJmif/3Xf7Wef//99xUKheLOdUFBgcrLyznXvTBz5kw1Njbq3XfflSS98cYbevXVV/Wnf/qnkjjfN0pPzmtTU5MGDx6ssrIya0xlZaX8fr+OHDlyXb+/X9z88EY4f/68IpGICgsL444XFhbqnXfe8WhW/Us0GtXy5cv1hS98QRMnTpQkhUIh5eTkaPDgwXFjCwsLFQqFPJjlzW3Xrl1qbW3V66+/nvQc57rv/P73v9cPfvAD1dXVadWqVXr99df1+OOPKycnRzU1Ndb5tPv7hHOduSeffFLhcFjjxo1TIBBQJBLR008/rQULFkgS5/sG6cl5DYVCGjZsWNzzwWBQd95553WfewILPLNkyRKdOHFCr776qtdT6ZdOnz6tZcuW6eDBg8rLy/N6Ov1aNBpVWVmZ/umf/kmSNGXKFJ04cUINDQ2qqanxeHb9z7/927/pueee086dO3XPPffo+PHjWr58uYYPH8757se4JORg6NChCgQCSSsm2traVFRU5NGs+o+lS5fqxRdf1Msvv6yRI0dax4uKinTlyhVduHAhbjznPXMtLS1qb2/X1KlTFQwGFQwG9corr+j73/++gsGgCgsLOdd95K677tKECRPijo0fP16nTp2SJOt88vdJ31ixYoWefPJJPfzww5o0aZIeeeQRfetb31J9fb0kzveN0pPzWlRUpPb29rjnOzs79eGHH173uSewOMjJydG0adPU2NhoHYtGo2psbFRFRYWHM7u5GYahpUuX6oUXXtBLL72kMWPGxD0/bdo0DRgwIO68nzx5UqdOneK8Z+grX/mK/vu//1vHjx+3HmVlZVqwYIH1z5zrvvGFL3whaXn+u+++q9GjR0uSxowZo6KiorhzHQ6HdeTIEc51L3z88cfy++O/vgKBgKLRqCTO943Sk/NaUVGhCxcuqKWlxRrz0ksvKRqNqry8/PomcF0tu/3crl27jNzcXONHP/qR8dZbbxmPPvqoMXjwYCMUCnk9tZvW3/7t3xoFBQXGoUOHjD/+8Y/W4+OPP7bGPPbYY8aoUaOMl156yTh69KhRUVFhVFRUeDjr/iN2lZBhcK77SnNzsxEMBo2nn37aeO+994znnnvOGDRokPGTn/zEGrNhwwZj8ODBxs9+9jPjzTffNB588EGW2fZSTU2NMWLECGtZ8/PPP28MHTrU+Pa3v22N4Xz3zsWLF41jx44Zx44dMyQZmzZtMo4dO2Z88MEHhmH07LzOnj3bmDJlinHkyBHj1VdfNcaOHcuyZjc8++yzxqhRo4ycnBxjxowZxmuvveb1lG5qkmwfP/zhD60xn3zyifF3f/d3xp/8yZ8YgwYNMh566CHjj3/8o3eT7kcSAwvnuu/8/Oc/NyZOnGjk5uYa48aNM7Zv3x73fDQaNdasWWMUFhYaubm5xle+8hXj5MmTHs325hYOh41ly5YZo0aNMvLy8ozPfOYzxlNPPWVcvnzZGsP57p2XX37Z9u/ompoawzB6dl7/93//15g/f75x++23G/n5+UZtba1x8eLF656bzzBitgYEAADIQvSwAACArEdgAQAAWY/AAgAAsh6BBQAAZD0CCwAAyHoEFgAAkPUILAAAIOsRWAAAQNYjsAAAgKxHYAEAAFmPwAIAALIegQUAAGS9/w8OkKGPRGB7rAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.plot(P_cf_A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ask nicola what to do with this numerical issue and show him some code snippets of various implementations  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: \n",
    "# check for further tricks and numerical stability issues\n",
    "# re-read the code\n",
    "# organize and commit the code, then ask nicola\n",
    "# design experiment over large dataset sample and run it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
