{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7353b311-654d-4f17-828b-b3d80dfcbc06",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "# from torchcfm.optimal_transport import OTPlanSampler\n",
    "\n",
    "from typing import List\n",
    "import time\n",
    "from torchdyn.core import NeuralODE\n",
    "\n",
    "from tqdm import tqdm\n",
    "from torch.distributions.multivariate_normal import MultivariateNormal\n",
    "import ot\n",
    "import ot.plot\n",
    "import pickle\n",
    "from copy import deepcopy\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9519a87-ba75-44ca-be7c-57041dcd3dbe",
   "metadata": {},
   "source": [
    "# 0. Data Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb866098-ea81-490f-ab03-fda2e4a809a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 100\n",
    "yend = 5\n",
    "\n",
    "np.random.seed(0)\n",
    "# start q(x0)\n",
    "p = .5\n",
    "z_id0 = np.random.binomial(1, p, N)[:,None]\n",
    "x0 = z_id0*np.random.multivariate_normal([-1, 0], [[.05, 0], [0, .05]], N) +\\\n",
    "(1-z_id0)*np.random.multivariate_normal([1, 0], [[.05, 0], [0, .05]], N)\n",
    "\n",
    "# end q(x1)\n",
    "z_id1 = np.random.binomial(1, p, N)[:,None]\n",
    "x1 = z_id1*np.random.multivariate_normal([-2, yend], [[.05, 0], [0, .05]], N) +\\\n",
    "(1-z_id1)*np.random.multivariate_normal([2, yend], [[.05, 0], [0, .05]], N)\n",
    "\n",
    "x0 = torch.from_numpy(x0).to(torch.float32)\n",
    "x1 = torch.from_numpy(x1).to(torch.float32)\n",
    "\n",
    "# plt.rcParams['figure.figsize'] = [4, 3]\n",
    "# plt.scatter(x0[:,0], x0[:,1], s = 4, c = \"black\");\n",
    "# plt.scatter(x1[:,0], x1[:,1], s= 4, c = \"orange\");\n",
    "\n",
    "# plt.plot()\n",
    "# plt.xlabel(\"x\")\n",
    "# plt.ylabel(\"y\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "209b4e93-1089-4436-aedf-2d3506589afe",
   "metadata": {},
   "source": [
    "# 1. Functions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f7dd90f-7365-41a2-a253-5278674a2337",
   "metadata": {},
   "source": [
    "## 1.1 Common Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68d182eb-fd45-4c8c-9d55-1178b6e9b96f",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(torch.nn.Module):\n",
    "    def __init__(self, dim, out_dim=None, w=64, time_varying=False):\n",
    "        super().__init__()\n",
    "        self.time_varying = time_varying\n",
    "        if out_dim is None:\n",
    "            out_dim = dim\n",
    "        self.net = torch.nn.Sequential(\n",
    "            torch.nn.Linear(dim + (1 if time_varying else 0), w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, out_dim),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f1035d4-f8d2-4fb2-a005-cbbdb88180a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "class torch_wrapper(torch.nn.Module):\n",
    "    \"\"\"Wraps model to torchdyn compatible format.\"\"\"\n",
    "\n",
    "    def __init__(self, model):\n",
    "        super().__init__()\n",
    "        self.model = model\n",
    "\n",
    "    def forward(self, t, x, *args, **kwargs):\n",
    "        return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21c64c0e-f9d9-49ea-b202-029f15dc923a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_traj(model, n_samp, nt_gen, seed, x_start = None):\n",
    "    \n",
    "    node = NeuralODE(torch_wrapper(model), solver=\"dopri5\",\n",
    "                 sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4)\n",
    "    if x_start is None:\n",
    "        torch.manual_seed(seed)\n",
    "        x_start = torch.randn(n_samp, dim)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        traj = node.trajectory(x_start, t_span=torch.linspace(0, 1, nt_gen))\n",
    "        \n",
    "    return traj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6769099f-684d-4e41-b5a1-f56475620235",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_traj(traj, nt_gen, mid_pts = True, start_color = \"black\", end_color = \"orange\"):\n",
    "    plt.scatter(traj[0, :, 0], traj[0, :, 1], s=4, alpha=1, c=start_color)\n",
    "    if mid_pts:\n",
    "        plt.scatter(traj[int(nt_gen/2), :, 0], traj[int(nt_gen/2), :, 1], s=4, alpha=1, c=\"red\")\n",
    "    plt.scatter(traj[:, :, 0], traj[:, :, 1], s=1, alpha=0.1, c=\"blue\")\n",
    "    plt.scatter(traj[-1, :, 0], traj[-1, :, 1], s=4, alpha=1, c=end_color)\n",
    "    \n",
    "    if mid_pts:\n",
    "        plt.legend([\"x0\", \"x_05\", \"Flow\", \"x1\"])\n",
    "    else:\n",
    "        plt.legend([\"x0\", \"Flow\", \"x1\"])\n",
    "\n",
    "    plt.xlabel(\"x\")\n",
    "    plt.ylabel(\"y\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78d59f46-b0bb-4138-af6c-f6b1c16a6f00",
   "metadata": {},
   "source": [
    "## 1.2 GP-ICFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75ffd3dd-48b4-4beb-add1-ca935021e80a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_r(ti, tj):\n",
    "    r = ti[...,None] - tj[...,None,:]\n",
    "    r[r == 0] = 1e-15\n",
    "    return r\n",
    "def k11(r, alpha, l):\n",
    "    return (alpha**2)*torch.exp(-0.5 * ((r/l)**2))\n",
    "def k12(r, alpha, l):\n",
    "    return (alpha**2/l**2)*r*torch.exp(-0.5*((r/l)**2))\n",
    "def k22(r, alpha, l):\n",
    "    return (alpha**2/l**4)*(l**2 - r**2)*torch.exp(-0.5*((r/l)**2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91369549-7ba6-4926-9aa3-e5e15108b858",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cov_mat2(ti, tj, alpha, l, beta = 1e-3, decrease = None):\n",
    "    \n",
    "    r = calc_r(ti, tj)\n",
    "    nB = r.shape[0]\n",
    "    nt = r.shape[1]\n",
    "    \n",
    "    if decrease is None:\n",
    "        Sig11 = k11(r, alpha, l) + (torch.eye(nt)*beta).repeat(nB,1,1)\n",
    "        Sig12 = k12(r, alpha, l)\n",
    "    elif decrease:\n",
    "        Sig11 = k11(r, alpha, l) + beta*torch.bmm((ti-1).unsqueeze(2), (tj-1).unsqueeze(1))\n",
    "        Sig12 = k12(r, alpha, l) + beta*(ti-1)[:,:,None].repeat(1,1,nt) + torch.diag_embed(beta*(ti-1))\n",
    "    else:\n",
    "        Sig11 = k11(r, alpha, l) + beta*torch.bmm(ti.unsqueeze(2), tj.unsqueeze(1))\n",
    "        Sig12 = k12(r, alpha, l) + beta*ti[:,:,None].repeat(1,1,nt) + torch.diag_embed(beta*ti)\n",
    "    \n",
    "    Sig21 = Sig12.permute(0, 2, 1)\n",
    "    Sig22 = k22(r, alpha, l)\n",
    "    \n",
    "    block_row1 = torch.cat([Sig11, Sig12], dim=2)\n",
    "    block_row2 = torch.cat([Sig21, Sig22], dim=2)\n",
    "    Sig = torch.cat([block_row1, block_row2], dim = 1)\n",
    "    Sig = (Sig + Sig.permute(0, 2, 1))/2\n",
    "    \n",
    "    return Sig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae4a9cc8",
   "metadata": {
    "vscode": {
     "languageId": "julia"
    }
   },
   "outputs": [],
   "source": [
    "def samp_x_dx2(t_mat, alpha, l, x_obs, t_obs, sig2_diag=1e-8):\n",
    "    nB, nt, dim = x_obs.shape[0], t_mat.shape[1], x_obs.shape[2]\n",
    "    nt_obs = t_obs.shape[0]\n",
    "\n",
    "    # Compute necessary covariance matrices and kernel functions\n",
    "    r_obs_x = calc_r(t_obs, t_mat)\n",
    "    r_obs_obs = calc_r(t_obs, t_obs)\n",
    "    Sig_11 = cov_mat2(t_mat, t_mat, alpha, l, sig2_diag)\n",
    "    \n",
    "    # Precompute parts of the covariance matrices\n",
    "    k_obs_x, k_obs_dx = k11(r_obs_x, alpha, l), k12(r_obs_x, alpha, l)\n",
    "    Sig_21 = torch.cat([k_obs_x, k_obs_dx], dim=2)\n",
    "    Sig_12 = Sig_21.permute(0, 2, 1)\n",
    "\n",
    "    Sig_22_sing = k11(r_obs_obs, alpha, l) + torch.eye(nt_obs) * sig2_diag\n",
    "    Sig_22_inv_sing = torch.linalg.inv(Sig_22_sing)\n",
    "    Sig_22_inv = Sig_22_inv_sing.repeat(nB, 1, 1)\n",
    "\n",
    "    # Compute conditional covariance matrix with stability adjustment\n",
    "    Sig_cond = Sig_11 - torch.bmm(torch.bmm(Sig_12, Sig_22_inv), Sig_21)\n",
    "    # Sig_cond = (Sig_cond + Sig_cond.permute(0, 2, 1)) / 2 + 1e-6 * torch.eye(Sig_cond.shape[1], device=Sig_cond.device)\n",
    "\n",
    "    Sig_cond = (Sig_cond + Sig_cond.permute(0, 2, 1))/2\n",
    "    \n",
    "    svd_add_idx = torch.sum((torch.linalg.eigvals(Sig_cond).real>=0).T, axis = 0) != Sig_cond.shape[1]\n",
    "    U, S, Vh = torch.linalg.svd(Sig_cond[svd_add_idx,:,:])\n",
    "#     U, S, Vh = torch.linalg.svd(Sig_cond)\n",
    "    Sig_cond_add = torch.bmm(torch.bmm(Vh.permute(0, 2, 1), torch.diag_embed(S + 1e-8)), Vh)\n",
    "    Sig_cond[svd_add_idx,:,:] = (Sig_cond_add + Sig_cond_add.permute(0, 2, 1))/2\n",
    "\n",
    "    # Mean adjustment matrix\n",
    "    mu_A = torch.bmm(Sig_12, Sig_22_inv)\n",
    "    x_obs_batch = x_obs.reshape(nB, nt_obs, dim)\n",
    "    mu_new = torch.bmm(mu_A, x_obs_batch).reshape(nB, 2 * nt, dim)\n",
    "\n",
    "    # Initialize sample matrices\n",
    "    x_samps = torch.zeros((nB, nt, dim), dtype=x_obs.dtype, device=x_obs.device)\n",
    "    dx_samps = torch.zeros((nB, nt, dim), dtype=x_obs.dtype, device=x_obs.device)\n",
    "    \n",
    "    mu_flat = mu_new.permute(0, 2, 1).reshape(nB * dim, 2 * nt)\n",
    "    Sig_cond_flat = Sig_cond.repeat_interleave(dim, dim=0)\n",
    "    \n",
    "    # Sampling in batch for all dimensions at once\n",
    "    try:\n",
    "        # Reshape mu_new and Sig_cond for compatible shapes\n",
    "#         mu_flat = mu_new.view(nB * dim, 2 * nt)\n",
    "#         Sig_cond_flat = Sig_cond.repeat(dim, 1, 1)\n",
    "        \n",
    "        dist = MultivariateNormal(loc=mu_flat, covariance_matrix=Sig_cond_flat)\n",
    "        x_dx_samps_flat = dist.rsample().reshape(nB, dim, 2 * nt).permute(0, 2, 1)\n",
    "    except RuntimeError:\n",
    "        print('Sampling failed; using numpy fallback.')\n",
    "        x_dx_samps_flat = torch.zeros((nB, 2 * nt, dim), dtype=x_obs.dtype, device=x_obs.device)\n",
    "        for bb in range(nB):\n",
    "            for dd in range(dim):\n",
    "                mu_single = mu_new[bb, :, dd].cpu().numpy()\n",
    "                cov_single = Sig_cond[bb].cpu().numpy()\n",
    "                sample = np.random.multivariate_normal(mu_single, cov_single)\n",
    "                x_dx_samps_flat[bb, :, dd] = torch.from_numpy(sample)\n",
    "\n",
    "    # Separate x and dx samples\n",
    "    x_samps[:, :, :] = x_dx_samps_flat[:, :nt, :]\n",
    "    dx_samps[:, :, :] = x_dx_samps_flat[:, nt:, :]\n",
    "\n",
    "    return x_samps, dx_samps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e44746e-7b7e-4f62-b16b-c7d7d00f4dc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def GP_FM2(x_data, model, optimizer, alpha, l,\n",
    "          nt, batch_size, t_obs, n_epochs, beta = 0, decrease = None,\n",
    "          ImpSamp = False, beta_a = 1.0, beta_b = 0.5, storeCheck = False, epoch_check_step = 100):\n",
    "    \n",
    "    N = x_data.shape[0]\n",
    "    dim = x_data.shape[2]\n",
    "    \n",
    "    if ImpSamp:\n",
    "        m = torch.distributions.beta.Beta(torch.tensor([beta_a]), torch.tensor([beta_b])) # put more weight on t = 1\n",
    "    \n",
    "    nbatch = int(N/batch_size)\n",
    "    batch_idx = np.reshape(np.arange(0,N),[nbatch, batch_size])\n",
    "    \n",
    "    losses: List[float] = []\n",
    "    if storeCheck:\n",
    "        check_pts = []\n",
    "        check_steps = []\n",
    "        \n",
    "    model.train()\n",
    "    for k in tqdm(range(n_epochs)):\n",
    "        for bb in range(nbatch):\n",
    "#             x0 = torch.randn((batch_size,dim))\n",
    "            x_obs = x_data[batch_idx[bb,:],:,:]\n",
    "            x_obs[:,0,:] = x_obs[np.random.permutation(batch_size),0,:]\n",
    "            x_obs[:,1,:] = x_obs[np.random.permutation(batch_size),1,:]\n",
    "            \n",
    "#             x_obs[:,0,:] = x0\n",
    "            \n",
    "            if ImpSamp:\n",
    "                t_mat = m.sample((batch_size,nt))[:,:,0]\n",
    "            else:\n",
    "                t_mat = torch.rand((batch_size,nt))\n",
    "            \n",
    "            try:\n",
    "                xt_batch, ut_batch = samp_x_dx2(t_mat, alpha, l, x_obs, t_obs, beta, decrease)\n",
    "            except:\n",
    "                pass\n",
    "            \n",
    "            t = torch.reshape(t_mat, (-1, 1))\n",
    "            xt = torch.reshape(xt_batch, (-1,dim))\n",
    "            ut = torch.reshape(ut_batch, (-1,dim))\n",
    "            \n",
    "            vt = model(torch.cat([xt, t], dim=-1))\n",
    "            if ImpSamp:\n",
    "                loss = torch.mean((1/torch.exp(m.log_prob(t))[:,None])*((vt - ut) ** 2))\n",
    "            else:\n",
    "                loss = torch.mean((vt - ut) ** 2)\n",
    "            \n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "            \n",
    "            # Logging\n",
    "            losses.append(loss.item())\n",
    "            \n",
    "            if storeCheck:\n",
    "                if k % epoch_check_step == 0:\n",
    "                    check_pts.append(deepcopy(model.state_dict()))\n",
    "                    check_steps.append(k)\n",
    "            \n",
    "    if storeCheck:       \n",
    "        return model, losses, check_pts, check_steps\n",
    "    else:\n",
    "        return model, losses"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5bdfbd9-74e6-4a24-9df4-d00ccfc693ef",
   "metadata": {},
   "source": [
    "## 1.3 W2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36eaf197-ab2a-4c75-aeb5-bddbf82fa502",
   "metadata": {},
   "outputs": [],
   "source": [
    "def w_mat_dist(x1_test, x1_gen, p = 2, ot_mat = False):\n",
    "    n_test = x1_test.shape[0]\n",
    "    n_gen = x1_gen.shape[0]\n",
    "    \n",
    "    a, b = np.ones((n_test,)) / n_test, np.ones((n_gen,)) / n_gen  # uniform distribution on samples\n",
    "    if p == 1:\n",
    "        M = ot.dist(x1_test, x1_gen, metric='euclidean')\n",
    "    elif p == 2:\n",
    "        M = ot.dist(x1_test, x1_gen)\n",
    "    G0 = None\n",
    "    if ot_mat:\n",
    "        G0 = ot.emd(a, b, M)\n",
    "    \n",
    "    d = ot.emd2(a, b, M)\n",
    "    \n",
    "    return G0, d"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d85b3d64-1c99-4e36-9d5e-918ae5ba2d76",
   "metadata": {},
   "source": [
    "# 2. Sample Paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63f2b783-3dfd-40a6-b59a-166453ddea30",
   "metadata": {},
   "outputs": [],
   "source": [
    "t_mat = torch.rand((2,100))\n",
    "t_obs = torch.tensor([0., 1.])\n",
    "x_obs = torch.zeros((2, 2, 1))\n",
    "x_obs[0,0,0] = x0[0,0]\n",
    "x_obs[0,1,0] = x1[0,0]\n",
    "x_obs[1,0,0] = x0[1,0]\n",
    "x_obs[1,1,0] = x1[1,0]\n",
    "\n",
    "alpha = 1\n",
    "l = 2\n",
    "plt.rcParams['figure.figsize'] = [10, 3]\n",
    "fig, axs = plt.subplots(1, 4)\n",
    "beta_all = [0, 1e-3, 1e-3, 1e-3]\n",
    "decrease_all = [None, None, True, False]\n",
    "for tt, ax in enumerate(axs.flatten()):\n",
    "    for ii in range(50):\n",
    "        x_samp, dx_samp =  samp_x_dx2(t_mat, alpha, l, x_obs,\n",
    "                                      t_obs, beta = beta_all[tt], decrease = decrease_all[tt])\n",
    "        ax.scatter(t_mat[0,:], x_samp[0,:,:], s = 2)\n",
    "#     plt.scatter(t_mat[0,:], dx_samp[0,:,:], s = 2)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "plt.rcParams['figure.figsize'] = [6, 4]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e81cee98-b012-4e87-9d59-ed7a3d661e71",
   "metadata": {},
   "source": [
    "# 3. Fitting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "737aedda-ca94-4b40-aea8-801207474b73",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim = x1.shape[1]\n",
    "# sigma = 1e-2\n",
    "n_samp = 1000\n",
    "nt_gen = 100\n",
    "\n",
    "x_data = torch.zeros(N, 2, dim)\n",
    "x_data[:,0,:] = x0\n",
    "x_data[:,1,:] = x1\n",
    "\n",
    "alpha = 1\n",
    "l = 2\n",
    "nt = 10\n",
    "batch_size = 100\n",
    "t_obs = torch.tensor([0., 1.])\n",
    "\n",
    "n_epochs = 5000\n",
    "lr_GP = 1e-3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e77fdaf3-9ac8-4946-9129-4d7bf7fa8900",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_GP0 = MLP(dim=dim, time_varying=True)\n",
    "optimizer_GP0 = torch.optim.Adam(model_GP0.parameters(), lr=lr_GP)\n",
    "model_GP0, losses_GP0 = GP_FM2(x_data, model_GP0, optimizer_GP0, alpha,\n",
    "                               l,nt, batch_size, t_obs, n_epochs, beta = 0, decrease = None)\n",
    "model_GP_noise = MLP(dim=dim, time_varying=True)\n",
    "optimizer_GP_noise = torch.optim.Adam(model_GP_noise.parameters(), lr=lr_GP)\n",
    "model_GP_noise, losses_GP_noise = GP_FM2(x_data, model_GP_noise, optimizer_GP_noise, alpha,\n",
    "                                         l,nt, batch_size, t_obs, n_epochs, beta = 1e-3, decrease = None)\n",
    "model_GP_dec = MLP(dim=dim, time_varying=True)\n",
    "optimizer_GP_dec = torch.optim.Adam(model_GP_dec.parameters(), lr=lr_GP)\n",
    "model_GP_dec, losses_GP_dec = GP_FM2(x_data, model_GP_dec, optimizer_GP_dec, alpha,\n",
    "                               l,nt, batch_size, t_obs, n_epochs, beta = 1e-3, decrease = True)\n",
    "model_GP_inc = MLP(dim=dim, time_varying=True)\n",
    "optimizer_GP_inc = torch.optim.Adam(model_GP_inc.parameters(), lr=lr_GP)\n",
    "model_GP_inc, losses_GP_inc = GP_FM2(x_data, model_GP_inc, optimizer_GP_inc, alpha,\n",
    "                               l,nt, batch_size, t_obs, n_epochs, beta = 1e-3, decrease = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6585ee6d-1d38-4be3-8b95-502f9fb59f35",
   "metadata": {},
   "outputs": [],
   "source": [
    "saveFolder = \"/hpc/group/mastatlab/gw74/fm_var_change5/\"\n",
    "rep_saveFolder = \"/hpc/group/mastatlab/gw74/fm_var_change5/100_seeds/\"\n",
    "\n",
    "# torch.save(model_GP0.state_dict(), saveFolder + \"model_GP0.pt\")\n",
    "# torch.save(model_GP_noise.state_dict(), saveFolder + \"model_GP_noise.pt\")\n",
    "# torch.save(model_GP_dec.state_dict(), saveFolder + \"model_GP_dec.pt\")\n",
    "# torch.save(model_GP_inc.state_dict(), saveFolder + \"model_GP_inc.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1f8b4f7-e1b8-43b3-93ef-506589c20c09",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture output\n",
    "nSeeds = 100\n",
    "n_epochs = 5000\n",
    "\n",
    "for ll in range(0):\n",
    "    \n",
    "    model_GP0 = MLP(dim=dim, time_varying=True)\n",
    "    optimizer_GP0 = torch.optim.Adam(model_GP0.parameters(), lr=lr_GP)\n",
    "    model_GP0.load_state_dict(torch.load(saveFolder + \"model_GP0.pt\"))\n",
    "    model_GP0, losses_GP0 = GP_FM2(x_data, model_GP0, optimizer_GP0, alpha,\n",
    "                               l,nt, batch_size, t_obs, n_epochs, beta = 0, decrease = None)\n",
    "    torch.save(model_GP0.state_dict(), rep_saveFolder + \"model_GP0_\" + str(ll) + \".pt\")\n",
    "    \n",
    "    model_GP_noise = MLP(dim=dim, time_varying=True)\n",
    "    optimizer_GP_noise = torch.optim.Adam(model_GP_noise.parameters(), lr=lr_GP)\n",
    "    model_GP_noise.load_state_dict(torch.load(saveFolder + \"model_GP_noise.pt\"))\n",
    "    model_GP_noise, losses_GP_noise = GP_FM2(x_data, model_GP_noise, optimizer_GP_noise, alpha,\n",
    "                                         l,nt, batch_size, t_obs, n_epochs,\n",
    "                                             beta = 1e-3, decrease = None)\n",
    "    torch.save(model_GP_noise.state_dict(), rep_saveFolder + \"model_GP_noise_\" + str(ll) + \".pt\")\n",
    "    \n",
    "    model_GP_dec = MLP(dim=dim, time_varying=True)\n",
    "    optimizer_GP_dec = torch.optim.Adam(model_GP_dec.parameters(), lr=lr_GP)\n",
    "    model_GP_dec.load_state_dict(torch.load(saveFolder + \"model_GP_dec.pt\"))\n",
    "    model_GP_dec, losses_GP_dec = GP_FM2(x_data, model_GP_dec, optimizer_GP_dec, alpha,\n",
    "                                   l,nt, batch_size, t_obs, n_epochs, beta = 1e-3, decrease = True)\n",
    "    torch.save(model_GP_dec.state_dict(), rep_saveFolder + \"model_GP_dec_\" + str(ll) + \".pt\")\n",
    "    \n",
    "    model_GP_inc = MLP(dim=dim, time_varying=True)\n",
    "    optimizer_GP_inc = torch.optim.Adam(model_GP_inc.parameters(), lr=lr_GP)\n",
    "    model_GP_inc.load_state_dict(torch.load(saveFolder + \"model_GP_inc.pt\"))\n",
    "    model_GP_inc, losses_GP_inc = GP_FM2(x_data, model_GP_inc, optimizer_GP_inc, alpha,\n",
    "                                   l,nt, batch_size, t_obs, n_epochs, beta = 1e-3, decrease = False)\n",
    "    torch.save(model_GP_inc.state_dict(), rep_saveFolder + \"model_GP_inc_\" + str(ll) + \".pt\")\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e3a5da3-e7f6-4a99-8bad-cb708fa5251e",
   "metadata": {},
   "source": [
    "# 4. W2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4067aec3-597b-4778-8995-732cc53fc567",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_test = 1000\n",
    "np.random.seed(0)\n",
    "z_id1_test = np.random.binomial(1, p, N_test)[:,None]\n",
    "x1_test = z_id1_test*np.random.multivariate_normal([-2, yend], [[.05, 0], [0, .05]], N_test) +\\\n",
    "(1-z_id1_test)*np.random.multivariate_normal([2, yend], [[.05, 0], [0, .05]], N_test)\n",
    "x1_test = torch.from_numpy(x1_test).to(torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f198d50-e87e-4155-9f59-4d8a84457122",
   "metadata": {},
   "outputs": [],
   "source": [
    "dAll_GP0 = np.zeros((nSeeds))\n",
    "dAll_GP_noise = np.zeros((nSeeds))\n",
    "dAll_GP_dec = np.zeros((nSeeds))\n",
    "dAll_GP_inc = np.zeros((nSeeds))\n",
    "\n",
    "for ss in range(nSeeds):\n",
    "    \n",
    "    z_id0 = np.random.binomial(1, p, N_test)[:,None]\n",
    "    x0_tmp = z_id0*np.random.multivariate_normal([-1, 0], [[.05, 0], [0, .05]], N_test) +\\\n",
    "    (1-z_id0)*np.random.multivariate_normal([1, 0], [[.05, 0], [0, .05]], N_test)\n",
    "    x0_tmp = torch.from_numpy(x0_tmp).to(torch.float32)\n",
    "    \n",
    "    \n",
    "    model_GP0 = MLP(dim=dim, time_varying=True)\n",
    "    model_GP0.load_state_dict(torch.load(rep_saveFolder + \"model_GP0_\" + str(ss) + \".pt\"))\n",
    "    traj_GP0 = gen_traj(model_GP0, x1_test.shape[0], 2, ss, x_start = x0_tmp)\n",
    "    \n",
    "    model_GP_noise = MLP(dim=dim, time_varying=True)\n",
    "    model_GP_noise.load_state_dict(torch.load(rep_saveFolder + \"model_GP_noise_\" + str(ss) + \".pt\"))\n",
    "    traj_GP_noise = gen_traj(model_GP_noise, x1_test.shape[0], 2, ss, x_start = x0_tmp)\n",
    "    \n",
    "    model_GP_dec = MLP(dim=dim, time_varying=True)\n",
    "    model_GP_dec.load_state_dict(torch.load(rep_saveFolder + \"model_GP_dec_\" + str(ss) + \".pt\"))\n",
    "    traj_GP_dec = gen_traj(model_GP_dec, x1_test.shape[0], 2, ss, x_start = x0_tmp)\n",
    "    \n",
    "    model_GP_inc = MLP(dim=dim, time_varying=True)\n",
    "    model_GP_inc.load_state_dict(torch.load(rep_saveFolder + \"model_GP_inc_\" + str(ss) + \".pt\"))\n",
    "    traj_GP_inc = gen_traj(model_GP_inc, x1_test.shape[0], 2, ss, x_start = x0_tmp)\n",
    "    \n",
    "    _, dAll_GP0[ss] = w_mat_dist(x1_test.numpy(), traj_GP0[-1,:,:].numpy(), p = 2)\n",
    "    _, dAll_GP_noise[ss] = w_mat_dist(x1_test.numpy(), traj_GP_noise[-1,:,:].numpy(), p = 2)\n",
    "    _, dAll_GP_dec[ss] = w_mat_dist(x1_test.numpy(), traj_GP_dec[-1,:,:].numpy(), p = 2)\n",
    "    _, dAll_GP_inc[ss] = w_mat_dist(x1_test.numpy(), traj_GP_inc[-1,:,:].numpy(), p = 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb820d79-d6b9-48ea-9624-21c5375b9b1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('I-GP-CFM, 0: {:.3f} +- {:.3f}'.format(np.mean(dAll_GP0), np.std(dAll_GP0)))\n",
    "print('I-GP-CFM, noise: {:.3f} +- {:.3f}'.format(np.mean(dAll_GP_noise), np.std(dAll_GP_noise)))\n",
    "print('I-GP-CFM, decrease: {:.3f} +- {:.3f}'.format(np.mean(dAll_GP_dec), np.std(dAll_GP_dec)))\n",
    "print('I-GP-CFM, increase: {:.3f} +- {:.3f}'.format(np.mean(dAll_GP_inc), np.std(dAll_GP_inc)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c07bbfbc-5147-433f-b24d-7943b29e0dff",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2981886-05a4-4cc9-b899-cfc9ea9d0340",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
