{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9abc68e1-3b5f-4fe7-b766-0c41d46e7627",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torchsde\n",
    "from torchdyn.core import NeuralODE\n",
    "from torchvision import datasets, transforms\n",
    "from torchvision.transforms import ToPILImage\n",
    "from torchvision.utils import make_grid\n",
    "from tqdm import tqdm\n",
    "\n",
    "from torchcfm.conditional_flow_matching import *\n",
    "from torchcfm.models.unet import UNetModel\n",
    "from myUnetWrapper import *\n",
    "import torchdiffeq\n",
    "\n",
    "import numpy as np\n",
    "from numpy import *\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "\n",
    "from typing import List\n",
    "import time\n",
    "from torchdyn.core import NeuralODE\n",
    "\n",
    "from tqdm import tqdm\n",
    "from torch.distributions.multivariate_normal import MultivariateNormal\n",
    "import pickle\n",
    "from copy import deepcopy\n",
    "import gc\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c12b89ea-f340-4587-b646-4608761bb4d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import os\n",
    "\n",
    "import torch\n",
    "from absl import app, flags\n",
    "from torchdyn.core import NeuralODE\n",
    "from torchvision import datasets, transforms\n",
    "from tqdm import trange\n",
    "from utils_cifar import ema, generate_samples, infiniteloop\n",
    "\n",
    "from torchcfm.conditional_flow_matching import (\n",
    "    ConditionalFlowMatcher,\n",
    "    ExactOptimalTransportConditionalFlowMatcher,\n",
    "    TargetConditionalFlowMatcher,\n",
    "    VariancePreservingConditionalFlowMatcher,\n",
    ")\n",
    "from torchcfm.models.unet.unet import UNetModelWrapper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0bb078b-6497-4a71-82e0-f28a0d7c6df1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e79576f-bf36-4ea1-9118-6b198e7701bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "use_cuda = torch.cuda.is_available()\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "# torch.set_default_device(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ff91c0d-598a-4201-9b93-93e121630cb1",
   "metadata": {},
   "source": [
    "# 1. Read Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d699cab-8b8c-434b-948d-793cc459faf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_full = datasets.CIFAR10(\n",
    "        root=\"/hpc/home/gw74/diff_model/FM/cifar10/data\",\n",
    "        train=True,\n",
    "        download=True,\n",
    "        transform=transforms.Compose(\n",
    "            [\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
    "            ]\n",
    "        ),\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e87fdfef-4e5f-4238-9b58-78643340ca4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dataset_full"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ed1373a-4545-4414-a3f3-ecec4249a305",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "num_workers = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c6a2857-9d1d-4f47-b06b-3511a6c3ea3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = torch.utils.data.DataLoader(\n",
    "        dataset,\n",
    "        batch_size= batch_size,\n",
    "        shuffle=True,\n",
    "        num_workers= num_workers,\n",
    "        drop_last=True,\n",
    "        # generator = torch.Generator(device=device)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9951236-1320-42da-825f-6036aeb64c0e",
   "metadata": {},
   "source": [
    "# 2. Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b728089d-4033-4fed-a60e-dfb044bbe652",
   "metadata": {},
   "outputs": [],
   "source": [
    "def makemodel(num_channel):\n",
    "    net_model = UNetModelWrapper(\n",
    "            dim=(3, 32, 32),\n",
    "            num_res_blocks=2,\n",
    "            num_channels=num_channel,\n",
    "            channel_mult=[1, 2, 2, 2],\n",
    "            num_heads=4,\n",
    "            num_head_channels=64,\n",
    "            attention_resolutions=\"16\",\n",
    "            dropout=0.1,\n",
    "        ).to(\n",
    "            device\n",
    "        )  # new dropout + bs of 128\n",
    "    \n",
    "    return net_model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "002cac03-2fb6-4921-bb97-c103979811dd",
   "metadata": {},
   "source": [
    "## 2.1 I-CFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7ead4d6-2857-4662-9912-b3290f6e65e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def icfm(net_model, ema_model,\n",
    "         save_step, savedir, device,\n",
    "         dataloader, x0_all = None,\n",
    "         sigma = 0.0, total_steps = 5001, grad_clip = 1.0, lr = 2e-4):\n",
    "    \n",
    "    optim = torch.optim.Adam(net_model.parameters(), lr=lr)\n",
    "    sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr)\n",
    "    \n",
    "    if x0_all is not None:\n",
    "        if x0_all.shape[0] == total_steps:\n",
    "            x0_idx = np.arange(total_steps)\n",
    "        else:\n",
    "            x0_idx = np.random.choice(x0_all.shape[0], total_steps)\n",
    "    \n",
    "    FM = ConditionalFlowMatcher(sigma=sigma)\n",
    "    datalooper = infiniteloop(dataloader)\n",
    "    with trange(total_steps, dynamic_ncols=True) as pbar:\n",
    "        for step in pbar:\n",
    "            optim.zero_grad()\n",
    "            x1 = next(datalooper).to(device)\n",
    "            \n",
    "            # no matter what, I still have to store x0...\n",
    "            if x0_all is None:\n",
    "                x0 = torch.randn_like(x1)\n",
    "            else:\n",
    "                x0 = x0_all[x0_idx[step],:].to(device)\n",
    "            \n",
    "            t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)\n",
    "            vt = net_model(t, xt)\n",
    "            loss = torch.mean((vt - ut) ** 2)\n",
    "            loss.backward()\n",
    "            torch.nn.utils.clip_grad_norm_(net_model.parameters(), grad_clip)  # new\n",
    "            optim.step()\n",
    "            sched.step()\n",
    "            ema(net_model, ema_model, ema_decay)  # new\n",
    "\n",
    "            # sample and Saving the weights\n",
    "            if save_step > 0 and step % save_step == 0:\n",
    "                generate_samples(net_model, False, savedir, step, net_=\"normal\")\n",
    "                generate_samples(ema_model, False, savedir, step, net_=\"ema\")\n",
    "                torch.save(\n",
    "                    {\n",
    "                        \"net_model\": net_model.state_dict(),\n",
    "                        \"ema_model\": ema_model.state_dict(),\n",
    "                        \"sched\": sched.state_dict(),\n",
    "                        \"optim\": optim.state_dict(),\n",
    "                        \"step\": step,\n",
    "                    },\n",
    "                    savedir + f\"cifar10_weights_step_{step}.pt\",\n",
    "                )\n",
    "    \n",
    "    \n",
    "    return net_model, ema_model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b1a6906-d963-4664-bcd4-68eaf72debd6",
   "metadata": {},
   "source": [
    "## 2.2 GP-I-CFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3315f919-f6c7-4ed7-8374-a4ed8b2ce423",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_r(ti, tj):\n",
    "    r = ti[...,None] - tj[...,None,:]\n",
    "    r[r == 0] = 1e-15\n",
    "    return r\n",
    "def k11(r, alpha, l):\n",
    "    return (alpha**2)*torch.exp(-0.5 * ((r/l)**2))\n",
    "def k12(r, alpha, l):\n",
    "    return (alpha**2/l**2)*r*torch.exp(-0.5*((r/l)**2))\n",
    "def k22(r, alpha, l):\n",
    "    return (alpha**2/l**4)*(l**2 - r**2)*torch.exp(-0.5*((r/l)**2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f91115d-28c2-4b7c-b91c-4082b42be154",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cov_mat2(ti, tj, alpha, l, sig2_diag = 1e-8):\n",
    "    \n",
    "    r = calc_r(ti, tj)\n",
    "    nB = r.shape[0]\n",
    "    nt = r.shape[1]\n",
    "    \n",
    "    Sig11 = k11(r, alpha, l) + (torch.eye(nt, device=device)*sig2_diag).repeat(nB,1,1)\n",
    "    Sig12 = k12(r, alpha, l)\n",
    "    Sig21 = Sig12.permute(0, 2, 1)\n",
    "    Sig22 = k22(r, alpha, l)\n",
    "    \n",
    "    block_row1 = torch.cat([Sig11, Sig12], dim=2)\n",
    "    block_row2 = torch.cat([Sig21, Sig22], dim=2)\n",
    "    Sig = torch.cat([block_row1, block_row2], dim = 1)\n",
    "    Sig = (Sig + Sig.permute(0, 2, 1))/2\n",
    "    \n",
    "    return Sig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6178b09f-c0da-4159-b0a2-bffdfe33558c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def samp_x_dx2(t_mat, alpha, l, x_obs, t_obs, sig2_diag=1e-8):\n",
    "    nB, nt, dim = x_obs.shape[0], t_mat.shape[1], x_obs.shape[2]\n",
    "    nt_obs = t_obs.shape[0]\n",
    "\n",
    "    # Compute necessary covariance matrices and kernel functions\n",
    "    r_obs_x = calc_r(t_obs, t_mat)\n",
    "    r_obs_obs = calc_r(t_obs, t_obs)\n",
    "    Sig_11 = cov_mat2(t_mat, t_mat, alpha, l, sig2_diag)\n",
    "    \n",
    "    # Precompute parts of the covariance matrices\n",
    "    k_obs_x, k_obs_dx = k11(r_obs_x, alpha, l), k12(r_obs_x, alpha, l)\n",
    "    Sig_21 = torch.cat([k_obs_x, k_obs_dx], dim=2)\n",
    "    Sig_12 = Sig_21.permute(0, 2, 1)\n",
    "\n",
    "    Sig_22_sing = k11(r_obs_obs, alpha, l) + torch.eye(nt_obs, device=device) * sig2_diag\n",
    "    Sig_22_inv_sing = torch.linalg.inv(Sig_22_sing)\n",
    "    Sig_22_inv = Sig_22_inv_sing.repeat(nB, 1, 1)\n",
    "\n",
    "    # Compute conditional covariance matrix with stability adjustment\n",
    "    Sig_cond = Sig_11 - torch.bmm(torch.bmm(Sig_12, Sig_22_inv), Sig_21)\n",
    "    # Sig_cond = (Sig_cond + Sig_cond.permute(0, 2, 1)) / 2 + 1e-6 * torch.eye(Sig_cond.shape[1], device=Sig_cond.device)\n",
    "\n",
    "    Sig_cond = (Sig_cond + Sig_cond.permute(0, 2, 1))/2\n",
    "    \n",
    "    svd_add_idx = torch.sum((torch.linalg.eigvals(Sig_cond).real>=0).T, axis = 0) != Sig_cond.shape[1]\n",
    "    U, S, Vh = torch.linalg.svd(Sig_cond[svd_add_idx,:,:])\n",
    "#     U, S, Vh = torch.linalg.svd(Sig_cond)\n",
    "    Sig_cond_add = torch.bmm(torch.bmm(Vh.permute(0, 2, 1), torch.diag_embed(S + 1e-8)), Vh)\n",
    "    Sig_cond[svd_add_idx,:,:] = (Sig_cond_add + Sig_cond_add.permute(0, 2, 1))/2\n",
    "\n",
    "    # Mean adjustment matrix\n",
    "    mu_A = torch.bmm(Sig_12, Sig_22_inv).to(device)\n",
    "    x_obs_batch = x_obs.reshape(nB, nt_obs, dim).to(device)\n",
    "    mu_new = torch.bmm(mu_A, x_obs_batch).reshape(nB, 2 * nt, dim)\n",
    "\n",
    "    # Initialize sample matrices\n",
    "    x_samps = torch.zeros((nB, nt, dim), dtype=x_obs.dtype, device=x_obs.device)\n",
    "    dx_samps = torch.zeros((nB, nt, dim), dtype=x_obs.dtype, device=x_obs.device)\n",
    "    \n",
    "    mu_flat = mu_new.permute(0, 2, 1).reshape(nB * dim, 2 * nt)\n",
    "    Sig_cond_flat = Sig_cond.repeat_interleave(dim, dim=0)\n",
    "    \n",
    "    # Sampling in batch for all dimensions at once\n",
    "    try:\n",
    "        # Reshape mu_new and Sig_cond for compatible shapes\n",
    "#         mu_flat = mu_new.view(nB * dim, 2 * nt)\n",
    "#         Sig_cond_flat = Sig_cond.repeat(dim, 1, 1)\n",
    "        \n",
    "        dist = MultivariateNormal(loc=mu_flat, covariance_matrix=Sig_cond_flat)\n",
    "        x_dx_samps_flat = dist.rsample().reshape(nB, dim, 2 * nt).permute(0, 2, 1)\n",
    "    except RuntimeError:\n",
    "        print('Sampling failed; using numpy fallback.')\n",
    "        x_dx_samps_flat = torch.zeros((nB, 2 * nt, dim), dtype=x_obs.dtype, device=x_obs.device)\n",
    "        for bb in range(nB):\n",
    "            for dd in range(dim):\n",
    "                mu_single = mu_new[bb, :, dd].cpu().numpy()\n",
    "                cov_single = Sig_cond[bb].cpu().numpy()\n",
    "                sample = np.random.multivariate_normal(mu_single, cov_single)\n",
    "                x_dx_samps_flat[bb, :, dd] = torch.from_numpy(sample)\n",
    "\n",
    "    # Separate x and dx samples\n",
    "    x_samps[:, :, :] = x_dx_samps_flat[:, :nt, :]\n",
    "    dx_samps[:, :, :] = x_dx_samps_flat[:, nt:, :]\n",
    "\n",
    "    return x_samps, dx_samps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae913abb-5d83-4e7d-9745-94df8e439fa9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # to debug...\n",
    "# net_model = makemodel(128)\n",
    "# ema_model = copy.deepcopy(net_model)\n",
    "# save_step = 2000\n",
    "# savedir = \"/cwork/gw74/iclr_cifar10/results\"\n",
    "# # device\n",
    "# # dataloader\n",
    "# x0_all = None\n",
    "# alpha  = 1\n",
    "# l = 1\n",
    "# sig2_diag = 0.0\n",
    "# total_steps = 50001\n",
    "# grad_clip = 1.0\n",
    "# lr = 2e-4\n",
    "# warmup = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "590e6a93-5dad-4fb3-9bbd-9f9235a7dc12",
   "metadata": {},
   "outputs": [],
   "source": [
    "def warmup_lr(step):\n",
    "    return min(step, warmup) / warmup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ff2f98b-d214-4737-bb5b-fc28ad2bcb56",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gp_icfm(net_model, ema_model,\n",
    "         save_step, savedir, device,\n",
    "         dataloader, x0_all = None,\n",
    "         sig2_diag = 0.0, total_steps = 5001, grad_clip = 1.0, lr = 2e-4,\n",
    "         btch_size = 128, pixel = 32, n_channel = 3):\n",
    "    \n",
    "    optim = torch.optim.Adam(net_model.parameters(), lr=lr)\n",
    "    sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr)\n",
    "    \n",
    "    if x0_all is not None:\n",
    "        if x0_all.shape[0] == total_steps:\n",
    "            x0_idx = np.arange(total_steps)\n",
    "        else:\n",
    "            x0_idx = np.random.choice(x0_all.shape[0], total_steps)\n",
    "    \n",
    "    datalooper = infiniteloop(dataloader)\n",
    "    with trange(total_steps, dynamic_ncols=True) as pbar:\n",
    "        for step in pbar:\n",
    "            optim.zero_grad()\n",
    "            \n",
    "            x1 = next(datalooper).to(device)\n",
    "            if x0_all is None:\n",
    "                x0 = torch.randn_like(x1)\n",
    "            else:\n",
    "                x0 = x0_all[x0_idx[step],:].to(device)\n",
    "            \n",
    "            \n",
    "            t_mat = torch.rand((batch_size,1)).to(device)\n",
    "            xt = torch.zeros(batch_size, n_channel, pixel, pixel).to(device)\n",
    "            ut = torch.zeros(batch_size, n_channel, pixel, pixel).to(device)\n",
    "            for ii in range(n_channel):\n",
    "                x01_trans = torch.zeros(batch_size, 2, pixel*pixel)\n",
    "                x01_trans[:,0,:] = torch.reshape(x0[:,ii,:,:], (batch_size, -1))\n",
    "                x01_trans[:,1,:] = torch.reshape(x1[:,ii,:,:], (batch_size, -1))\n",
    "                xt_batch, ut_batch = samp_x_dx2(t_mat, alpha, l, x01_trans,\n",
    "                                                torch.tensor([0, 1]).to(device), sig2_diag)\n",
    "                \n",
    "                xt[:,ii,:,:] = torch.reshape(xt_batch, (batch_size, pixel, pixel))\n",
    "                ut[:,ii,:,:] = torch.reshape(ut_batch, (batch_size, pixel, pixel))\n",
    "            \n",
    "            t = torch.reshape(t_mat, (-1, )).to(device)\n",
    "            vt = net_model(t, xt)\n",
    "            \n",
    "            loss = torch.mean((vt - ut) ** 2)\n",
    "            loss.backward()\n",
    "            torch.nn.utils.clip_grad_norm_(net_model.parameters(), grad_clip)  # new\n",
    "            optim.step()\n",
    "            sched.step()\n",
    "            ema(net_model, ema_model, ema_decay)  # new\n",
    "            \n",
    "            # sample and Saving the weights\n",
    "            if save_step > 0 and step % save_step == 0:\n",
    "                generate_samples(net_model, False, savedir, step, net_=\"normal\")\n",
    "                generate_samples(ema_model, False, savedir, step, net_=\"ema\")\n",
    "                torch.save(\n",
    "                    {\n",
    "                        \"net_model\": net_model.state_dict(),\n",
    "                        \"ema_model\": ema_model.state_dict(),\n",
    "                        \"sched\": sched.state_dict(),\n",
    "                        \"optim\": optim.state_dict(),\n",
    "                        \"step\": step,\n",
    "                    },\n",
    "                    savedir + f\"cifar10_weights_step_{step}.pt\",\n",
    "                )\n",
    "            \n",
    "    return net_model, ema_model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0dd3166-cf1d-4bd0-a60f-eb88b5499c0f",
   "metadata": {},
   "source": [
    "# 3. Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29a55e37-0b69-4fd4-a0a1-72a9436a38ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_channel = 128\n",
    "save_step = 5000\n",
    "total_steps = 400001\n",
    "warmup = 5000\n",
    "ema_decay =  0.9999"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "936056e9-99b4-48d1-8f41-f406a453c23a",
   "metadata": {},
   "source": [
    "## 3.1 I-CFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23982c60-ba71-42c3-92c5-e488b695d590",
   "metadata": {},
   "outputs": [],
   "source": [
    "net_model_icfm = makemodel(num_channel)\n",
    "ema_model_icfm = copy.deepcopy(net_model_icfm)\n",
    "savedir = \"/cwork/gw74/iclr_cifar10/results/icfm/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c83f67cd-e75d-437b-a6d1-98808388d4af",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture output\n",
    "sigma = 1e-3\n",
    "net_model_icfm, ema_model_icfm = icfm(net_model_icfm, ema_model_icfm,\n",
    "                                      save_step, savedir, device,\n",
    "                                      dataloader, x0_all = None,\n",
    "                                      sigma = sigma, total_steps = total_steps,\n",
    "                                      grad_clip = 1.0, lr = 2e-4)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dfa8d852-a976-49ea-b896-28fded100f76",
   "metadata": {},
   "source": [
    "## 3.2 GP-I-CFM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a837a295-e33b-4cb4-b9d1-b5c43d694a3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# need to check the alpha & l\n",
    "net_model_gpicfm = makemodel(num_channel).to(device)\n",
    "ema_model_gpicfm = copy.deepcopy(net_model_gpicfm)\n",
    "# savedir = \"/cwork/gw74/iclr_cifar10/results/gpicfm/\"\n",
    "savedir = \"/cwork/gw74/iclr_cifar10/results/gpicfm/\"\n",
    "alpha = 1\n",
    "l = 2\n",
    "sig2_diag = 1e-6\n",
    "net_model_gpicfm, ema_model_gpicfm = gp_icfm(net_model_gpicfm, ema_model_gpicfm,\n",
    "                                             save_step, savedir, device,\n",
    "                                             dataloader, x0_all = None,\n",
    "                                             sig2_diag = sig2_diag, total_steps = total_steps,\n",
    "                                             grad_clip = 1.0, lr = 2e-4,\n",
    "                                             btch_size = 128, pixel = 32, n_channel = 3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3665780b-1cc3-45ed-a859-1d5541ca8548",
   "metadata": {},
   "source": [
    "# 4. FID"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fd21d7b-a1d3-4b36-add5-021d7f88aeb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchdiffeq import odeint\n",
    "from cleanfid import fid\n",
    "import warnings\n",
    "\n",
    "# Configure warnings to display only once\n",
    "warnings.filterwarnings(\"once\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97d65fc0-16e5-45fe-8e0a-de65915e99fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "tol = 1e-5\n",
    "step_all = np.arange(1, 21)*20000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d62cae2e-2be3-45f1-857a-7db56391aa87",
   "metadata": {},
   "outputs": [],
   "source": [
    "score_all_gpicfm = np.zeros(20)\n",
    "num_channel = 128\n",
    "\n",
    "batch_size_fid = 32\n",
    "num_gen = 2000\n",
    "batch_size_fid = 32\n",
    "\n",
    "cc = 0\n",
    "for step in step_all:\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n",
    "    \n",
    "    PATH = f\"/cwork/gw74/iclr_cifar10/results/icfm/cifar10_weights_step_{step}.pt\"\n",
    "    # PATH = f\"/cwork/gw74/iclr_cifar10/results/gpicfm/cifar10_weights_step_{step}.pt\"\n",
    "\n",
    "    checkpoint = torch.load(PATH, map_location=device)\n",
    "    state_dict = checkpoint[\"ema_model\"]\n",
    "    new_net = makemodel(num_channel)\n",
    "    try:\n",
    "        new_net.load_state_dict(state_dict)\n",
    "    except RuntimeError:\n",
    "        from collections import OrderedDict\n",
    "\n",
    "        new_state_dict = OrderedDict()\n",
    "        for k, v in state_dict.items():\n",
    "            new_state_dict[k[7:]] = v\n",
    "        new_net.load_state_dict(new_state_dict)\n",
    "    new_net.eval();\n",
    "    def gen_1_img(unused_latent):\n",
    "        with torch.no_grad():\n",
    "            x = torch.randn(batch_size_fid, 3, 32, 32, device=device)\n",
    "            t_span = torch.linspace(0, 1, 2, device=device)\n",
    "            traj = odeint(\n",
    "                new_net, x, t_span, rtol=tol, atol=tol, method=\"dopri5\"\n",
    "            )\n",
    "        traj = traj[-1, :]  # .view([-1, 3, 32, 32]).clip(-1, 1)\n",
    "        img = (traj * 127.5 + 128).clip(0, 255).to(torch.uint8)  # .permute(1, 2, 0)\n",
    "        return img\n",
    "    \n",
    "    score = fid.compute_fid(\n",
    "        gen=gen_1_img,\n",
    "        dataset_name=\"cifar10\",\n",
    "        batch_size= batch_size_fid,\n",
    "        dataset_res=32,\n",
    "        num_gen=num_gen,\n",
    "        dataset_split=\"train\",\n",
    "        mode=\"legacy_tensorflow\",\n",
    "        device=device,\n",
    "        use_dataparallel=False,\n",
    "        num_workers=2\n",
    "    )\n",
    "    score_all_gpicfm[cc]  = score\n",
    "    cc = cc + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f5d2d38-291a-4fda-bf6f-41a1bd80c350",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(\"gpicfm_later.npy\", score_all_gpicfm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04a5aa78-6308-4c3c-8617-a91658eff548",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
