{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torchdiffeq\n",
    "import torchsde\n",
    "from torchdyn.core import NeuralODE\n",
    "from torchvision import datasets, transforms\n",
    "from torchvision.transforms import ToPILImage\n",
    "from torchvision.utils import make_grid\n",
    "from tqdm import tqdm\n",
    "\n",
    "from torchcfm.conditional_flow_matching import *\n",
    "from torchcfm.models.unet import UNetModel\n",
    "\n",
    "savedir = \"models/cond_mnist\"\n",
    "os.makedirs(savedir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "use_cuda = torch.cuda.is_available()\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "batch_size = 128\n",
    "n_epochs = 10\n",
    "\n",
    "trainset = datasets.MNIST(\n",
    "    \"../data\",\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]),\n",
    ")\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    trainset, batch_size=batch_size, shuffle=True, drop_last=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#################################\n",
    "#    Class Conditional CFM\n",
    "#################################\n",
    "\n",
    "sigma = 0.0\n",
    "model = UNetModel(\n",
    "    dim=(1, 28, 28), num_channels=32, num_res_blocks=1, num_classes=10, class_cond=True\n",
    ").to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "FM = ConditionalFlowMatcher(sigma=sigma)\n",
    "# Users can try target FM by changing the above line by\n",
    "# FM = TargetConditionalFlowMatcher(sigma=sigma)\n",
    "node = NeuralODE(model, solver=\"dopri5\", sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(n_epochs):\n",
    "    for i, data in enumerate(train_loader):\n",
    "        optimizer.zero_grad()\n",
    "        x1 = data[0].to(device)\n",
    "        y = data[1].to(device)\n",
    "        x0 = torch.randn_like(x1)\n",
    "        t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)\n",
    "        vt = model(t, xt, y)\n",
    "        loss = torch.mean((vt - ut) ** 2)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        print(f\"epoch: {epoch}, steps: {i}, loss: {loss.item():.4}\", end=\"\\r\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "USE_TORCH_DIFFEQ = True\n",
    "generated_class_list = torch.arange(10, device=device).repeat(10)\n",
    "with torch.no_grad():\n",
    "    if USE_TORCH_DIFFEQ:\n",
    "        traj = torchdiffeq.odeint(\n",
    "            lambda t, x: model.forward(t, x, generated_class_list),\n",
    "            torch.randn(100, 1, 28, 28, device=device),\n",
    "            torch.linspace(0, 1, 2, device=device),\n",
    "            atol=1e-4,\n",
    "            rtol=1e-4,\n",
    "            method=\"dopri5\",\n",
    "        )\n",
    "    else:\n",
    "        traj = node.trajectory(\n",
    "            torch.randn(100, 1, 28, 28, device=device),\n",
    "            t_span=torch.linspace(0, 1, 2, device=device),\n",
    "        )\n",
    "grid = make_grid(\n",
    "    traj[-1, :100].view([-1, 1, 28, 28]).clip(-1, 1), value_range=(-1, 1), padding=0, nrow=10\n",
    ")\n",
    "img = ToPILImage()(grid)\n",
    "plt.imshow(img)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#################################\n",
    "#            OT-CFM\n",
    "#################################\n",
    "\n",
    "sigma = 0.0\n",
    "model = UNetModel(\n",
    "    dim=(1, 28, 28), num_channels=32, num_res_blocks=1, num_classes=10, class_cond=True\n",
    ").to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)\n",
    "node = NeuralODE(model, solver=\"dopri5\", sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(n_epochs):\n",
    "    for i, data in enumerate(train_loader):\n",
    "        optimizer.zero_grad()\n",
    "        x1 = data[0].to(device)\n",
    "        y = data[1].to(device)\n",
    "        x0 = torch.randn_like(x1)\n",
    "        t, xt, ut, _, y1 = FM.guided_sample_location_and_conditional_flow(x0, x1, y1=y)\n",
    "        vt = model(t, xt, y1)\n",
    "        loss = torch.mean((vt - ut) ** 2)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        print(f\"epoch: {epoch}, steps: {i}, loss: {loss.item():.4}\", end=\"\\r\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "USE_TORCH_DIFFEQ = True\n",
    "generated_class_list = torch.arange(10, device=device).repeat(10)\n",
    "with torch.no_grad():\n",
    "    if USE_TORCH_DIFFEQ:\n",
    "        traj = torchdiffeq.odeint(\n",
    "            lambda t, x: model.forward(t, x, generated_class_list),\n",
    "            torch.randn(100, 1, 28, 28, device=device),\n",
    "            torch.linspace(0, 1, 2, device=device),\n",
    "            atol=1e-4,\n",
    "            rtol=1e-4,\n",
    "            method=\"dopri5\",\n",
    "        )\n",
    "    else:\n",
    "        traj = node.trajectory(\n",
    "            torch.randn(100, 1, 28, 28, device=device),\n",
    "            t_span=torch.linspace(0, 1, 2, device=device),\n",
    "        )\n",
    "grid = make_grid(\n",
    "    traj[-1, :100].view([-1, 1, 28, 28]).clip(-1, 1), value_range=(-1, 1), padding=0, nrow=10\n",
    ")\n",
    "img = ToPILImage()(grid)\n",
    "plt.imshow(img)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#################################\n",
    "#            SF2M\n",
    "#################################\n",
    "batch_size = 128\n",
    "n_epochs = 10\n",
    "sigma = 0.1\n",
    "\n",
    "\n",
    "model = UNetModel(\n",
    "    dim=(1, 28, 28), num_channels=32, num_res_blocks=1, num_classes=10, class_cond=True\n",
    ").to(device)\n",
    "score_model = UNetModel(\n",
    "    dim=(1, 28, 28), num_channels=32, num_res_blocks=1, num_classes=10, class_cond=True\n",
    ").to(device)\n",
    "\n",
    "optimizer = torch.optim.Adam(list(model.parameters()) + list(score_model.parameters()))\n",
    "FM = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma)\n",
    "node = NeuralODE(model, solver=\"dopri5\", sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(n_epochs):\n",
    "    for i, data in tqdm(enumerate(train_loader)):\n",
    "        optimizer.zero_grad()\n",
    "        x1 = data[0].to(device)\n",
    "        y = data[1].to(device)\n",
    "        x0 = torch.randn_like(x1)\n",
    "        t, xt, ut, _, y1, eps = FM.guided_sample_location_and_conditional_flow(\n",
    "            x0, x1, y1=y, return_noise=True\n",
    "        )\n",
    "        lambda_t = FM.compute_lambda(t)\n",
    "        vt = model(t, xt, y1)\n",
    "        st = score_model(t, xt, y1)\n",
    "        flow_loss = torch.mean((vt - ut) ** 2)\n",
    "        score_loss = torch.mean((lambda_t[:, None, None, None] * st + eps) ** 2)\n",
    "        loss = flow_loss + score_loss\n",
    "        loss.backward()\n",
    "        optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "USE_TORCH_DIFFEQ = True\n",
    "generated_class_list = torch.arange(10, device=device).repeat(10)\n",
    "\n",
    "node = NeuralODE(model, solver=\"euler\", sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4)\n",
    "# Evaluate the ODE\n",
    "with torch.no_grad():\n",
    "    if USE_TORCH_DIFFEQ:\n",
    "        traj = torchdiffeq.odeint(\n",
    "            lambda t, x: model.forward(t, x, generated_class_list),\n",
    "            torch.randn(100, 1, 28, 28, device=device),\n",
    "            torch.linspace(0, 1, 2, device=device),\n",
    "            atol=1e-4,\n",
    "            rtol=1e-4,\n",
    "            method=\"dopri5\",\n",
    "        )\n",
    "    else:\n",
    "        traj = node.trajectory(\n",
    "            torch.randn(100, 1, 28, 28, device=device),\n",
    "            t_span=torch.linspace(0, 1, 2, device=device),\n",
    "        )\n",
    "grid = make_grid(\n",
    "    traj[-1, :100].view([-1, 1, 28, 28]).clip(-1, 1), value_range=(-1, 1), padding=0, nrow=10\n",
    ")\n",
    "img = ToPILImage()(grid)\n",
    "plt.imshow(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# follows example from https://github.com/google-research/torchsde/blob/master/examples/cont_ddpm.py\n",
    "\n",
    "\n",
    "class SDE(torch.nn.Module):\n",
    "    noise_type = \"diagonal\"\n",
    "    sde_type = \"ito\"\n",
    "\n",
    "    def __init__(self, ode_drift, score, labels=None, reverse=False, sigma=0.1):\n",
    "        super().__init__()\n",
    "        self.drift = ode_drift\n",
    "        self.score = score\n",
    "        self.reverse = reverse\n",
    "        self.labels = labels\n",
    "        self.sigma = sigma\n",
    "\n",
    "    # Drift\n",
    "\n",
    "    def f(self, t, y):\n",
    "        y = y.view(-1, 1, 28, 28)\n",
    "        if self.reverse:\n",
    "            t = 1 - t\n",
    "            return -self.drift(t, y, self.labels) + self.score(t, y, self.labels)\n",
    "        return self.drift(t, y, self.labels).flatten(start_dim=1) + self.score(\n",
    "            t, y, self.labels\n",
    "        ).flatten(start_dim=1)\n",
    "\n",
    "    # Diffusion\n",
    "    def g(self, t, y):\n",
    "        return torch.ones_like(y) * self.sigma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sde = SDE(model, score_model, labels=torch.arange(10, device=device).repeat(10), sigma=0.1)\n",
    "with torch.no_grad():\n",
    "    sde_traj = torchsde.sdeint(\n",
    "        sde,\n",
    "        # x0.view(x0.size(0), -1),\n",
    "        torch.randn(100, 1 * 28 * 28, device=device),\n",
    "        ts=torch.linspace(0, 1, 2, device=device),\n",
    "        dt=0.01,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "grid = make_grid(\n",
    "    sde_traj[-1, :100].view([-1, 1, 28, 28]).clip(-1, 1), value_range=(-1, 1), padding=0, nrow=10\n",
    ")\n",
    "img = ToPILImage()(grid)\n",
    "plt.imshow(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torchcfm2",
   "language": "python",
   "name": "torchcfm2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}