{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c2e730d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.graph_objects as go\n",
    "from run import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0a037bcf",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = torch.load(open(\"args.pkl\", \"rb\"))\n",
    "res = torch.load(open(\"results.pkl\", \"rb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "844d9f8d-9f6c-45b8-90f4-3c8f59a9352b",
   "metadata": {},
   "outputs": [],
   "source": [
    "args[:,0] = args[:,0].round(decimals=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "baa48f24",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, arr in enumerate((\n",
    "    pcts,resols,reps,clust_sizes,thetas,pflips\n",
    ")):\n",
    "    fig = go.Figure()\n",
    "    for val in arr:\n",
    "        nsims = res.shape[0] // len(arr)\n",
    "        x = val*torch.ones((nsims,))\n",
    "        if i == 0:\n",
    "            x *= 80.\n",
    "        fig.add_trace(go.Box(\n",
    "            x=x, y=res[args[:,i] == val], notched=True,\n",
    "            line=dict(color=\"blue\"), showlegend=False\n",
    "        ))\n",
    "    fig.add_hline(y=0.)\n",
    "    tickvals = arr.clone()\n",
    "    if i == 0:\n",
    "        tickvals *= 80.\n",
    "    fig.update_layout(\n",
    "        plot_bgcolor=\"white\", xaxis=dict(\n",
    "            title=dict(text=[\n",
    "                r\"$m_s$\",\"Covariate resolution\",\n",
    "                r\"$T$\",r\"$m_t$\",\n",
    "                r\"$\\boldsymbol{\\theta}^{\\star}$\",\"Amount of proxy contamination\"\n",
    "            ][i], font=dict(size=20)),\n",
    "            tickvals = tickvals\n",
    "        ), yaxis=dict(\n",
    "            title=dict(text=r\"$\\mathrm{IG}^{\\mathcal{R}}\\left( \\boldsymbol{\\theta}^{\\star} \\right) - \\mathrm{IG}^c\\left( \\boldsymbol{\\theta}^{\\star} \\right)$\",\n",
    "                      font=dict(size=20))\n",
    "        )\n",
    "    )\n",
    "    fig.show(renderer=\"browser\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pibt",
   "language": "python",
   "name": "pibt"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
