{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "russian-table",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from torch.autograd import grad, Variable\n",
    "import autograd\n",
    "import autograd.numpy as np\n",
    "import copy\n",
    "import scipy as sp\n",
    "from scipy import stats\n",
    "from sklearn import metrics\n",
    "import sys\n",
    "import ot\n",
    "import gwot\n",
    "from gwot import models, sim, ts, util\n",
    "import gwot.bridgesampling as bs\n",
    "\n",
    "import models\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "norman-pearl",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "torch.set_default_dtype(torch.float64)\n",
    "torch.set_num_threads(16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "united-double",
   "metadata": {},
   "outputs": [],
   "source": [
    "PLT_CELL = 2.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "every-trade",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set random seed\n",
    "SRAND = 0\n",
    "torch.manual_seed(SRAND)\n",
    "np.random.seed(SRAND)\n",
    "# setup simulation parameters\n",
    "dim = 10 # dimension of simulation\n",
    "sim_steps = 250 # number of steps to use for Euler-Maruyama method\n",
    "T = 10 # number of timepoints\n",
    "D = 1.0 # diffusivity\n",
    "t_final = 0.5 # simulation run on [0, t_final]\n",
    "N = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "changed-people",
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup potential function\n",
    "def Psi(x, t, dim = dim):\n",
    "    x0 = 1.4*np.array([1, 1] + [0, ]*(dim - 2))\n",
    "    x1 = -1.25*np.array([1, 1] + [0, ]*(dim - 2))\n",
    "    return 1.25*np.sum((x - x0)*(x - x0), axis = -1) * np.sum((x - x1)*(x - x1), axis = -1) + 10*np.sum(x[:, 2:]*x[:, 2:], axis = -1)\n",
    "# get gradient \n",
    "dPsi = autograd.elementwise_grad(Psi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "surrounded-intersection",
   "metadata": {},
   "outputs": [],
   "source": [
    "# branching rates\n",
    "R = 10\n",
    "beta = lambda x, t: R*((np.tanh(2*x[0]) + 1)/2)\n",
    "delta = lambda x, t: 0\n",
    "r = lambda x, t: beta(x, t) - delta(x, t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "grand-missouri",
   "metadata": {},
   "outputs": [],
   "source": [
    "# function for particle initialisation\n",
    "ic_func = lambda N, d: np.random.randn(N, d)*0.1\n",
    "\n",
    "# setup simulation object\n",
    "sim = gwot.sim.Simulation(V = Psi, dV = dPsi, birth_death = True, \n",
    "                          birth = beta,\n",
    "                          death = delta,\n",
    "                          N = np.repeat(N, T),\n",
    "                          T = T, \n",
    "                          d = dim, \n",
    "                          D = D, \n",
    "                          t_final = t_final, \n",
    "                          ic_func = ic_func, \n",
    "                          pool = None)\n",
    "\n",
    "# sample from simulation\n",
    "sim.sample(steps_scale = int(sim_steps/sim.T), trunc = N);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "natural-pixel",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(np.kron(np.linspace(0, t_final, T), np.ones(N)), sim.x[:, 0], alpha = 0.1, color = \"red\")\n",
    "plt.xlabel(\"t\"); plt.ylabel(\"dim 1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "perceived-limitation",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_obs = torch.tensor(sim.x, device = device) \n",
    "t_idx_obs = torch.tensor(sim.t_idx, device = device)\n",
    "\n",
    "M = 100\n",
    "X0 = torch.randn(T, M, dim)*0.1\n",
    "R_func = lambda x: R*((torch.tanh(2*x[:, 0]) + 1)/2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fancy-lyric",
   "metadata": {},
   "outputs": [],
   "source": [
    "normalize = lambda x: x/x.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "exact-pulse",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.TrajLoss(X0.clone(), X_obs, t_idx_obs, dt = t_final/T, tau = D, \n",
    "                        sigma = None, M = M, lamda_reg = 0.025, lamda_cst = 0, sigma_cst = float(\"Inf\"), \n",
    "                        branching_rate_fn = R_func,\n",
    "                        sinkhorn_iters = 250, device = device, warm_start = True, \n",
    "                        lamda_unbal = None)\n",
    "# model.w[0] *= 2\n",
    "# model.w[-1] *= 2\n",
    "# model.w = normalize(model.w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "complimentary-copying",
   "metadata": {},
   "outputs": [],
   "source": [
    "output = models.optimize(model, n_iter = 2_500, eta_final = 0.1, tau_final = D, sigma_final = 0.5, temp_init = 1.0, temp_ratio = 1.0, N = M, dim = dim, tloss = model, print_interval = 50);\n",
    "# output2 = models.optimize(model, n_iter = 250, tau_final = 0.5, sigma_final = np.sqrt(D), eta_final = 0.5, temp_range = (1, 1), N = M, dim = dim, tloss = model, print_interval = 50);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "thrown-distributor",
   "metadata": {},
   "outputs": [],
   "source": [
    "R_func_null = lambda x: 0*((torch.tanh(2*x[:, 0]) + 1)/2)\n",
    "model_null = models.TrajLoss(X0.clone(), X_obs, t_idx_obs, dt = t_final/T, tau = D, \n",
    "                        sigma = None, M = M, lamda_reg = 0.025, lamda_cst = 0, sigma_cst = float(\"Inf\"), \n",
    "                        branching_rate_fn = R_func_null,\n",
    "                        sinkhorn_iters = 250, device = device, warm_start = True, \n",
    "                        lamda_unbal = None)\n",
    "\n",
    "# model_null.w[0] *= 2\n",
    "# model_null.w[-1] *= 2\n",
    "# model_null.w = normalize(model_null.w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fundamental-small",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_null = models.optimize(model_null, n_iter = 2_500, eta_final = 0.1, tau_final = D, sigma_final = 0.5, temp_init = 1.0, temp_ratio = 1.0, N = M, dim = dim, tloss = model, print_interval = 50);\n",
    "# output2_null = models.optimize(model_null, n_iter = 250, tau_final = 0.5, sigma_final = np.sqrt(D), eta_final = 0.5, temp_range = (1, 1), N = M, dim = dim, tloss = model, print_interval = 50);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "colonial-tours",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model_unbal = models.TrajLoss(X0.clone(), X_obs, t_idx_obs, dt = t_final/T, sigma2 = D, \n",
    "#                         eta = None, M = M, lamda = 0.025, lamda_cst = 0, eta_cst = 5, \n",
    "#                         branching_rate = R_func,\n",
    "#                         n_sinkhorn_iter = 250, device = device, warm_start = True, \n",
    "#                         unbalanced_lamda = 10.0)\n",
    "# model.w[0] *= 2\n",
    "# model.w[-1] *= 2\n",
    "# model.w = normalize(model.w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "weighted-halifax",
   "metadata": {},
   "outputs": [],
   "source": [
    "# output = models.optimize(model_unbal, n_iter = 2_500, tau_final = 0.1, sigma_final = np.sqrt(D), eta_final = 0.5, temp_range = (1, 1), N = M, dim = dim, tloss = model_unbal, print_interval = 50);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "divine-deputy",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.subplot(1, 2, 1)\n",
    "plt.plot(output[1])\n",
    "plt.title(\"Objective: primal (annealing)\")\n",
    "# plt.subplot(1, 2, 2)\n",
    "# plt.plot(output2[1])\n",
    "# plt.title(\"Objective: primal\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "southeast-multiple",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.subplot(1, 2, 1)\n",
    "plt.plot(output_null[1])\n",
    "plt.title(\"Objective: primal (annealing)\")\n",
    "# plt.subplot(1, 2, 2)\n",
    "# plt.plot(output2_null[1])\n",
    "# plt.title(\"Objective: primal\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "legitimate-sample",
   "metadata": {},
   "outputs": [],
   "source": [
    "u = np.full(dim, 1); u[1:] = 0\n",
    "u = u/np.linalg.norm(u)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "square-latvia",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.subplot(1, 2, 1)\n",
    "with torch.no_grad():\n",
    "    plt.scatter(np.kron(np.linspace(0, t_final, T), np.ones(M)), np.dot(model.x.reshape(-1, dim), u), c = np.kron(np.arange(T), np.ones(M)), alpha = 0.25)\n",
    "    plt.ylim(-2.5, 2.5)\n",
    "plt.subplot(1, 2, 2)\n",
    "with torch.no_grad():\n",
    "    plt.scatter(np.kron(np.linspace(0, t_final, T), np.ones(M)), np.dot(model_null.x.reshape(-1, dim), u), c = np.kron(np.arange(T), np.ones(M)), alpha = 0.25)\n",
    "    plt.ylim(-2.5, 2.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "blind-deviation",
   "metadata": {},
   "outputs": [],
   "source": [
    "rownorm = lambda x: x/x.sum(1).reshape(-1, 1)\n",
    "N_paths = 250\n",
    "with torch.no_grad():\n",
    "    paths = bs.sample_paths(None, N = N_paths, coord = True, x_all = model.x.cpu().numpy(), \n",
    "                        get_gamma_fn = lambda i : rownorm(model.loss_reg.ot_losses[i].coupling().cpu()), num_couplings = T-1)\n",
    "    paths_null = bs.sample_paths(None, N = N_paths, coord = True, x_all = model_null.x.cpu().numpy(), \n",
    "                        get_gamma_fn = lambda i : rownorm(model_null.loss_reg.ot_losses[i].coupling().cpu()), num_couplings = T-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "behind-philosophy",
   "metadata": {},
   "outputs": [],
   "source": [
    "paths_gt = sim.sample_trajectory(steps_scale = int(sim_steps/sim.T), N = N_paths)\n",
    "np.mean(paths_gt[:, -1, 0] > 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "brutal-vanilla",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(paths[:, -1, 0] > 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "formed-character",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(paths_null[:, -1, 0] > 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "collected-source",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize = (2*PLT_CELL, 1*PLT_CELL))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.plot(np.linspace(0, t_final, paths_gt.shape[1]), np.dot(paths_gt, u).T, color = 'k', alpha = 0.025);\n",
    "# plt.scatter(np.linspace(0, t_final, T)[sim.t_idx], np.dot(sim.x, u).T, alpha = 0.25, c = sim.t_idx, marker = \".\")\n",
    "plt.text(0.4, 1, \"%0.2f\" % np.mean(paths_gt[:, -1, 0] > 0))\n",
    "plt.text(0.4, -1, \"%0.2f\" % (1-np.mean(paths_gt[:, -1, 0] > 0)))\n",
    "plt.ylim(-2, 2)\n",
    "plt.xlabel(\"t\")\n",
    "plt.ylabel(\"$x_0$\")\n",
    "plt.title(\"Ground truth paths\")\n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "# plt.scatter(np.linspace(0, t_final, T)[sim.t_idx], np.dot(sim.x, u).T, alpha = 0.25, c = sim.t_idx, marker = \".\")\n",
    "im = plt.scatter(np.linspace(0, t_final, T)[sim.t_idx], np.dot(sim.x, u).T, alpha = 0.25, c = R_func(torch.tensor(sim.x)), marker = \".\", cmap = \"magma\", vmin = 0, vmax = 12.5)\n",
    "plt.ylim(-2, 2)\n",
    "plt.xlabel(\"t\")\n",
    "plt.gca().get_yaxis().set_visible(False)\n",
    "plt.title(\"Sample\")\n",
    "plt.tight_layout()\n",
    "\n",
    "fig.subplots_adjust(right=0.85)\n",
    "cbar_ax = fig.add_axes([0.875, 0.225, 0.025, 0.625])\n",
    "cb = fig.colorbar(im, cax=cbar_ax)\n",
    "cb.set_alpha(1)\n",
    "cb.draw_all()\n",
    "cbar_ax.set_title(\"$g$\")\n",
    "\n",
    "plt.savefig(\"fig3_growth_a.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "blessed-ancient",
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig = plt.figure(figsize = (2*PLT_CELL, 1*PLT_CELL))\n",
    "\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.plot(np.linspace(0, t_final, paths.shape[1]), np.dot(paths, u).T, color = 'grey', alpha = 0.025);\n",
    "with torch.no_grad():\n",
    "    im = plt.scatter(np.kron(np.linspace(0, t_final, T), np.ones(M)), np.dot(model.x.reshape(-1, dim), u), c = np.kron(np.linspace(0, t_final, T), np.ones(M)), alpha = 0.25, marker = \".\")\n",
    "plt.text(0.4, 1, \"%0.2f\" % np.mean(paths[:, -1, 0] > 0))\n",
    "plt.text(0.4, -1, \"%0.2f\" % (1-np.mean(paths[:, -1, 0] > 0)))\n",
    "plt.ylim(-2, 2)\n",
    "plt.xlabel(\"t\")\n",
    "plt.ylabel(\"$x_0$\")\n",
    "plt.title(\"MFL + Branching\")\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.plot(np.linspace(0, t_final, paths_null.shape[1]), np.dot(paths_null, u).T, color = 'grey', alpha = 0.025);\n",
    "with torch.no_grad():\n",
    "    plt.scatter(np.kron(np.linspace(0, t_final, T), np.ones(M)), np.dot(model_null.x.reshape(-1, dim), u), c = np.kron(np.linspace(0, t_final, T), np.ones(M)), alpha = 0.25, marker = \".\")\n",
    "plt.text(0.4, 1, \"%0.2f\" % np.mean(paths_null[:, -1, 0] > 0))\n",
    "plt.text(0.4, -1, \"%0.2f\" % (1-np.mean(paths_null[:, -1, 0] > 0)))\n",
    "plt.ylim(-2, 2)\n",
    "plt.xlabel(\"t\")\n",
    "plt.gca().get_yaxis().set_visible(False)\n",
    "plt.title(\"MFL\")\n",
    "plt.tight_layout()\n",
    "\n",
    "fig.subplots_adjust(right=0.85)\n",
    "cbar_ax = fig.add_axes([0.875, 0.225, 0.025, 0.625])\n",
    "cb = fig.colorbar(im, cax=cbar_ax)\n",
    "cb.set_alpha(1)\n",
    "cb.draw_all()\n",
    "cbar_ax.set_title(\"$t$\")\n",
    "\n",
    "\n",
    "plt.savefig(\"fig3_growth_b.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "metropolitan-vanilla",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
