{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d7ac179-e95b-48de-afe7-24c064777416",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "# from torchcfm.optimal_transport import OTPlanSampler\n",
    "\n",
    "from typing import List\n",
    "import time\n",
    "from torchdyn.core import NeuralODE\n",
    "\n",
    "from tqdm import tqdm\n",
    "from torch.distributions.multivariate_normal import MultivariateNormal\n",
    "import ot\n",
    "import ot.plot\n",
    "import pickle\n",
    "from copy import deepcopy\n",
    "\n",
    "# import warnings\n",
    "# warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f71fb95-fa6d-4c80-81df-6220516382ff",
   "metadata": {},
   "source": [
    "# 0. Data Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9b787f2-2788-4c45-96ad-6f8a04a77bcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 100\n",
    "yend = 5\n",
    "\n",
    "np.random.seed(0)\n",
    "# start q(x0)\n",
    "p = .5\n",
    "x0 = np.random.multivariate_normal([0, 0], [[1, 0], [0, 1]], N)\n",
    "\n",
    "# end q(x1)\n",
    "z_id1 = np.random.binomial(1, p, N)[:,None]\n",
    "x1 = z_id1*np.random.multivariate_normal([-1.5, yend], [[.05, 0], [0, .05]], N) +\\\n",
    "(1-z_id1)*np.random.multivariate_normal([1.5, yend], [[.05, 0], [0, .05]], N)\n",
    "\n",
    "x0 = torch.from_numpy(x0).to(torch.float32)\n",
    "x1 = torch.from_numpy(x1).to(torch.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97ef4631-254f-4bad-93aa-d075f66d386b",
   "metadata": {},
   "source": [
    "# 1. Functions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "777814fd-96fa-465d-ad8a-bd6c5674e97e",
   "metadata": {},
   "source": [
    "## 1.1 Common Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36155921-c0f4-4c00-a4c1-3f2f95e826e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(torch.nn.Module):\n",
    "    def __init__(self, dim, out_dim=None, w=64, time_varying=False):\n",
    "        super().__init__()\n",
    "        self.time_varying = time_varying\n",
    "        if out_dim is None:\n",
    "            out_dim = dim\n",
    "        self.net = torch.nn.Sequential(\n",
    "            torch.nn.Linear(dim + (1 if time_varying else 0), w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, out_dim),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8131b95-fbd1-4693-b967-470290d54afd",
   "metadata": {},
   "outputs": [],
   "source": [
    "class torch_wrapper(torch.nn.Module):\n",
    "    \"\"\"Wraps model to torchdyn compatible format.\"\"\"\n",
    "\n",
    "    def __init__(self, model):\n",
    "        super().__init__()\n",
    "        self.model = model\n",
    "\n",
    "    def forward(self, t, x, *args, **kwargs):\n",
    "        return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e4f9f1c-302f-4e4b-bf86-ce152c0f1b5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_traj(model, n_samp, nt_gen, seed, x_start = None):\n",
    "    \n",
    "    node = NeuralODE(torch_wrapper(model), solver=\"dopri5\",\n",
    "                 sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4)\n",
    "    if x_start is None:\n",
    "        torch.manual_seed(seed)\n",
    "        x_start = torch.randn(n_samp, dim)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        traj = node.trajectory(x_start, t_span=torch.linspace(0, 1, nt_gen))\n",
    "        \n",
    "    return traj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bc97184-4f72-413f-bffc-1e1f0275a46d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_traj(traj, nt_gen, mid_pts = True, start_color = \"black\", end_color = \"orange\"):\n",
    "    plt.scatter(traj[0, :, 0], traj[0, :, 1], s=4, alpha=1, c=start_color)\n",
    "    if mid_pts:\n",
    "        plt.scatter(traj[int(nt_gen/2), :, 0], traj[int(nt_gen/2), :, 1], s=4, alpha=1, c=\"red\")\n",
    "    plt.scatter(traj[:, :, 0], traj[:, :, 1], s=1, alpha=0.1, c=\"blue\")\n",
    "    plt.scatter(traj[-1, :, 0], traj[-1, :, 1], s=4, alpha=1, c=end_color)\n",
    "    \n",
    "    if mid_pts:\n",
    "        plt.legend([\"x0\", \"x_05\", \"Flow\", \"x1\"])\n",
    "    else:\n",
    "        plt.legend([\"x0\", \"Flow\", \"x1\"])\n",
    "\n",
    "    plt.xlabel(\"x\")\n",
    "    plt.ylabel(\"y\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ff5ea1f-2bc2-48b0-9f51-d226d485f6e1",
   "metadata": {},
   "source": [
    "## 1.2 GP-ICFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2570e5f5-35ac-4acd-bf9e-49c0e33db7c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_r(ti, tj):\n",
    "    r = ti[...,None] - tj[...,None,:]\n",
    "    r[r == 0] = 1e-15\n",
    "    return r\n",
    "def k11(r, alpha, l):\n",
    "    return (alpha**2)*torch.exp(-0.5 * ((r/l)**2))\n",
    "def k12(r, alpha, l):\n",
    "    return (alpha**2/l**2)*r*torch.exp(-0.5*((r/l)**2))\n",
    "def k22(r, alpha, l):\n",
    "    return (alpha**2/l**4)*(l**2 - r**2)*torch.exp(-0.5*((r/l)**2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ea3c227-f949-4499-8b1b-293146756ac1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cov_mat2(ti, tj, alpha, l, beta = 1e-3, decrease = None):\n",
    "    \n",
    "    r = calc_r(ti, tj)\n",
    "    nB = r.shape[0]\n",
    "    nt = r.shape[1]\n",
    "    \n",
    "    if decrease is None:\n",
    "        Sig11 = k11(r, alpha, l) + (torch.eye(nt)*beta).repeat(nB,1,1)\n",
    "        Sig12 = k12(r, alpha, l)\n",
    "    elif decrease:\n",
    "        Sig11 = k11(r, alpha, l) + beta*torch.bmm((ti-1).unsqueeze(2), (tj-1).unsqueeze(1))\n",
    "        Sig12 = k12(r, alpha, l) + beta*(ti-1)[:,:,None].repeat(1,1,nt) + torch.diag_embed(beta*(ti-1))\n",
    "    else:\n",
    "        Sig11 = k11(r, alpha, l) + beta*torch.bmm(ti.unsqueeze(2), tj.unsqueeze(1))\n",
    "        Sig12 = k12(r, alpha, l) + beta*ti[:,:,None].repeat(1,1,nt) + torch.diag_embed(beta*ti)\n",
    "    \n",
    "    Sig21 = Sig12.permute(0, 2, 1)\n",
    "    Sig22 = k22(r, alpha, l)\n",
    "    \n",
    "    block_row1 = torch.cat([Sig11, Sig12], dim=2)\n",
    "    block_row2 = torch.cat([Sig21, Sig22], dim=2)\n",
    "    Sig = torch.cat([block_row1, block_row2], dim = 1)\n",
    "    Sig = (Sig + Sig.permute(0, 2, 1))/2\n",
    "    \n",
    "    return Sig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0fc5e15-82e0-4d7e-958b-d1d1f69475a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def samp_x_dx2(t_mat, alpha, l, x_obs, t_obs, beta = 1e-3, decrease = None):\n",
    "    \n",
    "    nB = x_obs.shape[0]\n",
    "    dim = x_obs.shape[2]\n",
    "    nt = t_mat.shape[1]\n",
    "    nt_obs = t_obs.shape[0]\n",
    "    \n",
    "    r_obs_x = calc_r(t_obs, t_mat)\n",
    "    r_obs_obs = calc_r(t_obs, t_obs)\n",
    "    \n",
    "    Sig_11 = cov_mat2(t_mat, t_mat, alpha, l, beta, decrease)\n",
    "    if decrease is None:\n",
    "        k_obs_x = k11(r_obs_x, alpha, l)\n",
    "        k_obs_dx = k12(r_obs_x, alpha, l)\n",
    "        Sig_22_sing = k11(r_obs_obs, alpha, l) + torch.eye(nt_obs)*beta\n",
    "    elif decrease:\n",
    "        k_obs_x = k11(r_obs_x, alpha, l) + beta*torch.bmm((t_obs.repeat(nB,1)-1).unsqueeze(2),\n",
    "                                                      (t_mat-1).unsqueeze(1))\n",
    "        k_obs_dx = k12(r_obs_x, alpha, l) + beta*(t_obs.repeat(nB,1)-1)[:,:,None].repeat(1,1,nt)\n",
    "        Sig_22_sing = k11(r_obs_obs, alpha, l) + beta*torch.outer((t_obs-1), (t_obs-1))\n",
    "    else:\n",
    "        k_obs_x = k11(r_obs_x, alpha, l) + beta*torch.bmm(t_obs.repeat(nB,1).unsqueeze(2),\n",
    "                                                      t_mat.unsqueeze(1))\n",
    "        k_obs_dx = k12(r_obs_x, alpha, l) + beta*t_obs.repeat(nB,1)[:,:,None].repeat(1,1,nt)\n",
    "        Sig_22_sing = k11(r_obs_obs, alpha, l) + beta*torch.outer(t_obs, t_obs)\n",
    "        \n",
    "    \n",
    "    Sig_21 = torch.cat([k_obs_x, k_obs_dx], dim=2)\n",
    "    Sig_12 = Sig_21.permute(0, 2, 1)\n",
    "    \n",
    "    Sig_22_inv_sing = torch.linalg.inv(Sig_22_sing)\n",
    "    Sig_22 = Sig_22_sing.repeat(nB,1,1)\n",
    "    Sig_22_inv = Sig_22_inv_sing.repeat(nB,1,1)\n",
    "    \n",
    "    Sig_cond = Sig_11 - torch.bmm(torch.bmm(Sig_12, Sig_22_inv), Sig_21)\n",
    "    Sig_cond = (Sig_cond + Sig_cond.permute(0, 2, 1))/2\n",
    "    \n",
    "    svd_add_idx = np.where(sum((torch.linalg.eigvals(Sig_cond).real>=0).T) != Sig_cond.shape[1])[0]\n",
    "    U, S, Vh = torch.linalg.svd(Sig_cond[svd_add_idx,:,:])\n",
    "    Sig_cond_add = torch.bmm(torch.bmm(Vh.permute(0, 2, 1), torch.diag_embed(S + 1e-8)), Vh)\n",
    "    Sig_cond[svd_add_idx,:,:] = (Sig_cond_add + Sig_cond_add.permute(0, 2, 1))/2\n",
    "    \n",
    "    mu_A = torch.bmm(Sig_12, Sig_22_inv)\n",
    "    \n",
    "    x_samps = torch.zeros((nB, nt, dim))\n",
    "    dx_samps = torch.zeros((nB, nt, dim))\n",
    "    \n",
    "    for dd in range(dim):\n",
    "        x_obs_tmp = x_obs[:,:,dd]\n",
    "        x_obs_tmp_batch = torch.reshape(x_obs_tmp, (nB, nt_obs, 1))\n",
    "        mu_new = torch.bmm(mu_A, x_obs_tmp_batch).reshape((nB, 2*nt))\n",
    "        try:\n",
    "            x_dx_samps_tmp = MultivariateNormal(loc=mu_new, covariance_matrix=Sig_cond).rsample()\n",
    "        except:\n",
    "            x_dx_samps_tmp = np.zeros((nB, 2*nt))\n",
    "            for bb in range(nB):\n",
    "                x_dx_samps_tmp[bb,:] = np.random.multivariate_normal(mu_new[bb,:], Sig_cond[bb,:,:])\n",
    "            x_dx_samps_tmp = torch.from_numpy(x_dx_samps_tmp)\n",
    "            \n",
    "        x_samps[:,:,dd] = x_dx_samps_tmp[:,0:nt]\n",
    "        dx_samps[:,:,dd] = x_dx_samps_tmp[:,nt:(2*nt)]\n",
    "    \n",
    "    return x_samps, dx_samps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00d735ba-b612-4156-962b-2c5d53ae9d8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def GP_FM2(x_data, model, optimizer, alpha, l,\n",
    "          nt, batch_size, t_obs, n_epochs, beta = 0, decrease = None,\n",
    "          ImpSamp = False, beta_a = 1.0, beta_b = 0.5, storeCheck = False, epoch_check_step = 100):\n",
    "    \n",
    "    N = x_data.shape[0]\n",
    "    dim = x_data.shape[2]\n",
    "    \n",
    "    if ImpSamp:\n",
    "        m = torch.distributions.beta.Beta(torch.tensor([beta_a]), torch.tensor([beta_b])) # put more weight on t = 1\n",
    "    \n",
    "    nbatch = int(N/batch_size)\n",
    "    batch_idx = np.reshape(np.arange(0,N),[nbatch, batch_size])\n",
    "    \n",
    "    losses: List[float] = []\n",
    "    if storeCheck:\n",
    "        check_pts = []\n",
    "        check_steps = []\n",
    "        \n",
    "    model.train()\n",
    "    for k in tqdm(range(n_epochs)):\n",
    "        for bb in range(nbatch):\n",
    "            x0 = torch.randn((batch_size,dim))\n",
    "            x_obs = x_data[batch_idx[bb,:],:,:]\n",
    "            x_obs[:,0,:] = x0\n",
    "            \n",
    "            if ImpSamp:\n",
    "                t_mat = m.sample((batch_size,nt))[:,:,0]\n",
    "            else:\n",
    "                t_mat = torch.rand((batch_size,nt))\n",
    "            \n",
    "            try:\n",
    "                xt_batch, ut_batch = samp_x_dx2(t_mat, alpha, l, x_obs, t_obs, beta, decrease)\n",
    "            except:\n",
    "                pass\n",
    "            \n",
    "            t = torch.reshape(t_mat, (-1, 1))\n",
    "            xt = torch.reshape(xt_batch, (-1,dim))\n",
    "            ut = torch.reshape(ut_batch, (-1,dim))\n",
    "            \n",
    "            vt = model(torch.cat([xt, t], dim=-1))\n",
    "            if ImpSamp:\n",
    "                loss = torch.mean((1/torch.exp(m.log_prob(t))[:,None])*((vt - ut) ** 2))\n",
    "            else:\n",
    "                loss = torch.mean((vt - ut) ** 2)\n",
    "            \n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "            \n",
    "            # Logging\n",
    "            losses.append(loss.item())\n",
    "            \n",
    "            if storeCheck:\n",
    "                if k % epoch_check_step == 0:\n",
    "                    check_pts.append(deepcopy(model.state_dict()))\n",
    "                    check_steps.append(k)\n",
    "            \n",
    "    if storeCheck:       \n",
    "        return model, losses, check_pts, check_steps\n",
    "    else:\n",
    "        return model, losses"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb871d7f-5c49-41ff-bf06-b4d656efdeeb",
   "metadata": {},
   "source": [
    "## 1.3 W2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d345ba7-f18f-41ac-b548-da81d25ee8ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "def w_mat_dist(x1_test, x1_gen, p = 2, ot_mat = False):\n",
    "    n_test = x1_test.shape[0]\n",
    "    n_gen = x1_gen.shape[0]\n",
    "    \n",
    "    a, b = np.ones((n_test,)) / n_test, np.ones((n_gen,)) / n_gen  # uniform distribution on samples\n",
    "    if p == 1:\n",
    "        M = ot.dist(x1_test, x1_gen, metric='euclidean')\n",
    "    elif p == 2:\n",
    "        M = ot.dist(x1_test, x1_gen)\n",
    "    G0 = None\n",
    "    if ot_mat:\n",
    "        G0 = ot.emd(a, b, M)\n",
    "    \n",
    "    d = ot.emd2(a, b, M)\n",
    "    \n",
    "    return G0, d"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8f84e40-9a33-4be8-8e48-b706d1ff988b",
   "metadata": {},
   "source": [
    "# 2. Example Paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e1301ae-89ad-4a3b-9b5b-f7781bd817e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "t_mat = torch.rand((2,100))\n",
    "t_obs = torch.tensor([0., 1.])\n",
    "x_obs = torch.zeros((2, 2, 1))\n",
    "x_obs[0,0,0] = x0[0,0]\n",
    "x_obs[0,1,0] = x1[0,0]\n",
    "x_obs[1,0,0] = x0[1,0]\n",
    "x_obs[1,1,0] = x1[1,0]\n",
    "\n",
    "alpha = 1\n",
    "l = 1\n",
    "plt.rcParams['figure.figsize'] = [10, 3]\n",
    "fig, axs = plt.subplots(1, 4)\n",
    "beta_all = [0, 1e-2, 1e-2, 1e-2]\n",
    "decrease_all = [None, None, True, False]\n",
    "for tt, ax in enumerate(axs.flatten()):\n",
    "    for ii in range(50):\n",
    "        x_samp, dx_samp =  samp_x_dx2(t_mat, alpha, l, x_obs,\n",
    "                                      t_obs, beta = beta_all[tt], decrease = decrease_all[tt])\n",
    "        ax.scatter(t_mat[0,:], x_samp[0,:,:], s = 2)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "plt.rcParams['figure.figsize'] = [6, 4]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49e9f5b6-6d9c-4bb9-a4d1-6c31309cf6fd",
   "metadata": {},
   "source": [
    "# 3. Fitting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e00fc753-d73b-4b00-b215-8c638dafde9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim = x1.shape[1]\n",
    "sigma = 1e-2\n",
    "n_samp = 1000\n",
    "nt_gen = 100\n",
    "\n",
    "x_data = torch.zeros(N, 2, dim)\n",
    "x_data[:,1,:] = x1\n",
    "\n",
    "alpha = 1\n",
    "l = 1\n",
    "nt = 10\n",
    "batch_size = 100\n",
    "t_obs = torch.tensor([0., 1.])\n",
    "\n",
    "n_epochs = 5000\n",
    "lr_GP = 2e-3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2548d88a-7718-45c8-876c-ac03c5aceaa2",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_GP0 = MLP(dim=dim, time_varying=True)\n",
    "optimizer_GP0 = torch.optim.Adam(model_GP0.parameters(), lr=lr_GP)\n",
    "model_GP0, losses_GP0 = GP_FM2(x_data, model_GP0, optimizer_GP0, alpha,\n",
    "                               l,nt, batch_size, t_obs, n_epochs, beta = 0, decrease = None)\n",
    "model_GP_noise = MLP(dim=dim, time_varying=True)\n",
    "optimizer_GP_noise = torch.optim.Adam(model_GP_noise.parameters(), lr=lr_GP)\n",
    "model_GP_noise, losses_GP_noise = GP_FM2(x_data, model_GP_noise, optimizer_GP_noise, alpha,\n",
    "                                         l,nt, batch_size, t_obs, n_epochs, beta = 1e-2, decrease = None)\n",
    "model_GP_dec = MLP(dim=dim, time_varying=True)\n",
    "optimizer_GP_dec = torch.optim.Adam(model_GP_dec.parameters(), lr=lr_GP)\n",
    "model_GP_dec, losses_GP_dec = GP_FM2(x_data, model_GP_dec, optimizer_GP_dec, alpha,\n",
    "                               l,nt, batch_size, t_obs, n_epochs, beta = 1e-2, decrease = True)\n",
    "model_GP_inc = MLP(dim=dim, time_varying=True)\n",
    "optimizer_GP_inc = torch.optim.Adam(model_GP_inc.parameters(), lr=lr_GP)\n",
    "model_GP_inc, losses_GP_inc = GP_FM2(x_data, model_GP_inc, optimizer_GP_inc, alpha,\n",
    "                               l,nt, batch_size, t_obs, n_epochs, beta = 1e-2, decrease = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "733ff4d3-fdd6-4504-8193-666ff2fe4a6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "saveFolder = \"/hpc/group/mastatlab/gw74/fm_var_change/\"\n",
    "rep_saveFolder = \"/hpc/group/mastatlab/gw74/fm_var_change/100_seeds/\"\n",
    "\n",
    "# torch.save(model_GP0.state_dict(), saveFolder + \"model_GP0.pt\")\n",
    "# torch.save(model_GP_noise.state_dict(), saveFolder + \"model_GP_noise.pt\")\n",
    "# torch.save(model_GP_dec.state_dict(), saveFolder + \"model_GP_dec.pt\")\n",
    "# torch.save(model_GP_inc.state_dict(), saveFolder + \"model_GP_inc.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6a30fc0-c760-4e3d-b3f2-eebc0acb6bcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture output\n",
    "nSeeds = 100\n",
    "n_epochs = 10000\n",
    "\n",
    "for ll in range(0, nSeeds):\n",
    "    \n",
    "    model_GP0 = MLP(dim=dim, time_varying=True)\n",
    "    optimizer_GP0 = torch.optim.Adam(model_GP0.parameters(), lr=lr_GP)\n",
    "    model_GP0.load_state_dict(torch.load(saveFolder + \"model_GP0.pt\"))\n",
    "    model_GP0, losses_GP0 = GP_FM2(x_data, model_GP0, optimizer_GP0, alpha,\n",
    "                               l,nt, batch_size, t_obs, n_epochs, beta = 0, decrease = None)\n",
    "    torch.save(model_GP0.state_dict(), rep_saveFolder + \"model_GP0_\" + str(ll) + \".pt\")\n",
    "    \n",
    "    model_GP_noise = MLP(dim=dim, time_varying=True)\n",
    "    optimizer_GP_noise = torch.optim.Adam(model_GP_noise.parameters(), lr=lr_GP)\n",
    "    model_GP_noise.load_state_dict(torch.load(saveFolder + \"model_GP_noise.pt\"))\n",
    "    model_GP_noise, losses_GP_noise = GP_FM2(x_data, model_GP_noise, optimizer_GP_noise, alpha,\n",
    "                                         l,nt, batch_size, t_obs, n_epochs,\n",
    "                                             beta = 1e-2, decrease = None)\n",
    "    torch.save(model_GP_noise.state_dict(), rep_saveFolder + \"model_GP_noise_\" + str(ll) + \".pt\")\n",
    "    \n",
    "    model_GP_dec = MLP(dim=dim, time_varying=True)\n",
    "    optimizer_GP_dec = torch.optim.Adam(model_GP_dec.parameters(), lr=lr_GP)\n",
    "    model_GP_dec.load_state_dict(torch.load(saveFolder + \"model_GP_dec.pt\"))\n",
    "    model_GP_dec, losses_GP_dec = GP_FM2(x_data, model_GP_dec, optimizer_GP_dec, alpha,\n",
    "                                   l,nt, batch_size, t_obs, n_epochs, beta = 1e-2, decrease = True)\n",
    "    torch.save(model_GP_dec.state_dict(), rep_saveFolder + \"model_GP_dec_\" + str(ll) + \".pt\")\n",
    "    \n",
    "    model_GP_inc = MLP(dim=dim, time_varying=True)\n",
    "    optimizer_GP_inc = torch.optim.Adam(model_GP_inc.parameters(), lr=lr_GP)\n",
    "    model_GP_inc.load_state_dict(torch.load(saveFolder + \"model_GP_inc.pt\"))\n",
    "    model_GP_inc, losses_GP_inc = GP_FM2(x_data, model_GP_inc, optimizer_GP_inc, alpha,\n",
    "                                   l,nt, batch_size, t_obs, n_epochs, beta = 1e-2, decrease = False)\n",
    "    torch.save(model_GP_inc.state_dict(), rep_saveFolder + \"model_GP_inc_\" + str(ll) + \".pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a13e417-f09e-4cbd-a3f1-d49eb8793aaa",
   "metadata": {},
   "source": [
    "# 4. W2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a88f6af-c3e0-40a7-a099-6e4b9a28dad5",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_test = 1000\n",
    "np.random.seed(1)\n",
    "z_id1_test = np.random.binomial(1, p, N_test)[:,None]\n",
    "x1_test = z_id1_test*np.random.multivariate_normal([-1.5, yend], [[.05, 0], [0, .05]], N_test) +\\\n",
    "(1-z_id1_test)*np.random.multivariate_normal([1.5, yend], [[.05, 0], [0, .05]], N_test)\n",
    "x1_test = torch.from_numpy(x1_test).to(torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccbfd6da-a708-40e9-87f1-df50a873afdd",
   "metadata": {},
   "outputs": [],
   "source": [
    "dAll_GP0 = np.zeros((nSeeds))\n",
    "dAll_GP_noise = np.zeros((nSeeds))\n",
    "dAll_GP_dec = np.zeros((nSeeds))\n",
    "dAll_GP_inc = np.zeros((nSeeds))\n",
    "\n",
    "for ss in range(nSeeds):\n",
    "    \n",
    "    model_GP0 = MLP(dim=dim, time_varying=True)\n",
    "    model_GP0.load_state_dict(torch.load(rep_saveFolder + \"model_GP0_\" + str(ss) + \".pt\"))\n",
    "    traj_GP0 = gen_traj(model_GP0, x1_test.shape[0], 2, ss)\n",
    "    \n",
    "    model_GP_noise = MLP(dim=dim, time_varying=True)\n",
    "    model_GP_noise.load_state_dict(torch.load(rep_saveFolder + \"model_GP_noise_\" + str(ss) + \".pt\"))\n",
    "    traj_GP_noise = gen_traj(model_GP_noise, x1_test.shape[0], 2, ss)\n",
    "    \n",
    "    model_GP_dec = MLP(dim=dim, time_varying=True)\n",
    "    model_GP_dec.load_state_dict(torch.load(rep_saveFolder + \"model_GP_dec_\" + str(ss) + \".pt\"))\n",
    "    traj_GP_dec = gen_traj(model_GP_dec, x1_test.shape[0], 2, ss)\n",
    "    \n",
    "    model_GP_inc = MLP(dim=dim, time_varying=True)\n",
    "    model_GP_inc.load_state_dict(torch.load(rep_saveFolder + \"model_GP_inc_\" + str(ss) + \".pt\"))\n",
    "    traj_GP_inc = gen_traj(model_GP_inc, x1_test.shape[0], 2, ss)\n",
    "    \n",
    "    _, dAll_GP0[ss] = w_mat_dist(x1_test.numpy(), traj_GP0[-1,:,:].numpy(), p = 2)\n",
    "    _, dAll_GP_noise[ss] = w_mat_dist(x1_test.numpy(), traj_GP_noise[-1,:,:].numpy(), p = 2)\n",
    "    _, dAll_GP_dec[ss] = w_mat_dist(x1_test.numpy(), traj_GP_dec[-1,:,:].numpy(), p = 2)\n",
    "    _, dAll_GP_inc[ss] = w_mat_dist(x1_test.numpy(), traj_GP_inc[-1,:,:].numpy(), p = 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27c3bcf1-15d5-4ca9-aac7-28b5664287e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('I-GP-CFM, 0: {:.3f} +- {:.3f}'.format(np.mean(dAll_GP0), np.std(dAll_GP0)))\n",
    "print('I-GP-CFM, noise: {:.3f} +- {:.3f}'.format(np.mean(dAll_GP_noise), np.std(dAll_GP_noise)))\n",
    "print('I-GP-CFM, decrease: {:.3f} +- {:.3f}'.format(np.mean(dAll_GP_dec), np.std(dAll_GP_dec)))\n",
    "print('I-GP-CFM, increase: {:.3f} +- {:.3f}'.format(np.mean(dAll_GP_inc), np.std(dAll_GP_inc)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d65e1b1-cff9-4e22-9590-0d6f0ee7f996",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01988dfe-1719-4d40-b2f4-d5d8caacb2a4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
