{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Conditional Neural Processes (CNP) for 1D regression.\n",
    "[Conditional Neural Processes](https://arxiv.org/pdf/1807.01613.pdf) (CNPs) were\n",
    "introduced as a continuation of\n",
    "[Generative Query Networks](https://deepmind.com/blog/neural-scene-representation-and-rendering/)\n",
    "(GQN) to extend its training regime to tasks beyond scene rendering, e.g. to\n",
    "regression and classification."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch #torch==2.1.2\n",
    "import torch.nn as nn \n",
    "import numpy as np #numpy==1.24.3\n",
    "import torchsnooper #torchsnooper==0.8"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conditional Neural Processes\n",
    "\n",
    "We can visualise a forward pass in a CNP as follows:\n",
    "\n",
    "<img src=\"https://bit.ly/2OFb6ZK\" alt=\"drawing\" width=\"400\"/>\n",
    "\n",
    "As shown in the diagram, CNPs take in pairs **(x, y)<sub>i</sub>** of context\n",
    "points, pass them through an **encoder** to obtain\n",
    "individual representations **r<sub>i</sub>** which are combined using an **aggregator**. The resulting representation **r**\n",
    "is then combined with the locations of the targets **x<sub>T</sub>** and passed\n",
    "through a **decoder** that returns a mean estimate\n",
    "of the **y** value at that target location together with a measure of the\n",
    "uncertainty over said prediction. Implementing CNPs therefore involves coding up\n",
    "the three main building blocks:\n",
    "\n",
    "*   Encoder\n",
    "*   Aggregator\n",
    "*   Decoder\n",
    "\n",
    "A more detailed description of these three parts is presented in the following\n",
    "sections alongside the code."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Encoder\n",
    "\n",
    "The encoder **e** is shared between all the context points and consists of an\n",
    "MLP with a handful of layers. For this experiment four layers are enough, but we\n",
    "can still change the number and size of the layers when we build the graph later\n",
    "on via the variable **`encoder_output_sizes`**. Each of the context pairs **(x,\n",
    "y)<sub>i</sub>** results in an individual representation **r<sub>i</sub>** after\n",
    "encoding. These representations are then combined across context points to form\n",
    "a single representation **r** using the aggregator **a**.\n",
    "\n",
    "In this implementation we have included the aggregator **a** in the encoder as\n",
    "we are only taking the mean across all points. The representation **r** produced\n",
    "by the aggregator contains the information about the underlying unknown function\n",
    "**f** that is provided by all the context points."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DeterministicEncoder(nn.Module):\n",
    "    def __init__(self, output_sizes):\n",
    "        super(DeterministicEncoder, self).__init__()\n",
    "        self.linears = nn.ModuleList()\n",
    "        for i in range(len(output_sizes) - 1):\n",
    "            self.linears.append(nn.Linear(output_sizes[i], output_sizes[i + 1]))\n",
    "\n",
    "    def forward(self, context_x, context_y):\n",
    "        \"\"\"Encodes the inputs into one representation.\n",
    "\n",
    "        Args:\n",
    "        context_x: Tensor of size of batches x observations x m_ch. For this 1D regression\n",
    "          task this corresponds to the x-values.\n",
    "        context_y: Tensor of size bs x observations x d_ch. For this 1D regression\n",
    "          task this corresponds to the y-values.\n",
    "\n",
    "        Returns:\n",
    "            representation: The encoded representation averaged over all context \n",
    "            points.\n",
    "        \"\"\"\n",
    "\n",
    "        # Concatenate x and y along the filter axes\n",
    "        encoder_input = torch.cat((context_x, context_y), dim=-1)\n",
    "\n",
    "        # Get the shapes of the input and reshape to parallelise across observations\n",
    "        batch_size, num_context_points, _ = encoder_input.shape\n",
    "        hidden = encoder_input.view(batch_size * num_context_points, -1)\n",
    "        \n",
    "        # Pass through MLP\n",
    "        for i, linear in enumerate(self.linears[:-1]):\n",
    "            hidden = torch.relu(linear(hidden))\n",
    "        # Last layer without a ReLu\n",
    "        hidden = self.linears[-1](hidden)\n",
    "        # Bring back into original shape (# Flatten the output feature map into a 1D feature vector)\n",
    "        hidden = hidden.view(batch_size, num_context_points, -1)\n",
    "\n",
    "        # Aggregator: take the mean over all points\n",
    "        representation = hidden.mean(dim=1)\n",
    "        return representation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Decoder\n",
    "\n",
    "Once we have obtained our representation **r** we concatenate it with each of\n",
    "the targets **x<sub>t</sub>** and pass it through the decoder **d**. As with the\n",
    "encoder **e**, the decoder **d** is shared between all the target points and\n",
    "consists of a small MLP with layer sizes defined in **`decoder_output_sizes`**.\n",
    "The decoder outputs a mean **&mu;<sub>t</sub>** and a variance\n",
    "**&sigma;<sub>t</sub>** for each of the targets **x<sub>t</sub>**. To train our\n",
    "CNP we use the log likelihood of the ground truth value **y<sub>t</sub>** under\n",
    "a Gaussian parametrized by these predicted **&mu;<sub>t</sub>** and\n",
    "**&sigma;<sub>t</sub>**.\n",
    "\n",
    "In this implementation we clip the variance **&sigma;<sub>t</sub>** at 0.1 to\n",
    "avoid collapsing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sigmoid_expectation(mu, sigma):\n",
    "    # Bound the variance\n",
    "    sigma = 0.1 + 0.9*torch.nn.functional.softplus(sigma)\n",
    "    #sigma = 0.01 + 0.99*torch.nn.functional.softplus(sigma)\n",
    "    \n",
    "    y = torch.from_numpy(np.sqrt(1+3/np.pi**2*sigma.detach().numpy()**2))\n",
    "    # Bound the divisor to > 0\n",
    "    tmp0 = torch.where(y==0.,1e-4,0.)\n",
    "    y=torch.add(y,tmp0)\n",
    "    \n",
    "    expectation = torch.sigmoid(mu/y) \n",
    "    var = expectation * (1-expectation) * (1-(1/y))\n",
    "    #var = 0.01 + 0.99*torch.nn.functional.softplus(var)\n",
    "    #tmp = torch.where(var==0.,1.e-4,0.)\n",
    "    #var = torch.add(expectation, tmp)\n",
    "    \n",
    "    return expectation, var"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DeterministicDecoder(nn.Module):\n",
    "    def __init__(self, output_sizes):\n",
    "        \"\"\"CNP decoder.\n",
    "        Args:\n",
    "            output_sizes: An iterable containing the output sizes of the decoder MLP.\n",
    "        \"\"\"\n",
    "        super(DeterministicDecoder, self).__init__()\n",
    "        self.linears = nn.ModuleList()\n",
    "        for i in range(len(output_sizes) - 1):\n",
    "            self.linears.append(nn.Linear(output_sizes[i], output_sizes[i + 1]))\n",
    "\n",
    "    def forward(self, representation, target_x):\n",
    "        \"\"\"Decodes the individual targets.\n",
    "\n",
    "        Args:\n",
    "            representation: The encoded representation of the context\n",
    "            target_x: The x locations for the target query\n",
    "\n",
    "        Returns:\n",
    "            dist: A multivariate Gaussian over the target points.\n",
    "            mu: The mean of the multivariate Gaussian.\n",
    "            sigma: The standard deviation of the multivariate Gaussian.   \n",
    "        \"\"\"\n",
    "\n",
    "        # Get the shapes of the input and reshape to parallelise across observations\n",
    "        batch_size, num_total_points, _ = target_x.shape\n",
    "        representation = representation.unsqueeze(1).repeat([1, num_total_points, 1])\n",
    "\n",
    "        # Concatenate the representation and the target_x\n",
    "        input = torch.cat((representation, target_x), dim=-1)\n",
    "        hidden = input.view(batch_size * num_total_points, -1)\n",
    "\n",
    "        # Pass through MLP\n",
    "        for i, linear in enumerate(self.linears[:-1]):\n",
    "            hidden = torch.relu(linear(hidden))\n",
    "        # Last layer without a ReLu\n",
    "        hidden = self.linears[-1](hidden)\n",
    "\n",
    "        # Bring back into original shape\n",
    "        hidden = hidden.view(batch_size, num_total_points, -1)\n",
    "\n",
    "        # Get the mean an the variance\n",
    "        mu, sigma = torch.split(hidden, 1, dim=-1)\n",
    "        \n",
    "        # Map mu to a value between 0 and 1 and get the expectation and variance\n",
    "        mu, sigma = sigmoid_expectation(mu, sigma)\n",
    "\n",
    "        # Get the distribution\n",
    "        dist = torch.distributions.normal.Normal(loc=mu, scale=sigma)\n",
    "        return dist, mu, sigma"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model\n",
    "\n",
    "Now that the main building blocks (encoder, aggregator and decoder) of the CNP\n",
    "are defined we can put everything together into one model. Fundamentally this\n",
    "model only needs to include two main methods: 1. A method that returns the log\n",
    "likelihood of the targets' ground truth values under the predicted\n",
    "distribution.This method will be called during training as our loss function. 2.\n",
    "Another method that returns the predicted mean and variance at the target\n",
    "locations in order to evaluate or query the CNP at test time. This second method\n",
    "needs to be defined separately as, unlike the method above, it should not depend\n",
    "on the ground truth target values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DeterministicModel(nn.Module):\n",
    "    def __init__(self, encoder_sizes, decoder_sizes):\n",
    "        super(DeterministicModel, self).__init__()\n",
    "        \"\"\"Initialises the model.\n",
    "\n",
    "        Args:\n",
    "            encoder_output_sizes: An iterable containing the sizes of hidden layers of\n",
    "                the encoder. The last one is the size of the representation r.\n",
    "            decoder_output_sizes: An iterable containing the sizes of hidden layers of\n",
    "                the decoder. The last element should correspond to the dimension of\n",
    "                the y * 2 (it encodes both mean and variance concatenated)\n",
    "        \"\"\"\n",
    "        self._encoder = DeterministicEncoder(encoder_sizes)\n",
    "        self._decoder = DeterministicDecoder(decoder_sizes)\n",
    "\n",
    "    def forward(self, query, target_y=None):\n",
    "        \"\"\"Returns the predicted mean and variance at the target points.\n",
    "\n",
    "        Args:\n",
    "            query: Array containing ((context_x, context_y), target_x) where:\n",
    "                context_x: Array of shape batch_size x num_context x 1 contains the \n",
    "                    x values of the context points.\n",
    "                context_y: Array of shape batch_size x num_context x 1 contains the \n",
    "                    y values of the context points.\n",
    "                target_x: Array of shape batch_size x num_target x 1 contains the\n",
    "                    x values of the target points.\n",
    "            target_y: The ground truth y values of the target y. An array of \n",
    "                shape batchsize x num_targets x 1.\n",
    "\n",
    "        Returns:\n",
    "            log_p: The log_probability of the target_y given the predicted\n",
    "            distribution.\n",
    "            mu: The mean of the predicted distribution.\n",
    "            sigma: The variance of the predicted distribution.\n",
    "        \"\"\"\n",
    "\n",
    "        (context_x, context_y), target_x = query\n",
    "        # Pass query through the encoder and the decoder\n",
    "\n",
    "        representation = self._encoder(context_x, context_y)\n",
    "        dist, mu, sigma = self._decoder(representation, target_x)\n",
    "        \n",
    "        # If we want to calculate the log_prob for training we will make use of the\n",
    "        # target_y. At test time the target_y is not available so we return None\n",
    "        log_p = None if target_y is None else dist.log_prob(target_y)\n",
    "        return log_p, mu, sigma"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
