{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "285f37f4-4074-4f1f-bd9e-d8622344ddc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "# from torchcfm.optimal_transport import OTPlanSampler\n",
    "\n",
    "from typing import List\n",
    "import time\n",
    "from torchdyn.core import NeuralODE\n",
    "\n",
    "from tqdm import tqdm\n",
    "from torch.distributions.multivariate_normal import MultivariateNormal"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ca61cea-00dc-4795-b86f-ca4e0846a6e2",
   "metadata": {},
   "source": [
    "# 0. Data Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d611ce9b-1aba-49e0-9357-251b2798097f",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 100\n",
    "yend = 20\n",
    "\n",
    "np.random.seed(0)\n",
    "# start q(x0)\n",
    "p = .5\n",
    "x0 = np.random.multivariate_normal([0, 0], [[1, 0], [0, 1]], N)\n",
    "\n",
    "# end q(x1)\n",
    "z_id1 = np.random.binomial(1, p, N)[:,None]\n",
    "x1 = z_id1*np.random.multivariate_normal([-6, yend], [[1, 0], [0, 1]], N) +\\\n",
    "(1-z_id1)*np.random.multivariate_normal([6, yend], [[1, 0], [0, 1]], N)\n",
    "\n",
    "\n",
    "# intermediate points\n",
    "x_05 = np.zeros_like(x1)\n",
    "x_05[:,1] = np.random.normal(0,1,N) + yend/2\n",
    "x_05[:,0] = 0.5*x1[:,0] - 10 + np.random.normal(0,1,N)*0.5\n",
    "x0 = torch.from_numpy(x0).to(torch.float32)\n",
    "x1 = torch.from_numpy(x1).to(torch.float32)\n",
    "x_05 = torch.from_numpy(x_05).to(torch.float32)\n",
    "\n",
    "# plt.rcParams['svg.fonttype'] = 'none'\n",
    "# plt.rcParams['text.usetex'] = False\n",
    "# plt.rcParams.update({'font.size': 12})\n",
    "\n",
    "# plt.rcParams['figure.figsize'] = [4, 3]\n",
    "# plt.scatter(x_05[:,0], x_05[:,1], s = 4, c = \"red\");\n",
    "# plt.scatter(x1[:,0], x1[:,1], s= 4, c = \"orange\");\n",
    "# for ii in range(100):\n",
    "#     xx_tmp = torch.stack((x_05[ii,0], x1[ii,0]))\n",
    "#     yy_tmp = torch.stack((x_05[ii,1], x1[ii,1]))\n",
    "#     plt.plot(xx_tmp, yy_tmp, c = 'black', alpha = 0.2, linestyle='dashed')\n",
    "# plt.plot()\n",
    "# plt.xlabel(\"x\")\n",
    "# plt.ylabel(\"y\")\n",
    "# plt.xlim([-16, 9]);\n",
    "# plt.ylim([-5, 25]);\n",
    "# plt.savefig(\"1_sim_samp.svg\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3498fef8-db05-4b3e-a9c9-4c7a422c5996",
   "metadata": {},
   "source": [
    "# 1. Functions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2acf82ab-5ca8-4920-96f1-20cc0c0889b8",
   "metadata": {},
   "source": [
    "## 1.1 Common Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "551feaf1-73fb-4c0a-8979-bd9ef235ea5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(torch.nn.Module):\n",
    "    def __init__(self, dim, out_dim=None, w=64, time_varying=False):\n",
    "        super().__init__()\n",
    "        self.time_varying = time_varying\n",
    "        if out_dim is None:\n",
    "            out_dim = dim\n",
    "        self.net = torch.nn.Sequential(\n",
    "            torch.nn.Linear(dim + (1 if time_varying else 0), w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, out_dim),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "315b5373-e8ec-425e-b807-970c62aa272c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class torch_wrapper(torch.nn.Module):\n",
    "    \"\"\"Wraps model to torchdyn compatible format.\"\"\"\n",
    "\n",
    "    def __init__(self, model):\n",
    "        super().__init__()\n",
    "        self.model = model\n",
    "\n",
    "    def forward(self, t, x, *args, **kwargs):\n",
    "        return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "453e5373-2dd1-4bb4-8d7e-2bbc15bce7f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_traj(model, n_samp, nt_gen, seed, x_start = None):\n",
    "    \n",
    "    node = NeuralODE(torch_wrapper(model), solver=\"dopri5\",\n",
    "                 sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4)\n",
    "    if x_start is None:\n",
    "        torch.manual_seed(seed)\n",
    "        x_start = torch.randn(n_samp, dim)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        traj = node.trajectory(x_start, t_span=torch.linspace(0, 1, nt_gen))\n",
    "        \n",
    "    return traj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5febff6-000e-41bc-8341-1754bdc51577",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_traj(traj, nt_gen, mid_pts = True, start_color = \"black\", end_color = \"orange\"):\n",
    "    plt.scatter(traj[0, :, 0], traj[0, :, 1], s=4, alpha=1, c=start_color)\n",
    "    if mid_pts:\n",
    "        plt.scatter(traj[int(nt_gen/2), :, 0], traj[int(nt_gen/2), :, 1], s=4, alpha=1, c=\"red\")\n",
    "    plt.scatter(traj[:, :, 0], traj[:, :, 1], s=1, alpha=0.1, c=\"blue\")\n",
    "    plt.scatter(traj[-1, :, 0], traj[-1, :, 1], s=4, alpha=1, c=end_color)\n",
    "    \n",
    "    if mid_pts:\n",
    "        plt.legend([\"x0\", \"x_05\", \"Flow\", \"x1\"])\n",
    "    else:\n",
    "        plt.legend([\"x0\", \"Flow\", \"x1\"])\n",
    "\n",
    "    plt.xlabel(\"x\")\n",
    "    plt.ylabel(\"y\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "616fdd64-e403-4cc8-9f16-129a11f58bbb",
   "metadata": {},
   "source": [
    "## 1.2 GP-ICFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3774c383-a152-4a30-ad94-a40507dbb3b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_r(ti, tj):\n",
    "    r = ti[...,None] - tj[...,None,:]\n",
    "    r[r == 0] = 1e-15\n",
    "    return r\n",
    "def k11(r, alpha, l):\n",
    "    return (alpha**2)*torch.exp(-0.5 * ((r/l)**2))\n",
    "def k12(r, alpha, l):\n",
    "    return (alpha**2/l**2)*r*torch.exp(-0.5*((r/l)**2))\n",
    "def k22(r, alpha, l):\n",
    "    return (alpha**2/l**4)*(l**2 - r**2)*torch.exp(-0.5*((r/l)**2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0d8af2f-5ebf-4325-abfa-4abff714ea2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cov_mat(ti, tj, alpha, l, sig2_diag = 1e-8):\n",
    "    r = calc_r(ti, tj)\n",
    "    nt = r.shape[0]\n",
    "    \n",
    "    Sig11 = k11(r, alpha, l) + torch.eye(nt)*sig2_diag\n",
    "    Sig12 = k12(r, alpha, l)\n",
    "    Sig21 = Sig12.T\n",
    "    Sig22 = k22(r, alpha, l)\n",
    "    \n",
    "    block_row1 = torch.cat([Sig11, Sig12], dim=1)\n",
    "    block_row2 = torch.cat([Sig21, Sig22], dim=1)\n",
    "    Sig = torch.cat([block_row1, block_row2], dim=0)\n",
    "    Sig = (Sig + Sig.T)/2\n",
    "    \n",
    "    return Sig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adbae751-707d-41fb-a5ca-6d914fb0a5b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def samp_x_dx(t, alpha, l, x_obs, t_obs, sig2_diag=1e-8):\n",
    "    \n",
    "    nB, nt, dim = x_obs.shape[0], t.shape[0], x_obs.shape[2]\n",
    "    nt_obs = t_obs.shape[0]\n",
    "    \n",
    "    r_obs_x = calc_r(t_obs, t)\n",
    "    r_obs_obs = calc_r(t_obs, t_obs)\n",
    "    Sig_11 = cov_mat(t, t, alpha, l, sig2_diag)\n",
    "    \n",
    "    k_obs_x, k_obs_dx = k11(r_obs_x, alpha, l), k12(r_obs_x, alpha, l)\n",
    "    Sig_21 = torch.cat([k_obs_x, k_obs_dx], dim=1)\n",
    "    Sig_12 = Sig_21.T\n",
    "    \n",
    "    Sig_22 = k11(r_obs_obs, alpha, l) + torch.eye(nt_obs)*sig2_diag\n",
    "    Sig_22_inv = torch.linalg.inv(Sig_22)\n",
    "    Sig_cond = Sig_11 - Sig_12 @ Sig_22_inv @ Sig_21\n",
    "    \n",
    "    Sig_cond = (Sig_cond + Sig_cond.T)/2\n",
    "    if not bool((torch.linalg.eigvals(Sig_cond).real>=0).all()):\n",
    "        U, S, Vh = torch.linalg.svd(Sig_cond)\n",
    "        Sig_cond  = Vh.T @ torch.diag(S + 1e-6) @ Vh\n",
    "        Sig_cond = (Sig_cond + Sig_cond.T)/2\n",
    "    \n",
    "    mu_A = Sig_12 @ Sig_22_inv\n",
    "    mu_A_expand = mu_A.repeat(nB,1,1)\n",
    "    \n",
    "    x_samps = torch.zeros((nB, nt, dim))\n",
    "    dx_samps = torch.zeros((nB, nt, dim))\n",
    "    \n",
    "    \n",
    "    # Prepare the batch sampling\n",
    "    x_obs_batch = x_obs.reshape(nB, nt_obs, dim)\n",
    "    mu_new = torch.bmm(mu_A_expand, x_obs_batch).reshape(nB * dim, 2 * nt)\n",
    "\n",
    "    # Repeat Sig_cond for batch and dimensions\n",
    "    Sig_cond_expand = Sig_cond.unsqueeze(0).expand(nB * dim, -1, -1)\n",
    "\n",
    "    try:\n",
    "        # Perform batched sampling for all dimensions at once\n",
    "        dist = MultivariateNormal(loc=mu_new, covariance_matrix=Sig_cond_expand)\n",
    "        x_dx_samps_flat = dist.rsample().reshape(nB, 2 * nt, dim)\n",
    "    except RuntimeError:\n",
    "        print('Sampling failed; using numpy fallback.')\n",
    "        x_dx_samps_flat = torch.zeros((nB, 2 * nt, dim), dtype=x_obs.dtype, device=x_obs.device)\n",
    "        for bb in range(nB):\n",
    "            for dd in range(dim):\n",
    "                mu_single = mu_new[bb * dim + dd].cpu().numpy()\n",
    "                cov_single = Sig_cond.cpu().numpy()\n",
    "                sample = np.random.multivariate_normal(mu_single, cov_single)\n",
    "                x_dx_samps_flat[bb, :, dd] = torch.from_numpy(sample)\n",
    "\n",
    "    # Separate x and dx samples\n",
    "    x_samps = x_dx_samps_flat[:, :nt, :]\n",
    "    dx_samps = x_dx_samps_flat[:, nt:, :]\n",
    "\n",
    "    return x_samps, dx_samps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7831ecbb-a7ef-4122-8897-96f02f73328c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def GP_FM(x_data, alpha, l, nt, batch_size, t_obs, n_epochs, sig2_diag = 0):\n",
    "    \n",
    "    N = x_data.shape[0]\n",
    "    dim = x_data.shape[2]\n",
    "    \n",
    "    nbatch = int(N/batch_size)\n",
    "    batch_idx = np.reshape(np.arange(0,N),[nbatch, batch_size])\n",
    "    \n",
    "    model = MLP(dim=dim, time_varying=True)\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
    "    \n",
    "    losses: List[float] = []\n",
    "    model.train()\n",
    "    for k in tqdm(range(n_epochs)):\n",
    "\n",
    "        for bb in range(nbatch):\n",
    "            x0 = torch.randn((batch_size,dim))\n",
    "            x_obs = x_data[batch_idx[bb,:],:,:]\n",
    "            x_obs[:,0,:] = x0\n",
    "\n",
    "            t_batch = torch.rand(nt)\n",
    "            xt_batch, ut_batch = samp_x_dx(t_batch, alpha, l, x_obs, t_obs, sig2_diag)\n",
    "\n",
    "            t = t_batch.repeat(1,batch_size).T\n",
    "            xt = torch.reshape(xt_batch, (-1,dim))\n",
    "            ut = torch.reshape(ut_batch, (-1,dim))\n",
    "\n",
    "            vt = model(torch.cat([xt, t], dim=-1))\n",
    "            loss = torch.mean((vt - ut) ** 2)\n",
    "\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # Logging\n",
    "            losses.append(loss.item())\n",
    "    return model, losses"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46530eeb-8b76-4c8e-b663-7d29a2e5e495",
   "metadata": {},
   "source": [
    "## 1.3 ICFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61e098db-9d63-49e7-9156-964a62becbcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_conditional_pt(x0, x1, t, sigma):\n",
    "    t = t.reshape(-1, *([1] * (x0.dim() - 1)))\n",
    "    mu_t = t * x1 + (1 - t) * x0\n",
    "    epsilon = torch.randn_like(x0)\n",
    "    return mu_t + sigma * epsilon\n",
    "\n",
    "def compute_conditional_vector_field(x0, x1):\n",
    "    return x1 - x0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5eb54d75-de59-42e0-9d24-c3d62ce42917",
   "metadata": {},
   "outputs": [],
   "source": [
    "def I_FM(x1, model, optimizer, sigma = 1e-1, n_epochs = 10000, x0 = None):\n",
    "    \n",
    "    losses: List[float] = []\n",
    "    \n",
    "    model.train()\n",
    "    for k in tqdm(range(n_epochs)):\n",
    "        if x0 is None:\n",
    "            x0 = torch.randn_like(x1)\n",
    "            \n",
    "        # x0, x1 = ot_sampler.sample_plan(x0, y_train)\n",
    "        # x1 = y_train\n",
    "        # x0_ot, x1_ot = ot_sampler.sample_plan(x0, x1)\n",
    "\n",
    "        t = torch.rand(x0.shape[0]).type_as(x0)\n",
    "        xt = sample_conditional_pt(x0, x1, t, sigma=sigma)\n",
    "        ut = compute_conditional_vector_field(x0, x1)\n",
    "        vt = model(torch.cat([xt, t[:, None]], dim=-1))\n",
    "        loss = torch.mean((vt - ut) ** 2)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # Logging\n",
    "        losses.append(loss.item())\n",
    "    return model, losses"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbf64563-36a8-4f8e-acb3-4216ea65a6f0",
   "metadata": {},
   "source": [
    "# 2. Fitting"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c716ba9-a909-4b9b-8ebe-a35ef54cf4f4",
   "metadata": {},
   "source": [
    "## 2.1 GP-ICFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd6562d9-2a0c-424b-9f4a-6387510d51bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim = x1.shape[1]\n",
    "x_data = torch.zeros(N, 3, dim)\n",
    "x_data[:,1,:] = x_05\n",
    "x_data[:,2,:] = x1\n",
    "\n",
    "alpha = 0.1\n",
    "nt = 10\n",
    "batch_size = 20\n",
    "t_obs = torch.tensor([0, 0.5, 1])\n",
    "\n",
    "n_samp = 100\n",
    "nt_gen = 100\n",
    "seed = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e071bb2-b9f7-49fa-a89e-42a40b9ab129",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_1_10000, losses_1_10000 = GP_FM(x_data, alpha, 1, nt, batch_size, t_obs, 10000)\n",
    "traj_1_10000 = gen_traj(model_1_10000, n_samp, nt_gen, seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "178e6d2f-267d-4e94-9e3c-c52683f66539",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_traj(traj_1_10000, nt_gen)\n",
    "plt.xlim([-16, 9]);\n",
    "plt.ylim([-5, 25]);\n",
    "plt.savefig(\"2_GP_path.svg\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3fe0139a-dd88-4189-bc1a-85ba2d8c5c83",
   "metadata": {},
   "source": [
    "## 2.2 I-CFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06574d2b-35ef-4b66-9ae2-a01ffd096a42",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_icfm_1 = MLP(dim = dim, out_dim = dim, time_varying=True)\n",
    "optimizer_icfm_1 = torch.optim.Adam(model_icfm_1.parameters(), lr=1e-3)\n",
    "model_icfm_1,_ = I_FM(x_05, model_icfm_1, optimizer_icfm_1, 0.1, n_epochs = 10000)\n",
    "\n",
    "model_icfm_2 = MLP(dim = dim, out_dim = dim, time_varying=True)\n",
    "optimizer_icfm_2 = torch.optim.Adam(model_icfm_2.parameters(), lr=1e-3)\n",
    "model_icfm_2,_ = I_FM(x1, model_icfm_2, optimizer_icfm_2, 0.1, n_epochs = 10000, x0 = x_05)\n",
    "\n",
    "traj_icfm_1 = gen_traj(model_icfm_1, n_samp, int(nt_gen/2), 0)\n",
    "x_05_gen = traj_icfm_1[-1,:,:]\n",
    "traj_icfm_2 = gen_traj(model_icfm_2, n_samp, int(nt_gen/2), 0,x_05_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddd0efc0-2920-43a8-9fbc-8fa4fb98c21a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_traj(traj_icfm_1, nt_gen, mid_pts = False, end_color = \"red\")\n",
    "plot_traj(traj_icfm_2, nt_gen, mid_pts = False, start_color = \"red\")\n",
    "plt.xlim([-16, 9]);\n",
    "plt.ylim([-5, 25]);\n",
    "plt.savefig(\"3_ICFM_path.svg\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
