{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/marvin/miniforge3/envs/cmpe/lib/python3.10/site-packages/bayesflow/trainers.py:27: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from tqdm.autonotebook import tqdm\n"
     ]
    }
   ],
   "source": [
    "import logging\n",
    "import os\n",
    "import pickle\n",
    "\n",
    "import bayesflow as bf\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import tensorflow as tf\n",
    "import tensorflow_probability as tfp\n",
    "from cmdstanpy import CmdStanModel\n",
    "from matplotlib import cm\n",
    "from scipy import stats\n",
    "from tqdm.autonotebook import tqdm\n",
    "\n",
    "cmap = plt.get_cmap(\"viridis\", 6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GMM(tfp.distributions.MixtureSameFamily):\n",
    "    def __init__(self, theta):\n",
    "        logging.getLogger().setLevel(logging.ERROR)\n",
    "        mixture_weights_dist = tfp.distributions.Categorical(probs=[0.5, 0.5])\n",
    "        components_dist = tfp.distributions.MultivariateNormalDiag(\n",
    "            loc=tf.stack([theta, -1.0 * theta], axis=1), scale_diag=[[0.5, 0.5]]\n",
    "        )\n",
    "\n",
    "        super().__init__(mixture_distribution=mixture_weights_dist, components_distribution=components_dist)\n",
    "\n",
    "\n",
    "class GMMPrior:\n",
    "    def __init__(self, prior_location=[0, 0], prior_scale_diag=[1, 1]):\n",
    "        self.dist = tfp.distributions.MultivariateNormalDiag(loc=prior_location, scale_diag=prior_scale_diag)\n",
    "\n",
    "    def __call__(self, batch_size=None):\n",
    "        theta = self.dist.sample([batch_size]) if batch_size else self.dist.sample()\n",
    "        return theta\n",
    "\n",
    "\n",
    "class GMMSimulator:\n",
    "    def __init__(self, dist, n_obs=10):\n",
    "        self.dist = dist\n",
    "        self.n_obs = n_obs\n",
    "\n",
    "    def __call__(self, theta, n_obs=None):\n",
    "        if n_obs is None:\n",
    "            n_obs = self.n_obs\n",
    "        x = self.dist(theta).sample([n_obs])\n",
    "        x = tf.transpose(x, perm=[1, 0, 2])\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Performing 2 pilot runs with the anonymous model...\n",
      "INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 2)\n",
      "INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 10, 2)\n",
      "INFO:root:No optional prior non-batchable context provided.\n",
      "INFO:root:No optional prior batchable context provided.\n",
      "INFO:root:No optional simulation non-batchable context provided.\n",
      "INFO:root:No optional simulation batchable context provided.\n"
     ]
    }
   ],
   "source": [
    "prior = GMMPrior(prior_location=[0, 0])\n",
    "simulator = GMMSimulator(GMM)\n",
    "\n",
    "generative_model = bf.simulation.GenerativeModel(\n",
    "    prior=prior, simulator=simulator, prior_is_batched=True, simulator_is_batched=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "simulations = generative_model(2000)\n",
    "theta = simulations[\"prior_draws\"]\n",
    "y = simulations[\"sim_data\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "theta shape: (2000, 2)\n",
      "y shape: (2000, 10, 2)\n"
     ]
    }
   ],
   "source": [
    "print(\"theta shape:\", theta.shape)\n",
    "print(\"y shape:\", y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(5):\n",
    "    plt.plot(y[i, :, 0], y[i, :, 1], \"o\")\n",
    "\n",
    "    # plot theta^* and circle with +-1SD for reference\n",
    "    plt.plot(theta[i, 0], theta[i, 1], \"*\", markersize=10, color=\"gray\")\n",
    "    plt.gca().add_patch(plt.Circle((theta[i, 0], theta[i, 1]), radius=1, color=\"gray\", fill=False))\n",
    "\n",
    "    # add another circle at the 2nd mode\n",
    "    plt.gca().add_patch(\n",
    "        plt.Circle((-theta[i, 0], -theta[i, 1]), radius=1, color=\"gray\", linestyle=\"dotted\", fill=False)\n",
    "    )\n",
    "\n",
    "    plt.axis(\"equal\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SIMULATION_BUDGET = 1024\n",
    "filename = \"gmm_train_data.pkl\"\n",
    "\n",
    "\n",
    "if os.path.exists(filename):\n",
    "    with open(filename, \"rb\") as f:\n",
    "        train_data = pickle.load(f)\n",
    "    print(\"loaded train data from file\")\n",
    "\n",
    "else:\n",
    "    logging.getLogger().setLevel(logging.ERROR)\n",
    "    print(\"generating new train data and saving it\")\n",
    "    train_data = generative_model(SIMULATION_BUDGET)\n",
    "    with open(filename, \"wb\") as f:\n",
    "        pickle.dump(train_data, f)\n",
    "\n",
    "num_params = train_data[\"prior_draws\"].shape[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_TEST = 1000\n",
    "filename = \"gmm_test_data.pkl\"\n",
    "\n",
    "\n",
    "if os.path.exists(filename):\n",
    "    with open(filename, \"rb\") as f:\n",
    "        test_data = pickle.load(f)\n",
    "    print(\"loaded test data from file\")\n",
    "\n",
    "else:\n",
    "    logging.getLogger().setLevel(logging.ERROR)\n",
    "    print(\"generating new test data and saving it\")\n",
    "    test_data = generative_model(NUM_TEST)\n",
    "    with open(filename, \"wb\") as f:\n",
    "        pickle.dump(test_data, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate reference posterior on test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "HMC running ...: 100%|██████████| 1/1 [00:00<00:00,  2.63it/s]\n"
     ]
    }
   ],
   "source": [
    "test_data = generative_model(1)\n",
    "\n",
    "logger = logging.getLogger(\"cmdstanpy\")\n",
    "logger.addHandler(logging.NullHandler())\n",
    "logger.propagate = False\n",
    "logger.setLevel(logging.CRITICAL)\n",
    "\n",
    "num_posterior_samples = 5000\n",
    "iter_warmup = 2000\n",
    "\n",
    "y = test_data[\"sim_data\"].numpy()\n",
    "theta_true = test_data[\"prior_draws\"]\n",
    "n_sim, n_obs, data_dim = y.shape\n",
    "_, param_dim = theta_true.shape\n",
    "\n",
    "iter_sampling = num_posterior_samples // 2\n",
    "\n",
    "posterior_samples = np.zeros((n_sim, num_posterior_samples, param_dim))\n",
    "trace_plots = []\n",
    "\n",
    "for i in tqdm(range(n_sim), desc=\"HMC running ...\"):\n",
    "    stan_data = {\"n_obs\": n_obs, \"data_dim\": data_dim, \"x\": y[i]}\n",
    "    model = CmdStanModel(stan_file=\"gmm.stan\")\n",
    "    fit = model.sample(\n",
    "        data=stan_data,\n",
    "        iter_warmup=iter_warmup,\n",
    "        iter_sampling=iter_sampling,\n",
    "        chains=1,\n",
    "        inits={\"theta\": theta_true[i].tolist()},\n",
    "        show_progress=False\n",
    "    )\n",
    "    posterior_samples_chain = fit.stan_variable(\"theta\")\n",
    "    posterior_samples[i] = np.concatenate([posterior_samples_chain, -1.0 * posterior_samples_chain], axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARkAAAESCAYAAADJ16HUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA3hklEQVR4nO3daWxc553v+e/Za6/iTkqiJMqSJTuKHe9JOwnstDuJMbgNYy6CHjQGiIPAyAs7QMN5Y/eLGHnR8IsE3QHSQZIGuu0GBpkbYO4kuTeZztyMk3i7cZx4SUzHkkVro0iqWNxqPfs586JYJZIitdg+XKT/ByAokkWeQ4r14/P8n02J4zhGCCESom71DQghrm0SMkKIREnICCESJSEjhEiUhIwQIlESMkKIREnICCESpW/1DVxKFEVMT0+Tz+dRFGWrb0cIsSyOY+r1Ort27UJVL91W2dYhMz09zejo6FbfhhBiA5OTk+zZs+eSj9nWIZPP54H2N1IoFLb4boQQHbVajdHR0e5z9FK2dch0ukiFQkFCRoht6ErKGFL4FUIkSkJGCJEoCRkhRKIkZIQQiZKQEUIkSkJGCJEoCRkhRKIkZIQQiZKQEUIkSkJGCJEoCRkhRKIkZIQQiZKQEUIkSkJGCJEoCRkhRKIkZIQQiZKQEUIkSkJGCJEoCRkhRKIkZIQQiZKQEUIkKtGQefrpp7nrrrvI5/MMDg7y0EMPcfz48SQvKYTYZhINmeeff55HH32UV155hV/+8pf4vs9nP/tZms1mkpcVQmwjShzH8WZdrFKpMDg4yPPPP8+nP/3piz7uui6u63bf7hwgVa1W5dwlIbaRWq1GsVi8oufmptZkqtUqAL29vet+/Omnn6ZYLHZf5IhaIXa+TWvJRFHEX//1X7O0tMRLL7207mOkJSPEznA1LZlNO6b20UcfZXx8fMOAAbAsC8uyNuuWhBCbYFNC5rHHHuNnP/sZL7zwAnv27NmMSwohtolEQyaOY7761a/y4x//mN/85jeMjY0leTkhxDaUaMg8+uij/PCHP+SnP/0p+Xye8+fPA1AsFkmn00leWgixTSRa+FUUZd33P/PMMzz88MOX/fyrKS4JITbPtin8buIUHCHENiVrl4QQiZKQEUIkSkJGCJEoCRkhRKIkZIQQiZKQEUIkSkJGCJEoCRkhRKIkZIQQiZKQEUIkSkJGCJEoCRkhRKIkZIQQiZKQEUIkSkJGCJEoCRkhRKIkZIQQiZKQEUIkSkJGCJEoCRkhRKIkZIQQiZKQeR/mGy7jU1XmG+7lHyzEdW7TzsK+lsxUHU6UGwD05eTsbiEuRULmfRgppla9vpT5hstM1WGkmJJAEtclCZn3oS9nXXFgSKtHXO8kZBJ2Na0eIa5FEjIfsk73KKWrlOsuNdujkDYlZMR1S0LmQ9bpHilKzOS8TbluM5RPU0wbl+wuSe1GXKskZK7CekGwsuXiBBEpXeXQUI6UrjKQT1Gzs6taMhuFidRuxLVKQuYqrBcEK1sucaxwaCjH0d1FAA4O5buhcqmvMVGu88fJRXqz1mXDSIidRkJmhcs9sUeKKaq2R9X2mW+49OUuhMLKlsz4VLX7/l+Mz3Bu0eaTh/q59+DAuoXg352a56UT83zyUB99OYv5hsuLJyrUnQD29kjIiB1NQmaF9VoZfzg1z6+Ol+nPWRweLlKzfSbKTQCODHNRKI1PVXnjzBITaY2BvMWJcpNyzeaWPSUAFpseb55dZDJncvdYO1R6syZ9OZPerNm9j7odkk/rUjAWO56EzArrtTJ+e3Kel04skDM1mk5Eb86gFQScnmsAMW+dq1KuOgwXLf7ypmFGiin+qMa8eXYJXYeZJZtYAYgBGJ+u8tLEHP05k9HeLACFtMmDHx3myHDhwvX3laSrJK4JiYbMCy+8wDe/+U1ee+01ZmZm+PGPf8xDDz2U5CU/kLWT7CbKdYIw5CMjeQ4N5zg8XGQob9F45zyvnJxntmbjRzFvTVUZnwEvjHnw6AhxDA03YGHBY7bmoCoKdafdxbJ0jdv3lhjtzZDS1XZ3asnhkwf7ute+msl+Qmx3iS6QbDab3HrrrXz3u99N8jIfyKUWO/7u1Dy/O7lIpeFwy+4S9x7spydrcnquxetnFvnF22X8IObeg30c7M8ytWjzX1+f5PRcE4i5bbTIWH+WtKVxer7JsfN1zsy1cIOIQ4N5nCDi96cWePndCsfP1ze8v5cnKrw8MScLMsWOlGhL5sEHH+TBBx9M8hIfWKcOU7V9iukLQ9EjxRS9WYv5lst7FY9//tUJ/uojdQxV4dRcA8cLUBQ4X7X5+A39ZC2dd8t1ppccppZsNBQODeb51I0DvPjuLNOLDr85VuZkpUHY7jnxiQN9NFwfJ4iYazir7qtThK7aHq+fWYKYi+bayAiU2Am2VU3GdV1c98Jf61qtlvg1O/WXqu2tGooGODSYY3cpTc32OTnb4JnqKT46UkRVVAxDZVchxdhgjhPlBg0vZFcpxfSSQxxBMwh4eWKO3qxJyw8hDqm7PotNHycIeG/WwNQUcpbB3l64cTDfHZXqy1nd8BssWNy+rwQoFxWBZW6N2Am2Vcg8/fTTfOMb39jUa3bqH/MNd1VLJqWr/O7UPKamcnioQMv1ObfkcmK2Tn/ewtR0/DAmZ+pkDY2ZqkOl7nF6roUbhhRSOlUnQNdUBvPtoW/HC/HCgErdJW222NeXoeYEZEyDhhuuCoz2cLkPxBwZLmw4pL7ytRDb0bYKmSeffJLHH3+8+3atVmN0dDTRa843XI6drwEKlqbwx8k6vVmLQtrg3IKNF0Tomspf3TzMc8fLzNY9CkHE/r40th/z0ok5bttbxNBUXni3wkLTJW0ZWLpCEEacqtTxg5DZhstSyycIQ5peyMxSi7emq0wv2ZQyJqamMFgwqdo+E+U6ThABMbM1j2LaWTdkpEAsdoJtFTKWZWFZyT5p2qFSp9NCmKk6/PLtMuWaw3DB4nzdI2uq3DiUp5QxyKd0zi46nKg0uHGoQKU+x2zNZU9PBl0NeGemxlvnFinlLOpOgIqCEoecr7rMNVyCsN0V68laBGFEFMUQgRvEHJuuEaOgqgoNr93COV91eD2MKKUtDgxmOTSUk5aK2NG2Vch8UBsVQleuLxqfrnJ8pk7G0ikurylKmypuEFFIGxwcytNwAhabPhlL58ahPPMNn4WmR8sL8YIQP4R3Zqq0PJ+m176GH9n0Zi1MXScIoWZ7qIAC5NMm+/pyRFGDJdsjk9JJGyopQ6fa8iDSqdRdDFXHDyMMTUVV29VhKeqKnS7RkGk0GkxMTHTfPnXqFG+++Sa9vb3s3bv3Q7/e2kLoyhGa2ZqHosTUnYDdPWnGBrLdJ/B/vn2U8ekqR3cVOTiU5w+n5vn5W9MEcUwQhsS0R5FQFKIYojhkqRkSrLi2qevkUwa2FxLHEYrSLh6nLAVT05hearJke0RxRErXGcynsYOQSFEJ4hhieK9SY29vhpFShoWmy8nZJsW00f3eJHDETpRoyPzhD3/g/vvv777dqbd88Ytf5Nlnn/3Qr7e2ELpyhObQUA7HCzi3ZHfDpKPz71dOznNits6xmRqvnV0irWukTY1T83W8IGa0lGZXKd2epLfm2gutgKbdIFZBjcGLIALUMGah6dByQ7wQssuhU8oalDBw/ZCUrvGHM4sQKzS8kLmmy/SSg64q3DSSWxWene9r5fco4SO2s0RD5r777iOO4yQvscraQujKEZqUrjIxa1O3Q5wgYqJcX9V6ee6dMj/70wyWqTBf97DdkFJWh9iCGOIYxgaz9GRMTlbWnzjnxkB48fujKMaP2h8KopiWHzBRqVFMW+RSGg0nwAsi0qaG6wX8adLhfLVFCPRkDO47MsRgoT3itDZwZAhbbHfXVE1mrb6cRTHdflJW6i51J+guOnzxRIXfn1qkXHUYn64y13CJ45hTsw1aXoSuwaCeQlchbepoSsj5qsP5qouqqOStCFPVqLkhfrTxPRgK9GQt/Mil4UUoCgRByKwd0HRCerMpmm6ArcCSrbDY9MimdXRNIQhizlddZmsegwWzW1daWwyWwrDYzq6ZkFnbMoF2wXdyoUnL8zk8lCdl6t2uxdFd7T1fylWb359aZG9vmpt25VloOrihQimjYapwcr7JQtMjDOF8zVvVUMmZEYbKJUPGi6HmuLhehAZEIQRqjBNBZIf4YRNdU0npBq4f4AYhqqZw6+4iThBzz1gvgwWTU5UmMU1u29vT3a8GpAUjtr9rJmTGp6v8/tQicKHGMlN12tsyKJAy9VVPzp6sycHBPHtKabJWnb6ciReGvJO1GMgpDJXSnJtvcb7msVGHr+FdvisYAwvNiE4OBRHEy0nlxeC5YKoR2YLCDUM5ylUXYgjimN6shaIqVOouDSdkuGRJq0XsONdMyHRaJp3X0O5G3L6/B4i7T85Oi8fSNebqHvm0Rl/OZGK2ycnZJs3lQqzvhyza7oYBczVWNnT8uD2s3ZFSwTAUgijEUFRMTSWMY/b0ZDg8XMALQo6X6+wppfnUoQFpuYgd55oJmYND+VUjRp3h6yPD+VVPzE6L5+ZdefJpnboT0ABabsC5xRbVpotrGtRtn2rLx1AhXh4piuGqQkfZ4PEx7eXvmgIDhRSWpuBHMecWW8w3XFRN5eWJOUxNpZgxqbcCenfL7F6xM10zIbPWRosHj+4q0nACerMmpqby/PEyS62A3pzBfMsjVlQ0YoKoXT+JYsinFaJQwQsinBXNkrQGfshFw9kdhtLuEnWsDJ1OyFRtl4xp0JMxqTntrpnnR5yZa/HfnGn292YIY+ibMbh7rFeCRuw412zIbLR48OBQnnLd5aV3K7xXqfM/36vgeDBStChlTeI4xtA0FpsOEe0giEJQVQVvTYHXXme4eqW1JRtTBXf5awS06zOuGxOGHqauUEwZmLqKoaqkTA1dVZmpOe3ZyCmDmer6a5iE2M6u2ZBZeZrAyrfbYuquT80JCIJ2VyiIIoYKFkXfYLbu0PJD/PZEXDwvRiH8wPWZtSHVvpN2a6nphHhBTM7SGO3N8rmjwyw2XX51vIIC3DJalKKv2JGu2ZCBdsB0NvVeWTTt7KU7tdhCBd6bazCQt/BDmK25qKpK2tBpeRc6Qh9GAXijrxFGEMYhBgopoz3L2A8j9vbluGnYI20qfObIENDeqHzlxlqd71Nm/Yrt6poMmc5K65p9Yb3Syq7GYtPjVKWJqSs8dNtuTs03OT3X5MxCC10DU1cppAzmmhtVWz4YQ2mPMnV4MfhOTCGlcmSkQBBFvPLePLftK5E2FRabPq+cnG+PgpWb9GQNspbR/XyZ9Su2s2syZGaqDq+fXgQFbt9Xopg2V52H1DkxoOkGmLrK2ECGIIqJophi2qCYtrD98KIw6FBZPSx9tRQuDpoYaLoBTS9karHF9KLN+arNUDFNteUTxHBoIAMK9C2fdCCzfsVOcE2GzMr5MZ1d5V6emOOldyvs6U1zeCjPJw/28/ZUlYlKg1OzTTRVw9Q1RoppRkopMqaK7fucXfAu+vrvJ2B0LoxCeTHkLYW0ApqiYHsRUdRuidy1r4fZmk3TCzhZadKTsbhhMEva0Nnbl+uectDe1EqI7e+aDJm+nMW9B9d2HWIqDZdK02V/f47//RP7mSjX+d2peQxNYWrRxtQVBvIWhq7SkzUoWCY63oZD1Fdj7dfw/ZhsSqMva1FzfOpOSMbU6c1ZfHRXidPzLbKmzq6eFPcc6KOQNhnKWzhBRLnuMFu7EH7SXRLb2TUZMus5Mlzgtr1N3jpX5aUTs1iaQsrUuWesj3Ld4d3ZOrYXcnymyjvlBqoCfhCiftC+0Qb8CJpuiK751Fo+bgiV5dMlq7ZPMWUwWLC4eaSA48eM9ho4QbRq6wrpLomd4LoJmb6cxeePjrDQ9Dg+0+Anb05xeLiAqsA7MzX+PFOj1gpQlJillt9d9KgrkDWg5V96hOlqsyikvYYpCGLSpooWgBJHTC05KEqMrqmYmkbDC2nWXUZ7UowN5ICLd8uTFozYzq6bkIH2k/HBoyPADH4YkU/pWHp70lsQhHhhxOGhHGEcc2bObq+4jiFt6BAHXGqwydLBvsp+laGCaSqMFgs0XZ+qE9BwfA4M5oiimP5cirmGQ7UVcm7J5s6xPgkUseNcVyED7Rm/f5tt783ieAHHy3VGiimGixn8sEV/zuITB/v5H+MzvHl2CT9qHzl7uYky0QYfV5ZfouXXKQOCANIphcFciiCE6VoLxw/xg4j5pk5myWZsIMf+viympuAV4+7CTznQTew0113IwIUd9H7yxjl+c7xCHLePjS2kdWw/5M9TNQbyKTRdwfFiwhC0NV+j87ZOu+sTrekrKYCpQD6j4foRfhgvr0tqf9BQNIopg9PzTVpee3PyrKkCMQ0vwvFD+nImjh9z8+7cqu0rpNArdpLrLmRWnlxg6RpxHDHXCBguBowU07wzU2XJ9ilXXeIoRgcUpR0aneNlNZaXAyy/NvXlkFnRmtEATW1vvRlFMZoKfkC3C+YEAWcWbGwvxA/ANCBradw8XKDlh0wu2LxzvkYpZfDzP1VxvIA7x/q6Bd6V834kbMR2dt2FTGepwZLtYqgaY/1ZNNUmJqbpBqQMHaUVtJcUxJC1VPIpA1VVqNQdgrAdOJ3FjxEQLm8QnlNVaq2QCNA1iBSo2xHx8vLrVT2qCDKGRhiFmIaKrsTomo4bxpi6xsn5GvlZnbSpcvx8E0WhW5Ppy1mMT1WlRSN2hOsuZEaKKSbSGk1PxQ9DdF1DVRQWGgG7e3VKGYMb+jMs2S7nvABNV+nPm4z15Xh3ts7kQosoVtCCEDtqt1gUBdKWSV/WxPbr2D6kdAVFVXCVCCWGIKYbUJ2wUVTIpwwypkZvxkJXFe69oY+mF6LEcHR3EUNXaDgh/csTCjsTDNeuMpdajdiurruQ6ctZfOrQQLfLVK47jPVlWGx6nJ5vMVtzyad0enMp5ho+KU3BMnQ+uqfI3r4ML56Yw/YCJpda4MaEtGsvKU1hrC9DpW7j+QFuGKOE7ZZOxlCJvAhNg6wGpqHRn7MgVqi7IVGk4McxB/pzfGxfLyPFFPceGugGyN1j/VRtv7tUopg2Obq7uCpMpFYjtqvrLmTgQuF3vuHiBBFHhgu8emqBl9+dBxXG+rLUnYCphSbh8uhSpeFyy54SZxZanCjXKaVNotAlXB42qjRcXjxRIWfp5NIqvg9O0N7bNwgiUgbkTJ3RviwpTWNvXwYnCPjzdJ1iSufgQJbRnjQpXV11tMt8wwVgKG9dtJVo5+MbnWIgxHZwXYZMx8q//gtNF93U2NeTIp820FWFjGXgRxEjBQtDUzk2U+XP00vMNwNGChYZQ2Om2kJTVWw/ounFRJHPcClN3fFRnAjXa3enUqbOx/b20p81OLdoM1g0aXk6utpgT0+Gsf4cJytN/o9XTnPPgT7uXq6/HDtf5/XTi9y+v4d7D/Zv+D0cGsqt2ihdiO3iug6ZlSM1vVmTzxwe6G4qTtyujwwX09yxr4fj5Tq/OlYmihQKlk5PxuDd2ToRKsW0TsqMWGwE6JpKEEEpa6GpCnXFR9fa3aMwinhrukoYwfSiQynb3glvruFystLk7GILPwwJIhjtzS63Zjo7j68/EWejHQBF8qQOdmWu65DpdEtenqgwMdvk9n0ljgwXKKQNQFm1CbkbxuStRfpzAaWsyVLLI5+y2FXUGCml+ONkFU0N0FWFnKkTxiEpQydGwVAVShmDcs2hXPMwNZhcbFHKlujLmiy1AlKGykO37mLJ9hntzXRD48hwgWLa3DBE1p6aKTaP1MGuzHUdMtD+a/T2VJUz800ODmaXV3APXPQYiBkpWdQdHy+IKaRN9vRk+N/u2sufzi3x9lSNUsbkpl0FdFVhesmhkFIY7UlTdXzmGj5KHLG3N0XGNHCDkKYTkLcMbC+iL2tiGhqf2de76tQFCZHtS1qRV+a6D5mZqsOiHZDSNQppc93HPPfOef7ft8v0pHUKaZ1Kw6XlReztzeCGEZOLLVRVIasbLDU9vBA0VWFXMUMupROGMXMND1C4cyhPT85kct4mjttnLt04kqflhxcdTie2N/kDcGWu65CZb7hUbY9bdhcopE2ODK//5D411+TMnM18SkXXVKYXW+i6ysHBEFCIUejPWlRdj3OLNilT45bdJT51Yz+n55ucr6kULJ1symCwkCJt6uzphZ6MSRTDcDHFnlKac0s2R3cV1z1yV4id6poMmSstyM1U25s/XW5k5s59vVTqLpahMlvz8IKQUtrk/sODHBnOU7M9RgopvCBkaskGRWEwb2HqGp85PETNDjA0lVv2FHnw6Ajj01Wyls6BgWy33tKXs7hz+Xo/eeMcvz+1SMMJuhuGy19MsVNdkyFzpQW59frU6wXUUDHNraM9KMAnD5osND2iSGGomKYvZ1FIG3hhzO37e/nqwQHml+fM1O2Q1KDOw38xtqpl0rO8Cnyj8OisuLZ0VQqLYse7JkPmSgty6/Wp1wuokWKKiZRO3Q4Z7c1y91hfNyTalOURZqUbUkd3FVe1Qg4O5Zko1/nJG+fYU0qTMjf+0XeO3F0ZeELsVNdkyFxNQW5ty2W9gFq5FKHzuJVf/8hwnmLaYKSYumhy3HzD7Z6V9H+/McmxmSZHRrIcGix2v/aH8X2I7Unm0rR3jUzcd7/7Xfbv308qleKee+7h1Vdf3YzLXpFOKKw8aXLtuqC17+8ER2fK/8qPjRRTq6b3d77++HSVlhth6Qr7+7KyBOA6sfb363qUeEvmRz/6EY8//jjf//73ueeee/j2t7/N5z73OY4fP87g4GDSl7+s97M/y9ou1dq/Vis/P6WrKErMnlKagbzF2kl+4tp2JV33a721k3hL5h//8R955JFH+NKXvsTNN9/M97//fTKZDP/2b/+W9KWvSKcV0jkJ4Er+4mzUWul87sqWjhNExHH7ZIR7Dw5w78H+DX+R1raQxM63Uct4pcu1dnb670WiLRnP83jttdd48sknu+9TVZUHHniA3/72txc93nVdXPfCD7JWqyV5e6tcabF4vb86az93ZUvnamaFyjT169Plfkd2+u9FoiEzNzdHGIYMDQ2tev/Q0BDHjh276PFPP/003/jGN5K8pQ1daZF1vf/wjbpIa7dtuByZpr6zvZ9uz5V8zk7/vdhWo0tPPvkkjz/+ePftWq3G6OjoFt7Rxa7kP7zTRbrao2RlNGlnu5oWRydcqrbPbM1d9Tnr1fiu5Pdiu9Z2Eg2Z/v5+NE2jXC6ven+5XGZ4ePiix1uWhWVtnx/Oeq7kP3yn/+UR78/VdLlfPFGh7gQcGLh4pPFyAwsb2a7dqkRDxjRN7rjjDp577jkeeughAKIo4rnnnuOxxx5L8tJb6nJBtF3/4ogP5lL/7/MNl2Pn67RnbSrU7ZB8WufIcAFg1RSKS9X4rnQG+3b6HUu8u/T444/zxS9+kTvvvJO7776bb3/72zSbTb70pS8lfelta7v+xREfvs6T/c/TS/z6WIWhQoq/+sgQt+0rdQNg7ckTa8PqSqdZrPy87XSaReIh8zd/8zdUKhW+/vWvc/78eT72sY/xi1/84qJi8PVEulPXj84flNPzTdwgJmOpHBkurJrUuXJ/5pUtns7jruQYnCuZub5VNqXw+9hjj13T3aOrJQXea8PKTdw3Wi3feZKP9qS4cai9lUfnMevtzzw+VV11KgXQDR1LU2l5PpMLzYuutbZ1vJ2KxdtqdEmInWSm6vDG2UWWWh6ljAl7ezZ8oo4N5LhzrA+AiXKdV07OY2oK+bTJ5EKTqu11z9NaeSrFsfN1/s9XztDyQz66u8DUkgMxFNIm9x68uEt1tedwbUbXXUJGiPepszq/6YTkU/q6W4ZUbY+3zlXxw4jRnjSTizblmsOfpqrs7cnwiRv6eGViniXbZ19/liAMGSqkugtya7bHXMNhuuoSRxHZlEHD9fnJG5NYmtINrpXH/IxPVZlcaHb3rV67neza72Hl6yRIyAjxPq1dnQ/w8sQcNdtr7zkUg6Ur/HFyiSCKefPsItM1F00BP4iotTzcIGS26XJuscXUUouGF6IrcLLS4CO7esinNcb6c1TtAENX+fiBPv54bpFTcza/PTnfDZmOzhE6lnFh+5HLfQ9Jd90lZLap7TQEKTa2dkTn9dOLnFloUrV9iun20TkLDY9cSuPG4Ty2F5EyVWpOgBfFBGHMSMHi3HyTQlYnZ+kstHxqTkDVcTlfixgomNyxv4dqyyNnady1v4+BfINPHOhb547aR+iM9mYY7c1eP4VfcfVkmHt722gN28GhHNPVJnN1B9szKKQMcmkd2w84OdtkbDBDFMbMVm3KNRddU8hbBqahYegqcRyjKrDU9JitO8xWXUoZE0WJObPg8Nw7ZW4cKnLvwcFuK2blvQzlU4z2pjk0mN82+0NLyGxT22kI8npzJa3IY+drvH5mqVvzmCjX+e9/PMd7lSa6qlJMG4RRTE/GIJ8q8B/j55loNdnVkyZv6lSaLk0/YvzcEr1Zi6FciliJOb1g03J9Gl7AfMtFVVTCKCZtagRRiGVo3cP+OhvOe0FIueZx+74SxbT5vpa0JElCZpuSYe6tc2WtyNU1j/HpKr9+d4667XPHvh7G+rOMz9Q4s2izu2ihKgqmrpA1VOwggFhBI8L2oBK36zQhMUtND00FI4poeTEKETXHZ7HlUbV9IGa0N02l5vCvL56i6QXcNJwnbej85vgsfhBx80hhW836lZARYo2NWpErn7Qrt1yF9ubv99/Y3z79YjjHzJKNpigEYcjZeZsohmLGwNBUKg0XL4ywfQiAwI2Y9G38CELAiMHUFcIgJAjBzkR4UUgQxczWPMbPVXn11ALlmo2qqRwayFIqGvxifIYl2ydtapedvLeZNmX7TSF2gs7wL7DuRlPHztf49bFZjp2vXbQdqxNE/Kdb93B0T5E/T9c5M28TRDELDZ+T8w0ajkdv2mKh6VOpuShEqFr762ZNUFTQNShaCvv606RNFScAN4KlpkvBNChlDAbyKaaWbOpOgB+GaEDVDshaOgP5FP15i4yh8ZM3zuF4wbbY5lVaMkIsu3w3qd1FqtkBL0/MATFD+RTPvVPuztxdcnxQoC9nMt90ma07zDccak5IFFfxgoiqE2GqoKmQ0RUypoYXRkQR3L6vl2La4Pl3y4Rxu0fWcELOY7O7N0vd8XC8CEWNKaRMgjim4fiMFFN88d4xIObtqSr//c1pDo/k+F9vG12eMVzrLlPYbBIyQiy73DlcnS7Sn6eX+G9vTjFUaG/D+urpeWpOwO4ei0rNZnyqhqEpLDQ9DE0jjMANoVwP2mUcwI+ACFQlxvZjojgCRWOm2mKmqlJzIjqlWzcGK45ZbDpMLbYAyKd03CCk4Qacmm/xx6klDvTnOLqryOvBAmfmm5i6wvh0lcl5m1YQUKm7fOrQwKYHjYSMEMsudw5Xp3v0x8lFGm5IzvNxg5C0qbO3N8ORkSL/MX6eU3Mt4rhdiyhmVHqzFi3Pxg8hAjSWy8UKEIPjh4QRBIScXWqRMww0BYL4wn0YKswseQSAAcRxRMuNcCJQWw4vHK/wznSdhhNwdqGFqqrkUjpHdxUZyKc4PdegbofMVB0JGSG2k5FiiqrtU7U95hsufTmLe8b6llspKn7Q3mbVD2P+x/h53CBCV9otlQBoeTGDxfYku6VWQEprd5NiFBRidF3BD2Lc5WZL0wXX9UkZCqoaYwftLKq2IoLle/KBJftCSycK25P6XD/gvdkGpq5xZDjHruUW2b0H+zkynN+ygwIlZIS4hL6cRTHdbs0U0+1WQE/W5PZ9vUBMzfaZXGxxfKbOTN0hCEJ68yZBENF0Q9KmihJHNL2AAIhCIISIGAMIw7jddVohAJr+hWZMvPw+6DZ+iGm3lGLaiyXvHutFUxX+NFUla+ocGMjw7myL352a5+BQfkunREjICHEZa2s1nYl4BwezFNIG9x7sZ6iQ4vXTC5RrDhEKxZLObN2h6YacmmviLafEyjzx4aKA6YjXfzcx7SdtRDtwIqBc83jtzAKHhvLUbR/fjyibKjoKDSdgolzfcCuKzSAhI8RlrG0F1OyActXB1BQabsBM1aYna9KXT7HY8phcaDFbi1FUjbrtE0TtOkzM6pDZiE57vsxGQRPSrsugQrQ8t+bYTJ1yzaGUMSlmDLKWQW/OIGPp/O7UPOcWW+wuZXjwoyNSkxFiq200U7bzfogZKqbY25fhV8fKvHZ2iVLKYHdPmpYf0vDak+hUJUJRLrQ4rnSif3CZj3e7TxEULAVdU3GDEDeIaHkhpqbQkzH56O4So70ZJheavDNTY/xclcWWx4NHRzZ1XZOEjBBrbHRaQOf4ksGCxf1HsjheQN0JCaOY3qzBR3cXgZiFhocdBFiahqIoaCrUnQD7culxFTqBZegaHxst0XBCYiXGUjR6sikOj+QZyJlMzNbZU0pz78F+fnWszFvn2pMN/zZrblqLRkJGiDU2Oi0gZSgoSoylKVRtj//v7TInK008P2ByscVIKQ2ApikELlg6FFI6ThDih8nca9rQGCqmuGt/loODOf50bolyzWFyocXLExVsN+aBjwxycDDPDQN5zi22MFStO5Qt228KsQVW1mDmGy5V22ewYAIKJysN/p+3ZjhRbtByA1K6QlNVqDsBb01V0VWFasvHDcFthdTtdroEGxVYPqCZJZeXT8xhHtboyRo8/26FhZZHb7qOZepYqsKpSoM9pTSfPDTQ3Y84pau8PDHH6bkGUaTAvpKEjBCbZeVf95mqw2zN5dBQjpSu8vqZBSaXbOZbLn1Zi+GiRaVmoakwNpDlraka6orN6PyEwqUjAGbrDuNTNc4uNJlrePhhwMhIgbGBLJMLNsfLDcYGcjx0257u53U22GoFAYeH8rL9phCbaWVNZmXXaabaHr25a18Pd+3roZQ2QFHww4jhQopfHy9Ta3mkTJ1W8CEWYC7D1FS80KdcDxjKm/TlCxwaznPTcIH+rEWl4bJnuSvXkdJVerIGB3NZ7h7rS7Q+IyEjxBorg2Xt8HV1IEvNDoCYs/Mtllo+t4yW+MOZBX51rILtBsTEGEryrZiOMIoJIojjmJtGS+iqwk9fP8ezLY8b+nN87qO7SJmrn+pOEJG1DEZ7s7LHrxCbbaPZse3ZvyYvvTvH8XKdqu3i+BEnK3XSpkYpbRDFEbqqEkUeih/jrQialAKmoVLzPtiudXkTdFXF8SL8CPww5txCk1LWYmbJZqHpMVN18aL2xuIf2dMipa/e1WUzd16UkBHiKowUU5QyBi3PZ77h0XQDqrbPnlIaVYFay8fUVXKWThD5RCEocXt2rxtD8AEDRl+edxPESvcggiAEiFlsebw9VcU0NIoZnXorYKSUZk9PhhOz7a06j+4qdpcZwOozuJMiISPEVejLWRzdU2R8qkrDXd5zN1ZYsn2mF5t4YXuFdMbQSekqURSxnAGr1iB1qGw8SU8BTK29TURHEEPDg6wVomnghO0ncT6lL0/6U+jNmty5fwjHj7llT4Gbd5V48+wi78zUAboT8TZrs3oJGSGu0pHhAh+/oY+spaMQo2kqE+UGXhgThCEZQ2PR9tE0hZSp0nTXj5FLBQy0Q8ldZ36NQnu1NrSXKxQzOh/ZXaRuB7hhxF37e7l9Xw/QXjw5UkyROtDX3fph5fG6m7FznoSMEFehc0LA4aE8WUvneLnB5HwTLwrZU0xx78F+Xn5vjpmag+dHqAqEy3WZzgpqaO8PE77PntPK3DEUOLqn2O4S+Q0ylkbaVJmYbdKbNXD8mGLa4OjuYrcF09n7d7Bgds/bXvv9dbpVHwYJGSGuwvh0lRffnWN/f4YHj44wkLf4WdNh9qzDVBiz0PIIYjBUBTuGIFrdWtE7m1FFV76W6VLSlsr0ksNEuY6qqvRlTWwvwlcDLF2hN2ut2gsHLhR7q7Z/UXdpfLrK708tAkjICJGktdPtO2/vKaXZ35/B0FReeHeWk3NNLE0jaxnMNRxmaw4pSyefMlHwsf0A329/zRiIYsia7eXTQbDxSmtod4WypkoYRTSXizmWBqoCpq5i6SoZ06BSs2n6ESkdNFUhVuDwSJ66E7DQdJmYbVCzfUZ7s6R0lXLdpbM/8coTF6B96sLK1x8GCRkh1rG2KNp5+9BQjr+9Zx8zVYfv/+YEr51domjp9GRMimmDlhcQRDG2F1LKmexPZXmv0sANIoIwJowhiiJMTSUVR3jh6u7PSiHgBhExkNZhX1+WUsbEDyP8MMTzY4ZLKebqGmcXWqQNFZYHnTpndE8utFhotEfCHD9GUWIm521Q4P4jgxzdvTpMDg59+CdPSsgIsY6180g6r1O62m3hHB7Kc6xcR1VB1VRGcgY1J8Q0FJpOSMrQ6M2193aZb7pML9lUWwFx3G7RdIagNS4EzdpicKdmrEeQMzVSmkbLDWh6IQoKYaRwaDDPUCGNaShUWyEDy/N8+nIWI8UUo72ZbgumZnscHMoBMVXb73ajklwoKSEjxDrWTsjrvL3ywLTb9/fR8iL8OMJQFfwoxgkcPraniKEr/PxP05yv2vTnLRw/wgsidBUsQ6UnbTLfcvGCGFMHY3moOmu0t7eqOdGqRZUBMFW1Gci3w8ELIoaLKdKGih/FfOamQQ4P5Tm3ZLOnlGZ8qrpqxvJ8w2V8ukrdCTgwkKVSdzk526SYNla11Drf64cpsZD5h3/4B37+85/z5ptvYpomS0tLSV1KiE2zsmUzUmzXNKq2x8lKE9sLGMiaGLrK/5yocHbBJorb+wDPNRzCENKmxkgx1d5Y3PbaW0cst3h6UiYoMN/wsH0HLVqu40TtbSNi2mc0BWFMHAd8/IY+SmmTE7N1DE3FDSMG8ilOzjV59eQCh4Zy/OVNQzhBRNX2qdsh+bQOtFeN59P6hi22D1NiIeN5Hl/4whf4xCc+wb/+678mdRkhNtV6Sw6qto+qQNowyKc1/nBqkYVWQCndfnq5QYyqtjtFhq4QxVBpeqiKStaKUIGlRsBSw0PXNBwvwA9A0+BAX5ZYAS+IKKUNbhvt4Y3JJWwvwtBVju4psuQEnJ5vcGq+SUbXsQOf6aoDKgwVU8SxwmDB5LZ9pW6IdAq+ne8lyY3GEwuZb3zjGwA8++yzSV1CiC0133B58USFuh1yYDAHQM32ODCYpVK3cf32rnmKEqAqMWGsEEQQRBF9GRNNgZrdnqXr+X576BuF3pwFDZeUoVFI6aiaQtMN22uUFlv0ZgzmWx6T8y1uGi6yp5RabqVo7O/PYWkKu4oZ+nImhwbz624ivpn7/G6rmozruriu2327Vqtt4d0IcWkzVafb7eicazRbc+nNmuQzJnrdoy+r0fA0GnZArMcUUjrFjIGuqezryzDX8FiyfSxDxfNCxvqz3Lq3h7rjU6m7nK+6+FFMX1bH9iPOzNtExFiqwmLT5+UTc9x7qJ9bRzOrguTOsb4t/ulcsK1C5umnn+62gITYapfbUDylqxwYyNIZJup0RSbKNWotj7G+NP/p1t28cnKeuYaLBty0q4DthUwvuphFi719OYoNDz+OsP2IvQN5bt/Xy+RCk3zaxDJUziy0UFAZ609zolxjtuaRS+lYuoKurZxHvPE9byX18g+54IknnkBRlEu+HDt27H3fzJNPPkm1Wu2+TE5Ovu+vJcQH1Rlx6axUhgtdpDfOLuIEEcW0yWzN7e6Ze3R3kclFm3NLNrM1j5Sp0Z8zyVk6xbRJxjLQdY1CRufwcI7hgkWkwg2DOY7uLuD5PjNLLc4u2JycbaCrKildZa7p0vQCdpWy5JZrPU0vpFxzmKna3fucqTq8cXaRF09UmG+4G31rm+qqWjJf+9rXePjhhy/5mAMHDrzvm7EsC8vaHukrxHojLjNVpztKs/L9K//9iQN9vHu+hhu0R5b29mW5eVeRUsZgIGvS8gKqTsRA3mJ6yaZhtwu1fVmDSiPgRLlOytTIpzUWmwF1J8DUVPaUMtw11serJ+c4UWnhBT6VRsxcw+Xusf4LLamUvmXnXq/nqkJmYGCAgYGBpO5FiG1lvRGXkWIKlkdpVo7MrHTnWB+VhsfLJ+YAhUODOdwgwtI1HD9isGAxfm6Rl07MEUQx+ZTOYM7ixuE8A3mH6cV2K2igYBKHPrmUwa27S/ztPfvoyZrMVFvMN136shlSpsZnDg+tmrnbme27FederyexmszZs2dZWFjg7NmzhGHIm2++CcDBgwfJ5XJJXVaIRF3pUG8hrTNUTAHx8gmONrfsKXa3Vnh7aom5houlqxweLrC7J82nDw3gBBH/1x8mcUOHwYJFEMMgcPeBXpwg4tVTC/z+1CJzDZdC2uQ/37L7oqUBW3nu9XoSC5mvf/3r/Pu//3v37dtuuw2AX//619x3331JXVaITbey2Arts7Jrtr+8p0vMn6Zq1O2AQtrsBoKuKWQtjcPDeW4aKRJFSneo+a8+MkxnAWO57gAKNdvj18dmMbX23JfhosXdY33da27Hgm9HYiHz7LPPyhwZcU1be7IkQNX2+I+3ZsindP6XW3av6LIoHBm+sPDQD2IMTWe0J8vnj46sCqmVE+U6ixX/460ZylWHW0YL3Hto8KIw2axd7t6PbTWELcRO0nliDxbM7rlMb1aa6KrK7tKFeStHhgurRqgA9vZlODCQZW9fZlX3ZnyqyhtnlphIa3zq0ED3/Z3u1+6ezEXdI9jcjcGvloSMEO/T2qNTxqeqxLHCx/aWVgXEeq2Mu8f6GO3NXtTdSekq+bRG3QlWjQ4dGS5QXN5Kcz3brQ6zkoSMEO/T2if2eiNP3feveN0JFMcLePFEhaO7ijhB1N2vZr3Roe0cIpcjISPEh+RS5zWtVz85t9jkfLVdy/nUofbUkPUOlNvpJGSE2GQpXUVRYm4eKTDaG3B0V/GaC5aVJGSE2GROEBHHCkPFNH+5ThH3WiMhI8Qm284jQUmQkBFik11t12g7T7S7EhIyQmxz23mi3ZWQkBFim9vp3SsJGSG2uZ0+8nRVm1YJIcTVkpARQiRKQkYIkSgJGSFEoiRkhBCJkpARQiRKQkYIkSgJGSFEoiRkhBCJkpARQiRKQkYIkSgJGSFEoiRkhBCJkpARQiRKQkYIkSgJGSFEoiRkhBCJkpARQiRKQkYIkSgJGSFEoiRkhBCJkpARQiQqsZA5ffo0X/7ylxkbGyOdTnPDDTfw1FNP4XleUpcUQmxDiZ27dOzYMaIo4gc/+AEHDx5kfHycRx55hGazybe+9a2kLiuE2GaUOI7jzbrYN7/5Tb73ve9x8uTJK3p8rVajWCxSrVYpFAoJ350Q4kpdzXNzU0+QrFar9Pb2bvhx13VxXbf7dq1W24zbEkIkaNMKvxMTE3znO9/hK1/5yoaPefrppykWi92X0dHRzbo9IURCrjpknnjiCRRFueTLsWPHVn3O1NQUn//85/nCF77AI488suHXfvLJJ6lWq92XycnJq/+OhBDbylXXZCqVCvPz85d8zIEDBzBNE4Dp6Wnuu+8+Pv7xj/Pss8+iqleea1KTEWJ7SrQmMzAwwMDAwBU9dmpqivvvv5877riDZ5555qoCRghxbUis8Ds1NcV9993Hvn37+Na3vkWlUul+bHh4OKnLCiG2mcRC5pe//CUTExNMTEywZ8+eVR/bxFFzIcQWS6z/8vDDDxPH8bovQojrhxRJhBCJkpARQiRKQkYIkSgJGSFEoiRkhBCJkpARQiRKQkYIkSgJGSFEoiRkhBCJkpARQiRKQkYIkSgJGSFEoiRkhBCJkpARQiRKQkYIkSgJGSFEoiRkhBCJkpARQiRKQkYIkahNPab2anX2A5bjaoXYXjrPySvZs3tbh0y9XgeQ42qF2Kbq9TrFYvGSj7nqEyQ3UxRFTE9Pk8/nURQl0WvVajVGR0eZnJyU0yoTIj/jZG3mzzeOY+r1Ort27brsoY3buiWjqupFZzYlrVAoyBMgYfIzTtZm/Xwv14LpkMKvECJREjJCiERJyCyzLIunnnoKy7K2+lauWfIzTtZ2/flu68KvEGLnk5aMECJREjJCiERJyAghEiUhI4RIlISMECJREjLrOH36NF/+8pcZGxsjnU5zww038NRTT+F53lbf2o713e9+l/3795NKpbjnnnt49dVXt/qWrhlPP/00d911F/l8nsHBQR566CGOHz++1bfVJSGzjmPHjhFFET/4wQ94++23+ad/+ie+//3v8/d///dbfWs70o9+9CMef/xxnnrqKV5//XVuvfVWPve5zzE7O7vVt3ZNeP7553n00Ud55ZVX+OUvf4nv+3z2s5+l2Wxu9a0BMk/min3zm9/ke9/7HidPntzqW9lx7rnnHu666y7++Z//GWgvfB0dHeWrX/0qTzzxxBbf3bWnUqkwODjI888/z6c//emtvh1pyVyparVKb2/vVt/GjuN5Hq+99hoPPPBA932qqvLAAw/w29/+dgvv7NpVrVYBts3vq4TMFZiYmOA73/kOX/nKV7b6Vnacubk5wjBkaGho1fuHhoY4f/78Ft3VtSuKIv7u7/6Oe++9l6NHj2717QDXWcg88cQTKIpyyZdjx46t+pypqSk+//nP84UvfIFHHnlki+5ciCvz6KOPMj4+zn/5L/9lq2+la1vvJ/Nh+9rXvsbDDz98ycccOHCg++/p6Wnuv/9+/uIv/oJ/+Zd/Sfjurk39/f1omka5XF71/nK5zPDw8Bbd1bXpscce42c/+xkvvPDCpu/DdCnXVcgMDAwwMDBwRY+dmpri/vvv54477uCZZ5657O5fYn2maXLHHXfw3HPP8dBDDwHtJv1zzz3HY489trU3d42I45ivfvWr/PjHP+Y3v/kNY2NjW31Lq1xXIXOlpqamuO+++9i3bx/f+ta3qFQq3Y/JX9+r9/jjj/PFL36RO++8k7vvvptvf/vbNJtNvvSlL231rV0THn30UX74wx/y05/+lHw+3611FYtF0un0Ft8dEIuLPPPMMzGw7ot4f77zne/Ee/fujU3TjO++++74lVde2epbumZs9Lv6zDPPbPWtxXEcxzJPRgiRKCk0CCESJSEjhEiUhIwQIlESMkKIREnICCESJSEjhEiUhIwQIlESMkKIREnICCESJSEjhEiUhIwQIlH/P1ZRopUkNoCIAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# visualize\n",
    "f, ax = plt.subplots(1, 1, figsize=(3, 3), subplot_kw=dict(box_aspect=1))\n",
    "ax.scatter(posterior_samples[0, :, 0], posterior_samples[0, :, 1], alpha=0.2, s=1)\n",
    "ax.set_aspect(\"equal\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(\"test.npy\", posterior_samples.astype(np.float32))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## BayesFlow trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_npe = bf.trainers.Trainer(\n",
    "    amortizer=bf.amortizers.AmortizedPosterior(\n",
    "        inference_net=bf.networks.InvertibleNetwork(\n",
    "            num_params=2, num_coupling_layers=4, coupling_design=\"spline\", permutation=\"learnable\"\n",
    "        ),\n",
    "        summary_net=bf.networks.DeepSet(summary_dim=2 * 2),\n",
    "    ),\n",
    "    generative_model=generative_model,\n",
    "    default_lr=1e-3,\n",
    "    memory=False,\n",
    "    checkpoint_path=f\"checkpoints/gmm_npe\",\n",
    "    max_to_keep=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = trainer_npe.train_offline(train_data, epochs=30, batch_size=32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# add your models here\n",
    "\n",
    "trainers = (trainer_npe,)\n",
    "names = (\"NPE\",)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation on a single data set (pretty plots)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "theta = np.array([[-1.8, -1.0]])\n",
    "x = simulator(theta)\n",
    "test_sims = {\"summary_conditions\": x, \"parameters\": theta}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# HMC\n",
    "logger = logging.getLogger(\"cmdstanpy\")\n",
    "logger.addHandler(logging.NullHandler())\n",
    "logger.propagate = False\n",
    "logger.setLevel(logging.CRITICAL)\n",
    "\n",
    "n_sim, n_obs, data_dim = x.shape\n",
    "_, param_dim = theta.shape\n",
    "\n",
    "stan_data = {\"n_obs\": n_obs, \"data_dim\": data_dim, \"x\": x[0].numpy()}\n",
    "model = CmdStanModel(stan_file=\"gmm.stan\")\n",
    "fit = model.sample(data=stan_data, iter_warmup=2000, iter_sampling=5000, chains=1, show_progress=False)\n",
    "posterior_samples_chain = fit.stan_variable(\"theta\")\n",
    "posterior_samples = np.concatenate([posterior_samples_chain, -1.0 * posterior_samples_chain], axis=0)\n",
    "\n",
    "side = np.linspace(-3.0, 3.0, 200)\n",
    "xx, yy = np.meshgrid(side, side)\n",
    "obs_data_rep = np.concatenate([x] * 40000, axis=0)\n",
    "params = np.c_[xx.flatten(), yy.flatten()]\n",
    "\n",
    "f, axes = plt.subplots(1, 5, figsize=(10, 2))\n",
    "hmc_kde = stats.gaussian_kde(posterior_samples.T, bw_method=0.05)\n",
    "hmc_pdf = hmc_kde(np.vstack([xx.ravel(), yy.ravel()]))\n",
    "\n",
    "axes[0].pcolormesh(xx, yy, hmc_pdf.reshape(xx.shape), cmap=cm.viridis, rasterized=True)\n",
    "axes[0].scatter(theta[:, 0], theta[:, 1], color=\"magenta\", s=50, marker=\"*\")\n",
    "axes[0].set_aspect(\"equal\")\n",
    "axes[0].set_title(\"True\", fontsize=24)\n",
    "axes[0].set_xlim(-3, 3)\n",
    "axes[0].set_ylim(-3, 3)\n",
    "\n",
    "for i, (trainer, name) in enumerate(zip(trainers, names), 1):\n",
    "    # Compute log density on relevant posterior range\n",
    "    lpdf = trainer.amortizer.log_posterior({\"parameters\": params, \"summary_conditions\": obs_data_rep})\n",
    "\n",
    "    # Plot the density map using nearest-neighbor interpolation\n",
    "    axes[i].pcolormesh(xx, yy, np.exp(lpdf).reshape(200, 200), cmap=cm.viridis, rasterized=True)\n",
    "    axes[i].scatter(theta[:, 0], theta[:, 1], color=\"magenta\", s=50, marker=\"*\")\n",
    "    axes[i].set_title(name, fontsize=24)\n",
    "    axes[i].set_aspect(\"equal\")\n",
    "    axes[i].set_xlim(-3, 3)\n",
    "    axes[i].set_ylim(-3, 3)\n",
    "\n",
    "plt.savefig(\"gmm_density.pdf\", bbox_inches=\"tight\")\n",
    "\n",
    "f, axes = plt.subplots(1, 5, figsize=(10, 2))\n",
    "\n",
    "axes[0].scatter(posterior_samples[:, 0], posterior_samples[:, 1], alpha=0.1, color=cmap(4), s=1, rasterized=True)\n",
    "axes[0].scatter(theta[:, 0], theta[:, 1], color=\"magenta\", s=150, marker=\"*\")\n",
    "axes[0].set_aspect(\"equal\")\n",
    "axes[0].set_title(\"True\", fontsize=24)\n",
    "axes[0].set_xlim(-3, 3)\n",
    "axes[0].set_ylim(-3, 3)\n",
    "\n",
    "for i, (trainer, name) in enumerate(zip(trainers, names), 1):\n",
    "    samples = trainer.amortizer.sample(test_sims, n_samples=2000)\n",
    "\n",
    "    # Compute log density on relevant posterior range\n",
    "    # axes[i].scatter(posterior_samples[:, 0], posterior_samples[:, 1], alpha=0.1, color=cmap(i-1), s=1, rasterized=True)\n",
    "    axes[i].scatter(samples[:, 0], samples[:, 1], alpha=0.5, color=cmap(i - 1), s=1, label=name, rasterized=True)\n",
    "\n",
    "    # Plot the density map using nearest-neighbor interpolation\n",
    "    axes[i].scatter(theta[:, 0], theta[:, 1], color=\"magenta\", s=150, marker=\"*\")\n",
    "    axes[i].set_title(name, fontsize=24)\n",
    "    axes[i].set_aspect(\"equal\")\n",
    "    axes[i].set_xlim(-3, 3)\n",
    "    axes[i].set_ylim(-3, 3)\n",
    "\n",
    "plt.savefig(\"gmm_samples.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation across the prior space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger = logging.getLogger(\"cmdstanpy\")\n",
    "logger.addHandler(logging.NullHandler())\n",
    "logger.propagate = False\n",
    "logger.setLevel(logging.CRITICAL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_hmc(test_sims, iter_warmup=2000, iter_sampling=2000, **kwargs):\n",
    "    x = test_sims[\"summary_conditions\"]\n",
    "    theta_true = test_sims[\"parameters\"]\n",
    "    n_sim, n_obs, data_dim = x.shape\n",
    "    _, param_dim = theta_true.shape\n",
    "\n",
    "    num_posterior_samples = iter_sampling * 2\n",
    "\n",
    "    posterior_samples = np.zeros((n_sim, num_posterior_samples, param_dim))\n",
    "    trace_plots = []\n",
    "\n",
    "    for i in tqdm(range(n_sim), desc=\"HMC running ...\"):\n",
    "        stan_data = {\"n_obs\": n_obs, \"data_dim\": data_dim, \"x\": x[i]}\n",
    "        model = CmdStanModel(stan_file=\"gmm.stan\")\n",
    "        fit = model.sample(\n",
    "            data=stan_data,\n",
    "            iter_warmup=iter_warmup,\n",
    "            iter_sampling=iter_sampling,\n",
    "            chains=1,\n",
    "            show_progress=False,\n",
    "            **kwargs,\n",
    "        )\n",
    "        posterior_samples_chain = fit.stan_variable(\"theta\")\n",
    "        posterior_samples[i] = np.concatenate([posterior_samples_chain, -1.0 * posterior_samples_chain], axis=0)\n",
    "        dat = az.InferenceData(posterior=fit.draws_xr())\n",
    "\n",
    "        trace_plot = az.plot_trace(dat, var_names=\"theta\", compact=False, show=False)\n",
    "        trace_plots.extend(trace_plot)\n",
    "        plt.close()\n",
    "\n",
    "    return posterior_samples, trace_plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_TEST = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sims = trainer_npe.configurator(generative_model(NUM_TEST))\n",
    "y = test_sims[\"summary_conditions\"]\n",
    "theta_true = test_sims[\"parameters\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hmc_posterior_samples, traces = sample_hmc(test_sims)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "npe_posterior_samples = trainer_npe.amortizer.sample(test_sims, n_samples=4000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cmpe",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
