{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0516fbcc-094f-4f65-a159-fd99e676b063",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "# from torchcfm.optimal_transport import OTPlanSampler\n",
    "\n",
    "from typing import List\n",
    "import time\n",
    "from torchdyn.core import NeuralODE\n",
    "\n",
    "from tqdm import tqdm\n",
    "from torch.distributions.multivariate_normal import MultivariateNormal"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34700fb1-b32b-47e8-9da7-c02458bc8789",
   "metadata": {},
   "source": [
    "# 0. Data Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e280d08-e0b6-4aca-bd2e-5a02f20fc2f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 100\n",
    "yend = 20\n",
    "\n",
    "np.random.seed(0)\n",
    "# start q(x0)\n",
    "p = .5\n",
    "z_id0 = np.random.binomial(1, p, N)[:,None]\n",
    "x0 = z_id0*np.random.multivariate_normal([-4, 0], [[1, 0], [0, 1]], N) +\\\n",
    "(1-z_id0)*np.random.multivariate_normal([4, 0], [[1, 0], [0, 1]], N)\n",
    "\n",
    "# intermediate points\n",
    "x_05 = z_id0*np.random.multivariate_normal([3, yend/2], [[1, 0], [0, 1]], N) +\\\n",
    "(1-z_id0)*np.random.multivariate_normal([-3, yend/2], [[1, 0], [0, 1]], N)\n",
    "\n",
    "# z_id1 = np.random.binomial(1, p, N)[:,None]\n",
    "x1 = z_id0*np.random.multivariate_normal([-4, yend], [[1, 0], [0, 1]], N) +\\\n",
    "(1-z_id0)*np.random.multivariate_normal([4, yend], [[1, 0], [0, 1]], N)\n",
    "\n",
    "x0 = torch.from_numpy(x0).to(torch.float32)\n",
    "x1 = torch.from_numpy(x1).to(torch.float32)\n",
    "x_05 = torch.from_numpy(x_05).to(torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70a4f611-c886-4228-8860-f8de597efa77",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['svg.fonttype'] = 'none'\n",
    "plt.rcParams['text.usetex'] = False\n",
    "plt.rcParams.update({'font.size': 12})\n",
    "\n",
    "plt.rcParams['figure.figsize'] = [4, 3]\n",
    "\n",
    "plt.scatter(x0[:,0], x0[:,1], s = 4, c = \"black\");\n",
    "plt.scatter(x_05[:,0], x_05[:,1], s = 4, c = \"red\");\n",
    "plt.scatter(x1[:,0], x1[:,1], s= 4, c = \"orange\");\n",
    "for ii in range(N):\n",
    "    xx_tmp = torch.stack((x0[ii,0], x_05[ii,0]))\n",
    "    yy_tmp = torch.stack((x0[ii,1], x_05[ii,1]))\n",
    "    plt.plot(xx_tmp, yy_tmp, c = 'black', alpha = 0.2, linestyle='dashed')\n",
    "    \n",
    "    xx_tmp = torch.stack((x_05[ii,0], x1[ii,0]))\n",
    "    yy_tmp = torch.stack((x_05[ii,1], x1[ii,1]))\n",
    "    plt.plot(xx_tmp, yy_tmp, c = 'black', alpha = 0.1, linestyle='dashed')\n",
    "    \n",
    "plt.plot()\n",
    "plt.xlabel(\"x\")\n",
    "plt.ylabel(\"y\")\n",
    "plt.xlim([-8, 8]);\n",
    "plt.ylim([-6, 26]);\n",
    "# plt.savefig(\"1_sim_samp.svg\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6ec3141-2dfa-4a07-9cb8-3cf0a316a4d2",
   "metadata": {},
   "source": [
    "# 1. Functions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b7c925e-f3e3-4b4e-9be2-e43fa7d9493e",
   "metadata": {},
   "source": [
    "## 1.1 Common Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57e6c4e7-dedb-4365-90c2-bdb4f98de940",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(torch.nn.Module):\n",
    "    def __init__(self, dim, out_dim=None, w=64, time_varying=False):\n",
    "        super().__init__()\n",
    "        self.time_varying = time_varying\n",
    "        if out_dim is None:\n",
    "            out_dim = dim\n",
    "        self.net = torch.nn.Sequential(\n",
    "            torch.nn.Linear(dim + (1 if time_varying else 0), w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, out_dim),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12c65196-3560-458a-9032-9fd43ccafc5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class torch_wrapper(torch.nn.Module):\n",
    "    \"\"\"Wraps model to torchdyn compatible format.\"\"\"\n",
    "\n",
    "    def __init__(self, model):\n",
    "        super().__init__()\n",
    "        self.model = model\n",
    "\n",
    "    def forward(self, t, x, *args, **kwargs):\n",
    "        return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56d62ffe-89bf-4481-b404-dd1a3911a5be",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_traj(model, n_samp, nt_gen, seed, x_start = None):\n",
    "    \n",
    "    node = NeuralODE(torch_wrapper(model), solver=\"dopri5\",\n",
    "                 sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4)\n",
    "    if x_start is None:\n",
    "        torch.manual_seed(seed)\n",
    "        x_start = torch.randn(n_samp, dim)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        traj = node.trajectory(x_start, t_span=torch.linspace(0, 1, nt_gen))\n",
    "        \n",
    "    return traj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76c66789-b47c-4d0b-996d-9c502527f5db",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_traj(traj, nt_gen, mid_pts = True, start_color = \"black\", end_color = \"orange\"):\n",
    "    plt.scatter(traj[0, :, 0], traj[0, :, 1], s=4, alpha=1, c=start_color)\n",
    "    if mid_pts:\n",
    "        plt.scatter(traj[int(nt_gen/2), :, 0], traj[int(nt_gen/2), :, 1], s=4, alpha=1, c=\"red\")\n",
    "    plt.scatter(traj[:, :, 0], traj[:, :, 1], s=1, alpha=0.1, c=\"blue\")\n",
    "    plt.scatter(traj[-1, :, 0], traj[-1, :, 1], s=4, alpha=1, c=end_color)\n",
    "    \n",
    "    if mid_pts:\n",
    "        plt.legend([\"x0\", \"x_05\", \"Flow\", \"x1\"])\n",
    "    else:\n",
    "        plt.legend([\"x0\", \"Flow\", \"x1\"])\n",
    "        \n",
    "    plt.xlabel(\"x\")\n",
    "    plt.ylabel(\"y\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a9d19c1-13ca-4cbd-8dea-86d94daccec4",
   "metadata": {},
   "source": [
    "## 1.2 GP-ICFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d1e3db6-8b80-4ca3-801d-deed09ae807d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_r(ti, tj):\n",
    "    r = ti[...,None] - tj[...,None,:]\n",
    "    r[r == 0] = 1e-15\n",
    "    return r\n",
    "def k11(r, alpha, l):\n",
    "    return (alpha**2)*torch.exp(-0.5 * ((r/l)**2))\n",
    "def k12(r, alpha, l):\n",
    "    return (alpha**2/l**2)*r*torch.exp(-0.5*((r/l)**2))\n",
    "def k22(r, alpha, l):\n",
    "    return (alpha**2/l**4)*(l**2 - r**2)*torch.exp(-0.5*((r/l)**2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f35c562-3dc0-440c-b2f0-84b62a8ab0bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cov_mat(ti, tj, alpha, l, sig2_diag = 1e-8):\n",
    "    r = calc_r(ti, tj)\n",
    "    nt = r.shape[0]\n",
    "    \n",
    "    Sig11 = k11(r, alpha, l) + torch.eye(nt)*sig2_diag\n",
    "    Sig12 = k12(r, alpha, l)\n",
    "    Sig21 = Sig12.T\n",
    "    Sig22 = k22(r, alpha, l)\n",
    "    \n",
    "    block_row1 = torch.cat([Sig11, Sig12], dim=1)\n",
    "    block_row2 = torch.cat([Sig21, Sig22], dim=1)\n",
    "    Sig = torch.cat([block_row1, block_row2], dim=0)\n",
    "    Sig = (Sig + Sig.T)/2\n",
    "    \n",
    "    return Sig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb7f9418-b56a-482b-bffc-f0467ad8e98e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def samp_x_dx(t, alpha, l, x_obs, t_obs, sig2_diag=1e-8):\n",
    "    \n",
    "    nB, nt, dim = x_obs.shape[0], t.shape[0], x_obs.shape[2]\n",
    "    nt_obs = t_obs.shape[0]\n",
    "    \n",
    "    r_obs_x = calc_r(t_obs, t)\n",
    "    r_obs_obs = calc_r(t_obs, t_obs)\n",
    "    Sig_11 = cov_mat(t, t, alpha, l, sig2_diag)\n",
    "    \n",
    "    k_obs_x, k_obs_dx = k11(r_obs_x, alpha, l), k12(r_obs_x, alpha, l)\n",
    "    Sig_21 = torch.cat([k_obs_x, k_obs_dx], dim=1)\n",
    "    Sig_12 = Sig_21.T\n",
    "    \n",
    "    Sig_22 = k11(r_obs_obs, alpha, l) + torch.eye(nt_obs)*sig2_diag\n",
    "    Sig_22_inv = torch.linalg.inv(Sig_22)\n",
    "    Sig_cond = Sig_11 - Sig_12 @ Sig_22_inv @ Sig_21\n",
    "    \n",
    "    Sig_cond = (Sig_cond + Sig_cond.T)/2\n",
    "    if not bool((torch.linalg.eigvals(Sig_cond).real>=0).all()):\n",
    "        U, S, Vh = torch.linalg.svd(Sig_cond)\n",
    "        Sig_cond  = Vh.T @ torch.diag(S + 1e-6) @ Vh\n",
    "        Sig_cond = (Sig_cond + Sig_cond.T)/2\n",
    "    \n",
    "    mu_A = Sig_12 @ Sig_22_inv\n",
    "    mu_A_expand = mu_A.repeat(nB,1,1)\n",
    "    \n",
    "    x_samps = torch.zeros((nB, nt, dim))\n",
    "    dx_samps = torch.zeros((nB, nt, dim))\n",
    "    \n",
    "    \n",
    "    # Prepare the batch sampling\n",
    "    x_obs_batch = x_obs.reshape(nB, nt_obs, dim)\n",
    "    mu_new = torch.bmm(mu_A_expand, x_obs_batch).reshape(nB * dim, 2 * nt)\n",
    "\n",
    "    # Repeat Sig_cond for batch and dimensions\n",
    "    Sig_cond_expand = Sig_cond.unsqueeze(0).expand(nB * dim, -1, -1)\n",
    "\n",
    "    try:\n",
    "        # Perform batched sampling for all dimensions at once\n",
    "        dist = MultivariateNormal(loc=mu_new, covariance_matrix=Sig_cond_expand)\n",
    "        x_dx_samps_flat = dist.rsample().reshape(nB, 2 * nt, dim)\n",
    "    except RuntimeError:\n",
    "        print('Sampling failed; using numpy fallback.')\n",
    "        x_dx_samps_flat = torch.zeros((nB, 2 * nt, dim), dtype=x_obs.dtype, device=x_obs.device)\n",
    "        for bb in range(nB):\n",
    "            for dd in range(dim):\n",
    "                mu_single = mu_new[bb * dim + dd].cpu().numpy()\n",
    "                cov_single = Sig_cond.cpu().numpy()\n",
    "                sample = np.random.multivariate_normal(mu_single, cov_single)\n",
    "                x_dx_samps_flat[bb, :, dd] = torch.from_numpy(sample)\n",
    "\n",
    "    # Separate x and dx samples\n",
    "    x_samps = x_dx_samps_flat[:, :nt, :]\n",
    "    dx_samps = x_dx_samps_flat[:, nt:, :]\n",
    "\n",
    "    return x_samps, dx_samps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11e4b4fe-ca7e-4706-9f1d-30a16bba8caa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def GP_FM(model, optimizer, x_data, alpha, l, nt, batch_size, t_obs, n_epochs, sig2_diag = 1e-8):\n",
    "    \n",
    "    N = x_data.shape[0]\n",
    "    dim = x_data.shape[2]\n",
    "    \n",
    "    nbatch = int(N/batch_size)\n",
    "    batch_idx = np.reshape(np.arange(0,N),[nbatch, batch_size])\n",
    "    \n",
    "    losses: List[float] = []\n",
    "    model.train()\n",
    "    for k in tqdm(range(n_epochs)):\n",
    "\n",
    "        for bb in range(nbatch):\n",
    "#             x0 = torch.randn((batch_size,dim))\n",
    "            x_obs = x_data[batch_idx[bb,:],:,:]\n",
    "#             x_obs[:,0,:] = x0\n",
    "\n",
    "            t_batch = torch.rand(nt)\n",
    "            try:\n",
    "                xt_batch, ut_batch = samp_x_dx(t_batch, alpha, l, x_obs, t_obs, sig2_diag)\n",
    "            except:\n",
    "                pass\n",
    "\n",
    "            t = t_batch.repeat(1,batch_size).T\n",
    "            xt = torch.reshape(xt_batch, (-1,dim))\n",
    "            ut = torch.reshape(ut_batch, (-1,dim))\n",
    "\n",
    "            vt = model(torch.cat([xt, t], dim=-1))\n",
    "            loss = torch.mean((vt - ut) ** 2)\n",
    "\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # Logging\n",
    "            losses.append(loss.item())\n",
    "    return model, losses"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8fed4cf-f323-4fcc-b375-398439da0d72",
   "metadata": {},
   "source": [
    "## 1.3 ICFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e1545a0-8ced-4ab5-a600-953681d09691",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_conditional_pt(x0, x1, t, sigma):\n",
    "    t = t.reshape(-1, *([1] * (x0.dim() - 1)))\n",
    "    mu_t = t * x1 + (1 - t) * x0\n",
    "    epsilon = torch.randn_like(x0)\n",
    "    return mu_t + sigma * epsilon\n",
    "\n",
    "def compute_conditional_vector_field(x0, x1):\n",
    "    return x1 - x0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bde7f36-859e-4358-83a4-db632a929e6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def I_FM(x1, model, optimizer, sigma = 1e-1, n_epochs = 10000, x0 = None):\n",
    "    \n",
    "    losses: List[float] = []\n",
    "    \n",
    "    model.train()\n",
    "    for k in tqdm(range(n_epochs)):\n",
    "        if x0 is None:\n",
    "            x0 = torch.randn_like(x1)\n",
    "            \n",
    "        # x0, x1 = ot_sampler.sample_plan(x0, y_train)\n",
    "        # x1 = y_train\n",
    "        # x0_ot, x1_ot = ot_sampler.sample_plan(x0, x1)\n",
    "\n",
    "        t = torch.rand(x0.shape[0]).type_as(x0)\n",
    "        xt = sample_conditional_pt(x0, x1, t, sigma=sigma)\n",
    "        ut = compute_conditional_vector_field(x0, x1)\n",
    "        vt = model(torch.cat([xt, t[:, None]], dim=-1))\n",
    "        loss = torch.mean((vt - ut) ** 2)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # Logging\n",
    "        losses.append(loss.item())\n",
    "    return model, losses"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38078347-9869-421f-a5a3-150f8077d7bc",
   "metadata": {},
   "source": [
    "# 2. Fitting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b44e4bf8-144c-4f52-bee0-cebc777dd452",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim = x1.shape[1]\n",
    "x_data = torch.zeros(N, 3, dim)\n",
    "x_data[:,0,:] = x0\n",
    "x_data[:,1,:] = x_05\n",
    "x_data[:,2,:] = x1\n",
    "\n",
    "alpha = 2\n",
    "l = 1\n",
    "nt = 10\n",
    "batch_size = 20\n",
    "t_obs = torch.tensor([0, 0.5, 1])\n",
    "\n",
    "n_samp = 100\n",
    "nt_gen = 100\n",
    "seed = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "927e338e-0cdd-43b3-926d-84088bd2b40c",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_icfm0 = MLP(dim = dim, out_dim = dim, time_varying=True)\n",
    "optimizer_icfm0 = torch.optim.Adam(model_icfm0.parameters(), lr=1e-3)\n",
    "model_icfm0,_ = I_FM(x0, model_icfm0, optimizer_icfm0, 0, n_epochs = 10000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb67fa2c-4d0c-4ff8-8f3f-107a8a86e155",
   "metadata": {},
   "outputs": [],
   "source": [
    "traj_start = gen_traj(model_icfm0, n_samp, 2, 6)\n",
    "x0_gen = traj_start[-1,:,:]\n",
    "\n",
    "# import pickle\n",
    "# with open(\"x0_gen\", \"wb\") as fp: pickle.dump(x0_gen, fp);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a036d72-562c-4af4-b6d6-c096c1f922c2",
   "metadata": {},
   "source": [
    "## 2.1 Unconditional GP-ICFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2cec427-e123-4969-9029-9863e001ac8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_1_10000 = MLP(dim=dim, time_varying=True)\n",
    "optimizer = torch.optim.Adam(model_1_10000.parameters(), lr=1e-3)\n",
    "model_1_10000, losses_1_10000 = GP_FM(model_1_10000, optimizer, x_data,\n",
    "                                      1, 5, nt, batch_size, t_obs, 10000, sig2_diag = 1e-6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ee38dad-31c5-42eb-97ad-986e2f94bb4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "traj_1_10000 = gen_traj(model_1_10000, n_samp, nt_gen, 1, x0_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22ee0fa3-38cc-4f3b-a502-c13fb71804dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_traj(traj_1_10000, nt_gen)\n",
    "plt.xlim([-8, 8]);\n",
    "plt.ylim([-6, 26]);\n",
    "plt.savefig(\"2_GP_path_un.svg\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8b79b4e-ca71-488b-a25f-f2874fd51a50",
   "metadata": {},
   "source": [
    "## 2.2 Conditional GP-ICFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eebe2027-8504-4a7c-9e15-4de2dd56bc4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_1_10000 = MLP(dim=dim + dim, out_dim = dim, time_varying=True)\n",
    "optimizer = torch.optim.Adam(model_1_10000.parameters(), lr=1e-3)\n",
    "\n",
    "model_1_10000, losses_1_10000 = GP_FM(model_1_10000, optimizer, x_data, 2, 1, nt,\n",
    "                                      batch_size, t_obs, 10000, sig2_diag = 1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dab1287f-3d58-4b6e-b2d5-43d9ea8bf04c",
   "metadata": {},
   "outputs": [],
   "source": [
    "traj_1_10000 = gen_traj(model_1_10000, n_samp, nt_gen, 1, x0_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc6761fb-242c-4182-ae64-5612208b1880",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['svg.fonttype'] = 'none'\n",
    "plt.rcParams['text.usetex'] = False\n",
    "plt.rcParams.update({'font.size': 12})\n",
    "plt.rcParams['figure.figsize'] = [4, 3]\n",
    "\n",
    "plot_traj(traj_1_10000, nt_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03ebb455-9de4-4051-9462-2ba1dcc9b7d9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
