{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe6d4638-af43-47ea-95fe-909946e8453c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "# from torchcfm.optimal_transport import OTPlanSampler\n",
    "\n",
    "from typing import List\n",
    "import time\n",
    "from torchdyn.core import NeuralODE\n",
    "\n",
    "from tqdm import tqdm\n",
    "from torch.distributions.multivariate_normal import MultivariateNormal\n",
    "import ot\n",
    "import ot.plot\n",
    "import pickle\n",
    "from copy import deepcopy"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb22251c-4564-4648-8f68-7ec35096d5e5",
   "metadata": {},
   "source": [
    "# 0. Data Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ce91305-00d4-4217-838b-10457f77ed60",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 100\n",
    "yend = 10\n",
    "\n",
    "np.random.seed(0)\n",
    "# start q(x0)\n",
    "p = .5\n",
    "x0 = np.random.multivariate_normal([0, 0], [[1, 0], [0, 1]], N)\n",
    "\n",
    "# end q(x1)\n",
    "z_id1 = np.random.binomial(1, p, N)[:,None]\n",
    "x1 = z_id1*np.random.multivariate_normal([-3, yend], [[.1, 0], [0, .1]], N) +\\\n",
    "(1-z_id1)*np.random.multivariate_normal([3, yend], [[.1, 0], [0, .1]], N)\n",
    "\n",
    "x0 = torch.from_numpy(x0).to(torch.float32)\n",
    "x1 = torch.from_numpy(x1).to(torch.float32)\n",
    "\n",
    "# # plt.rcParams['figure.figsize'] = [4, 3]\n",
    "# plt.scatter(x0[:,0], x0[:,1], s = 4, c = \"black\");\n",
    "# plt.scatter(x1[:,0], x1[:,1], s= 4, c = \"orange\");\n",
    "\n",
    "# plt.plot()\n",
    "# plt.xlabel(\"x\")\n",
    "# plt.ylabel(\"y\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc491dbc-af39-4490-bb83-0fe322e1c476",
   "metadata": {},
   "source": [
    "# 1. Functions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc6988c4-5b8d-44fa-bc5a-26381fadb52f",
   "metadata": {},
   "source": [
    "## 1.1 Common Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d42b6bc7-afb3-4701-86ba-f97f99224b4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(torch.nn.Module):\n",
    "    def __init__(self, dim, out_dim=None, w=64, time_varying=False):\n",
    "        super().__init__()\n",
    "        self.time_varying = time_varying\n",
    "        if out_dim is None:\n",
    "            out_dim = dim\n",
    "        self.net = torch.nn.Sequential(\n",
    "            torch.nn.Linear(dim + (1 if time_varying else 0), w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, out_dim),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83cbbb44-f248-494c-9d5d-ae53d8b8297f",
   "metadata": {},
   "outputs": [],
   "source": [
    "class torch_wrapper(torch.nn.Module):\n",
    "    \"\"\"Wraps model to torchdyn compatible format.\"\"\"\n",
    "\n",
    "    def __init__(self, model):\n",
    "        super().__init__()\n",
    "        self.model = model\n",
    "\n",
    "    def forward(self, t, x, *args, **kwargs):\n",
    "        return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f954704-77a7-41e8-b04a-51f148b1916a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_traj(model, n_samp, nt_gen, seed, x_start = None):\n",
    "    \n",
    "    node = NeuralODE(torch_wrapper(model), solver=\"dopri5\",\n",
    "                 sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4)\n",
    "    if x_start is None:\n",
    "        torch.manual_seed(seed)\n",
    "        x_start = torch.randn(n_samp, dim)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        traj = node.trajectory(x_start, t_span=torch.linspace(0, 1, nt_gen))\n",
    "        \n",
    "    return traj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "136de1cc-de95-4a06-9e1c-7e15fcee9ae8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_traj(traj, nt_gen, mid_pts = True, start_color = \"black\", end_color = \"orange\"):\n",
    "    plt.scatter(traj[0, :, 0], traj[0, :, 1], s=4, alpha=1, c=start_color)\n",
    "    if mid_pts:\n",
    "        plt.scatter(traj[int(nt_gen/2), :, 0], traj[int(nt_gen/2), :, 1], s=4, alpha=1, c=\"red\")\n",
    "    plt.scatter(traj[:, :, 0], traj[:, :, 1], s=1, alpha=0.1, c=\"blue\")\n",
    "    plt.scatter(traj[-1, :, 0], traj[-1, :, 1], s=4, alpha=1, c=end_color)\n",
    "    \n",
    "    if mid_pts:\n",
    "        plt.legend([\"x0\", \"x_05\", \"Flow\", \"x1\"])\n",
    "    else:\n",
    "        plt.legend([\"x0\", \"Flow\", \"x1\"])\n",
    "\n",
    "    plt.xlabel(\"x\")\n",
    "    plt.ylabel(\"y\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0aa08f2e-3a8c-4e36-b678-b3836f8e75b4",
   "metadata": {},
   "source": [
    "## 1.2 ICFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57cfd554-617c-4e06-8bba-e89baeaa7fd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_conditional_pt(x0, x1, t, sigma):\n",
    "    t = t.reshape(-1, *([1] * (x0.dim() - 1)))\n",
    "    mu_t = t * x1 + (1 - t) * x0\n",
    "    epsilon = torch.randn_like(x0)\n",
    "    return mu_t + sigma * epsilon\n",
    "\n",
    "def compute_conditional_vector_field(x0, x1):\n",
    "    return x1 - x0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6282dc8-5455-4671-9fa0-286451b5fcdb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def I_FM2(x1, model, optimizer, batch_size, nt = 1, sigma = 1e-2, n_epochs = 10000, x0 = None,\n",
    "         ImpSamp = False, beta_a = 1.0, beta_b = 0.5, storeCheck = True, epoch_check_step = 100):\n",
    "    \n",
    "    N = x1.shape[0]\n",
    "    dim = x1.shape[1]\n",
    "    \n",
    "    if ImpSamp:\n",
    "        m = torch.distributions.beta.Beta(torch.tensor([beta_a]), torch.tensor([beta_b])) # put more weight on t = 1\n",
    "    \n",
    "    nbatch = int(N/batch_size)\n",
    "    batch_idx = np.reshape(np.arange(0,N),[nbatch, batch_size])\n",
    "    \n",
    "    losses: List[float] = []\n",
    "    if storeCheck:\n",
    "        check_pts = []\n",
    "        check_steps = []\n",
    "        \n",
    "    model.train()\n",
    "    for k in tqdm(range(n_epochs)):\n",
    "        for bb in range(nbatch):\n",
    "            \n",
    "            if x0 is None:\n",
    "                x0_batch = torch.randn((batch_size,dim))\n",
    "            else:\n",
    "                x0_batch = x0[batch_idx[bb,:],:]\n",
    "    \n",
    "            x1_batch = x1[batch_idx[bb,:],:]\n",
    "            \n",
    "            if ImpSamp:\n",
    "                t_expand = m.sample((nt*batch_size,)).flatten()\n",
    "            else:\n",
    "                t_expand = torch.rand(nt*batch_size).type_as(x0_batch)\n",
    "            \n",
    "            x1_expand = x1_batch.repeat(nt, 1)\n",
    "            x0_expand = x0_batch.repeat(nt, 1)\n",
    "            \n",
    "            xt = sample_conditional_pt(x0_expand, x1_expand, t_expand, sigma=sigma)\n",
    "            ut = compute_conditional_vector_field(x0_expand, x1_expand)\n",
    "            vt = model(torch.cat([xt, t_expand[:, None]], dim=-1))\n",
    "            \n",
    "            if ImpSamp:\n",
    "                loss = torch.mean((1/torch.exp(m.log_prob(t_expand))[:,None])*((vt - ut) ** 2))\n",
    "            else:\n",
    "                loss = torch.mean((vt - ut) ** 2)\n",
    "            \n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # Logging\n",
    "            losses.append(loss.item())\n",
    "            \n",
    "            if storeCheck:\n",
    "                if k % epoch_check_step == 0:\n",
    "                    check_pts.append(deepcopy(model.state_dict()))\n",
    "                    check_steps.append(k)\n",
    "    if storeCheck:       \n",
    "        return model, losses, check_pts, check_steps\n",
    "    else:\n",
    "        return model, losses"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b6f8d6b-b205-4105-b792-ab806548f208",
   "metadata": {},
   "source": [
    "## 1.3 GP-ICFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b36839eb-9fd3-4f31-8aa6-f77a9afe1cd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_r(ti, tj):\n",
    "    r = ti[...,None] - tj[...,None,:]\n",
    "    r[r == 0] = 1e-15\n",
    "    return r\n",
    "def k11(r, alpha, l):\n",
    "    return (alpha**2)*torch.exp(-0.5 * ((r/l)**2))\n",
    "def k12(r, alpha, l):\n",
    "    return (alpha**2/l**2)*r*torch.exp(-0.5*((r/l)**2))\n",
    "def k22(r, alpha, l):\n",
    "    return (alpha**2/l**4)*(l**2 - r**2)*torch.exp(-0.5*((r/l)**2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77fd544d-8ba2-43e5-b965-d58935285bf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cov_mat2(ti, tj, alpha, l, sig2_diag = 1e-8):\n",
    "    \n",
    "    r = calc_r(ti, tj)\n",
    "    nB = r.shape[0]\n",
    "    nt = r.shape[1]\n",
    "    \n",
    "    Sig11 = k11(r, alpha, l) + (torch.eye(nt)*sig2_diag).repeat(nB,1,1)\n",
    "    Sig12 = k12(r, alpha, l)\n",
    "    Sig21 = Sig12.permute(0, 2, 1)\n",
    "    Sig22 = k22(r, alpha, l)\n",
    "    \n",
    "    block_row1 = torch.cat([Sig11, Sig12], dim=2)\n",
    "    block_row2 = torch.cat([Sig21, Sig22], dim=2)\n",
    "    Sig = torch.cat([block_row1, block_row2], dim = 1)\n",
    "    Sig = (Sig + Sig.permute(0, 2, 1))/2\n",
    "    \n",
    "    return Sig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34d76bfc-bb96-41ca-84c9-15937f1511f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def samp_x_dx2(t_mat, alpha, l, x_obs, t_obs, sig2_diag=1e-8):\n",
    "    nB, nt, dim = x_obs.shape[0], t_mat.shape[1], x_obs.shape[2]\n",
    "    nt_obs = t_obs.shape[0]\n",
    "\n",
    "    # Compute necessary covariance matrices and kernel functions\n",
    "    r_obs_x = calc_r(t_obs, t_mat)\n",
    "    r_obs_obs = calc_r(t_obs, t_obs)\n",
    "    Sig_11 = cov_mat2(t_mat, t_mat, alpha, l, sig2_diag)\n",
    "    \n",
    "    # Precompute parts of the covariance matrices\n",
    "    k_obs_x, k_obs_dx = k11(r_obs_x, alpha, l), k12(r_obs_x, alpha, l)\n",
    "    Sig_21 = torch.cat([k_obs_x, k_obs_dx], dim=2)\n",
    "    Sig_12 = Sig_21.permute(0, 2, 1)\n",
    "\n",
    "    Sig_22_sing = k11(r_obs_obs, alpha, l) + torch.eye(nt_obs) * sig2_diag\n",
    "    Sig_22_inv_sing = torch.linalg.inv(Sig_22_sing)\n",
    "    Sig_22_inv = Sig_22_inv_sing.repeat(nB, 1, 1)\n",
    "\n",
    "    # Compute conditional covariance matrix with stability adjustment\n",
    "    Sig_cond = Sig_11 - torch.bmm(torch.bmm(Sig_12, Sig_22_inv), Sig_21)\n",
    "    # Sig_cond = (Sig_cond + Sig_cond.permute(0, 2, 1)) / 2 + 1e-6 * torch.eye(Sig_cond.shape[1], device=Sig_cond.device)\n",
    "\n",
    "    Sig_cond = (Sig_cond + Sig_cond.permute(0, 2, 1))/2\n",
    "    \n",
    "    svd_add_idx = torch.sum((torch.linalg.eigvals(Sig_cond).real>=0).T, axis = 0) != Sig_cond.shape[1]\n",
    "    U, S, Vh = torch.linalg.svd(Sig_cond[svd_add_idx,:,:])\n",
    "#     U, S, Vh = torch.linalg.svd(Sig_cond)\n",
    "    Sig_cond_add = torch.bmm(torch.bmm(Vh.permute(0, 2, 1), torch.diag_embed(S + 1e-8)), Vh)\n",
    "    Sig_cond[svd_add_idx,:,:] = (Sig_cond_add + Sig_cond_add.permute(0, 2, 1))/2\n",
    "\n",
    "    # Mean adjustment matrix\n",
    "    mu_A = torch.bmm(Sig_12, Sig_22_inv)\n",
    "    x_obs_batch = x_obs.reshape(nB, nt_obs, dim)\n",
    "    mu_new = torch.bmm(mu_A, x_obs_batch).reshape(nB, 2 * nt, dim)\n",
    "\n",
    "    # Initialize sample matrices\n",
    "    x_samps = torch.zeros((nB, nt, dim), dtype=x_obs.dtype, device=x_obs.device)\n",
    "    dx_samps = torch.zeros((nB, nt, dim), dtype=x_obs.dtype, device=x_obs.device)\n",
    "    \n",
    "    mu_flat = mu_new.permute(0, 2, 1).reshape(nB * dim, 2 * nt)\n",
    "    Sig_cond_flat = Sig_cond.repeat_interleave(dim, dim=0)\n",
    "    \n",
    "    # Sampling in batch for all dimensions at once\n",
    "    try:\n",
    "        # Reshape mu_new and Sig_cond for compatible shapes\n",
    "#         mu_flat = mu_new.view(nB * dim, 2 * nt)\n",
    "#         Sig_cond_flat = Sig_cond.repeat(dim, 1, 1)\n",
    "        \n",
    "        dist = MultivariateNormal(loc=mu_flat, covariance_matrix=Sig_cond_flat)\n",
    "        x_dx_samps_flat = dist.rsample().reshape(nB, dim, 2 * nt).permute(0, 2, 1)\n",
    "    except RuntimeError:\n",
    "        print('Sampling failed; using numpy fallback.')\n",
    "        x_dx_samps_flat = torch.zeros((nB, 2 * nt, dim), dtype=x_obs.dtype, device=x_obs.device)\n",
    "        for bb in range(nB):\n",
    "            for dd in range(dim):\n",
    "                mu_single = mu_new[bb, :, dd].cpu().numpy()\n",
    "                cov_single = Sig_cond[bb].cpu().numpy()\n",
    "                sample = np.random.multivariate_normal(mu_single, cov_single)\n",
    "                x_dx_samps_flat[bb, :, dd] = torch.from_numpy(sample)\n",
    "\n",
    "    # Separate x and dx samples\n",
    "    x_samps[:, :, :] = x_dx_samps_flat[:, :nt, :]\n",
    "    dx_samps[:, :, :] = x_dx_samps_flat[:, nt:, :]\n",
    "\n",
    "    return x_samps, dx_samps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fa33547-f25a-473f-96db-e09aa7a6444f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def GP_FM2(x_data, model, optimizer, alpha, l,\n",
    "          nt, batch_size, t_obs, n_epochs, sig2_diag = 0,\n",
    "          ImpSamp = False, beta_a = 1.0, beta_b = 0.5, storeCheck = True, epoch_check_step = 100):\n",
    "    \n",
    "    N = x_data.shape[0]\n",
    "    dim = x_data.shape[2]\n",
    "    \n",
    "    if ImpSamp:\n",
    "        m = torch.distributions.beta.Beta(torch.tensor([beta_a]), torch.tensor([beta_b])) # put more weight on t = 1\n",
    "    \n",
    "    nbatch = int(N/batch_size)\n",
    "    batch_idx = np.reshape(np.arange(0,N),[nbatch, batch_size])\n",
    "    \n",
    "    losses: List[float] = []\n",
    "    if storeCheck:\n",
    "        check_pts = []\n",
    "        check_steps = []\n",
    "        \n",
    "    model.train()\n",
    "    for k in tqdm(range(n_epochs)):\n",
    "        for bb in range(nbatch):\n",
    "            x0 = torch.randn((batch_size,dim))\n",
    "            x_obs = x_data[batch_idx[bb,:],:,:]\n",
    "            x_obs[:,0,:] = x0\n",
    "            \n",
    "            if ImpSamp:\n",
    "                t_mat = m.sample((batch_size,nt))[:,:,0]\n",
    "            else:\n",
    "                t_mat = torch.rand((batch_size,nt))\n",
    "            \n",
    "            try:\n",
    "                xt_batch, ut_batch = samp_x_dx2(t_mat, alpha, l, x_obs, t_obs, sig2_diag)\n",
    "            except:\n",
    "                pass\n",
    "            \n",
    "            t = torch.reshape(t_mat, (-1, 1))\n",
    "            xt = torch.reshape(xt_batch, (-1,dim))\n",
    "            ut = torch.reshape(ut_batch, (-1,dim))\n",
    "            \n",
    "            vt = model(torch.cat([xt, t], dim=-1))\n",
    "            if ImpSamp:\n",
    "                loss = torch.mean((1/torch.exp(m.log_prob(t))[:,None])*((vt - ut) ** 2))\n",
    "            else:\n",
    "                loss = torch.mean((vt - ut) ** 2)\n",
    "            \n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "            \n",
    "            # Logging\n",
    "            losses.append(loss.item())\n",
    "            \n",
    "            if storeCheck:\n",
    "                if k % epoch_check_step == 0:\n",
    "                    check_pts.append(deepcopy(model.state_dict()))\n",
    "                    check_steps.append(k)\n",
    "            \n",
    "    if storeCheck:       \n",
    "        return model, losses, check_pts, check_steps\n",
    "    else:\n",
    "        return model, losses"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1d969aa-8041-4205-9132-2e562853c435",
   "metadata": {},
   "source": [
    "## 1.4 W2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d1da2d2-f7e1-4967-9ce4-c0d70bd72ac9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def w_mat_dist(x1_train, x1_gen, p = 2, ot_mat = False):\n",
    "    n_train = x1_train.shape[0]\n",
    "    n_gen = x1_gen.shape[0]\n",
    "    \n",
    "    a, b = np.ones((n_train,)) / n_train, np.ones((n_gen,)) / n_gen  # uniform distribution on samples\n",
    "    if p == 1:\n",
    "        M = ot.dist(x1_train, x1_gen, metric='euclidean')\n",
    "    elif p == 2:\n",
    "        M = ot.dist(x1_train, x1_gen)\n",
    "    G0 = None\n",
    "    if ot_mat:\n",
    "        G0 = ot.emd(a, b, M)\n",
    "    \n",
    "    d = ot.emd2(a, b, M)\n",
    "    \n",
    "    return G0, d"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad326580-6ef5-4d5e-8090-16603fa8111c",
   "metadata": {},
   "source": [
    "# 2. Fitting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02b02f37-d1c1-45dd-a152-f87fe254e871",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim = x1.shape[1]\n",
    "# sigma = 1e-2\n",
    "sigma = 0\n",
    "n_samp = 1000\n",
    "nt_gen = 100\n",
    "\n",
    "x_data = torch.zeros(N, 2, dim)\n",
    "x_data[:,1,:] = x1\n",
    "\n",
    "alpha = 1\n",
    "l = 2\n",
    "nt = 10\n",
    "batch_size = 100\n",
    "t_obs = torch.tensor([0, 1])\n",
    "\n",
    "n_epochs = 5000\n",
    "lr_icfm = 1e-3\n",
    "lr_GP = 1e-3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54a60d8f-e1ed-4763-a057-c06e17f9bc4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_icfm = MLP(dim = dim, out_dim = dim, time_varying=True)\n",
    "optimizer_icfm = torch.optim.Adam(model_icfm.parameters(), lr=lr_icfm)\n",
    "model_icfm, losses_icfm, check_pts_icfm, check_steps_icfm = I_FM2(x1, model_icfm, optimizer_icfm,\n",
    "                                                                  batch_size = batch_size, nt = nt,\n",
    "                                                                  sigma = sigma, n_epochs = n_epochs,\n",
    "                                                                  storeCheck = True,\n",
    "                                                                  epoch_check_step = 100)\n",
    "\n",
    "model_GP0 = MLP(dim=dim, time_varying=True)\n",
    "optimizer_GP0 = torch.optim.Adam(model_GP0.parameters(), lr=lr_GP)\n",
    "model_GP0, losses_GP0, check_pts_GP0, check_steps_GP0 = GP_FM2(x_data, model_GP0, optimizer_GP0,\n",
    "                                                               alpha,l,nt, batch_size, t_obs,\n",
    "                                                               n_epochs, sig2_diag = 0, \n",
    "                                                               storeCheck = True,\n",
    "                                                               epoch_check_step = 100)\n",
    "\n",
    "saveFolder = \"/hpc/group/mastatlab/gw74/bv_trade/\"\n",
    "rep_saveFolder = \"/hpc/group/mastatlab/gw74/bv_trade/100_seeds/\"\n",
    "\n",
    "# torch.save(model_icfm.state_dict(), saveFolder + \"model_icfm_1.pt\")\n",
    "# torch.save(model_GP0.state_dict(), saveFolder + \"model_GP0_1.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c452c8c6-1e64-4316-889f-18ca4c3523f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# replicate 100 times...\n",
    "%%capture output\n",
    "nSeeds = 100\n",
    "n_epochs = 5000\n",
    "lr_icfm = 2e-3\n",
    "lr_GP = 2e-3\n",
    "\n",
    "for ll in range(0, nSeeds):\n",
    "    \n",
    "    model_icfm = MLP(dim = dim, out_dim = dim, time_varying=True)\n",
    "    optimizer_icfm = torch.optim.Adam(model_icfm.parameters(), lr=lr_icfm)\n",
    "    model_icfm.load_state_dict(torch.load(saveFolder + \"model_icfm_1.pt\"))\n",
    "    model_icfm, losses_icfm, check_pts_icfm, check_steps_icfm = I_FM2(x1, model_icfm, optimizer_icfm,\n",
    "                                                                  batch_size = batch_size, nt = nt,\n",
    "                                                                  sigma = 0, n_epochs = n_epochs,\n",
    "                                                                  storeCheck = True, epoch_check_step = 100)\n",
    "    torch.save(model_icfm.state_dict(), rep_saveFolder + \"model_icfm_\" + str(ll) + \".pt\")\n",
    "    \n",
    "    model_GP0 = MLP(dim=dim, time_varying=True)\n",
    "    optimizer_GP0 = torch.optim.Adam(model_GP0.parameters(), lr=lr_GP)\n",
    "    model_GP0.load_state_dict(torch.load(saveFolder + \"model_GP0_1.pt\"))\n",
    "    model_GP0, losses_GP0, check_pts_GP0, check_steps_GP0 = GP_FM2(x_data, model_GP0, optimizer_GP0, alpha,\n",
    "                                   l,nt, batch_size, t_obs, n_epochs, sig2_diag = 0, \n",
    "                                   storeCheck = True, epoch_check_step = 100)\n",
    "    torch.save(model_GP0.state_dict(), rep_saveFolder + \"model_gp_icfm_\" + str(ll) + \".pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a40a17ad-1b9e-41d8-98f9-dc62407cfdd6",
   "metadata": {},
   "source": [
    "# 3. W2 distance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bacff84b-dba6-491a-8062-75fbf9f89ad8",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_test = 1000\n",
    "np.random.seed(1)\n",
    "z_id1_test = np.random.binomial(1, p, N_test)[:,None]\n",
    "x1_test = z_id1_test*np.random.multivariate_normal([-3, yend], [[.1, 0], [0, .1]], N_test) +\\\n",
    "(1-z_id1_test)*np.random.multivariate_normal([3, yend], [[.1, 0], [0, .1]], N_test)\n",
    "x1_test = torch.from_numpy(x1_test).to(torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "647347ab-ce2c-43c8-a14c-a6edb163b25f",
   "metadata": {},
   "outputs": [],
   "source": [
    "dAll_icfm = np.zeros((nSeeds))\n",
    "dAll_GP0 = np.zeros((nSeeds))\n",
    "for ss in range(nSeeds):\n",
    "    \n",
    "    model_icfm = MLP(dim = dim, out_dim = dim, time_varying=True)\n",
    "    model_icfm.load_state_dict(torch.load(rep_saveFolder + \"model_icfm_\" + str(ss) + \".pt\"))\n",
    "    traj_icfm = gen_traj(model_icfm, x1_test.shape[0], 2, ss)\n",
    "    \n",
    "    model_GP0 = MLP(dim=dim, time_varying=True)\n",
    "    model_GP0.load_state_dict(torch.load(rep_saveFolder + \"model_gp_icfm_\" + str(ss) + \".pt\"))\n",
    "    traj_GP0 = gen_traj(model_GP0, x1_test.shape[0], 2, ss)\n",
    "    \n",
    "    _, dAll_icfm[ss] = w_mat_dist(x1_test.numpy(), traj_icfm[-1,:,:].numpy(), p = 2)\n",
    "    _, dAll_GP0[ss] = w_mat_dist(x1_test.numpy(), traj_GP0[-1,:,:].numpy(), p = 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46f77961-9306-4d16-b33f-b3052d1d7766",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('I-CFM: {:.3f} +- {:.3f}'.format(np.mean(dAll_icfm), np.std(dAll_icfm)))\n",
    "print('GP-I-CFM: {:.3f} +- {:.3f}'.format(np.mean(dAll_GP0), np.std(dAll_GP0)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "873ef570-85d5-4846-98af-1a0d07e095fb",
   "metadata": {},
   "source": [
    "# 4. Plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80134646-42d6-4666-a33f-b598c14d6f56",
   "metadata": {},
   "outputs": [],
   "source": [
    "icfm_maxid = np.argmax(dAll_icfm)\n",
    "gp_icfm_maxid = np.argmax(dAll_GP0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07353bda-07a4-4a67-95a2-3f1516dee753",
   "metadata": {},
   "outputs": [],
   "source": [
    "sinit = 0\n",
    "n_start = 1000\n",
    "\n",
    "model_icfm = MLP(dim = dim, out_dim = dim, time_varying=True)\n",
    "model_icfm.load_state_dict(torch.load(rep_saveFolder + \"model_icfm_\" + str(icfm_maxid) + \".pt\"))\n",
    "traj_icfm = gen_traj(model_icfm, n_start, 50, sinit)\n",
    "plot_traj(traj_icfm, nt_gen, mid_pts = False, end_color = \"orange\")\n",
    "plt.scatter(x1[:,0], x1[:,1], s= 4, c = \"red\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3e27045-5a28-4567-a9fc-fc68aacb23d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_GP0 = MLP(dim=dim, time_varying=True)\n",
    "model_GP0.load_state_dict(torch.load(rep_saveFolder + \"model_gp_icfm_\" + str(gp_icfm_maxid) + \".pt\"))\n",
    "traj_GP0 = gen_traj(model_GP0, n_start, 50, sinit)\n",
    "plot_traj(traj_GP0, nt_gen, mid_pts = False, end_color = \"orange\")\n",
    "plt.scatter(x1[:,0], x1[:,1], s= 4, c = \"red\");"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
