{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cd621939-e73b-440c-8b86-e8734c7f9be4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import pyro\n",
    "from tqdm import trange\n",
    "\n",
    "\n",
    "from epidemic import Epidemic\n",
    "from neural.modules import Mlp, LazyFn\n",
    "from neural.aggregators import LSTMImplicitDAD\n",
    "from neural.baselines import DesignBaseline, BatchDesignBaseline\n",
    "from neural.critics import CriticDotProd\n",
    "\n",
    "from estimators.bb_mi import InfoNCE\n",
    "from oed.design import OED"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9316f2b-3363-42a2-9ec4-d35ae0477146",
   "metadata": {},
   "source": [
    "## SIR model\n",
    "\n",
    "We are going to study the SIR model---an SDE-based model from epidemiology (for details see e.g. [Wikipedia article](https://en.wikipedia.org/wiki/Compartmental_models_in_epidemiology)). The model is governed by two parameters -- infection rate and recovery rate, which we wish to learn. What we control is time $ \\tau \\in (0, 100)$, at which we measure the number of infected people in the population.\n",
    "\n",
    "Before we begin, you have to generate some training data (if you haven't done so already). Note this may take some time to run, but you only need to do it once "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ac42031b-262a-4887-bc4d-820423a07be2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading data.\n"
     ]
    }
   ],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "if os.path.exists(\"data/sir_sde_data.pt\"):\n",
    "    print(\"Loading data.\")\n",
    "    simdata = torch.load(\"data/sir_sde_data.pt\", map_location=device)\n",
    "else: \n",
    "    from epidemic_simulate_data import solve_sir_sdes\n",
    "    simdata = solve_sir_sdes(\n",
    "        num_samples=100000,\n",
    "        device=device,\n",
    "        grid=10000,\n",
    "        save=True, \n",
    "        filename=\"sir_sde_data.pt\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f21bcc9b-7c68-4684-b585-59fc18cfa815",
   "metadata": {},
   "source": [
    "Below we define some constants required for the model. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "08393dff-fd68-477c-959e-605832bdff3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "T = 2 # we wish to perform T=2 experiemnts\n",
    "design_dim = (1, 1) # one observation per design; design is 1d\n",
    "observation_dim = 1 #observation is 1-dimensional too -- number of infected people\n",
    "latent_dim = 2 # the dimension of the parameter is 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32f72f82-1587-4616-919c-c3c955505ebd",
   "metadata": {},
   "source": [
    "### Constant designs\n",
    "\n",
    "Let's start very simple: suppose we choose two constants, say $t_1=25$ and $t_2=75$. The cells below set up a constnant design network and show a few realisations of running such a policy (i.e. regardless of the underlying paramter, our design strategy is to always query at the same times)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c4cc097a-a985-4180-b21b-022719dcb157",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleDesignNet2Constants(DesignBaseline):\n",
    "    \"\"\" \n",
    "    Design Network which returns a pre-defined constant \n",
    "    \n",
    "    design1, design2: two constants\n",
    "    \n",
    "    The transformed design (corresponding to time) equals Sigmoid(design)*100, \n",
    "    which is a number between 0 and 100.\n",
    "    \"\"\"\n",
    "    def __init__(self, design1, design2, design_dim=design_dim):\n",
    "        super().__init__(design_dim=design_dim)\n",
    "        self.design = [torch.zeros(design_dim) + design1, torch.zeros(design_dim) + design2]\n",
    "\n",
    "    def forward(self, *design_obs_pairs):\n",
    "        return self.design[len(design_obs_pairs)]\n",
    "\n",
    "## Initialize the design net with two constants, which after transformation \n",
    "## (pass though sigmoid, multiply by 100) correspond to ~25 and 75\n",
    "design_net_const = SimpleDesignNet2Constants(-1.1, 1.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a167b587-52e4-437d-88ae-42729ddec6da",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Example run\n",
      "*True Theta: tensor([0.6063, 0.0813])*\n",
      "xi1: 24.978992462158203  y1: 170.554931640625\n",
      "xi2: 75.02100372314453  y2: 0.8645527362823486\n",
      "\n",
      "\n",
      "Example run\n",
      "*True Theta: tensor([0.2852, 0.1718])*\n",
      "xi1: 24.978992462158203  y1: 25.199087142944336\n",
      "xi2: 75.02100372314453  y2: 11.668272972106934\n",
      "\n",
      "\n",
      "Example run\n",
      "*True Theta: tensor([0.6901, 0.3115])*\n",
      "xi1: 24.978992462158203  y1: 45.527217864990234\n",
      "xi2: 75.02100372314453  y2: 0.0\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "sir_model2 = Epidemic(design_net=design_net_const, T=2, simdata=simdata)\n",
    "for i in range(3):\n",
    "    _ = sir_model2.eval(verbose=True)\n",
    "    print(\"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52237204-cef9-4b11-9f7c-b77cc4cfac93",
   "metadata": {},
   "source": [
    "### Optimized designs\n",
    "\n",
    "Are these the best constants we could choose? Probably not! \n",
    "\n",
    "We can optimize the two constants to obtain designs that are optimal according to the Expected Information gain (EIG) objective."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b69b0763-a1ef-4b10-b464-3145904afbc0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Initial designs ---\n",
      "Example run\n",
      "*True Theta: tensor([0.7385, 0.0712])*\n",
      "xi1: 38.99824905395508  y1: 54.18731689453125\n",
      "xi2: 91.85362243652344  y2: 0.639757513999939\n",
      "\n",
      "\n",
      "Example run\n",
      "*True Theta: tensor([1.4754, 0.0524])*\n",
      "xi1: 38.99824905395508  y1: 86.43742370605469\n",
      "xi2: 91.85362243652344  y2: 4.93392276763916\n",
      "\n",
      "\n",
      "Example run\n",
      "*True Theta: tensor([0.2476, 0.0391])*\n",
      "xi1: 38.99824905395508  y1: 306.08258056640625\n",
      "xi2: 91.85362243652344  y2: 47.065521240234375\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "pyro.clear_param_store()\n",
    "\n",
    "# fix seed as the initial designs are sampled from uniform(-5, 5)\n",
    "torch.manual_seed(20211101)\n",
    "\n",
    "# <BatchDesignBaseline> (from neural.baselines) is a very simple extension \n",
    "# of the <SimpleDesignNet2Constants> class above\n",
    "design_net_optimized = BatchDesignBaseline(\n",
    "    T=2, \n",
    "    design_dim=design_dim, \n",
    "    design_init=torch.distributions.Uniform(torch.tensor(-5.0, device=device), torch.tensor(5.0, device=device))\n",
    ")\n",
    "sir_model_const_optimized = Epidemic(design_net=design_net_optimized, T=2, simdata=simdata)\n",
    "\n",
    "print(\"--- Initial designs ---\")\n",
    "for i in range(3):\n",
    "    _ = sir_model_const_optimized.eval(verbose=True)\n",
    "    print(\"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f469a39a-ef5d-4130-ae5e-6e884e3e61e2",
   "metadata": {},
   "source": [
    "#### Critic network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5f92b0e4-f128-45f5-9af9-9ff88fd86643",
   "metadata": {},
   "outputs": [],
   "source": [
    "### We need to define a critic network -- we will train a tiny one\n",
    "encoding_dim = 16\n",
    "hidden_dim = 64\n",
    "\n",
    "critic_history_encoder = LSTMImplicitDAD(\n",
    "    # encoder network (MLP): encodes individual design-outcome pairs, \n",
    "    # whcih are then stacked and passed through an LSTM aggregator to \n",
    "    # get a vector of size <encoding_dim>\n",
    "    encoder_network=Mlp(input_dim=[*design_dim, observation_dim], hidden_dim=hidden_dim, output_dim=encoding_dim),\n",
    "    # emission network (MLP): takes the final representation and passes\n",
    "    # though final (\"head\") layers\n",
    "    emission_network=Mlp(input_dim=encoding_dim, hidden_dim=hidden_dim, output_dim=encoding_dim), \n",
    "    empty_value=torch.zeros(design_dim).to(device)\n",
    ").to(device)\n",
    "critic_latent_encoder = Mlp(input_dim=latent_dim, hidden_dim=[64, 128], output_dim=encoding_dim).to(device)\n",
    "\n",
    "# Critic takes experimental histories and parameters as inputs and returns a number\n",
    "# Optimal critic achieves tight bounds\n",
    "critic_net = CriticDotProd(\n",
    "    history_encoder_network=critic_history_encoder, \n",
    "    latent_encoder_network=critic_latent_encoder\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "951091e8-6d40-4cc1-9026-e41c6e163b89",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 3.718 : 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [04:44<00:00, 17.57it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Final designs ---\n",
      "Example run\n",
      "*True Theta: tensor([0.2109, 0.1383])*\n",
      "xi1: 33.92506790161133  y1: 34.3079719543457\n",
      "xi2: 90.68833923339844  y2: 9.713905334472656\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Let's optimize for a few steps\n",
    "pyro.clear_param_store()\n",
    "# First set up loss:\n",
    "mi_loss = InfoNCE(\n",
    "    model=sir_model_const_optimized.model, \n",
    "    critic=critic_net, \n",
    "    batch_size=256, \n",
    "    num_negative_samples=255\n",
    ")\n",
    "\n",
    "\n",
    "# and an otpimizer\n",
    "optimizer = pyro.optim.Adam({\"lr\": 0.001})\n",
    "oed = OED(optim=optimizer, loss=mi_loss)\n",
    "\n",
    "num_steps=5000\n",
    "num_steps_range = trange(1, num_steps + 1, desc=\"Loss: 0.000 \")\n",
    "for i in num_steps_range:\n",
    "    sir_model_const_optimized.train()\n",
    "    loss = oed.step()  \n",
    "    num_steps_range.set_description(\"Loss: {:.3f} \".format(loss))\n",
    "    \n",
    "print(f\"--- Final designs ---\")\n",
    "_ = sir_model_const_optimized.eval(verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c56742a-3a1e-4d82-89b7-3f588c48bff6",
   "metadata": {},
   "source": [
    "We will need to train for longer to ensure the designs (and critic params) have converged. Although the two designs might be the optimal constants, this design strategy is still not adaptive, i.e. we are not using past information to make future design decisions.\n",
    "\n",
    "### Adaptive Designs\n",
    "\n",
    "We wish our designs to be a function of the history. Specifically, given we have $T=2$ designs here, we want the second design to be informed by the design-outcome pair from the first experiment. Here's how we can do that"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "03068f60-d1c4-42a0-8cf2-efee63dab5ba",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Initial designs ---\n",
      "Example run\n",
      "*True Theta: tensor([0.8327, 0.0386])*\n",
      "xi1: 45.891273498535156  y1: 97.0364990234375\n",
      "xi2: 70.79920196533203  y2: 41.37067413330078\n",
      "\n",
      "\n",
      "Example run\n",
      "*True Theta: tensor([0.4836, 0.0941])*\n",
      "xi1: 45.891273498535156  y1: 20.98923110961914\n",
      "xi2: 70.77635192871094  y2: 1.7925260066986084\n",
      "\n",
      "\n",
      "Example run\n",
      "*True Theta: tensor([0.6650, 0.1082])*\n",
      "xi1: 45.891273498535156  y1: 12.45532512664795\n",
      "xi2: 70.7790756225586  y2: 0.0\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "pyro.clear_param_store()\n",
    "torch.manual_seed(20211101)\n",
    "### Set up a design network\n",
    "# We do this in (essentially) the same way as with the critic net (one big difference is that\n",
    "# the design net takes intermediate histories of variable length as inputs, while the critic\n",
    "# only takes the full history $h_T$ (this is dealt with in the networks in neural.aggregator),\n",
    "# so the code looks pretty much the same\n",
    "design_net_adaptive = LSTMImplicitDAD(\n",
    "    encoder_network=Mlp(input_dim=[*design_dim, observation_dim], hidden_dim=hidden_dim, output_dim=encoding_dim), \n",
    "    # note that the \"head\" layer here outputs a design, i.e. something of size <design_dim>\n",
    "    emission_network=Mlp(input_dim=encoding_dim, hidden_dim=hidden_dim, output_dim=design_dim), \n",
    "    empty_value=torch.zeros(design_dim)\n",
    ")\n",
    "\n",
    "sir_model_adaptive=Epidemic(\n",
    "    design_net=design_net_adaptive,\n",
    "    T=2, \n",
    "    simdata=simdata,\n",
    "    # note! we need to make sure the designs are increasing!\n",
    "    # to do that, simply select designs_transform=\"ts\" (for time series) \n",
    "    design_transform=\"ts\"\n",
    ")\n",
    "print(\"--- Initial designs ---\")\n",
    "for i in range(3):\n",
    "    _ = sir_model_adaptive.eval(verbose=True)\n",
    "    print(\"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "81415223-2332-4bfb-8482-2044c86b08e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "### define a new critic network (same as before)\n",
    "encoding_dim = 16\n",
    "hidden_dim = 64\n",
    "\n",
    "critic_history_encoder2 = LSTMImplicitDAD(\n",
    "    encoder_network=Mlp(input_dim=[*design_dim, observation_dim], hidden_dim=hidden_dim, output_dim=encoding_dim),\n",
    "    emission_network=Mlp(input_dim=encoding_dim, hidden_dim=hidden_dim, output_dim=encoding_dim), \n",
    "    empty_value=torch.zeros(design_dim).to(device)\n",
    ").to(device)\n",
    "critic_latent_encoder2 = Mlp(input_dim=latent_dim, hidden_dim=[64, 128], output_dim=encoding_dim).to(device)\n",
    "\n",
    "critic_net2 = CriticDotProd(\n",
    "    history_encoder_network=critic_history_encoder2, \n",
    "    latent_encoder_network=critic_latent_encoder2\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "940d9806-e1aa-43d4-9914-98d6e354e6e7",
   "metadata": {},
   "source": [
    "Notice that the second design, $\\xi_2$, is now slightly different. Let's now optimize the design network (and the critic) for a few steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d73b260f-2b44-4240-ae22-f524af282ac0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 3.091 : 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [05:07<00:00, 16.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Final designs ---\n",
      "Example run\n",
      "*True Theta: tensor([0.7548, 0.1081])*\n",
      "xi1: 10.930394172668457  y1: 274.59637451171875\n",
      "xi2: 33.50607681274414  y2: 36.46894454956055\n",
      "\n",
      "\n",
      "Example run\n",
      "*True Theta: tensor([0.4764, 0.1522])*\n",
      "xi1: 10.930394172668457  y1: 103.99957275390625\n",
      "xi2: 33.510799407958984  y2: 49.70225524902344\n",
      "\n",
      "\n",
      "Example run\n",
      "*True Theta: tensor([0.3431, 0.0438])*\n",
      "xi1: 10.930394172668457  y1: 50.67266082763672\n",
      "xi2: 34.003257751464844  y2: 251.98411560058594\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Let's optimize for a few steps\n",
    "pyro.clear_param_store()\n",
    "# First set up a new loss with the new model (though we could have used the old loss \n",
    "# and the old model, by just changing the design network):\n",
    "mi_loss_adaptive = InfoNCE(\n",
    "    model=sir_model_adaptive.model, \n",
    "    critic=critic_net2, \n",
    "    batch_size=256, \n",
    "    num_negative_samples=255\n",
    ")\n",
    "\n",
    "\n",
    "# and an otpimizer\n",
    "optimizer = pyro.optim.Adam({\"lr\": 0.001})\n",
    "oed = OED(optim=optimizer, loss=mi_loss_adaptive)\n",
    "\n",
    "num_steps=5000\n",
    "num_steps_range = trange(1, num_steps + 1, desc=\"Loss: 0.000 \")\n",
    "for i in num_steps_range:\n",
    "    sir_model_adaptive.train()\n",
    "    loss = oed.step()  \n",
    "    num_steps_range.set_description(\"Loss: {:.3f} \".format(loss))\n",
    "    \n",
    "print(f\"--- Final designs ---\")\n",
    "for i in range(3):\n",
    "    _ = sir_model_adaptive.eval(verbose=True)\n",
    "    print(\"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb799670-0b9f-4368-b745-10b28c7db14c",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
