{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Plotting\n",
    "\n",
    "Here we compare different interpolants together on the same dataset from saved models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import os\n",
    "import time\n",
    "\n",
    "import imageio\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import ot as pot\n",
    "import torch\n",
    "import torchdyn\n",
    "from torchdyn.core import DEFunc, NeuralODE\n",
    "from torchdyn.datasets import generate_moons\n",
    "from torchdyn.nn import Augmenter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Implement some helper functions\n",
    "\n",
    "\n",
    "def sample_normal(n):\n",
    "    m = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(2), torch.eye(2))\n",
    "    return m.sample((n,))\n",
    "\n",
    "\n",
    "def log_normal_density(x):\n",
    "    m = torch.distributions.multivariate_normal.MultivariateNormal(\n",
    "        torch.zeros(x.shape[-1]), torch.eye(x.shape[-1])\n",
    "    )\n",
    "    return m.log_prob(x)\n",
    "\n",
    "\n",
    "def eight_normal_sample(n, dim, scale=1, var=1):\n",
    "    m = torch.distributions.multivariate_normal.MultivariateNormal(\n",
    "        torch.zeros(dim), math.sqrt(var) * torch.eye(dim)\n",
    "    )\n",
    "    centers = [\n",
    "        (1, 0),\n",
    "        (-1, 0),\n",
    "        (0, 1),\n",
    "        (0, -1),\n",
    "        (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),\n",
    "        (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),\n",
    "        (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),\n",
    "        (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),\n",
    "    ]\n",
    "    centers = torch.tensor(centers) * scale\n",
    "    noise = m.sample((n,))\n",
    "    multi = torch.multinomial(torch.ones(8), n, replacement=True)\n",
    "    data = []\n",
    "    for i in range(n):\n",
    "        data.append(centers[multi[i]] + noise[i])\n",
    "    data = torch.stack(data)\n",
    "    return data\n",
    "\n",
    "\n",
    "def log_8gaussian_density(x, scale=5, var=0.1):\n",
    "    centers = [\n",
    "        (1, 0),\n",
    "        (-1, 0),\n",
    "        (0, 1),\n",
    "        (0, -1),\n",
    "        (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),\n",
    "        (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),\n",
    "        (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),\n",
    "        (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),\n",
    "    ]\n",
    "    centers = torch.tensor(centers) * scale\n",
    "    centers = centers.T.reshape(1, 2, 8)\n",
    "    # calculate shifted xs [batch, centers, dims]\n",
    "    x = (x[:, :, None] - centers).mT\n",
    "    m = torch.distributions.multivariate_normal.MultivariateNormal(\n",
    "        torch.zeros(x.shape[-1]), math.sqrt(var) * torch.eye(x.shape[-1])\n",
    "    )\n",
    "    log_probs = m.log_prob(x)\n",
    "    log_probs = torch.logsumexp(log_probs, -1)\n",
    "    return log_probs\n",
    "\n",
    "\n",
    "def sample_moons(n):\n",
    "    x0, _ = generate_moons(n, noise=0.2)\n",
    "    return x0 * 3 - 1\n",
    "\n",
    "\n",
    "def sample_8gaussians(n):\n",
    "    return eight_normal_sample(n, 2, scale=5, var=0.1).float()\n",
    "\n",
    "\n",
    "class MLP(torch.nn.Module):\n",
    "    def __init__(self, dim, out_dim=None, w=64, time_varying=False):\n",
    "        super().__init__()\n",
    "        self.time_varying = time_varying\n",
    "        if out_dim is None:\n",
    "            out_dim = dim\n",
    "        self.net = torch.nn.Sequential(\n",
    "            torch.nn.Linear(dim + (1 if time_varying else 0), w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, w),\n",
    "            torch.nn.SELU(),\n",
    "            torch.nn.Linear(w, out_dim),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "\n",
    "\n",
    "class MLP2(torch.nn.Module):\n",
    "    \"\"\"Change activations for Action Matching\"\"\"\n",
    "\n",
    "    def __init__(self, dim, out_dim=None, w=64, time_varying=False):\n",
    "        super().__init__()\n",
    "        self.time_varying = time_varying\n",
    "        if out_dim is None:\n",
    "            out_dim = dim\n",
    "        self.net = torch.nn.Sequential(\n",
    "            torch.nn.Linear(dim + (1 if time_varying else 0), w),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(w, w),\n",
    "            torch.nn.SiLU(),\n",
    "            torch.nn.Linear(w, w),\n",
    "            torch.nn.SiLU(),\n",
    "            torch.nn.Linear(w, out_dim),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "\n",
    "\n",
    "class GradModel(torch.nn.Module):\n",
    "    def __init__(self, action):\n",
    "        super().__init__()\n",
    "        self.action = action\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.requires_grad_(True)\n",
    "        grad = torch.autograd.grad(torch.sum(self.action(x)), x, create_graph=True)[0]\n",
    "        return grad[:, :-1]\n",
    "\n",
    "\n",
    "class torch_wrapper(torch.nn.Module):\n",
    "    \"\"\"Wraps model to torchdyn compatible format.\"\"\"\n",
    "\n",
    "    def __init__(self, model):\n",
    "        super().__init__()\n",
    "        self.model = model\n",
    "\n",
    "    def forward(self, t, x, *args, **kwargs):\n",
    "        return model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))\n",
    "\n",
    "\n",
    "def autograd_trace(x_out, x_in, **kwargs):\n",
    "    \"\"\"Standard brute-force means of obtaining trace of the Jacobian, O(d) calls to autograd\"\"\"\n",
    "    trJ = 0.0\n",
    "    for i in range(x_in.shape[1]):\n",
    "        trJ += torch.autograd.grad(x_out[:, i].sum(), x_in, allow_unused=False, create_graph=True)[\n",
    "            0\n",
    "        ][:, i]\n",
    "    return trJ\n",
    "\n",
    "\n",
    "class CNF(torch.nn.Module):\n",
    "    def __init__(self, net, trace_estimator=None, noise_dist=None):\n",
    "        super().__init__()\n",
    "        self.net = net\n",
    "        self.trace_estimator = trace_estimator if trace_estimator is not None else autograd_trace\n",
    "        self.noise_dist, self.noise = noise_dist, None\n",
    "\n",
    "    def forward(self, t, x, *args, **kwargs):\n",
    "        with torch.set_grad_enabled(True):\n",
    "            x_in = x[:, 1:].requires_grad_(\n",
    "                True\n",
    "            )  # first dimension reserved to divergence propagation\n",
    "            # the neural network will handle the data-dynamics here\n",
    "            x_out = self.net(\n",
    "                torch.cat([x_in, t * torch.ones(x.shape[0], 1).type_as(x_in)], dim=-1)\n",
    "            )\n",
    "            trJ = self.trace_estimator(x_out, x_in, noise=self.noise)\n",
    "        return (\n",
    "            torch.cat([-trJ[:, None], x_out], 1) + 0 * x\n",
    "        )  # `+ 0*x` has the only purpose of connecting x[:, 0] to autograd graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "savedir = \"models/8gaussian-moons\"\n",
    "models = {\n",
    "    \"CFM\": torch.load(f\"{savedir}/cfm_v1.pt\"),\n",
    "    \"OT-CFM (ours)\": torch.load(f\"{savedir}/otcfm_v1.pt\"),\n",
    "    \"SB-CFM (ours)\": torch.load(f\"{savedir}/sbcfm_v1.pt\"),\n",
    "    \"VP-CFM\": torch.load(f\"{savedir}/stochastic_interpolant_v1.pt\"),\n",
    "    \"Action-Matching\": torch.load(f\"{savedir}/action_matching_v1.pt\"),\n",
    "    \"Action-Matching (Swish)\": torch.load(f\"{savedir}/action_matching_swish_v1.pt\"),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "w = 7\n",
    "points = 100j\n",
    "points_real = 100\n",
    "device = \"cpu\"\n",
    "Y, X = np.mgrid[-w:w:points, -w:w:points]\n",
    "gridpoints = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1)).type(torch.float32)\n",
    "points_small = 20j\n",
    "points_real_small = 20\n",
    "Y_small, X_small = np.mgrid[-w:w:points_small, -w:w:points_small]\n",
    "gridpoints_small = torch.tensor(np.stack([X_small.flatten(), Y_small.flatten()], axis=1)).type(\n",
    "    torch.float32\n",
    ")\n",
    "\n",
    "torch.manual_seed(42)\n",
    "sample = sample_8gaussians(1024)\n",
    "ts = torch.linspace(0, 1, 101)\n",
    "trajs = {}\n",
    "for name, model in models.items():\n",
    "    nde = NeuralODE(DEFunc(torch_wrapper(model)), solver=\"euler\").to(device)\n",
    "    # with torch.no_grad():\n",
    "    traj = nde.trajectory(sample.to(device), t_span=ts.to(device)).detach().cpu().numpy()\n",
    "    trajs[name] = traj\n",
    "for i, t in enumerate(ts):\n",
    "    names = [\n",
    "        \"CFM\",\n",
    "        \"Action-Matching\",\n",
    "        \"Action-Matching (Swish)\",\n",
    "        \"VP-CFM\",\n",
    "        \"SB-CFM (ours)\",\n",
    "        \"OT-CFM (ours)\",\n",
    "    ]\n",
    "    fig, axes = plt.subplots(3, len(names), figsize=(6 * len(names), 6 * 3))\n",
    "    for axis, name in zip(axes.T, names):\n",
    "        model = models[name]\n",
    "        cnf = DEFunc(CNF(model))\n",
    "        nde = NeuralODE(cnf, solver=\"euler\", sensitivity=\"adjoint\")\n",
    "        cnf_model = torch.nn.Sequential(Augmenter(augment_idx=1, augment_dims=1), nde)\n",
    "        with torch.no_grad():\n",
    "            if t > 0:\n",
    "                aug_traj = (\n",
    "                    cnf_model[1]\n",
    "                    .to(device)\n",
    "                    .trajectory(\n",
    "                        Augmenter(1, 1)(gridpoints).to(device),\n",
    "                        t_span=torch.linspace(t, 0, 201).to(device),\n",
    "                    )\n",
    "                )[-1].cpu()\n",
    "                log_probs = log_8gaussian_density(aug_traj[:, 1:]) - aug_traj[:, 0]\n",
    "            else:\n",
    "                log_probs = log_8gaussian_density(gridpoints)\n",
    "        log_probs = log_probs.reshape(Y.shape)\n",
    "        ax = axis[0]\n",
    "        ax.pcolormesh(X, Y, torch.exp(log_probs), vmax=1)\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        ax.set_xlim(-w, w)\n",
    "        ax.set_ylim(-w, w)\n",
    "        ax.set_title(f\"{name}\", fontsize=30)\n",
    "        # Quiver plot\n",
    "        # with torch.no_grad():\n",
    "        out = model(\n",
    "            torch.cat(\n",
    "                [gridpoints_small, torch.ones((gridpoints_small.shape[0], 1)) * t], dim=1\n",
    "            ).to(device)\n",
    "        )\n",
    "        out = out.reshape([points_real_small, points_real_small, 2]).cpu().detach().numpy()\n",
    "        ax = axis[1]\n",
    "        ax.quiver(\n",
    "            X_small,\n",
    "            Y_small,\n",
    "            out[:, :, 0],\n",
    "            out[:, :, 1],\n",
    "            np.sqrt(np.sum(out**2, axis=-1)),\n",
    "            cmap=\"coolwarm\",\n",
    "            scale=50.0,\n",
    "            width=0.015,\n",
    "            pivot=\"mid\",\n",
    "        )\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        ax.set_xlim(-w, w)\n",
    "\n",
    "        ax = axis[2]\n",
    "        sample_traj = trajs[name]\n",
    "        ax.scatter(sample_traj[0, :, 0], sample_traj[0, :, 1], s=10, alpha=0.8, c=\"black\")\n",
    "        ax.scatter(sample_traj[:i, :, 0], sample_traj[:i, :, 1], s=0.2, alpha=0.2, c=\"olive\")\n",
    "        ax.scatter(sample_traj[i, :, 0], sample_traj[i, :, 1], s=4, alpha=1, c=\"blue\")\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        ax.set_xlim(-w, w)\n",
    "        ax.set_ylim(-w, w)\n",
    "    plt.suptitle(f\"8gaussians to Moons T={t:0.2f}\", fontsize=40)\n",
    "    os.makedirs(\"figures/trajectory/v3/\", exist_ok=True)\n",
    "    plt.savefig(f\"figures/trajectory/v3/{t:0.2f}.png\", dpi=40)\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gif_name = \"8gaussians-to-moons\"\n",
    "with imageio.get_writer(f\"{gif_name}.gif\", mode=\"I\") as writer:\n",
    "    for filename in [f\"figures/trajectory/v3/{t:0.2f}.png\" for t in ts] + [\n",
    "        f\"figures/trajectory/v3/{ts[-1].item():0.2f}.png\"\n",
    "    ] * 10:\n",
    "        image = imageio.imread(filename)\n",
    "        writer.append_data(image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = {\n",
    "    \"CFM\": torch.load(\"./models/8gaussian-moons/cfm_v1.pt\"),\n",
    "    \"OT-CFM (ours)\": torch.load(\"models/8gaussian-moons/otcfm_v1.pt\"),\n",
    "    \"SB-CFM (ours)\": torch.load(\"models/8gaussian-moons/sbcfm_v1.pt\"),\n",
    "    \"VP-CFM\": torch.load(\"models/8gaussian-moons/stochastic_interpolant_v1.pt\"),\n",
    "    # \"FM\": torch.load(\"models/8gaussian-moons/flow_matching_v1.pt\"),\n",
    "    # \"VP-SDE\": torch.load(\"models/8gaussian-moons/vp_flow_v1.pt\"),\n",
    "    \"Action-Matching\": torch.load(\"models/8gaussian-moons/action_matching_v1.pt\"),\n",
    "    \"Action-Matching (Swish)\": torch.load(\"models/8gaussian-moons/action_matching_swish_v1.pt\"),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = 7\n",
    "points = 100j\n",
    "points_real = 100\n",
    "device = \"cpu\"\n",
    "Y, X = np.mgrid[-w:w:points, -w:w:points]\n",
    "gridpoints = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1)).type(torch.float32)\n",
    "points_small = 20j\n",
    "points_real_small = 20\n",
    "Y_small, X_small = np.mgrid[-w:w:points_small, -w:w:points_small]\n",
    "gridpoints_small = torch.tensor(np.stack([X_small.flatten(), Y_small.flatten()], axis=1)).type(\n",
    "    torch.float32\n",
    ")\n",
    "\n",
    "torch.manual_seed(42)\n",
    "sample = sample_normal(1024)\n",
    "ts = torch.linspace(0, 1, 101)\n",
    "trajs = {}\n",
    "for name, model in models.items():\n",
    "    nde = NeuralODE(DEFunc(torch_wrapper(model)), solver=\"euler\").to(device)\n",
    "    # with torch.no_grad():\n",
    "    traj = nde.trajectory(sample.to(device), t_span=ts.to(device)).detach().cpu().numpy()\n",
    "    trajs[name] = traj\n",
    "names = [\n",
    "    # \"VP-SDE\",\n",
    "    # \"FM\",\n",
    "    \"CFM\",\n",
    "    \"Action-Matching\",\n",
    "    \"Action-Matching (Swish)\",\n",
    "    \"VP-CFM\",\n",
    "    \"SB-CFM (ours)\",\n",
    "    \"OT-CFM (ours)\",\n",
    "]\n",
    "for i, t in enumerate(ts):\n",
    "    fig, axes = plt.subplots(3, len(names), figsize=(len(names) * 6, len(names) * 3))\n",
    "\n",
    "    for axis, name in zip(axes.T, names):\n",
    "        model = models[name]\n",
    "        cnf = DEFunc(CNF(model))\n",
    "        nde = NeuralODE(cnf, solver=\"euler\")\n",
    "        cnf_model = torch.nn.Sequential(Augmenter(augment_idx=1, augment_dims=1), nde)\n",
    "        with torch.no_grad():\n",
    "            if t > 0:\n",
    "                aug_traj = (\n",
    "                    cnf_model[1]\n",
    "                    .to(device)\n",
    "                    .trajectory(\n",
    "                        Augmenter(1, 1)(gridpoints).to(device),\n",
    "                        t_span=torch.linspace(t, 0, 201).to(device),\n",
    "                    )\n",
    "                )[-1].cpu()\n",
    "                log_probs = log_normal_density(aug_traj[:, 1:]) - aug_traj[:, 0]\n",
    "            else:\n",
    "                log_probs = log_normal_density(gridpoints)\n",
    "        log_probs = log_probs.reshape(Y.shape)\n",
    "        ax = axis[0]\n",
    "        ax.pcolormesh(X, Y, torch.exp(log_probs))\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        ax.set_xlim(-w, w)\n",
    "        ax.set_ylim(-w, w)\n",
    "        ax.set_title(f\"{name}\", fontsize=30)\n",
    "        # Quiver plot\n",
    "        # with torch.no_grad():\n",
    "        out = model(\n",
    "            torch.cat(\n",
    "                [gridpoints_small, torch.ones((gridpoints_small.shape[0], 1)) * t], dim=1\n",
    "            ).to(device)\n",
    "        )\n",
    "        out = out.reshape([points_real_small, points_real_small, 2]).cpu().detach().numpy()\n",
    "        ax = axis[1]\n",
    "        ax.quiver(\n",
    "            X_small,\n",
    "            Y_small,\n",
    "            out[:, :, 0],\n",
    "            out[:, :, 1],\n",
    "            np.sqrt(np.sum(out**2, axis=-1)),\n",
    "            cmap=\"coolwarm\",\n",
    "            scale=50.0,\n",
    "            width=0.015,\n",
    "            pivot=\"mid\",\n",
    "        )\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        ax.set_xlim(-w, w)\n",
    "\n",
    "        ax = axis[2]\n",
    "        sample_traj = trajs[name]\n",
    "        ax.scatter(sample_traj[0, :, 0], sample_traj[0, :, 1], s=10, alpha=0.8, c=\"black\")\n",
    "        ax.scatter(sample_traj[:i, :, 0], sample_traj[:i, :, 1], s=0.2, alpha=0.2, c=\"olive\")\n",
    "        ax.scatter(sample_traj[i, :, 0], sample_traj[i, :, 1], s=4, alpha=1, c=\"blue\")\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        ax.set_xlim(-w, w)\n",
    "        ax.set_ylim(-w, w)\n",
    "    plt.suptitle(f\"Gaussian to Moons T={t:0.2f}\", fontsize=40)\n",
    "    os.makedirs(\"figures/trajectory2/v3/\", exist_ok=True)\n",
    "    plt.savefig(f\"figures/trajectory2/v3/{t:0.2f}.png\", dpi=40)\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gif_name = \"gaussians-to-moons\"\n",
    "ts = torch.linspace(0, 1, 101)\n",
    "with imageio.get_writer(f\"{gif_name}.gif\", mode=\"I\") as writer:\n",
    "    for filename in [f\"figures/trajectory2/v3/{t:0.2f}.png\" for t in ts] + [\n",
    "        f\"figures/trajectory2/v3/{ts[-1].item():0.2f}.png\"\n",
    "    ] * 10:\n",
    "        image = imageio.imread(filename)\n",
    "        writer.append_data(image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torchcfm2",
   "language": "python",
   "name": "torchcfm2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}