{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "087660fb-44bb-457b-970b-7b2b16540d03",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torchsde\n",
    "from torchdyn.core import NeuralODE\n",
    "from torchvision import datasets, transforms\n",
    "from torchvision.transforms import ToPILImage\n",
    "from torchvision.utils import make_grid\n",
    "from tqdm import tqdm\n",
    "\n",
    "from torchcfm.conditional_flow_matching import *\n",
    "from torchcfm.models.unet import UNetModel\n",
    "from myUnetWrapper import *\n",
    "import torchdiffeq\n",
    "\n",
    "import numpy as np\n",
    "from numpy import *\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "\n",
    "from typing import List\n",
    "import time\n",
    "from torchdyn.core import NeuralODE\n",
    "\n",
    "from tqdm import tqdm\n",
    "from torch.distributions.multivariate_normal import MultivariateNormal\n",
    "import pickle\n",
    "from copy import deepcopy\n",
    "import gc\n",
    "\n",
    "root_dir = \"/hpc/group/mastatlab/gw74/HWD/\"\n",
    "use_cuda = torch.cuda.is_available()\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "torch.set_default_device(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9146159d-3383-49bb-b9ed-f6a20969ac8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import metric.pytorch_ssim\n",
    "from metric.IS_score import *\n",
    "from metric.Fid_score import *\n",
    "from torchmetrics.image.kid import KernelInceptionDistance"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f0d06e2-1091-4ded-a611-e61eba62a85d",
   "metadata": {},
   "source": [
    "# 0. Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01d5215b-c43d-4c9b-91ef-0319e393500d",
   "metadata": {},
   "outputs": [],
   "source": [
    "Images = np.load(root_dir + \"Images(28x28).npy\")\n",
    "WriterInfo = np.load(root_dir + \"WriterInfo.npy\")\n",
    "# first 2 number: digit & ID\n",
    "\n",
    "# normalize the images and send to torch and cuda\n",
    "Images = torch.tensor(Images/255).float().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2315c5ff-74c6-4362-a959-e1d214b9793a",
   "metadata": {},
   "outputs": [],
   "source": [
    "num00 = 0\n",
    "num05 = 8\n",
    "num10 = 6\n",
    "\n",
    "selId_00 = WriterInfo[:,0] == num00\n",
    "selId_05 = WriterInfo[:,0] == num05\n",
    "selId_10 = WriterInfo[:,0] == num10\n",
    "\n",
    "image_00 = Images[selId_00,:,:]\n",
    "image_05 = Images[selId_05,:,:]\n",
    "image_10 = Images[selId_10,:,:]\n",
    "id_00 = WriterInfo[selId_00,1]\n",
    "id_05 = WriterInfo[selId_05,1]\n",
    "id_10 = WriterInfo[selId_10,1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8398403f-85dd-4133-8509-75c46b228f37",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train-test split\n",
    "n_total = image_00.shape[0]\n",
    "train_prop = 0.8\n",
    "\n",
    "train_id = np.random.choice(n_total, int(n_total*train_prop), replace=False)\n",
    "test_id = np.setdiff1d(np.arange(n_total), train_id)\n",
    "\n",
    "image_00_train = image_00[train_id,:,:]\n",
    "image_05_train = image_05[train_id,:,:]\n",
    "image_10_train = image_10[train_id,:,:]\n",
    "id_00_train = id_00[train_id]\n",
    "id_05_train = id_05[train_id]\n",
    "id_10_train = id_10[train_id]\n",
    "\n",
    "image_00_test = image_00[test_id,:,:]\n",
    "image_05_test = image_05[test_id,:,:]\n",
    "image_10_test = image_10[test_id,:,:]\n",
    "id_00_test = id_00[test_id]\n",
    "id_05_test = id_05[test_id]\n",
    "id_10_test = id_10[test_id]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f23001c5-9376-47d7-87d0-8bc07c30bdbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "id_unique = unique(id_00_train)\n",
    "n_per = 10\n",
    "n_sec = int(np.ceil(id_unique.shape[0]/n_per))\n",
    "\n",
    "id_grp = []\n",
    "for ll in range(n_sec):\n",
    "    if ll < n_sec-1:\n",
    "        idx_tmp = np.arange(n_per) + ll*n_per\n",
    "        id_grp.append(id_unique[idx_tmp])\n",
    "    else:\n",
    "        idx_tmp = np.arange(n_per*ll, id_unique.shape[0])\n",
    "        id_grp.append(id_unique[idx_tmp])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "793d9fd4-7e7d-4d75-b478-984873230631",
   "metadata": {},
   "source": [
    "# 1. Functions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20aa49a6-3788-4179-9648-6891bd96d41e",
   "metadata": {},
   "source": [
    "## 1.1 ICFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4349fcc-6c7d-4667-b3b2-6f23d8dc7b69",
   "metadata": {},
   "outputs": [],
   "source": [
    "def icfm_fit(model, optimizer, d0, d1, id0, id1, n_epochs, id_grp, sigma = 0.0,\n",
    "            cond = True):\n",
    "    \n",
    "    FM = ConditionalFlowMatcher(sigma=sigma)\n",
    "    for epoch in tqdm(range(n_epochs)):\n",
    "        for ll in range(len(id_grp)):\n",
    "            optimizer.zero_grad()\n",
    "            d1_idx_tmp = np.concatenate([np.random.permutation(np.where((id1 == idx))[0]) \n",
    "                            for idx in id_grp[ll]], axis = 0)\n",
    "            d1_tmp = d1[d1_idx_tmp,:,:]\n",
    "            x1 = d1_tmp.reshape(-1,1,28,28)\n",
    "            \n",
    "            if d0 is None:\n",
    "                x0 = torch.randn_like(x1)\n",
    "            else:\n",
    "                d0_idx_tmp = np.concatenate([np.random.permutation(np.where((id0 == idx))[0]) \n",
    "                            for idx in id_grp[ll]], axis = 0)\n",
    "                d0_tmp = d0[d0_idx_tmp,:,:]\n",
    "                x0 = d0_tmp.reshape(-1,1,28,28)\n",
    "            \n",
    "            t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)\n",
    "            if cond:\n",
    "                vt = model(t, torch.cat((x0, xt), 1))\n",
    "            else:\n",
    "                vt = model(t, xt)\n",
    "            \n",
    "            loss = torch.mean((vt - ut) ** 2)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3fc2cc21-68d3-42c7-a3c3-ea44ab423335",
   "metadata": {},
   "source": [
    "## 1.2 GP-ICFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d926581f-181c-4fc4-9701-77484bb4aec2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_r(ti, tj):\n",
    "    r = ti[...,None] - tj[...,None,:]\n",
    "    r[r == 0] = 1e-15\n",
    "    return r\n",
    "def k11(r, alpha, l):\n",
    "    return (alpha**2)*torch.exp(-0.5 * ((r/l)**2))\n",
    "def k12(r, alpha, l):\n",
    "    return (alpha**2/l**2)*r*torch.exp(-0.5*((r/l)**2))\n",
    "def k22(r, alpha, l):\n",
    "    return (alpha**2/l**4)*(l**2 - r**2)*torch.exp(-0.5*((r/l)**2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30483496-4374-40b3-86fe-02da41adec63",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cov_mat2(ti, tj, alpha, l, sig2_diag = 1e-8):\n",
    "    \n",
    "    r = calc_r(ti, tj)\n",
    "    nB = r.shape[0]\n",
    "    nt = r.shape[1]\n",
    "    \n",
    "    Sig11 = k11(r, alpha, l) + (torch.eye(nt)*sig2_diag).repeat(nB,1,1)\n",
    "    Sig12 = k12(r, alpha, l)\n",
    "    Sig21 = Sig12.permute(0, 2, 1)\n",
    "    Sig22 = k22(r, alpha, l)\n",
    "    \n",
    "    block_row1 = torch.cat([Sig11, Sig12], dim=2)\n",
    "    block_row2 = torch.cat([Sig21, Sig22], dim=2)\n",
    "    Sig = torch.cat([block_row1, block_row2], dim = 1)\n",
    "    Sig = (Sig + Sig.permute(0, 2, 1))/2\n",
    "    \n",
    "    return Sig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2decc29-e572-4d79-ba51-8a469c24f9a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def samp_x_dx2(t_mat, alpha, l, x_obs, t_obs, sig2_diag=1e-8):\n",
    "    nB, nt, dim = x_obs.shape[0], t_mat.shape[1], x_obs.shape[2]\n",
    "    nt_obs = t_obs.shape[0]\n",
    "\n",
    "    # Compute necessary covariance matrices and kernel functions\n",
    "    r_obs_x = calc_r(t_obs, t_mat)\n",
    "    r_obs_obs = calc_r(t_obs, t_obs)\n",
    "    Sig_11 = cov_mat2(t_mat, t_mat, alpha, l, sig2_diag)\n",
    "    \n",
    "    # Precompute parts of the covariance matrices\n",
    "    k_obs_x, k_obs_dx = k11(r_obs_x, alpha, l), k12(r_obs_x, alpha, l)\n",
    "    Sig_21 = torch.cat([k_obs_x, k_obs_dx], dim=2)\n",
    "    Sig_12 = Sig_21.permute(0, 2, 1)\n",
    "\n",
    "    Sig_22_sing = k11(r_obs_obs, alpha, l) + torch.eye(nt_obs) * sig2_diag\n",
    "    Sig_22_inv_sing = torch.linalg.inv(Sig_22_sing)\n",
    "    Sig_22_inv = Sig_22_inv_sing.repeat(nB, 1, 1)\n",
    "\n",
    "    # Compute conditional covariance matrix with stability adjustment\n",
    "    Sig_cond = Sig_11 - torch.bmm(torch.bmm(Sig_12, Sig_22_inv), Sig_21)\n",
    "    # Sig_cond = (Sig_cond + Sig_cond.permute(0, 2, 1)) / 2 + 1e-6 * torch.eye(Sig_cond.shape[1], device=Sig_cond.device)\n",
    "\n",
    "    Sig_cond = (Sig_cond + Sig_cond.permute(0, 2, 1))/2\n",
    "    \n",
    "    svd_add_idx = torch.sum((torch.linalg.eigvals(Sig_cond).real>=0).T, axis = 0) != Sig_cond.shape[1]\n",
    "    U, S, Vh = torch.linalg.svd(Sig_cond[svd_add_idx,:,:])\n",
    "#     U, S, Vh = torch.linalg.svd(Sig_cond)\n",
    "    Sig_cond_add = torch.bmm(torch.bmm(Vh.permute(0, 2, 1), torch.diag_embed(S + 1e-8)), Vh)\n",
    "    Sig_cond[svd_add_idx,:,:] = (Sig_cond_add + Sig_cond_add.permute(0, 2, 1))/2\n",
    "\n",
    "    # Mean adjustment matrix\n",
    "    mu_A = torch.bmm(Sig_12, Sig_22_inv)\n",
    "    x_obs_batch = x_obs.reshape(nB, nt_obs, dim)\n",
    "    mu_new = torch.bmm(mu_A, x_obs_batch).reshape(nB, 2 * nt, dim)\n",
    "\n",
    "    # Initialize sample matrices\n",
    "    x_samps = torch.zeros((nB, nt, dim), dtype=x_obs.dtype, device=x_obs.device)\n",
    "    dx_samps = torch.zeros((nB, nt, dim), dtype=x_obs.dtype, device=x_obs.device)\n",
    "    \n",
    "    mu_flat = mu_new.permute(0, 2, 1).reshape(nB * dim, 2 * nt)\n",
    "    Sig_cond_flat = Sig_cond.repeat_interleave(dim, dim=0)\n",
    "    \n",
    "    # Sampling in batch for all dimensions at once\n",
    "    try:\n",
    "        # Reshape mu_new and Sig_cond for compatible shapes\n",
    "#         mu_flat = mu_new.view(nB * dim, 2 * nt)\n",
    "#         Sig_cond_flat = Sig_cond.repeat(dim, 1, 1)\n",
    "        \n",
    "        dist = MultivariateNormal(loc=mu_flat, covariance_matrix=Sig_cond_flat)\n",
    "        x_dx_samps_flat = dist.rsample().reshape(nB, dim, 2 * nt).permute(0, 2, 1)\n",
    "    except RuntimeError:\n",
    "        print('Sampling failed; using numpy fallback.')\n",
    "        x_dx_samps_flat = torch.zeros((nB, 2 * nt, dim), dtype=x_obs.dtype, device=x_obs.device)\n",
    "        for bb in range(nB):\n",
    "            for dd in range(dim):\n",
    "                mu_single = mu_new[bb, :, dd].cpu().numpy()\n",
    "                cov_single = Sig_cond[bb].cpu().numpy()\n",
    "                sample = np.random.multivariate_normal(mu_single, cov_single)\n",
    "                x_dx_samps_flat[bb, :, dd] = torch.from_numpy(sample)\n",
    "\n",
    "    # Separate x and dx samples\n",
    "    x_samps[:, :, :] = x_dx_samps_flat[:, :nt, :]\n",
    "    dx_samps[:, :, :] = x_dx_samps_flat[:, nt:, :]\n",
    "\n",
    "    return x_samps, dx_samps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b5a5778-3576-4f3f-b65e-9374dc88d394",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gp_icfm_fit(model, optimizer, d00, d05, d10,\n",
    "                id00, id05, id10,\n",
    "                n_epochs, id_grp, alpha, l, sig2_diag = 0, cond = True):\n",
    "    for epoch in tqdm(range(n_epochs)):\n",
    "        for ll in range(len(id_grp)):\n",
    "            optimizer.zero_grad()\n",
    "            \n",
    "            d10_idx_tmp = np.concatenate([np.random.permutation(np.where((id10 == idx))[0]) \n",
    "                            for idx in id_grp[ll]], axis = 0)\n",
    "            d10_tmp = d10[d10_idx_tmp,:,:]\n",
    "            x10 = d10_tmp.reshape(-1,1,28,28)\n",
    "            \n",
    "            d05_idx_tmp = np.concatenate([np.random.permutation(np.where((id05 == idx))[0]) \n",
    "                            for idx in id_grp[ll]], axis = 0)\n",
    "            d05_tmp = d05[d05_idx_tmp,:,:]\n",
    "            x05 = d05_tmp.reshape(-1,1,28,28)\n",
    "            \n",
    "            if d00 is None:\n",
    "                x00 = torch.randn_like(x10)\n",
    "            else:\n",
    "                d00_idx_tmp = np.concatenate([np.random.permutation(np.where((id00 == idx))[0]) \n",
    "                            for idx in id_grp[ll]], axis = 0)\n",
    "                d00_tmp = d00[d00_idx_tmp,:,:]\n",
    "                x00 = d00_tmp.reshape(-1,1,28,28)\n",
    "            \n",
    "            n_samp = x10.shape[0]\n",
    "            \n",
    "            xall_trans = torch.zeros(n_samp, 3, 28*28)\n",
    "            xall_trans[:,0,:] = torch.reshape(x00, (n_samp, -1))\n",
    "            xall_trans[:,1,:] = torch.reshape(x05, (n_samp, -1))\n",
    "            xall_trans[:,2,:] = torch.reshape(x10, (n_samp, -1))\n",
    "            \n",
    "            t_mat = torch.rand((n_samp,1))\n",
    "            try:\n",
    "                xt_batch, ut_batch = samp_x_dx2(t_mat, alpha, l, xall_trans,\n",
    "                                                torch.tensor([0, 0.5, 1]), sig2_diag)\n",
    "            except:\n",
    "                print('sample fail')\n",
    "                pass\n",
    "            \n",
    "            t = torch.reshape(t_mat, (-1, ))\n",
    "            xt = torch.reshape(xt_batch, (n_samp, 1, 28, 28))\n",
    "            ut = torch.reshape(ut_batch, (n_samp, 1, 28, 28))\n",
    "            \n",
    "            if cond:\n",
    "                vt = model(t, torch.cat((x00, xt), 1))\n",
    "            else:\n",
    "                vt = model(t, xt)\n",
    "\n",
    "            loss = torch.mean((vt - ut) ** 2)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf64206c-7225-4294-b948-10cc5b537a3c",
   "metadata": {},
   "source": [
    "## 1.3 fitting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "add3c7e7-6ca0-4e88-99e4-1c8cadc9f278",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_icfm_grad(model, d0, d1, id0, id1, id_grp,\n",
    "                  sigma = 0.0,\n",
    "                  lr_grad = [1e-3, 8e-4, 5e-4, 2e-4],\n",
    "                  n_epoch_grad = [100, 100, 100, 100],\n",
    "                  cond = True):\n",
    "    n_grad = len(lr_grad)\n",
    "    for ll in range(n_grad):\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr = lr_grad[ll])\n",
    "        model = icfm_fit(model, optimizer, d0, d1,\n",
    "                         id0, id1, n_epoch_grad[ll],\n",
    "                         id_grp, sigma = sigma, cond = cond)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35260547-7ff9-441d-bf0c-ab725dbace8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_gp_icfm_grad(model, d00, d05, d10,\n",
    "                    id00, id05, id10, id_grp, alpha, l, sig2_diag = 0,\n",
    "                    lr_grad = [1e-3, 8e-4, 5e-4, 2e-4],\n",
    "                    n_epoch_grad = [100, 100, 100, 100],\n",
    "                    cond = True):\n",
    "    n_grad = len(lr_grad)\n",
    "    for ll in range(n_grad):\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr = lr_grad[ll])\n",
    "        \n",
    "        model = gp_icfm_fit(model, optimizer, d00, d05, d10,\n",
    "                            id00, id05, id10,\n",
    "                            n_epoch_grad[ll], id_grp, alpha, l,\n",
    "                            sig2_diag = sig2_diag, cond = cond)\n",
    "        \n",
    "    return model  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3270439-0768-408d-afe9-7809f96f544b",
   "metadata": {},
   "source": [
    "## 1.4 Plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63b705ba-6016-4449-84b4-8680cf71249b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_samp(x0, model, device, cond = True):\n",
    "    \n",
    "    if cond:\n",
    "        traj = torchdiffeq.odeint(\n",
    "            lambda t, x: model.forward(t, torch.cat((x0, x), 1)),\n",
    "            x0,\n",
    "            torch.linspace(0, 1, 2, device=device),\n",
    "            atol=1e-4,\n",
    "            rtol=1e-4,\n",
    "            method=\"dopri5\",\n",
    "        )\n",
    "    else:\n",
    "        traj = torchdiffeq.odeint(\n",
    "            lambda t, x: model.forward(t, x),\n",
    "            x0,\n",
    "            torch.linspace(0, 1, 2, device=device),\n",
    "            atol=1e-4,\n",
    "            rtol=1e-4,\n",
    "            method=\"dopri5\",\n",
    "        )\n",
    "    \n",
    "    samp_out = traj[-1,:,:,:,:].clip(0,1)\n",
    "    \n",
    "    return samp_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee290d63-7a48-437e-8b39-4d6fb5a76651",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_grid(x0_gen, model, device, n_sec = 5, cond = True):\n",
    "    \n",
    "    if cond:\n",
    "        traj = torchdiffeq.odeint(\n",
    "            lambda t, x: model.forward(t, torch.cat((x0_gen, x), 1)),\n",
    "            x0_gen,\n",
    "            torch.linspace(0, 1, n_sec, device=device),\n",
    "            atol=1e-4,\n",
    "            rtol=1e-4,\n",
    "            method=\"dopri5\",\n",
    "        )\n",
    "    else:\n",
    "        traj = torchdiffeq.odeint(\n",
    "            lambda t, x: model.forward(t, x),\n",
    "            x0_gen,\n",
    "            torch.linspace(0, 1, n_sec, device=device),\n",
    "            atol=1e-4,\n",
    "            rtol=1e-4,\n",
    "            method=\"dopri5\",\n",
    "        )\n",
    "    \n",
    "    grid = make_grid(\n",
    "        traj.view([-1, 1, 28, 28]).clip(0, 1), value_range=(0, 1), padding=0, nrow=10\n",
    "    )\n",
    "    img = ToPILImage()(grid)\n",
    "    plt.imshow(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da90a5a9-bf0c-4a7a-b77b-5d2e5abaa78e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_grid_comb(x0_gen, model0, model1, device, cond = True):\n",
    "    \n",
    "    if cond:\n",
    "        traj0 = torchdiffeq.odeint(\n",
    "            lambda t, x: model0.forward(t, torch.cat((x0_gen, x), 1)),\n",
    "            x0_gen,\n",
    "            torch.linspace(0, 1, 5, device=device),\n",
    "            atol=1e-4,\n",
    "            rtol=1e-4,\n",
    "            method=\"dopri5\",\n",
    "        )\n",
    "        \n",
    "        traj1 = torchdiffeq.odeint(\n",
    "            lambda t, x: model1.forward(t, torch.cat((traj0[-1,:].view([-1, 1, 28, 28]).clip(0, 1), x), 1)),\n",
    "            traj0[-1,:].view([-1, 1, 28, 28]).clip(0, 1),\n",
    "            torch.linspace(0, 1, 5, device=device),\n",
    "            atol=1e-4,\n",
    "            rtol=1e-4,\n",
    "            method=\"dopri5\",\n",
    "        )\n",
    "        \n",
    "    else:\n",
    "        traj0 = torchdiffeq.odeint(\n",
    "            lambda t, x: model0.forward(t, x),\n",
    "            x0_gen,\n",
    "            torch.linspace(0, 1, 5, device=device),\n",
    "            atol=1e-4,\n",
    "            rtol=1e-4,\n",
    "            method=\"dopri5\",\n",
    "        )\n",
    "        \n",
    "        traj1 = torchdiffeq.odeint(\n",
    "            lambda t, x: model1.forward(t, x),\n",
    "            traj0[-1,:].view([-1, 1, 28, 28]).clip(0, 1),\n",
    "            torch.linspace(0, 1, 5, device=device),\n",
    "            atol=1e-4,\n",
    "            rtol=1e-4,\n",
    "            method=\"dopri5\",\n",
    "        )\n",
    "    \n",
    "    \n",
    "    traj0_trans = traj0.view([-1, 1, 28, 28]).clip(0, 1)\n",
    "    traj1_trans = traj1[1:,:,:,:,:].view([-1, 1, 28, 28]).clip(0, 1)    \n",
    "    traj_cat = torch.cat((traj0_trans, traj1_trans), 0) \n",
    "\n",
    "    grid = make_grid(\n",
    "        traj_cat, value_range=(0, 1), padding=0, nrow=10\n",
    "    )\n",
    "\n",
    "    img = ToPILImage()(grid)\n",
    "    plt.imshow(img)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d95d6385-01c1-4d0e-81b4-91f5fdb86958",
   "metadata": {},
   "source": [
    "## 1.5 evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5640a59-f60b-437e-a376-1da08d7640b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_1_traj(x0_gen, model, device, n_sec = 5, cond = True):\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n",
    "    if cond:\n",
    "        traj = torchdiffeq.odeint(\n",
    "            lambda t, x: model.forward(t, torch.cat((x0_gen, x), 1)),\n",
    "            x0_gen,\n",
    "            torch.linspace(0, 1, n_sec, device=device),\n",
    "            atol=1e-4,\n",
    "            rtol=1e-4,\n",
    "            method=\"dopri5\",\n",
    "        )\n",
    "    else:\n",
    "        traj = torchdiffeq.odeint(\n",
    "            lambda t, x: model.forward(t, x),\n",
    "            x0_gen,\n",
    "            torch.linspace(0, 1, n_sec, device=device),\n",
    "            atol=1e-4,\n",
    "            rtol=1e-4,\n",
    "            method=\"dopri5\",\n",
    "        )\n",
    "    \n",
    "    traj = traj.clip(0, 1)\n",
    "    \n",
    "    return traj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66c60fd4-0f1b-488d-81f1-cf21667825c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fid_calc(all_images, test_data):\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n",
    "    up = nn.Upsample(size=(299, 299), mode='bilinear').type(torch.cuda.FloatTensor)\n",
    "    \n",
    "    \n",
    "    # all_images = up(torch.Tensor(all_images).cuda(0)).cpu().numpy()\n",
    "    all_images = up(torch.Tensor(all_images)).cpu().numpy()\n",
    "    all_images = np.transpose(all_images,(0,2,3,1))\n",
    "    all_images = np.repeat(all_images,3,axis=3)\n",
    "    \n",
    "    real_image = np.repeat(test_data,3,axis=1)\n",
    "    # real_image=up(real_image.cuda(0)).cpu().numpy()\n",
    "    real_image=up(real_image).cpu().numpy()\n",
    "    real_images=np.transpose(real_image,(0,2,3,1))\n",
    "    \n",
    "    Fid = calculate_fid(all_images, real_images, use_multiprocessing=False, batch_size=4)\n",
    "    \n",
    "    return Fid"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "beaf457f-28f8-47fe-8727-609575d89643",
   "metadata": {},
   "source": [
    "# 2. Fitting"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c39a8bb-1531-4a1c-a900-259fb95333f3",
   "metadata": {},
   "source": [
    "## 2.0 noise to $x_0$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3dab4c8e-a7ab-4512-9712-4be530824de0",
   "metadata": {},
   "outputs": [],
   "source": [
    "model0 = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)\n",
    "model0 = fit_icfm_grad(model0, None, image_00_train, id_00_train, id_05_train, id_grp,\n",
    "              sigma = 0.0,\n",
    "              lr_grad = [1e-3, 8e-4, 5e-4, 2e-4, 8e-5, 2e-5, 8e-6, 2e-6],\n",
    "              n_epoch_grad = [200, 200, 200, 200, 100, 100, 100, 100],\n",
    "              cond = False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "154cf145-7c07-44ee-8855-4cbe41331dbd",
   "metadata": {},
   "source": [
    "## 2.1 ICFM"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f668de3d-96ec-4651-b24c-79e3d0561f67",
   "metadata": {},
   "source": [
    "$x_0$ to $x_{0.5}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92b64d7d-6df4-437f-9f01-b2097b32c207",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture output\n",
    "# unconditional\n",
    "model_icfm_uncond0 = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)\n",
    "model_icfm_uncond0 = fit_icfm_grad(model_icfm_uncond0, image_00_train, image_05_train,\n",
    "                                   id_00_train, id_05_train, id_grp,\n",
    "                                   sigma = 0.0,\n",
    "                                   lr_grad = [8e-5, 2e-5, 8e-6, 2e-6],\n",
    "                                   n_epoch_grad = [100, 100, 100, 100],\n",
    "                                   cond = False)\n",
    "\n",
    "# conditional\n",
    "model_icfm_cond0 = UNetModelWrapper2(dim=(1, 28, 28), in_channels = 2,\n",
    "                                     out_channels = 1,\n",
    "                                     num_channels=32, num_res_blocks=1).to(device)\n",
    "model_icfm_cond0 = fit_icfm_grad(model_icfm_cond0, image_00_train, image_05_train,\n",
    "                                 id_00_train, id_05_train, id_grp,\n",
    "                                 sigma = 0.0,\n",
    "                                 lr_grad = [8e-5, 2e-5, 8e-6, 2e-6],\n",
    "                                 n_epoch_grad = [100, 100, 100, 100],\n",
    "                                 cond = True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ff37013-274f-4a92-8447-0c90e21054d3",
   "metadata": {},
   "source": [
    "$x_{0.5}$ to $x_1$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9422985-35da-47d9-8c8e-5fb20b1fbe17",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture output\n",
    "# unconditional\n",
    "model_icfm_uncond1 = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)\n",
    "model_icfm_uncond1 = fit_icfm_grad(model_icfm_uncond1, image_05_train, image_10_train,\n",
    "                                   id_05_train, id_10_train, id_grp,\n",
    "                                   sigma = 0.0,\n",
    "                                   lr_grad = [8e-5, 2e-5, 8e-6, 2e-6],\n",
    "                                   n_epoch_grad = [100, 100, 100, 100],\n",
    "                                   cond = False)\n",
    "\n",
    "# conditional\n",
    "model_icfm_cond1 = UNetModelWrapper2(dim=(1, 28, 28), in_channels = 2,\n",
    "                                     out_channels = 1,\n",
    "                                     num_channels=32, num_res_blocks=1).to(device)\n",
    "model_icfm_cond1 = fit_icfm_grad(model_icfm_cond1, image_05_train, image_10_train,\n",
    "                                 id_05_train, id_10_train, id_grp,\n",
    "                                 sigma = 0.0,\n",
    "                                 lr_grad = [8e-5, 2e-5, 8e-6, 2e-6],\n",
    "                                 n_epoch_grad = [100, 100, 100, 100],\n",
    "                                 cond = True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad1e307c-f90d-4067-8e24-643a90cfbe98",
   "metadata": {},
   "source": [
    "## 2.2 GP-ICFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd188c52-dfee-417c-af12-2830c270c703",
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 1\n",
    "l = 6 # 6\n",
    "sig2_diag = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef192eef-e281-49fb-87da-5d8f9d039366",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture output\n",
    "# unconditional\n",
    "model_gp_icfm_uncond = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)\n",
    "model_gp_icfm_uncond = fit_gp_icfm_grad(model_gp_icfm_uncond,\n",
    "                                        image_00_train, image_05_train, image_10_train,\n",
    "                                        id_00_train, id_05_train, id_10_train, id_grp,\n",
    "                                        alpha, l, sig2_diag = sig2_diag,\n",
    "                                        lr_grad = [8e-5, 2e-5, 8e-6, 2e-6],\n",
    "                                        n_epoch_grad = [100, 100, 100, 100],\n",
    "                                        cond = False)\n",
    "\n",
    "# conditional\n",
    "model_gp_icfm_cond = UNetModelWrapper2(dim=(1, 28, 28), in_channels = 2,\n",
    "                                       out_channels = 1,\n",
    "                                       num_channels=32, num_res_blocks=1).to(device)\n",
    "model_gp_icfm_cond = fit_gp_icfm_grad(model_gp_icfm_cond,\n",
    "                                      image_00_train, image_05_train, image_10_train,\n",
    "                                      id_00_train, id_05_train, id_10_train, id_grp,\n",
    "                                      alpha, l, sig2_diag = sig2_diag,\n",
    "                                      lr_grad = [8e-5, 2e-5, 8e-6, 2e-6],\n",
    "                                      n_epoch_grad = [100, 100, 100, 100],\n",
    "                                      cond = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01a90b43-01ba-432a-a35e-d28967991d51",
   "metadata": {},
   "outputs": [],
   "source": [
    "rootFolder = \"/hpc/group/mastatlab/gw74/HWD\"\n",
    "\n",
    "# torch.save(model0.state_dict(), rootFolder + \"/initial0.pt\")\n",
    "\n",
    "# torch.save(model_icfm_uncond0.state_dict(), rootFolder + \"/icfm_uncond0.pt\")\n",
    "# torch.save(model_icfm_cond0.state_dict(), rootFolder + \"/icfm_cond0.pt\")\n",
    "# torch.save(model_icfm_uncond1.state_dict(), rootFolder + \"/icfm_uncond1.pt\")\n",
    "# torch.save(model_icfm_cond1.state_dict(), rootFolder + \"/icfm_cond1.pt\")\n",
    "# torch.save(model_gp_icfm_uncond.state_dict(), rootFolder + \"/gp_icfm_uncond.pt\")\n",
    "# torch.save(model_gp_icfm_cond.state_dict(), rootFolder + \"/gp_icfm_cond.pt\")\n",
    "\n",
    "# model0.load_state_dict(torch.load(rootFolder + \"/initial0.pt\"))\n",
    "# model_icfm_uncond0.load_state_dict(torch.load(rootFolder + \"/icfm_uncond0.pt\"))\n",
    "# model_icfm_uncond1.load_state_dict(torch.load(rootFolder + \"/icfm_uncond1.pt\"))\n",
    "# model_icfm_cond0.load_state_dict(torch.load(rootFolder + \"/icfm_cond0.pt\"))\n",
    "# model_icfm_cond1.load_state_dict(torch.load(rootFolder + \"/icfm_cond1.pt\"))\n",
    "# model_gp_icfm_uncond.load_state_dict(torch.load(rootFolder + \"/gp_icfm_uncond.pt\"))\n",
    "# model_gp_icfm_cond.load_state_dict(torch.load(rootFolder + \"/gp_icfm_cond.pt\"))\n",
    "\n",
    "\n",
    "# CPU version, in case GPU is not available\n",
    "model0.load_state_dict(torch.load(rootFolder + \"/initial0.pt\", map_location=torch.device('cpu')))\n",
    "model_icfm_uncond0.load_state_dict(torch.load(rootFolder + \"/icfm_uncond0.pt\", map_location=torch.device('cpu')))\n",
    "model_icfm_uncond1.load_state_dict(torch.load(rootFolder + \"/icfm_uncond1.pt\", map_location=torch.device('cpu')))\n",
    "model_icfm_cond0.load_state_dict(torch.load(rootFolder + \"/icfm_cond0.pt\", map_location=torch.device('cpu')))\n",
    "model_icfm_cond1.load_state_dict(torch.load(rootFolder + \"/icfm_cond1.pt\", map_location=torch.device('cpu')))\n",
    "model_gp_icfm_uncond.load_state_dict(torch.load(rootFolder + \"/gp_icfm_uncond.pt\", map_location=torch.device('cpu')))\n",
    "model_gp_icfm_cond.load_state_dict(torch.load(rootFolder + \"/gp_icfm_cond.pt\", map_location=torch.device('cpu')))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73580a01-b14f-4af7-a7a6-7deea9d0352a",
   "metadata": {},
   "source": [
    "# 3. Plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7d47a45-9ff9-47ec-b282-adce844b3963",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_dir = \"/hpc/home/gw74/diff_model/FM/submission/plots/5_HWD\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6e8fc59-d03d-4cf9-b809-95e171b444c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "gc.collect()\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "torch.manual_seed(0)\n",
    "x0_noise = torch.randn(10, 1, 28, 28, device=device)\n",
    "x0_gen = gen_samp(x0_noise, model0, device, cond = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "516da60e-3e7e-4044-9ff6-94f61069e2de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# unconditional ICFM\n",
    "# model_icfm_uncond0.load_state_dict(torch.load(rootFolder + \"/icfm_uncond0.pt\"))\n",
    "# model_icfm_uncond1.load_state_dict(torch.load(rootFolder + \"/icfm_uncond1.pt\"))\n",
    "plot_grid_comb(x0_gen, model_icfm_uncond0, model_icfm_uncond1, device, cond = False);\n",
    "plt.xticks([]);\n",
    "plt.yticks(np.arange(0, 9*28, 28)+ 15, np.arange(0, 1.125, 0.125));\n",
    "plt.savefig(plot_dir + \"/icfm_uncond.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f47723c-a5c8-4d9d-a2d7-638707380732",
   "metadata": {},
   "outputs": [],
   "source": [
    "# conditional ICFM\n",
    "# model_icfm_cond0.load_state_dict(torch.load(rootFolder + \"/icfm_cond0.pt\"))\n",
    "# model_icfm_cond1.load_state_dict(torch.load(rootFolder + \"/icfm_cond1.pt\"))\n",
    "plot_grid_comb(x0_gen, model_icfm_cond0, model_icfm_cond1, device, cond = True)\n",
    "plt.xticks([]);\n",
    "plt.yticks(np.arange(0, 9*28, 28)+ 15, np.arange(0, 1.125, 0.125));\n",
    "plt.savefig(plot_dir + \"/icfm_cond.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20b9aded-8f2b-4562-a97a-fa92ef735d9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# unconditional GP-ICFM\n",
    "# model_gp_icfm_uncond.load_state_dict(torch.load(rootFolder + \"/gp_icfm_uncond.pt\"))\n",
    "plot_grid(x0_gen, model_gp_icfm_uncond, device, n_sec = 9, cond = False)\n",
    "plt.xticks([]);\n",
    "plt.yticks(np.arange(0, 9*28, 28)+ 15, np.arange(0, 1.125, 0.125));\n",
    "plt.savefig(plot_dir + \"/gp_icfm_uncond.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bc82e68-6ac8-4d2a-9607-1ded33b82147",
   "metadata": {},
   "outputs": [],
   "source": [
    "# conditional GP-ICFM\n",
    "# model_gp_icfm_cond.load_state_dict(torch.load(rootFolder + \"/gp_icfm_cond.pt\"))\n",
    "plot_grid(x0_gen, model_gp_icfm_cond, device, n_sec = 9, cond = True)\n",
    "plt.xticks([]);\n",
    "plt.yticks(np.arange(0, 9*28, 28)+ 15, np.arange(0, 1.125, 0.125));\n",
    "plt.savefig(plot_dir + \"/gp_icfm_cond.svg\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "550bb345-91cd-468d-a724-12d35ddd03f9",
   "metadata": {},
   "source": [
    "# 4. FID"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97c49c43-6a42-43ba-8e0a-091afa6e0166",
   "metadata": {},
   "source": [
    "## 4.1 generate samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e99d2a4-b8be-4e5d-9be5-8d9057d9415a",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_eval = 32 # 32\n",
    "rep_eval = 10 # 40 # 10 is enough, total test size = 272\n",
    "\n",
    "torch.set_default_device('cpu')\n",
    "device_def = torch.get_default_device() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d81193c1-ef8b-4be4-8472-2e9277bfd5c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%capture output\n",
    "# generate starting points\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "init_list = []\n",
    "trace_list_icfm_uncond = [] # 9 grids\n",
    "trace_list_icfm_cond = [] # 9 grids\n",
    "\n",
    "trace_list_gpicfm_uncond = [] # 9 grids\n",
    "trace_list_gpicfm_cond = [] # 9 grids\n",
    "\n",
    "for ii in tqdm(range(rep_eval)):\n",
    "    \n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n",
    "    try:\n",
    "        x0_noise = torch.randn(batch_eval, 1, 28, 28, device=device)\n",
    "        x0_tm = gen_samp(x0_noise, model0.to(device), device, cond = False)\n",
    "        \n",
    "    except:\n",
    "        x0_noise = torch.randn(batch_eval, 1, 28, 28)\n",
    "        x0_tm = gen_samp(x0_noise.to(device_def), model0.to(device_def),\n",
    "                         device_def, cond = False)\n",
    "    init_list.append(x0_tm)\n",
    "    \n",
    "    # icfm_unconditional\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n",
    "    try:\n",
    "        traj0_uncond_tmp = gen_1_traj(x0_tm.to(device), model_icfm_uncond0.to(device), device,\n",
    "                                  n_sec = 5, cond = False)\n",
    "        x_05_tmp = traj0_uncond_tmp[-1,:,:,:,:]\n",
    "        traj1_uncond_tmp = gen_1_traj(x_05_tmp, model_icfm_uncond1.to(device), device,\n",
    "                                  n_sec = 5, cond = False)\n",
    "    except:\n",
    "        traj0_uncond_tmp = gen_1_traj(x0_tm.to(device_def), model_icfm_uncond0.to(device_def), device_def,\n",
    "                                      n_sec = 5, cond = False)\n",
    "        x_05_tmp = traj0_uncond_tmp[-1,:,:,:,:]\n",
    "        traj1_uncond_tmp = gen_1_traj(x_05_tmp, model_icfm_uncond1.to(device_def), device_def,\n",
    "                                      n_sec = 5, cond = False)\n",
    "    trace_icfm_uncond_tmp = torch.cat((traj0_uncond_tmp,\n",
    "                                       traj1_uncond_tmp[1:,:,:,:,:]), axis = 0).detach().cpu().numpy()\n",
    "    \n",
    "    # icfm conditional\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n",
    "    try:\n",
    "        traj0_cond_tmp = gen_1_traj(x0_tm.to(device), model_icfm_cond0.to(device), device,\n",
    "                                n_sec = 5, cond = True)\n",
    "        x_05_tmp = traj0_cond_tmp[-1,:,:,:,:]\n",
    "        traj1_cond_tmp = gen_1_traj(x_05_tmp, model_icfm_cond1.to(device), device,\n",
    "                                    n_sec = 5, cond = True)\n",
    "    except:\n",
    "        traj0_cond_tmp = gen_1_traj(x0_tm.to(device_def), model_icfm_cond0.to(device_def), device_def,\n",
    "                                    n_sec = 5, cond = True)\n",
    "        x_05_tmp = traj0_cond_tmp[-1,:,:,:,:]\n",
    "        traj1_cond_tmp = gen_1_traj(x_05_tmp, model_icfm_cond1.to(device_def), device_def,\n",
    "                                    n_sec = 5, cond = True)\n",
    "    trace_icfm_cond_tmp = torch.cat((traj0_cond_tmp,\n",
    "                                     traj1_cond_tmp[1:,:,:,:,:]), axis = 0).detach().cpu().numpy()\n",
    "    \n",
    "    # gp-icfm unconditional\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n",
    "    try:\n",
    "        traj_gp_uncond_tmp = gen_1_traj(x0_tm.to(device), model_gp_icfm_uncond.to(device), device,\n",
    "                                    n_sec = 9, cond = False).detach().cpu().numpy()\n",
    "    except:\n",
    "        traj_gp_uncond_tmp = gen_1_traj(x0_tm.to(device_def), model_gp_icfm_uncond.to(device_def), device_def,\n",
    "                                        n_sec = 9, cond = False).detach().cpu().numpy()\n",
    "    \n",
    "    # gp-icfm conditional\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n",
    "    try:\n",
    "        traj_gp_cond_tmp = gen_1_traj(x0_tm.to(device), model_gp_icfm_cond.to(device), device,\n",
    "                                  n_sec = 9, cond = True).detach().cpu().numpy()\n",
    "    except:\n",
    "        traj_gp_cond_tmp = gen_1_traj(x0_tm.to(device_def), model_gp_icfm_cond.to(device_def), device_def,\n",
    "                                      n_sec = 9, cond = True).detach().cpu().numpy()\n",
    "    \n",
    "    if ii == 0:\n",
    "        for jj in range(9):\n",
    "            trace_list_icfm_uncond.append(trace_icfm_uncond_tmp[jj,:])\n",
    "            trace_list_icfm_cond.append(trace_icfm_cond_tmp[jj,:])\n",
    "            trace_list_gpicfm_uncond.append(traj_gp_uncond_tmp[jj,:])\n",
    "            trace_list_gpicfm_cond.append(traj_gp_cond_tmp[jj,:])\n",
    "    else:\n",
    "        for jj in range(9):\n",
    "            trace_list_icfm_uncond[jj] = np.concatenate((trace_list_icfm_uncond[jj],\n",
    "                                                         trace_icfm_uncond_tmp[jj,:]),axis=0)\n",
    "            trace_list_icfm_cond[jj] = np.concatenate((trace_list_icfm_cond[jj],\n",
    "                                                         trace_icfm_cond_tmp[jj,:]),axis=0)\n",
    "            trace_list_gpicfm_uncond[jj] = np.concatenate((trace_list_gpicfm_uncond[jj],\n",
    "                                                         traj_gp_uncond_tmp[jj,:]),axis=0)\n",
    "            trace_list_gpicfm_cond[jj] = np.concatenate((trace_list_gpicfm_cond[jj],\n",
    "                                                         traj_gp_cond_tmp[jj,:]),axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ec7be47-7e53-42b0-a395-003ebcf88044",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "image_list_icfm_uncond = [] # 9 grids\n",
    "image_list_icfm_cond = [] # 9 grids\n",
    "image_list_gpicfm_uncond = [] # 9 grids\n",
    "image_list_gpicfm_cond = [] # 9 grids\n",
    "\n",
    "for ii in range(1, 9):\n",
    "    \n",
    "    with open(\"trace_list_icfm_uncond_\" + str(ii), \"rb\") as fp: trace_list_icfm_uncond_tmp = pickle.load(fp);  \n",
    "    with open(\"trace_list_icfm_cond_\" + str(ii), \"rb\") as fp: trace_list_icfm_cond_tmp = pickle.load(fp);\n",
    "    with open(\"trace_list_gpicfm_uncond_\" + str(ii), \"rb\") as fp: trace_list_gpicfm_uncond_tmp = pickle.load(fp);\n",
    "    with open(\"trace_list_gpicfm_cond_\" + str(ii), \"rb\") as fp: trace_list_gpicfm_cond_tmp = pickle.load(fp);\n",
    "    \n",
    "    for jj in range(9):\n",
    "        \n",
    "        if ii == 1:\n",
    "            image_list_icfm_uncond.append(trace_list_icfm_uncond_tmp[jj])\n",
    "            image_list_icfm_cond.append(trace_list_icfm_cond_tmp[jj])\n",
    "            image_list_gpicfm_uncond.append(trace_list_gpicfm_uncond_tmp[jj])\n",
    "            image_list_gpicfm_cond.append(trace_list_gpicfm_cond_tmp[jj])\n",
    "        else:\n",
    "            image_list_icfm_uncond[jj] = np.concatenate((image_list_icfm_uncond[jj],\n",
    "                                                         trace_list_icfm_uncond_tmp[jj]),axis=0)\n",
    "            image_list_icfm_cond[jj] = np.concatenate((image_list_icfm_cond[jj],\n",
    "                                                         trace_list_icfm_cond_tmp[jj]),axis=0)\n",
    "            image_list_gpicfm_uncond[jj] = np.concatenate((image_list_gpicfm_uncond[jj],\n",
    "                                                         trace_list_gpicfm_uncond_tmp[jj]),axis=0)\n",
    "            image_list_gpicfm_cond[jj] = np.concatenate((image_list_gpicfm_cond[jj],\n",
    "                                                         trace_list_gpicfm_cond_tmp[jj]),axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de5269fa-7656-4239-8d52-3bbbe7a44ec5",
   "metadata": {},
   "source": [
    "## 4.2 calculate FID"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09ef7d01-f8a5-4943-8053-e0a3d45aa36e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fid_calc(all_images_raw, test_data):\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n",
    "    up = nn.Upsample(size=(299, 299), mode='bilinear').type(torch.cuda.FloatTensor)\n",
    "    \n",
    "    all_images = all_images_raw[272:(test_data.shape[0] + 272),:,:,:]\n",
    "    # all_images = up(torch.Tensor(all_images).cuda(0)).cpu().numpy()\n",
    "    all_images = up(torch.Tensor(all_images)).cpu().numpy()\n",
    "    all_images = np.transpose(all_images,(0,2,3,1))\n",
    "    all_images = np.repeat(all_images,3,axis=3)\n",
    "    \n",
    "    real_image = np.repeat(test_data,3,axis=1)\n",
    "    # real_image=up(real_image.cuda(0)).cpu().numpy()\n",
    "    real_image=up(real_image).cpu().numpy()\n",
    "    real_images=np.transpose(real_image,(0,2,3,1))\n",
    "    \n",
    "    Fid = calculate_fid(all_images, real_images, use_multiprocessing=False, batch_size=4)\n",
    "    \n",
    "    return Fid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c2ce40b-d01d-4623-9628-d83372b39830",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(0)\n",
    "n_tot = image_list_icfm_uncond[0].shape[0]\n",
    "n_test = image_00_test.shape[0]\n",
    "img_idx = np.random.choice(n_tot, size=n_test, replace=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "baacf09d-e69d-4bc0-a55c-0dfbfb50b410",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture output\n",
    "fid_00_icfm_uncond = np.zeros(9)\n",
    "fid_00_icfm_cond = np.zeros(9)\n",
    "fid_00_gp_uncond = np.zeros(9)\n",
    "fid_00_gp_cond = np.zeros(9)\n",
    "\n",
    "fid_05_icfm_uncond = np.zeros(9)\n",
    "fid_05_icfm_cond = np.zeros(9)\n",
    "fid_05_gp_uncond = np.zeros(9)\n",
    "fid_05_gp_cond = np.zeros(9)\n",
    "\n",
    "fid_10_icfm_uncond = np.zeros(9)\n",
    "fid_10_icfm_cond = np.zeros(9)\n",
    "fid_10_gp_uncond = np.zeros(9)\n",
    "fid_10_gp_cond = np.zeros(9)\n",
    "\n",
    "for ll in tqdm(range(9)):\n",
    "    \n",
    "    fid_00_icfm_uncond[ll] = fid_calc(image_list_icfm_uncond[ll][img_idx,:,:,:], image_00_test.view([-1, 1, 28, 28]))\n",
    "    fid_00_icfm_cond[ll] = fid_calc(image_list_icfm_cond[ll][img_idx,:,:,:], image_00_test.view([-1, 1, 28, 28]))\n",
    "    fid_00_gp_uncond[ll] = fid_calc(image_list_gpicfm_uncond[ll][img_idx,:,:,:], image_00_test.view([-1, 1, 28, 28]))\n",
    "    fid_00_gp_cond[ll] = fid_calc(image_list_gpicfm_cond[ll][img_idx,:,:,:], image_00_test.view([-1, 1, 28, 28]))\n",
    "\n",
    "    fid_05_icfm_uncond[ll] = fid_calc(image_list_icfm_uncond[ll][img_idx,:,:,:], image_05_test.view([-1, 1, 28, 28]))\n",
    "    fid_05_icfm_cond[ll] = fid_calc(image_list_icfm_cond[ll][img_idx,:,:,:], image_05_test.view([-1, 1, 28, 28]))\n",
    "    fid_05_gp_uncond[ll] = fid_calc(image_list_gpicfm_uncond[ll][img_idx,:,:,:], image_05_test.view([-1, 1, 28, 28]))\n",
    "    fid_05_gp_cond[ll] = fid_calc(image_list_gpicfm_cond[ll][img_idx,:,:,:], image_05_test.view([-1, 1, 28, 28]))\n",
    "\n",
    "    fid_10_icfm_uncond[ll] = fid_calc(image_list_icfm_uncond[ll][img_idx,:,:,:], image_10_test.view([-1, 1, 28, 28]))\n",
    "    fid_10_icfm_cond[ll] = fid_calc(image_list_icfm_cond[ll][img_idx,:,:,:], image_10_test.view([-1, 1, 28, 28]))\n",
    "    fid_10_gp_uncond[ll] = fid_calc(image_list_gpicfm_uncond[ll][img_idx,:,:,:], image_10_test.view([-1, 1, 28, 28]))\n",
    "    fid_10_gp_cond[ll] = fid_calc(image_list_gpicfm_cond[ll][img_idx,:,:,:], image_10_test.view([-1, 1, 28, 28]))\n",
    "\n",
    "fid_all = np.column_stack((fid_00_icfm_uncond, fid_00_icfm_cond, fid_00_gp_uncond, fid_00_gp_cond,\n",
    "           fid_05_icfm_uncond, fid_05_icfm_cond, fid_05_gp_uncond, fid_05_gp_cond,\n",
    "           fid_10_icfm_uncond, fid_10_icfm_cond, fid_10_gp_uncond, fid_10_gp_cond))\n",
    "np.savetxt(\"fid.csv\", fid_all, delimiter=\",\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba059c50-eb76-42f9-861f-0ca1ef5d4f0e",
   "metadata": {},
   "source": [
    "## 4.3 plot results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dcf1f69-c93b-4c64-8012-17cd2cbf9a0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_dir = \"/hpc/home/gw74/diff_model/FM/submission/plots/5_HWD\"\n",
    "plt.rcParams['svg.fonttype'] = 'none'\n",
    "plt.rcParams['text.usetex'] = False\n",
    "plt.rcParams.update({'font.size': 14})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b86a2b7c-c1e3-43c9-accc-c577157bc2eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(fid_00_icfm_uncond);\n",
    "plt.plot(fid_00_icfm_cond);\n",
    "plt.plot(fid_00_gp_uncond);\n",
    "plt.plot(fid_00_gp_cond);\n",
    "plt.title(\"FID to '0'\")\n",
    "plt.legend(['ICFM-uncond', 'ICFM-cond', 'GP-ICFM-uncond', 'GP-ICFM-cond'])\n",
    "plt.savefig(plot_dir + \"/fid0.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a887443-b19f-405c-b855-bfe2958c65b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(fid_05_icfm_uncond);\n",
    "plt.plot(fid_05_icfm_cond);\n",
    "plt.plot(fid_05_gp_uncond);\n",
    "plt.plot(fid_05_gp_cond);\n",
    "plt.title(\"FID to '8'\")\n",
    "plt.legend(['ICFM-uncond', 'ICFM-cond', 'GP-ICFM-uncond', 'GP-ICFM-cond'])\n",
    "plt.savefig(plot_dir + \"/fid8.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04b0ce81-001d-4f19-ba16-a9792f0dced8",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(fid_10_icfm_uncond);\n",
    "plt.plot(fid_10_icfm_cond);\n",
    "plt.plot(fid_10_gp_uncond);\n",
    "plt.plot(fid_10_gp_cond);\n",
    "plt.title(\"FID to '6'\");\n",
    "plt.legend(['ICFM-uncond', 'ICFM-cond', 'GP-ICFM-uncond', 'GP-ICFM-cond'])\n",
    "plt.savefig(plot_dir + \"/fid6.svg\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
