{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8abb343a-59d2-462a-897c-56f5fd8f9278",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torchsde\n",
    "from torchdyn.core import NeuralODE\n",
    "from torchvision import datasets, transforms\n",
    "from torchvision.transforms import ToPILImage\n",
    "from torchvision.utils import make_grid\n",
    "from tqdm import tqdm\n",
    "\n",
    "from torchcfm.conditional_flow_matching import *\n",
    "from torchcfm.models.unet import UNetModel\n",
    "from torchcfm.optimal_transport import OTPlanSampler\n",
    "\n",
    "savedir = \"models/mnist\"\n",
    "os.makedirs(savedir, exist_ok=True)\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "# from torchcfm.optimal_transport import OTPlanSampler\n",
    "\n",
    "from typing import List\n",
    "import time\n",
    "from torchdyn.core import NeuralODE\n",
    "\n",
    "from tqdm import tqdm\n",
    "from torch.distributions.multivariate_normal import MultivariateNormal\n",
    "# import ot\n",
    "# import ot.plot\n",
    "import pickle\n",
    "from copy import deepcopy\n",
    "import gc\n",
    "\n",
    "import metric.pytorch_ssim\n",
    "from metric.IS_score import *\n",
    "from metric.Fid_score import *\n",
    "from torchmetrics.image.kid import KernelInceptionDistance"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b952d154-4910-44e9-b9cb-02c510712c84",
   "metadata": {},
   "source": [
    "# 0. Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccd56fb4-194a-4adb-9164-c17df31c8284",
   "metadata": {},
   "outputs": [],
   "source": [
    "use_cuda = torch.cuda.is_available()\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "torch.set_default_device(device)\n",
    "# torch.get_default_device()\n",
    "\n",
    "batch_size = 128\n",
    "\n",
    "trainset = datasets.MNIST(\n",
    "    \"../data\",\n",
    "    train=True,\n",
    "    download=False,\n",
    "    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]),\n",
    ")\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    trainset, batch_size=batch_size, shuffle=True, drop_last=True,\n",
    "    generator=torch.Generator(device)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db7c8035-7eed-4219-95da-75b9a3414b77",
   "metadata": {},
   "outputs": [],
   "source": [
    "testset = datasets.MNIST(\n",
    "    \"../data\",\n",
    "    train=False,\n",
    "    download=False,\n",
    "    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f8faebe-7196-4b62-8125-af2fb64c4556",
   "metadata": {},
   "source": [
    "# 1. Functions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c7f8918-5b2e-4d2e-9b2b-c968e37a5d99",
   "metadata": {},
   "source": [
    "Compare 4 models:\n",
    "1. Independent CFM (I-CFM)\n",
    "2. Optimal transport CFM (OT-CFM)\n",
    "3. GP independent CFM (GP-I-CFM)\n",
    "4. GP optimal transport CFM (GP-OT-CFM)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4454484-707f-4858-bfa9-2a6f46fdc889",
   "metadata": {},
   "source": [
    "## 1.0 Common Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35f82842-a345-49ba-b21f-617816410240",
   "metadata": {},
   "outputs": [],
   "source": [
    "def nonGP_model_fit(model, optimizer, FM, n_epochs, train_loader, device, subset = False, iMax = None):\n",
    "    for epoch in tqdm(range(n_epochs)):\n",
    "        for i, data in enumerate(train_loader):\n",
    "            \n",
    "            if subset:\n",
    "                if i > iMax:\n",
    "                    break\n",
    "            \n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            x1 = data[0].to(device)\n",
    "            x0 = torch.randn_like(x1)\n",
    "            t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)\n",
    "            vt = model(t, xt)\n",
    "            loss = torch.mean((vt - ut) ** 2)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3842ac5d-6960-49af-87af-9f3272e96813",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_r(ti, tj):\n",
    "    r = ti[...,None] - tj[...,None,:]\n",
    "    r[r == 0] = 1e-15\n",
    "    return r\n",
    "def k11(r, alpha, l):\n",
    "    return (alpha**2)*torch.exp(-0.5 * ((r/l)**2))\n",
    "def k12(r, alpha, l):\n",
    "    return (alpha**2/l**2)*r*torch.exp(-0.5*((r/l)**2))\n",
    "def k22(r, alpha, l):\n",
    "    return (alpha**2/l**4)*(l**2 - r**2)*torch.exp(-0.5*((r/l)**2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0311a04-6102-483c-bc12-3fb14f72d958",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cov_mat2(ti, tj, alpha, l, sig2_diag = 1e-8):\n",
    "    \n",
    "    r = calc_r(ti, tj)\n",
    "    nB = r.shape[0]\n",
    "    nt = r.shape[1]\n",
    "    \n",
    "    Sig11 = k11(r, alpha, l) + (torch.eye(nt)*sig2_diag).repeat(nB,1,1)\n",
    "    Sig12 = k12(r, alpha, l)\n",
    "    Sig21 = Sig12.permute(0, 2, 1)\n",
    "    Sig22 = k22(r, alpha, l)\n",
    "    \n",
    "    block_row1 = torch.cat([Sig11, Sig12], dim=2)\n",
    "    block_row2 = torch.cat([Sig21, Sig22], dim=2)\n",
    "    Sig = torch.cat([block_row1, block_row2], dim = 1)\n",
    "    Sig = (Sig + Sig.permute(0, 2, 1))/2\n",
    "    \n",
    "    return Sig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78564f01-8b7e-48b9-b9f7-a5ab59d2682b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def samp_x_dx2(t_mat, alpha, l, x_obs, t_obs, sig2_diag=1e-8):\n",
    "    nB, nt, dim = x_obs.shape[0], t_mat.shape[1], x_obs.shape[2]\n",
    "    nt_obs = t_obs.shape[0]\n",
    "\n",
    "    # Compute necessary covariance matrices and kernel functions\n",
    "    r_obs_x = calc_r(t_obs, t_mat)\n",
    "    r_obs_obs = calc_r(t_obs, t_obs)\n",
    "    Sig_11 = cov_mat2(t_mat, t_mat, alpha, l, sig2_diag)\n",
    "    \n",
    "    # Precompute parts of the covariance matrices\n",
    "    k_obs_x, k_obs_dx = k11(r_obs_x, alpha, l), k12(r_obs_x, alpha, l)\n",
    "    Sig_21 = torch.cat([k_obs_x, k_obs_dx], dim=2)\n",
    "    Sig_12 = Sig_21.permute(0, 2, 1)\n",
    "\n",
    "    Sig_22_sing = k11(r_obs_obs, alpha, l) + torch.eye(nt_obs) * sig2_diag\n",
    "    Sig_22_inv_sing = torch.linalg.inv(Sig_22_sing)\n",
    "    Sig_22_inv = Sig_22_inv_sing.repeat(nB, 1, 1)\n",
    "\n",
    "    # Compute conditional covariance matrix with stability adjustment\n",
    "    Sig_cond = Sig_11 - torch.bmm(torch.bmm(Sig_12, Sig_22_inv), Sig_21)\n",
    "    # Sig_cond = (Sig_cond + Sig_cond.permute(0, 2, 1)) / 2 + 1e-6 * torch.eye(Sig_cond.shape[1], device=Sig_cond.device)\n",
    "\n",
    "    Sig_cond = (Sig_cond + Sig_cond.permute(0, 2, 1))/2\n",
    "    \n",
    "    svd_add_idx = torch.sum((torch.linalg.eigvals(Sig_cond).real>=0).T, axis = 0) != Sig_cond.shape[1]\n",
    "    U, S, Vh = torch.linalg.svd(Sig_cond[svd_add_idx,:,:])\n",
    "#     U, S, Vh = torch.linalg.svd(Sig_cond)\n",
    "    Sig_cond_add = torch.bmm(torch.bmm(Vh.permute(0, 2, 1), torch.diag_embed(S + 1e-8)), Vh)\n",
    "    Sig_cond[svd_add_idx,:,:] = (Sig_cond_add + Sig_cond_add.permute(0, 2, 1))/2\n",
    "\n",
    "    # Mean adjustment matrix\n",
    "    mu_A = torch.bmm(Sig_12, Sig_22_inv)\n",
    "    x_obs_batch = x_obs.reshape(nB, nt_obs, dim)\n",
    "    mu_new = torch.bmm(mu_A, x_obs_batch).reshape(nB, 2 * nt, dim)\n",
    "\n",
    "    # Initialize sample matrices\n",
    "    x_samps = torch.zeros((nB, nt, dim), dtype=x_obs.dtype, device=x_obs.device)\n",
    "    dx_samps = torch.zeros((nB, nt, dim), dtype=x_obs.dtype, device=x_obs.device)\n",
    "    \n",
    "    mu_flat = mu_new.permute(0, 2, 1).reshape(nB * dim, 2 * nt)\n",
    "    Sig_cond_flat = Sig_cond.repeat_interleave(dim, dim=0)\n",
    "    \n",
    "    # Sampling in batch for all dimensions at once\n",
    "    try:\n",
    "        # Reshape mu_new and Sig_cond for compatible shapes\n",
    "#         mu_flat = mu_new.view(nB * dim, 2 * nt)\n",
    "#         Sig_cond_flat = Sig_cond.repeat(dim, 1, 1)\n",
    "        \n",
    "        dist = MultivariateNormal(loc=mu_flat, covariance_matrix=Sig_cond_flat)\n",
    "        x_dx_samps_flat = dist.rsample().reshape(nB, dim, 2 * nt).permute(0, 2, 1)\n",
    "    except RuntimeError:\n",
    "        print('Sampling failed; using numpy fallback.')\n",
    "        x_dx_samps_flat = torch.zeros((nB, 2 * nt, dim), dtype=x_obs.dtype, device=x_obs.device)\n",
    "        for bb in range(nB):\n",
    "            for dd in range(dim):\n",
    "                mu_single = mu_new[bb, :, dd].cpu().numpy()\n",
    "                cov_single = Sig_cond[bb].cpu().numpy()\n",
    "                sample = np.random.multivariate_normal(mu_single, cov_single)\n",
    "                x_dx_samps_flat[bb, :, dd] = torch.from_numpy(sample)\n",
    "\n",
    "    # Separate x and dx samples\n",
    "    x_samps[:, :, :] = x_dx_samps_flat[:, :nt, :]\n",
    "    dx_samps[:, :, :] = x_dx_samps_flat[:, nt:, :]\n",
    "\n",
    "    return x_samps, dx_samps"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee4ece8b-670c-417f-ac9b-9a86d6db55cd",
   "metadata": {},
   "source": [
    "## 1.1 I-CFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50553d31-1d9e-42fe-b415-4a3846e60470",
   "metadata": {},
   "outputs": [],
   "source": [
    "def icfm_fit(model, optimizer, sigma, n_epochs, train_loader, device, subset = False, iMax = None):\n",
    "    FM = ConditionalFlowMatcher(sigma=sigma)\n",
    "    model = nonGP_model_fit(model, optimizer, FM, n_epochs, train_loader, device, subset, iMax)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc084bdc-8c15-4856-a718-6428182e286b",
   "metadata": {},
   "source": [
    "## 1.2 OT-CFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "412f97e8-fe36-4266-b095-54eaa2c991d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def otcfm_fit(model, optimizer, sigma, n_epochs, train_loader, device, subset = False, iMax = None):\n",
    "    FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)\n",
    "    model = nonGP_model_fit(model, optimizer, FM, n_epochs, train_loader, device, subset, iMax)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26d5f1af-e4dc-471c-845a-52cf4280545d",
   "metadata": {},
   "source": [
    "## 1.3 GP-ICFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b28d41ff-6044-41a2-a525-3fd94c66c1d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gp_icfm_fit(model, optimizer, alpha, l, sig2_diag, n_epochs, train_loader, device,\n",
    "               subset = False, iMax = None):\n",
    "    \n",
    "    for epoch in tqdm(range(n_epochs)):\n",
    "        for i, data in enumerate(train_loader):\n",
    "            \n",
    "            if subset:\n",
    "                if i > iMax:\n",
    "                    break\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            x1 = data[0].to(device)\n",
    "            x0 = torch.randn_like(x1)\n",
    "\n",
    "            btch_size = x1.shape[0]\n",
    "            x01_trans = torch.zeros(btch_size, 2, 28*28)\n",
    "            x01_trans[:,0,:] = torch.reshape(x0, (btch_size, -1))\n",
    "            x01_trans[:,1,:] = torch.reshape(x1, (btch_size, -1))\n",
    "\n",
    "            t_mat = torch.rand((btch_size,1))\n",
    "            \n",
    "            try:\n",
    "                xt_batch, ut_batch = samp_x_dx2(t_mat, alpha, l, x01_trans,\n",
    "                                                torch.tensor([0, 1]), sig2_diag)\n",
    "            except:\n",
    "                print('fail')\n",
    "                pass\n",
    "\n",
    "            t = torch.reshape(t_mat, (-1, ))\n",
    "            xt = torch.reshape(xt_batch, (btch_size, 1, 28, 28))\n",
    "            ut = torch.reshape(ut_batch, (btch_size, 1, 28, 28))\n",
    "            vt = model(t, xt)\n",
    "\n",
    "            loss = torch.mean((vt - ut) ** 2)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3dd13d90-1848-40cc-bf61-5aa86d1ccb40",
   "metadata": {},
   "source": [
    "## 1.4 GP-OTCFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4154f48-2a91-4af3-8b36-c101ec8448d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gp_otcfm_fit(model, optimizer, alpha, l, sig2_diag, n_epochs, train_loader, device,\n",
    "                subset = False, iMax = None):\n",
    "    \n",
    "    ot_sampler = OTPlanSampler(method=\"exact\")\n",
    "    \n",
    "    for epoch in tqdm(range(n_epochs)):\n",
    "        for i, data in enumerate(train_loader):\n",
    "            \n",
    "            if subset:\n",
    "                if i > iMax:\n",
    "                    break\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            x1 = data[0].to(device)\n",
    "            x0 = torch.randn_like(x1)\n",
    "            \n",
    "            x0, x1 = ot_sampler.sample_plan(x0, x1)\n",
    "            \n",
    "            btch_size = x1.shape[0]\n",
    "            x01_trans = torch.zeros(btch_size, 2, 28*28)\n",
    "            x01_trans[:,0,:] = torch.reshape(x0, (btch_size, -1))\n",
    "            x01_trans[:,1,:] = torch.reshape(x1, (btch_size, -1))\n",
    "\n",
    "            t_mat = torch.rand((btch_size,1))\n",
    "            \n",
    "            try:\n",
    "                xt_batch, ut_batch = samp_x_dx2(t_mat, alpha, l, x01_trans,\n",
    "                                                torch.tensor([0, 1]), sig2_diag)\n",
    "            except:\n",
    "                print('fail')\n",
    "                pass\n",
    "\n",
    "            t = torch.reshape(t_mat, (-1, ))\n",
    "            xt = torch.reshape(xt_batch, (btch_size, 1, 28, 28))\n",
    "            ut = torch.reshape(ut_batch, (btch_size, 1, 28, 28))\n",
    "            vt = model(t, xt)\n",
    "\n",
    "            loss = torch.mean((vt - ut) ** 2)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9cc817a0-1d9c-4f74-a636-358f6e446d9a",
   "metadata": {},
   "source": [
    "# 3. Model Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db6a4637-27b3-41f0-a138-e6845b797ce5",
   "metadata": {},
   "outputs": [],
   "source": [
    "sigma = 0.0\n",
    "n_epochs = 5\n",
    "\n",
    "# truncate the samples\n",
    "subset = False\n",
    "iMax = None\n",
    "\n",
    "alpha = 1\n",
    "l = 5 # let's try l = 4 later...\n",
    "sig2_diag = 0\n",
    "\n",
    "nRep = 5\n",
    "rootFolder = \"/hpc/home/gw74/diff_model/FM/image/model_store/unconditional_short\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "335e57f1-8175-48c0-b5cb-25e104ee9d5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture output\n",
    "torch.set_default_device(device)\n",
    "load_pre = False\n",
    "\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "for ll in range(nRep):\n",
    "    \n",
    "    # 1. icfm\n",
    "    model_icfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)\n",
    "    if load_pre:\n",
    "        model_icfm.load_state_dict(torch.load(rootFolder + \"/icfm\" + str(ll) + \".pt\"))\n",
    "    optimizer_icfm = torch.optim.Adam(model_icfm.parameters(), lr = 2e-4)\n",
    "    model_icfm = icfm_fit(model_icfm, optimizer_icfm, sigma, n_epochs,\n",
    "                          train_loader, device, subset = subset, iMax = iMax)\n",
    "    torch.save(model_icfm.state_dict(), rootFolder + \"/icfm\" + str(ll) + \".pt\")\n",
    "    \n",
    "    \n",
    "    # 2. otcfm\n",
    "    model_otcfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)\n",
    "    if load_pre:\n",
    "        model_otcfm.load_state_dict(torch.load(rootFolder + \"/otcfm\" + str(ll) + \".pt\"))\n",
    "    optimizer_otcfm = torch.optim.Adam(model_otcfm.parameters(), lr = 2e-4)\n",
    "    model_otcfm = otcfm_fit(model_otcfm, optimizer_otcfm, sigma, n_epochs,\n",
    "                          train_loader, device, subset = subset, iMax = iMax)\n",
    "    torch.save(model_otcfm.state_dict(), rootFolder + \"/otcfm\" + str(ll) + \".pt\")\n",
    "    \n",
    "    \n",
    "    # 3. gp-icfm\n",
    "    model_gp_icfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)\n",
    "    if load_pre:\n",
    "        model_gp_icfm.load_state_dict(torch.load(rootFolder + \"/gp_icfm\" + str(ll) + \".pt\"))\n",
    "    optimizer_gp_icfm = torch.optim.Adam(model_gp_icfm.parameters(), lr = 2e-4)\n",
    "    model_gp_icfm = gp_icfm_fit(model_gp_icfm, optimizer_gp_icfm, alpha, l, sig2_diag, n_epochs, train_loader,\n",
    "                                device, subset = subset, iMax = iMax)\n",
    "    torch.save(model_gp_icfm.state_dict(), rootFolder + \"/gp_icfm\" + str(ll) + \".pt\")\n",
    "    \n",
    "    \n",
    "    # 4. gp-otcfm\n",
    "    model_gp_otcfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)\n",
    "    if load_pre:\n",
    "        model_gp_otcfm.load_state_dict(torch.load(rootFolder + \"/gp_otcfm\" + str(ll) + \".pt\"))\n",
    "    optimizer_gp_otcfm = torch.optim.Adam(model_gp_otcfm.parameters(), lr = 2e-4)\n",
    "    model_gp_otcfm = gp_otcfm_fit(model_gp_otcfm, optimizer_gp_otcfm, alpha, l, sig2_diag, n_epochs, train_loader,\n",
    "                                device, subset = subset, iMax = iMax)\n",
    "    torch.save(model_gp_otcfm.state_dict(), rootFolder + \"/gp_otcfm\" + str(ll) + \".pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4c72a27-3114-4e97-9075-5a4be95ceae6",
   "metadata": {},
   "source": [
    "# 4. Plotting (check)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f19a5e7-6c29-4b07-b971-8ffcd9a90d23",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plotFun(model, x0_gen):\n",
    "    node = NeuralODE(model, solver=\"dopri5\", sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        traj = node.trajectory(\n",
    "            x0_gen, t_span=torch.linspace(0, 1, 2, device=device),\n",
    "        )\n",
    "    grid = make_grid(\n",
    "        traj[-1, :100].view([-1, 1, 28, 28]).clip(-1, 1), value_range=(-1, 1), padding=0, nrow=10\n",
    "    )\n",
    "\n",
    "    img = ToPILImage()(grid)\n",
    "    plt.imshow(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac417803-ca6d-408c-b3e2-9b9445733b0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "x0_gen = torch.randn(100, 1, 28, 28, device=device)\n",
    "nCheck = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4670f885-6925-483d-90aa-cd8ce16bd9c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_icfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)\n",
    "model_icfm.load_state_dict(torch.load(rootFolder + \"/icfm\" + str(nCheck) + \".pt\"))\n",
    "plotFun(model_icfm, x0_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e07ecab-8bea-4e55-a3d3-7c732b3c87b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_otcfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)\n",
    "model_otcfm.load_state_dict(torch.load(rootFolder + \"/otcfm\" + str(nCheck) + \".pt\"))\n",
    "plotFun(model_otcfm, x0_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec95a7b7-c9ff-4ddb-a1d3-5635a77ea427",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_gp_icfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)\n",
    "model_gp_icfm.load_state_dict(torch.load(rootFolder + \"/gp_icfm\" + str(nCheck) + \".pt\"))\n",
    "plotFun(model_gp_icfm, x0_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ce98954-6999-44be-88d3-c67ba70319cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_gp_otcfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)\n",
    "model_gp_otcfm.load_state_dict(torch.load(rootFolder + \"/gp_otcfm\" + str(nCheck) + \".pt\"))\n",
    "plotFun(model_gp_otcfm, x0_gen)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1276ed79-7bf5-4d21-bb5e-59861ce8c41c",
   "metadata": {},
   "source": [
    "# 5. Model Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "054032e9-4378-45d9-a806-053f99029e6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import metric.pytorch_ssim\n",
    "from metric.IS_score import *\n",
    "from metric.Fid_score import *\n",
    "from torchmetrics.image.kid import KernelInceptionDistance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "252f34c5-bbe5-4579-93a5-f43c92357001",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_eval = 32\n",
    "rep_eval = 40\n",
    "\n",
    "torch.set_default_device('cpu')\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    testset, batch_size=batch_eval, shuffle=True, drop_last=True,\n",
    "    generator=torch.Generator('cpu')\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6718a345-abca-41ce-9f5b-b0ce9e3f334f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_1_img(node, x0_gen):\n",
    "    with torch.no_grad():\n",
    "        traj = node.trajectory(\n",
    "        x0_gen, t_span=torch.linspace(0, 1, 2, device=device),)\n",
    "    traj = traj[-1, :].clip(-1, 1)\n",
    "    return traj\n",
    "\n",
    "def gen_img(node, rep_eval, batch_eval):\n",
    "    for i in range(rep_eval):\n",
    "        torch.manual_seed(i)\n",
    "        x0_gen_id = torch.randn(batch_eval, 1, 28, 28, device=device)\n",
    "        sampled_x = gen_1_img(node, x0_gen_id)\n",
    "\n",
    "        recon_images = sampled_x.detach().cpu().numpy()\n",
    "        if i==0:\n",
    "            all_images=recon_images\n",
    "        else:\n",
    "            all_images = np.concatenate((all_images,recon_images),axis=0)\n",
    "\n",
    "    return all_images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ab4f2ee-93b2-4b4a-96de-ce529220914c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture output\n",
    "node_list_icfm = []\n",
    "node_list_otcfm = []\n",
    "node_list_gp_icfm = []\n",
    "node_list_gp_otcfm = []\n",
    "\n",
    "img_list_icfm = []\n",
    "img_list_otcfm = []\n",
    "img_list_gp_icfm = []\n",
    "img_list_gp_otcfm = []\n",
    "\n",
    "for ll in tqdm(range(nRep)):\n",
    "    \n",
    "    # 1. icfm\n",
    "    model_icfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)\n",
    "    model_icfm.load_state_dict(torch.load(rootFolder + \"/icfm\" + str(ll) + \".pt\"))\n",
    "    node = NeuralODE(model_icfm, solver=\"dopri5\", sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4)\n",
    "    all_images = gen_img(node, rep_eval, batch_eval)\n",
    "    node_list_icfm.append(node)\n",
    "    img_list_icfm.append(all_images)\n",
    "    \n",
    "    # 2. otcfm\n",
    "    model_otcfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)\n",
    "    model_otcfm.load_state_dict(torch.load(rootFolder + \"/otcfm\" + str(ll) + \".pt\"))\n",
    "    node = NeuralODE(model_otcfm, solver=\"dopri5\", sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4)\n",
    "    all_images = gen_img(node, rep_eval, batch_eval)\n",
    "    node_list_otcfm.append(node)\n",
    "    img_list_otcfm.append(all_images)\n",
    "    \n",
    "    # 3. gp_icfm\n",
    "    model_gp_icfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)\n",
    "    model_gp_icfm.load_state_dict(torch.load(rootFolder + \"/gp_icfm\" + str(ll) + \".pt\"))\n",
    "    node = NeuralODE(model_gp_icfm, solver=\"dopri5\", sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4)\n",
    "    all_images = gen_img(node, rep_eval, batch_eval)\n",
    "    node_list_gp_icfm.append(node)\n",
    "    img_list_gp_icfm.append(all_images)\n",
    "    \n",
    "    # 4. gp_otcfm\n",
    "    model_gp_otcfm = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)\n",
    "    model_gp_otcfm.load_state_dict(torch.load(rootFolder + \"/gp_otcfm\" + str(ll) + \".pt\"))\n",
    "    node = NeuralODE(model_gp_otcfm, solver=\"dopri5\", sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4)\n",
    "    all_images = gen_img(node, rep_eval, batch_eval)\n",
    "    node_list_gp_otcfm.append(node)\n",
    "    img_list_gp_otcfm.append(all_images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "975e0aea-6ffb-4381-8882-3339aa69d0ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open(\"img_list_icfm\", \"wb\") as fp: pickle.dump(img_list_icfm, fp);\n",
    "with open(\"img_list_otcfm\", \"wb\") as fp: pickle.dump(img_list_otcfm, fp);    \n",
    "with open(\"img_list_gp_icfm\", \"wb\") as fp: pickle.dump(img_list_gp_icfm, fp);    \n",
    "with open(\"img_list_gp_otcfm\", \"wb\") as fp: pickle.dump(img_list_gp_otcfm, fp);\n",
    "\n",
    "# with open(\"img_list_icfm\", \"rb\") as fp: img_list_icfm = pickle.load(fp);\n",
    "# with open(\"img_list_otcfm\", \"rb\") as fp: img_list_otcfm = pickle.load(fp);\n",
    "# with open(\"img_list_gp_icfm\", \"rb\") as fp: img_list_gp_icfm = pickle.load(fp);\n",
    "# with open(\"img_list_gp_otcfm\", \"rb\") as fp: img_list_gp_otcfm = pickle.load(fp);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4be8a6d3-e545-4b3c-82c0-434c2e9477ad",
   "metadata": {},
   "source": [
    "## 5.1 KID"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec5b4f26-e72a-4b95-b7ba-b9be28684add",
   "metadata": {},
   "outputs": [],
   "source": [
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67a8365f-4c18-498f-8dbf-382ceb4c92e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, (images, labels) in enumerate(test_loader):\n",
    "    \n",
    "    if i==0:\n",
    "        real_images_all = np.repeat(images,3,axis=1)\n",
    "    else:\n",
    "        real_image = np.repeat(images,3,axis=1)\n",
    "        real_images_all = np.concatenate((real_images_all,real_image),axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ea1db67-1e4c-45b9-bcb7-e3f1bc65a6d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_real_n = real_images_all.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ef03e05-1914-4f3d-8caa-413ea76bf6b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def kid_calc(real_images, all_images, subset_size = rep_eval):\n",
    "    \n",
    "    kid = KernelInceptionDistance(subset_size = subset_size)\n",
    "    \n",
    "    real_trans = 255*(real_images + 1)/2\n",
    "    all_trans = 255*(np.repeat(all_images,3,axis=1) + 1)/2\n",
    "    \n",
    "    A = torch.from_numpy(real_trans).type(torch.uint8)\n",
    "    B = torch.from_numpy(all_trans).type(torch.uint8)\n",
    "    \n",
    "    kid.update(A, real=True)\n",
    "    kid.update(B, real=False)\n",
    "    kid_mean, kid_std = kid.compute()\n",
    "    \n",
    "    return kid_mean.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae101d54-f327-410b-afa1-e2383de0c7a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture output\n",
    "\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "kid_icfm = np.zeros(nRep)\n",
    "kid_otcfm = np.zeros(nRep)\n",
    "kid_gp_icfm = np.zeros(nRep)\n",
    "kid_gp_otcfm = np.zeros(nRep)\n",
    "\n",
    "for ll in tqdm(range(nRep)):\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n",
    "    real_img_id = np.random.choice(all_real_n, batch_eval*rep_eval, replace = False)\n",
    "    real_images = real_images_all[real_img_id,:,:,:]\n",
    "    \n",
    "    kid_icfm[ll] = kid_calc(real_images, img_list_icfm[ll], subset_size = rep_eval)\n",
    "    kid_otcfm[ll] = kid_calc(real_images, img_list_otcfm[ll], subset_size = rep_eval)\n",
    "    kid_gp_icfm[ll] = kid_calc(real_images, img_list_gp_icfm[ll], subset_size = rep_eval)\n",
    "    kid_gp_otcfm[ll] = kid_calc(real_images, img_list_gp_otcfm[ll], subset_size = rep_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b34408e-9e6d-492b-8c92-3f47a52d0887",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('icfm: {:.3f} +- {:.3f}'.format(np.mean(kid_icfm), np.std(kid_icfm)))\n",
    "print('otcfm: {:.3f} +- {:.3f}'.format(np.mean(kid_otcfm), np.std(kid_otcfm)))\n",
    "print('gp_icfm: {:.3f} +- {:.3f}'.format(np.mean(kid_gp_icfm), np.std(kid_gp_icfm)))\n",
    "print('gp_otcfm: {:.3f} +- {:.3f}'.format(np.mean(kid_gp_otcfm), np.std(kid_gp_otcfm)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41513adf-c3a7-4eea-b7c4-0ae258132a90",
   "metadata": {},
   "source": [
    "## 5.2 FID"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "580dba27-4385-433d-a084-ea6204fa0126",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fid_calc(all_images, test_loader, rep_eval = rep_eval):\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n",
    "    up = nn.Upsample(size=(299, 299), mode='bilinear').type(torch.cuda.FloatTensor)\n",
    "    \n",
    "    all_images = up(torch.Tensor((all_images + 1)/2).cuda(0)).cpu().numpy()\n",
    "    all_images = np.transpose(all_images,(0,2,3,1))\n",
    "    all_images = np.repeat(all_images,3,axis=3)\n",
    "    \n",
    "    for i, (images, labels) in enumerate(test_loader):\n",
    "        \n",
    "        images = (images + 1)/2\n",
    "        \n",
    "        if i == rep_eval:\n",
    "            break\n",
    "        \n",
    "        if i == 0:\n",
    "            # real_image = np.repeat(images,3,axis=1)\n",
    "            real_image = np.repeat((images + 1)/2,3,axis=1)\n",
    "            real_image=up(real_image.cuda(0)).cpu().numpy()\n",
    "            real_images=np.transpose(real_image,(0,2,3,1))\n",
    "        else:\n",
    "            # real_image = np.repeat(images,3,axis=1)\n",
    "            real_image = np.repeat((images + 1)/2,3,axis=1)\n",
    "            real_image=up(real_image.cuda(0)).cpu().numpy()\n",
    "            real_image=np.transpose(real_image,(0,2,3,1))\n",
    "            real_images = np.concatenate((real_images,real_image),axis=0)\n",
    "        \n",
    "    Fid = calculate_fid(all_images, real_images, use_multiprocessing=False, batch_size=4)\n",
    "    return Fid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d21c548-05ea-4ee8-9a8c-76edd88db8e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "gc.collect()\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "fid_icfm = np.zeros(nRep)\n",
    "fid_otcfm = np.zeros(nRep)\n",
    "fid_gp_icfm = np.zeros(nRep)\n",
    "fid_gp_otcfm = np.zeros(nRep)\n",
    "\n",
    "for ll in tqdm(range(nRep)):\n",
    "    \n",
    "    fid_icfm[ll] = fid_calc(img_list_icfm[ll], test_loader, rep_eval = rep_eval)\n",
    "    fid_otcfm[ll] = fid_calc(img_list_otcfm[ll], test_loader, rep_eval = rep_eval)\n",
    "    fid_gp_icfm[ll] = fid_calc(img_list_gp_icfm[ll], test_loader, rep_eval = rep_eval)\n",
    "    fid_gp_otcfm[ll] = fid_calc(img_list_gp_otcfm[ll], test_loader, rep_eval = rep_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3f733c4-a479-49fa-8b66-73f9d233d03f",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('icfm: {:.3f} +- {:.3f}'.format(np.mean(fid_icfm), np.std(fid_icfm)))\n",
    "print('otcfm: {:.3f} +- {:.3f}'.format(np.mean(fid_otcfm), np.std(fid_otcfm)))\n",
    "print('gp_icfm: {:.3f} +- {:.3f}'.format(np.mean(fid_gp_icfm), np.std(fid_gp_icfm)))\n",
    "print('gp_otcfm: {:.3f} +- {:.3f}'.format(np.mean(fid_gp_otcfm), np.std(fid_gp_otcfm)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9c831ab-7548-48c3-b341-28407aec3497",
   "metadata": {},
   "source": [
    "# 5. Summarize"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26d5ba0a-dc4a-4ed8-8734-afde77032db6",
   "metadata": {},
   "source": [
    "Fit the model 100 ($5\\times 20$) times, and summarize them via histogrms, mean and std error."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdde04d0-799c-4e7b-8b4b-044a349faf73",
   "metadata": {},
   "outputs": [],
   "source": [
    "readFolder = \"/hpc/group/mastatlab/gw74/mnist_eval\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3debf21e-d2c1-4bd3-972c-4032df9a0889",
   "metadata": {},
   "outputs": [],
   "source": [
    "kid_icfm = []\n",
    "kid_otcfm = []\n",
    "kid_gp_icfm = []\n",
    "kid_gp_otcfm = []\n",
    "\n",
    "fid_icfm = []\n",
    "fid_otcfm = []\n",
    "fid_gp_icfm = []\n",
    "fid_gp_otcfm = []\n",
    "\n",
    "for ll in range(20): \n",
    "\n",
    "    with open(readFolder + \"/kid_icfm_\" + str(ll), \"rb\") as fp: kid_icfm_tmp = pickle.load(fp);   \n",
    "    with open(readFolder + \"/kid_otcfm_\" + str(ll), \"rb\") as fp: kid_otcfm_tmp = pickle.load(fp);\n",
    "    with open(readFolder + \"/kid_gp_icfm_\" + str(ll), \"rb\") as fp: kid_gp_icfm_tmp = pickle.load(fp);\n",
    "    with open(readFolder + \"/kid_gp_otcfm_\" + str(ll), \"rb\") as fp: kid_gp_otcfm_tmp = pickle.load(fp);    \n",
    "\n",
    "    with open(readFolder + \"/fid_icfm_\" + str(ll), \"rb\") as fp: fid_icfm_tmp = pickle.load(fp);   \n",
    "    with open(readFolder + \"/fid_otcfm_\" + str(ll), \"rb\") as fp: fid_otcfm_tmp = pickle.load(fp);\n",
    "    with open(readFolder + \"/fid_gp_icfm_\" + str(ll), \"rb\") as fp: fid_gp_otcfm_tmp = pickle.load(fp);\n",
    "    with open(readFolder + \"/fid_gp_otcfm_\" + str(ll), \"rb\") as fp: fid_gp_icfm_tmp = pickle.load(fp);\n",
    "    \n",
    "    kid_icfm.append(kid_icfm_tmp)\n",
    "    kid_otcfm.append(kid_otcfm_tmp)\n",
    "    kid_gp_icfm.append(kid_gp_icfm_tmp)\n",
    "    kid_gp_otcfm.append(kid_gp_otcfm_tmp)\n",
    "    \n",
    "    fid_icfm.append(fid_icfm_tmp)\n",
    "    fid_otcfm.append(fid_otcfm_tmp)\n",
    "    fid_gp_icfm.append(fid_gp_icfm_tmp)\n",
    "    fid_gp_otcfm.append(fid_gp_otcfm_tmp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01c37556-60ea-4022-a98c-33e00ff0f7aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "kid_icfm_all = np.ravel(kid_icfm)\n",
    "kid_otcfm_all = np.ravel(kid_otcfm)\n",
    "kid_gp_icfm_all = np.ravel(kid_gp_icfm)\n",
    "kid_gp_otcfm_all = np.ravel(kid_gp_otcfm)\n",
    "\n",
    "fid_icfm_all = np.ravel(fid_icfm)\n",
    "fid_otcfm_all = np.ravel(fid_otcfm)\n",
    "fid_gp_icfm_all = np.ravel(fid_gp_icfm)\n",
    "fid_gp_otcfm_all = np.ravel(fid_gp_otcfm)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40a30711-190a-4ff7-8759-6196e496cc44",
   "metadata": {},
   "source": [
    "## 5.1 mean (std error)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db252c03-0f85-494f-a7ca-c6bb420837b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('icfm: {:.4f} +- {:.4f}'.format(np.mean(kid_icfm_all), np.std(kid_icfm_all)))\n",
    "print('otcfm: {:.4f} +- {:.4f}'.format(np.mean(kid_otcfm_all), np.std(kid_otcfm_all)))\n",
    "print('gp_icfm: {:.4f} +- {:.4f}'.format(np.mean(kid_gp_icfm_all), np.std(kid_gp_icfm_all)))\n",
    "print('gp_otcfm: {:.4f} +- {:.4f}'.format(np.mean(kid_gp_otcfm_all), np.std(kid_gp_otcfm_all)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0426a1f-ab86-45d5-981b-b3f5180778d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('icfm: {:.4f} +- {:.4f}'.format(np.mean(fid_icfm_all), np.std(fid_icfm_all)))\n",
    "print('otcfm: {:.4f} +- {:.4f}'.format(np.mean(fid_otcfm_all), np.std(fid_otcfm_all)))\n",
    "print('gp_icfm: {:.4f} +- {:.4f}'.format(np.mean(fid_gp_icfm_all), np.std(fid_gp_icfm_all)))\n",
    "print('gp_otcfm: {:.4f} +- {:.4f}'.format(np.mean(fid_gp_otcfm_all), np.std(fid_gp_otcfm_all)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a52dcbf3-7229-4e4b-b656-1f882d8d6b28",
   "metadata": {},
   "source": [
    "## 5.2 histograms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f8e9fda-9adc-4e46-acf2-0a7bed59b6ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_dir = \"/hpc/home/gw74/diff_model/FM/submission/plots/4_MNIST\"\n",
    "plt.rcParams['svg.fonttype'] = 'none'\n",
    "plt.rcParams['text.usetex'] = False\n",
    "plt.rcParams.update({'font.size': 14})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a33c2c08-d017-4956-b719-6bb09f563cc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['figure.figsize'] = [4, 3]\n",
    "plt.hist(kid_icfm_all, alpha = 0.5, bins = 50);\n",
    "plt.hist(kid_otcfm_all, alpha = 0.5, bins = 50);\n",
    "plt.hist(kid_gp_icfm_all, alpha = 0.5, bins = 50);\n",
    "plt.hist(kid_gp_otcfm_all, alpha = 0.5, bins = 50);\n",
    "plt.legend(['i-cfm', 'ot-cfm', 'gp-icfm', 'gp-otcfm']);\n",
    "plt.title('KID, 100 seeds');\n",
    "plt.savefig(plot_dir + \"/1_kid.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eeff9178-5ac2-4cb1-a996-5ad15aeb6fa2",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(fid_icfm_all, alpha = 0.5, bins = 50);\n",
    "plt.hist(fid_otcfm_all, alpha = 0.5, bins = 50);\n",
    "plt.hist(fid_gp_icfm_all, alpha = 0.5, bins = 50);\n",
    "plt.hist(fid_gp_otcfm_all, alpha = 0.5, bins = 50);\n",
    "plt.legend(['i-cfm', 'ot-cfm', 'gp-icfm', 'gp-otcfm']);\n",
    "plt.title('FID, 100 seeds');\n",
    "plt.savefig(plot_dir + \"/2_fid.svg\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
