{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "id": "d4ecb6de",
      "metadata": {
        "id": "d4ecb6de"
      },
      "outputs": [],
      "source": [
        "import random\n",
        "import math\n",
        "import numpy as np\n",
        "import torch\n",
        "from sklearn import datasets as sklearn_datasets\n",
        "import torch.nn.functional as F\n",
        "import matplotlib.pyplot as plt\n",
        "from IPython.display import clear_output\n",
        "import torch.nn as nn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "id": "3b397641-d042-4b41-9942-05c29942ad73",
      "metadata": {
        "id": "3b397641-d042-4b41-9942-05c29942ad73"
      },
      "outputs": [],
      "source": [
        "class Sampler:\n",
        "    def __init__(self, device='cpu'):\n",
        "        self.device = device\n",
        "\n",
        "    def sample(self, size=5):\n",
        "        raise NotImplementedError(\"Subclasses must implement sample()\")\n",
        "\n",
        "class StandardNormalSampler(Sampler):\n",
        "    def __init__(self, dim=1, device='cpu'):\n",
        "        super(StandardNormalSampler, self).__init__(device=device)\n",
        "        self.dim = dim\n",
        "\n",
        "    def sample(self, batch_size=10):\n",
        "        return torch.randn(batch_size, self.dim, device=self.device)\n",
        "\n",
        "\n",
        "class SwissRollSampler(Sampler):\n",
        "    def __init__(self, dim=2, device='cpu', noise=0.8, scale=7.5):\n",
        "        super(SwissRollSampler, self).__init__(device=device)\n",
        "        assert dim == 2\n",
        "        self.dim = 2\n",
        "        self.noise = noise\n",
        "        self.scale = scale\n",
        "\n",
        "    def sample(self, batch_size=10):\n",
        "        batch = sklearn_datasets.make_swiss_roll(\n",
        "            n_samples=batch_size,\n",
        "            noise=self.noise\n",
        "        )[0].astype('float32')[:, [0, 2]] / self.scale\n",
        "        return torch.tensor(batch, device=self.device)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "id": "dd7cdfda",
      "metadata": {
        "id": "dd7cdfda"
      },
      "outputs": [],
      "source": [
        "class EOTConfig:\n",
        "    def __init__(self,\n",
        "                 eps: float = 0.1,\n",
        "                 batch_size: int = 2048,\n",
        "                 device: str =\"cpu\",\n",
        "                 K: int = 32,\n",
        "                 epoch: int = 100,\n",
        "                 lmc_steps: int = 100,\n",
        "                 lmc_step_size: float = 0.003,\n",
        "                 seed: int = 42,\n",
        "                 grad_clip = 100000.0,\n",
        "                 max_diff_exp_clip=100,\n",
        "                 ema_momentum = 0.999\n",
        "        ):\n",
        "        self.device = device\n",
        "        self.eps = eps\n",
        "        self.batch_size = batch_size\n",
        "        self.K = K\n",
        "        self.epoch = epoch\n",
        "        self.score_clip = 10000000.0\n",
        "        self.grad_clip = grad_clip\n",
        "        self.max_diff_exp_clip = max_diff_exp_clip\n",
        "        self.lmc_steps = lmc_steps\n",
        "        self.lmc_step_size = lmc_step_size\n",
        "        self.seed = seed\n",
        "        self.ema_momentum = ema_momentum"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "id": "2e862d2d",
      "metadata": {
        "id": "2e862d2d"
      },
      "outputs": [],
      "source": [
        "class MLP(nn.Module):\n",
        "    def __init__(self, din=2, hidden=128, dout=1):\n",
        "        super().__init__()\n",
        "        self.net = nn.Sequential(\n",
        "            nn.Linear(din, hidden),\n",
        "            nn.SiLU(),\n",
        "            nn.Linear(hidden, hidden),\n",
        "            nn.SiLU(),\n",
        "            nn.Linear(hidden, hidden),\n",
        "            nn.SiLU(),\n",
        "            nn.Linear(hidden, dout),\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.net(x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "id": "0217d2d4",
      "metadata": {
        "id": "0217d2d4"
      },
      "outputs": [],
      "source": [
        "class EOTTrainer:\n",
        "\n",
        "    def __init__(self, config, source_sampler, target_sampler, model_theta, model_phi, name):\n",
        "        self.experiment_name = name\n",
        "\n",
        "        self.config = config\n",
        "        self.source_sampler = source_sampler\n",
        "        self.target_sampler = target_sampler\n",
        "\n",
        "        self.f_theta = model_theta\n",
        "        self.f_phi = model_phi\n",
        "        import copy\n",
        "        self.f_theta_ema = copy.deepcopy(self.f_theta).eval()\n",
        "        for p in self.f_theta_ema.parameters(): p.requires_grad_(False)\n",
        "        self.current_step = 0\n",
        "\n",
        "    def ema_update(self, model, ema):\n",
        "        m = self.config.ema_momentum\n",
        "        with torch.no_grad():\n",
        "            for p, pe in zip(model.parameters(), ema.parameters()):\n",
        "                pe.mul_(m).add_(p, alpha=1-m)\n",
        "\n",
        "\n",
        "    def compute_loss(self, x, y):\n",
        "        cfg = self.config\n",
        "        broad_shape = list(x.shape)\n",
        "        broad_shape.insert(1, cfg.K)\n",
        "        z = torch.randn(size=broad_shape, device=cfg.device)\n",
        "\n",
        "        x_noisy = x[:, None, :] - math.sqrt(cfg.eps(self.current_step)) * z\n",
        "\n",
        "        fphi_x = self.f_phi(x)\n",
        "        ftheta_xnoisy = self.f_theta(x_noisy.reshape(-1, *x_noisy.shape[2:])).view(cfg.batch_size, cfg.K) / cfg.eps(self.current_step)\n",
        "        ftheta_y = self.f_theta(y)\n",
        "\n",
        "        diff_in_exp = ftheta_xnoisy - fphi_x\n",
        "        diff_in_exp_truncated = torch.clamp(diff_in_exp, min=None, max=self.config.max_diff_exp_clip)\n",
        "        exp_term = torch.exp(diff_in_exp_truncated.to(torch.float64))\n",
        "\n",
        "        ftheta_y_mean = ftheta_y.mean()\n",
        "        fphi_x_mean = fphi_x.mean()\n",
        "        Loss_theor = (fphi_x_mean + exp_term.mean()) * cfg.eps(self.current_step) - ftheta_y_mean\n",
        "        L_main = Loss_theor\n",
        "        return L_main\n",
        "\n",
        "    def train_step(self):\n",
        "        x = self.source_sampler.sample(self.config.batch_size).to(self.config.device)\n",
        "        y = self.target_sampler.sample(self.config.batch_size).to(self.config.device)\n",
        "        self.train_theta = True\n",
        "        self.train_phi = True\n",
        "        loss = self.compute_loss(x, y)\n",
        "        self.opt_both.zero_grad(set_to_none=True)\n",
        "        loss.backward()\n",
        "        torch.nn.utils.clip_grad_norm_(self.f_theta.parameters(), max_norm=self.config.grad_clip)\n",
        "        torch.nn.utils.clip_grad_norm_(self.f_phi.parameters(), max_norm=self.config.grad_clip)\n",
        "        self.opt_both.step()\n",
        "        self.ema_update(self.f_theta, self.f_theta_ema)\n",
        "        self.current_step += 1\n",
        "\n",
        "    def train(self, viz_callback=None):\n",
        "        print(f\"Starting training name = {self.experiment_name}\")\n",
        "        print(\"-\" * 60)\n",
        "\n",
        "        while True:\n",
        "            L_main = self.train_step()\n",
        "            viz_callback(self)\n",
        "            if (self.current_step >= self.config.epoch):\n",
        "                break\n",
        "\n",
        "        print(\"Training complete!\")\n",
        "        return 0\n",
        "\n",
        "\n",
        "    def score_y_given_x(self, y, x):\n",
        "        y = y.detach().requires_grad_(True)\n",
        "\n",
        "        ft = self.f_theta_ema(y).sum()\n",
        "        (gy,) = torch.autograd.grad(ft, y, retain_graph=False, create_graph=False)\n",
        "        sc = (gy - (y - x)) / self.config.eps(self.current_step)\n",
        "\n",
        "        return torch.clamp(sc, -self.config.score_clip, self.config.score_clip)\n",
        "\n",
        "    def sample_pi_given_x(self, x):\n",
        "        cfg = self.config\n",
        "        y = x + 0.0 * math.sqrt(cfg.eps(self.current_step)) * torch.randn(x.shape[0], x.shape[1], device=cfg.device)\n",
        "\n",
        "        for _ in range(cfg.lmc_steps):\n",
        "            y = y.detach()\n",
        "            sc = self.score_y_given_x(y, x)\n",
        "\n",
        "            with torch.no_grad():\n",
        "                y = y + cfg.lmc_step_size * sc + \\\n",
        "                    math.sqrt(2 * cfg.lmc_step_size) * torch.randn_like(y)\n",
        "\n",
        "        return y.detach()\n",
        "\n",
        "    def sample_marginal(self, n_x=500):\n",
        "        cfg = self.config\n",
        "        xs = self.source_sampler.sample(n_x).to(cfg.device)\n",
        "        y = xs\n",
        "\n",
        "        for _ in range(cfg.lmc_steps):\n",
        "            y = y.detach()\n",
        "            sc = self.score_y_given_x(y, xs)\n",
        "\n",
        "            with torch.no_grad():\n",
        "                y = y + cfg.lmc_step_size * sc + \\\n",
        "                    math.sqrt(2 * cfg.lmc_step_size) * torch.randn_like(y)\n",
        "\n",
        "        return y.detach()\n",
        "\n",
        "\n",
        "    def save_checkpoint(self, filename):\n",
        "        checkpoint = {\n",
        "            'f_theta_state_dict': self.f_theta_ema.state_dict(),\n",
        "            'f_phi_state_dict': self.f_phi.state_dict(),\n",
        "            'epoch': self.current_step,\n",
        "            'eps': self.config.eps(self.current_step),\n",
        "        }\n",
        "        torch.save(checkpoint, filename)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "fff95e37",
      "metadata": {},
      "source": [
        "# Training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "id": "1f1144bd",
      "metadata": {
        "id": "1f1144bd"
      },
      "outputs": [],
      "source": [
        "config = EOTConfig(\n",
        "     eps = lambda step: 1.0,\n",
        "     batch_size = 256,\n",
        "     device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\"),\n",
        "     K = 256,\n",
        "     epoch = 5000,\n",
        "     lmc_steps = 1000,\n",
        "     lmc_step_size = 0.001,\n",
        "     seed = 42,\n",
        "     ema_momentum=0.999,\n",
        "     max_diff_exp_clip=20,\n",
        "     grad_clip=1e20\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "id": "2a632a03",
      "metadata": {
        "id": "2a632a03"
      },
      "outputs": [],
      "source": [
        "def visualize_training(trainer):\n",
        "    if (trainer.current_step % 1000 == 0):\n",
        "        print(trainer.current_step)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "id": "2e497caf",
      "metadata": {
        "id": "2e497caf"
      },
      "outputs": [],
      "source": [
        "dimension = 2\n",
        "trainer = EOTTrainer(\n",
        "    config=config,\n",
        "    source_sampler=StandardNormalSampler(dim=dimension),\n",
        "    target_sampler=SwissRollSampler(),\n",
        "    model_theta=MLP(din=dimension, hidden=256).to(config.device),\n",
        "    model_phi=MLP(din=dimension, hidden=256).to(config.device),\n",
        "    name = f\"swissroll_test\"\n",
        ")\n",
        "\n",
        "trainer.opt_both = torch.optim.AdamW(\n",
        "            list(trainer.f_phi.parameters()) + list(trainer.f_theta.parameters()),\n",
        "            lr=3e-4,\n",
        "            weight_decay=1e-4,\n",
        "            betas = (0.7, 0.8)\n",
        "        )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "id": "2aae6ec0",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2aae6ec0",
        "outputId": "d1adfb2c-c45f-45a2-aae6-1620e73af7fd"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Starting training name = swissroll_test\n",
            "------------------------------------------------------------\n",
            "Training complete!\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "0"
            ]
          },
          "execution_count": 15,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "trainer.train(viz_callback=lambda t: visualize_training(t))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "5e061d2b",
      "metadata": {},
      "source": [
        "# Plotting"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8a8bb1a2-5b61-468b-b02f-0868565ae85a",
      "metadata": {
        "id": "8a8bb1a2-5b61-468b-b02f-0868565ae85a"
      },
      "outputs": [],
      "source": [
        "data = trainer.sample_marginal(n_x=5000).cpu()\n",
        "plt.scatter(data[:, 0], data[:, 1], alpha=0.1)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
