{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torchsde\n",
    "from torchdyn.core import NeuralODE\n",
    "from tqdm import tqdm\n",
    "\n",
    "from torchcfm.conditional_flow_matching import *\n",
    "from torchcfm.models import MLP\n",
    "from torchcfm.utils import sample_8gaussians, sample_moons, torch_wrapper\n",
    "\n",
    "savedir = \"models/2d\"\n",
    "os.makedirs(savedir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_trajectories_sb(traj, legend=True):\n",
    "    n = 2000\n",
    "    plt.figure(figsize=(10, 10))\n",
    "    plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c=\"black\")\n",
    "    plt.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.4, alpha=0.1, c=\"olive\")\n",
    "    plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c=\"blue\")\n",
    "\n",
    "    for i in range(5, 15):\n",
    "        plt.plot(traj[:, i, 0], traj[:, i, 1], alpha=0.9, c=\"red\")\n",
    "    if legend:\n",
    "        plt.legend([r\"$p_0$\", r\"$p_t$\", r\"$p_1$\", r\"$X_t \\mid X_0$\"])\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "    plt.axis(\"off\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "use_cuda = torch.cuda.is_available()\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "batch_size = 256\n",
    "sigma = 1.0\n",
    "dim = 2\n",
    "model = MLP(dim=dim, time_varying=True, w=64)\n",
    "score_model = MLP(dim=dim, time_varying=True, w=64)\n",
    "optimizer = torch.optim.Adam(list(model.parameters()) + list(score_model.parameters()), 0.01)\n",
    "# FM = ConditionalFlowMatcher(sigma=sigma)\n",
    "FM = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in tqdm(range(10000)):\n",
    "    optimizer.zero_grad()\n",
    "    x0 = sample_8gaussians(batch_size)\n",
    "    x1 = sample_moons(batch_size)\n",
    "    t, xt, ut, eps = FM.sample_location_and_conditional_flow(x0, x1, return_noise=True)\n",
    "    lambda_t = FM.compute_lambda(t)\n",
    "    vt = model(torch.cat([xt, t[:, None]], dim=-1))\n",
    "    st = score_model(torch.cat([xt, t[:, None]], dim=-1))\n",
    "    flow_loss = torch.mean((vt - ut) ** 2)\n",
    "    score_loss = torch.mean((lambda_t[:, None] * st + eps) ** 2)\n",
    "    loss = flow_loss + score_loss\n",
    "    loss.backward()\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model, f\"{savedir}/sf2m_v1.pt\")\n",
    "torch.save(score_model, f\"{savedir}/sf2m_v1.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "node = NeuralODE(torch_wrapper(model), solver=\"euler\", sensitivity=\"adjoint\", atol=1e-4, rtol=1e-4)\n",
    "x0 = sample_8gaussians(1024)\n",
    "\n",
    "with torch.no_grad():\n",
    "    traj = node.trajectory(\n",
    "        x0,\n",
    "        t_span=torch.linspace(0, 1, 100, device=device),\n",
    "    )\n",
    "\n",
    "\n",
    "class SDE(torch.nn.Module):\n",
    "    noise_type = \"diagonal\"\n",
    "    sde_type = \"ito\"\n",
    "\n",
    "    def __init__(self, ode_drift, score, input_size=(3, 32, 32), sigma=1.0):\n",
    "        super().__init__()\n",
    "        self.drift = ode_drift\n",
    "        self.score = score\n",
    "        self.input_size = input_size\n",
    "        self.sigma = sigma\n",
    "\n",
    "    # Drift\n",
    "    def f(self, t, y):\n",
    "        y = y.view(-1, *self.input_size)\n",
    "        if len(t.shape) == len(y.shape):\n",
    "            x = torch.cat([y, t], 1)\n",
    "        else:\n",
    "            x = torch.cat([y, t.repeat(y.shape[0])[:, None]], 1)\n",
    "        return self.drift(x).flatten(start_dim=1) + self.score(x).flatten(start_dim=1)\n",
    "\n",
    "    # Diffusion\n",
    "    def g(self, t, y):\n",
    "        return torch.ones_like(y) * self.sigma\n",
    "\n",
    "\n",
    "sde = SDE(model, score_model, input_size=(2,), sigma=sigma)\n",
    "with torch.no_grad():\n",
    "    sde_traj = torchsde.sdeint(\n",
    "        sde,\n",
    "        x0,\n",
    "        ts=torch.linspace(0, 1, 100),\n",
    "        solver=\"euler\",\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_trajectories_sb(traj.cpu().numpy(), legend=False)\n",
    "plot_trajectories_sb(sde_traj.cpu().numpy(), legend=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torchcfm2",
   "language": "python",
   "name": "torchcfm2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}