{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8c30ce04-7328-41d7-93d3-a36b03f9eba8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x1090d0530>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import plotly.express as px\n",
    "colors = px.colors.qualitative.Plotly\n",
    "import plotly.graph_objects as go\n",
    "import torch\n",
    "from torch.distributions import Binomial as Binomial\n",
    "from torch.distributions import Normal as Normal\n",
    "torch.set_default_dtype(torch.double)\n",
    "torch.manual_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "61245c9b-8930-43da-9ae9-765c4cbf9264",
   "metadata": {},
   "outputs": [],
   "source": [
    "sigma = 1.\n",
    "n = 100\n",
    "\n",
    "Mu = torch.zeros(2,)\n",
    "Cov = 4.*torch.diag(torch.ones(2,))\n",
    "Cov_psi = torch.zeros_like(Cov)\n",
    "Cov_psi[0,0] += Cov[0,0] - Cov[0,1]**2.*Cov[1,1]**-1.\n",
    "Cov_theta = torch.zeros_like(Cov)\n",
    "Cov_theta[1,1] += Cov[1,1] - Cov[0,1]**2.*Cov[0,0]**-1.\n",
    "\n",
    "tildef = torch.tensor((-1., -4.))\n",
    "\n",
    "Theta = torch.normal(mean=Mu[None,:].repeat((10000,1)), std=torch.diag(Cov)[None,:].repeat((10000,1))**.5)\n",
    "\n",
    "def zeta(x, y, f, Mu_theta=0., Cov_theta=Cov_psi):\n",
    "    max_logp = lambda x: Normal(loc=0., scale=(sigma**2. + torch.diag(x@Cov_theta@x.T))**.5).log_prob(torch.tensor(0.))\n",
    "    logp_psi = lambda x,y,f: Normal(loc=Mu_theta*x[None,:,0] + f[:,None,1]*x[None,:,1], scale=(sigma**2. + torch.diag(x@Cov_theta@x.T))**.5).log_prob(y)\n",
    "    return (logp_psi(x,y,f) - max_logp(x)).exp()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "926ab62c-599f-4a90-a34a-8b42a004ffcf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_robust_eta(eta, X, Y, Theta, L_psi, reps=3):\n",
    "    for rep in range(reps):\n",
    "        L = Normal(loc=Theta@X.T, scale=sigma).log_prob(Y) * eta\n",
    "        L = L.sum(dim=1)\n",
    "        L += L_psi\n",
    "        L -= L[:-1].exp().mean().log()\n",
    "        wghts = L[:-1].exp() / L[:-1].exp().sum()\n",
    "        Mu1 = Theta[:-1,0]@wghts\n",
    "        Cov1 = torch.zeros((2,2))\n",
    "        Cov1[0,0] = ((Theta[:-1,0] - Mu1)**2.)@wghts\n",
    "        eta = zeta(X,Y,Theta,Mu1,Cov1)\n",
    "    return eta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3ec614e7-bdeb-4b83-9b51-7c3913111948",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_source_data(sep=1., target_clust_size=25, tildef=tildef):\n",
    "    z = torch.normal(sep*torch.ones((n,)), .25)\n",
    "    X0 = torch.zeros((n,2))\n",
    "    X0[:,1] += torch.normal(z, .25)\n",
    "    X0[:,0] = torch.normal(-sep**2./z, .25)\n",
    "    psi = tildef[None,:]*torch.ones_like(X0)\n",
    "    psi[:(n-target_clust_size),1] *= 0.\n",
    "    psi[:,1] += (torch.rand(n) * 4.) - 2.\n",
    "    Y0 = (X0*psi).sum(axis=1) + torch.normal(torch.zeros(n,), sigma*torch.ones(n,))\n",
    "    return (X0, Y0), psi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a0de3ba9-8b93-4aa0-a792-d556e894211a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ig_theta(D0, Z, robust=True, reps=3, tildef=tildef):\n",
    "    X, Y = D0\n",
    "    Theta_ = Theta.clone()\n",
    "    Theta_[:,0] = tildef[0]\n",
    "    eta = zeta(X,Y,Theta)\n",
    "    \n",
    "    I = ~torch.isnan(Z)\n",
    "    L_psi = Binomial(N, eta[:,I]).log_prob(Z[None,I])\n",
    "    L_psi = L_psi.sum(dim=1)\n",
    "    L_psi -= L_psi.exp().mean().log()\n",
    "    \n",
    "    X, Y, eta = X[~I,:], Y[~I], eta[:,~I]\n",
    "    \n",
    "    if robust:\n",
    "        eta = get_robust_eta(eta, X, Y, Theta, L_psi, reps=reps)\n",
    "        L = Normal(loc=Theta@X.T, scale=sigma).log_prob(Y) * eta\n",
    "        L_ = Normal(loc=Theta_@X.T, scale=sigma).log_prob(Y) * eta\n",
    "        L = L.sum(dim=1)\n",
    "        L_ = L_.sum(dim=1)\n",
    "        wghts = L_psi / L_psi[:-1].sum()\n",
    "        L_ = (wghts@L_.exp()).log()\n",
    "        L = L[:-1].exp().mean().log()\n",
    "    else:\n",
    "        L = Normal(loc=Theta[:,None,0]*X[None,:,0], scale=(sigma**2. + torch.diag(X@Cov_theta@X.T))**.5).log_prob(Y)\n",
    "        L_ = Normal(loc=Theta_[:,None,0]*X[None,:,0], scale=(sigma**2. + torch.diag(X@Cov_theta@X.T))**.5).log_prob(Y)\n",
    "        L = L.sum(dim=1)\n",
    "        L_ = L_.sum(dim=1)\n",
    "        L = L.exp().mean().log()\n",
    "        L_ = L_.exp().mean().log()\n",
    "        \n",
    "    return L_ - L"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c415d9af-87a6-4951-837c-b75d471e9b10",
   "metadata": {},
   "outputs": [],
   "source": [
    "def simulation(sep, target_clust_size, reps=3, pflip=.5):\n",
    "    tildef_ = tildef.clone()\n",
    "    tildef_[1] = torch.normal(mean=Mu[1], std=Cov[1,1]**.5)\n",
    "    D0, F = generate_source_data(sep=sep, target_clust_size=target_clust_size, tildef=tildef_)\n",
    "    X, Y = D0\n",
    "    Eta_star = zeta(X,Y,tildef_[None,:])[0,:]\n",
    "    eta_mask = torch.bernoulli(pflip*torch.ones_like(Eta_star)).bool()\n",
    "    Eta_star[eta_mask] = 1. - Eta_star[eta_mask]\n",
    "    Z = Binomial(N, Eta_star).sample()\n",
    "    Z[25:] = torch.nan\n",
    "    ig_robust = ig_theta((X,Y), Z=Z, reps=reps, tildef=tildef_)\n",
    "    ig = ig_theta((X,Y), Z=Z, robust=False, tildef=tildef_)\n",
    "    return ig_robust - ig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "70e1ce5c-b4ec-45e8-9033-3b319fbba6f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "nsims = 50\n",
    "\n",
    "N = 7\n",
    "\n",
    "reps = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e1a3077d-04e3-4dbe-9cd4-1a9ff0a738d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "rhos = (0.,1.,2.)\n",
    "pcts = (0,25,50,75,100)\n",
    "pflips = (0.,.25,.5,.75)\n",
    "res = torch.empty((3,5,4,nsims))\n",
    "\n",
    "for i,rho in enumerate(rhos):\n",
    "    for j,pct in enumerate(pcts):\n",
    "        for k,pflip in enumerate(pflips):\n",
    "            res[i,j,k,:] = torch.tensor([ simulation(rho, pct, reps=reps, pflip=pflip) for _ in range(nsims) ])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64346dd6-ab3e-4091-a8f6-584623fa52f9",
   "metadata": {},
   "source": [
    "# Figure 2(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "92cc9cfd-7444-4b6a-b7cc-2d2edf1faf08",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = go.Figure()\n",
    "for i, rho in enumerate((0., 1., 2.)):\n",
    "    fig.add_trace(go.Box(\n",
    "        x=i*2.*torch.ones((nsims,)), y=res[i,0,0,:].ravel(), fillcolor=colors[0],\n",
    "        line=dict(color=colors[0]), notched=True, name=\"None\", showlegend=i==0.\n",
    "    ))\n",
    "    fig.add_trace(go.Box(\n",
    "        x=((i*2.)+.25)*torch.ones((nsims,)), y=res[i,1,0,:].ravel(), fillcolor=colors[1],\n",
    "        line=dict(color=colors[1]), notched=True, name=\"25% source tasks like target task\",\n",
    "        showlegend=i==0.\n",
    "    ))\n",
    "    fig.add_trace(go.Box(\n",
    "        x=((i*2.)+.5)*torch.ones((nsims,)), y=res[i,2,0,:].ravel(), fillcolor=colors[2],\n",
    "        line=dict(color=colors[2]), notched=True, name=\"50% source tasks like target task\",\n",
    "        showlegend=i==0.\n",
    "    ))\n",
    "    fig.add_trace(go.Box(\n",
    "        x=((i*2.)+.75)*torch.ones((nsims,)), y=res[i,3,0,:].ravel(), fillcolor=colors[3],\n",
    "        line=dict(color=colors[3]), notched=True, name=\"75% source tasks like target task\",\n",
    "        showlegend=i==0.\n",
    "    ))\n",
    "    fig.add_trace(go.Box(\n",
    "        x=((i*2.)+1.)*torch.ones((nsims,)), y=res[i,4,0,:].ravel(), fillcolor=colors[4],\n",
    "        line=dict(color=colors[4]), notched=True, name=\"All source tasks like target task\",\n",
    "        showlegend=i==0.\n",
    "    ))\n",
    "\n",
    "fig.add_hline(y=0.)\n",
    "fig.update_layout(\n",
    "    legend=dict(font=dict(size=15), x=.99, xanchor=\"right\", y=.99, yanchor=\"top\"), plot_bgcolor=\"white\",\n",
    "    xaxis=dict(title=dict(text=\"Degree of multicollinearity\", font=dict(size=20)), ticktext=[\"None\",\"Mild\",\"Extreme\"], tickvals=[.5,2.5,4.5]),\n",
    "    yaxis=dict(title=dict(text=r\"$\\mathrm{IG}^{\\mathcal{R}}\\left( \\boldsymbol{\\theta}^{\\star} \\right) - \\mathrm{IG}^c\\left( \\boldsymbol{\\theta}^{\\star} \\right)$\", font=dict(size=20)))\n",
    ")\n",
    "fig.show(renderer=\"browser\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0468488-1931-4acf-b3ba-a481424688d7",
   "metadata": {},
   "source": [
    "# Figure 2(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "029be2b9-baa4-4ae1-9dc7-d01ad7778fee",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = go.Figure()\n",
    "for i, pflip in enumerate(pflips):\n",
    "    fig.add_trace(go.Box(\n",
    "        x=i*2.*torch.ones((nsims,)), y=res[-1,0,i,:].ravel(), fillcolor=colors[0],\n",
    "        line=dict(color=colors[0]), notched=True, name=\"None\", showlegend=i==0.\n",
    "    ))\n",
    "    fig.add_trace(go.Box(\n",
    "        x=((i*2.)+.25)*torch.ones((nsims,)), y=res[-1,1,i,:].ravel(), fillcolor=colors[1],\n",
    "        line=dict(color=colors[1]), notched=True, name=\"25% source tasks like target task\",\n",
    "        showlegend=i==0.\n",
    "    ))\n",
    "    fig.add_trace(go.Box(\n",
    "        x=((i*2.)+.5)*torch.ones((nsims,)), y=res[-1,2,i,:].ravel(), fillcolor=colors[2],\n",
    "        line=dict(color=colors[2]), notched=True, name=\"50% source tasks like target task\",\n",
    "        showlegend=i==0.\n",
    "    ))\n",
    "    fig.add_trace(go.Box(\n",
    "        x=((i*2.)+.75)*torch.ones((nsims,)), y=res[-1,3,i,:].ravel(), fillcolor=colors[3],\n",
    "        line=dict(color=colors[3]), notched=True, name=\"75% source tasks like target task\",\n",
    "        showlegend=i==0.\n",
    "    ))\n",
    "    fig.add_trace(go.Box(\n",
    "        x=((i*2.)+1.)*torch.ones((nsims,)), y=res[-1,4,i,:].ravel(), fillcolor=colors[4],\n",
    "        line=dict(color=colors[4]), notched=True, name=\"All source tasks like target task\",\n",
    "        showlegend=i==0.\n",
    "    ))\n",
    "\n",
    "fig.add_hline(y=0.)\n",
    "fig.update_layout(\n",
    "    legend=dict(font=dict(size=15), x=.99, xanchor=\"right\", y=.99, yanchor=\"top\"), plot_bgcolor=\"white\",\n",
    "    xaxis=dict(title=dict(text=\"Degree of proxy contamination\", font=dict(size=20)), ticktext=[\"None\",\"25%\",\"50%\",\"75%\"], tickvals=[.5,2.5,4.5,6.5]),\n",
    "    yaxis=dict(title=dict(text=r\"$\\mathrm{IG}^{\\mathcal{R}}\\left( \\boldsymbol{\\theta}^{\\star} \\right) - \\mathrm{IG}^c\\left( \\boldsymbol{\\theta}^{\\star} \\right)$\", font=dict(size=20)))\n",
    ")\n",
    "fig.show(renderer=\"browser\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pibt",
   "language": "python",
   "name": "pibt"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
