{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68e0e186-3c69-47e4-9f6a-86e4c0eb72d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "import math\n",
    "\n",
    "from unet import UNetModel\n",
    "from diffusion import GaussianDiffusion\n",
    "\n",
    "import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "mpl.rc('image', cmap='gray')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea4cdeea-d9eb-4378-bf75-2d64c11c74e3",
   "metadata": {},
   "source": [
    "## Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "085ef277-6ee3-436e-93f6-27c6d7522108",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load MNIST dataset\n",
    "device = torch.device('cuda:0')\n",
    "batch_size = 128\n",
    "\n",
    "transforms = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    torchvision.transforms.Pad(2),\n",
    "    torchvision.transforms.Normalize(0.5, 0.5),\n",
    "])\n",
    "mnist_train = torchvision.datasets.MNIST(root='data/', train=True, transform=transforms, download=True)\n",
    "data_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=4)\n",
    "\n",
    "for batch in data_loader:\n",
    "    img, labels = batch\n",
    "    break\n",
    "    \n",
    "fig, ax = plt.subplots(1, 4, figsize=(15,15))\n",
    "for i in range(4):\n",
    "    ax[i].imshow(img[i,0,:,:].numpy())\n",
    "    ax[i].set_title(str(labels[i].item()))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc84dffb-4990-4d9b-91e0-7dbb4f1ab6b5",
   "metadata": {},
   "source": [
    "## Load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df759cd7-3720-42ab-b465-72339cb929ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save/Load model\n",
    "#torch.save(net.state_dict(), 'models/mnist_unet.pth')\n",
    "#print('Saved model')\n",
    "\n",
    "net = UNetModel(image_size=32, in_channels=1, out_channels=1, \n",
    "                model_channels=64, num_res_blocks=2, channel_mult=(1,2,3,4),\n",
    "                attention_resolutions=[8,4], num_heads=4).to(device)\n",
    "net.load_state_dict(torch.load('models/mnist_unet.pth'))\n",
    "net.to(device)\n",
    "net.train()\n",
    "print('Loaded model')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fc6a4fe-97c8-4109-9d6a-e082a534c52d",
   "metadata": {},
   "source": [
    "## Inference with hand-crafted conditions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3a7fdbe-5113-4e56-995b-ff758664becf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize t annealing schedule\n",
    "diffusion = GaussianDiffusion(T=1000, schedule='linear')\n",
    "steps = 1000\n",
    "\n",
    "t_vals = []\n",
    "for i in range(steps):\n",
    "    t = ((steps-i) + (steps-i)//4*math.cos(i/50))/steps*diffusion.T # Linearly decreasing + cosine\n",
    "    t = np.array([t]).astype(int)\n",
    "    t = np.clip(t, 1, diffusion.T)\n",
    "    t_vals.append(t[0])\n",
    "    \n",
    "plt.figure(figsize=(8,5))\n",
    "plt.plot(range(steps), t_vals, linewidth=2)\n",
    "plt.title('$t$ Annealing Schedule')\n",
    "plt.xlabel('Steps')\n",
    "plt.ylabel('$t$')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a9f7f91-515b-46a3-9555-12531054563d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inference Model\n",
    "class Model(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Model, self).__init__()\n",
    "        self.img = nn.Parameter(torch.randn(1,1,32,32))\n",
    "        self.img.requires_grad = True\n",
    "                \n",
    "    def encode(self):\n",
    "        return self.img\n",
    "    \n",
    "model = Model().to(device)\n",
    "opt = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "\n",
    "diffusion = GaussianDiffusion(T=1000, schedule='linear')\n",
    "net.train()\n",
    "\n",
    "steps = 1000\n",
    "bar = tqdm.tqdm(range(steps))\n",
    "losses = []\n",
    "update_every = 50\n",
    "for i, _ in enumerate(bar):\n",
    "    sample_img = model.encode()\n",
    "   \n",
    "    # Select t\n",
    "    t = ((steps-i) + (steps-i)//4*math.cos(i/50))/steps*diffusion.T # Linearly decreasing + cosine\n",
    "    t = np.array([t]).astype(int)\n",
    "    t = np.clip(t, 1, diffusion.T)\n",
    "    \n",
    "    # Denoise\n",
    "    xt, epsilon = diffusion.sample(sample_img, t)       \n",
    "    t = torch.from_numpy(t).float().view(sample_img.shape[0])\n",
    "    epsilon_pred = net(xt.float(), t.to(device))\n",
    "\n",
    "    # Hand-crafted conditions\n",
    "    sample_img_clipped = torch.clip(sample_img, -1, 1)\n",
    "    #vertical_similarity = F.mse_loss(sample_img_clipped, torchvision.transforms.functional.hflip(model.encode()))\n",
    "    #horizontal_similarity = F.mse_loss(sample_img_clipped, torchvision.transforms.functional.vflip(model.encode()))\n",
    "    #vertical_dissimilarity = -F.mse_loss(sample_img_clipped, torchvision.transforms.functional.hflip(sample_img_clipped))\n",
    "    horizontal_dissimilarity = -F.mse_loss(sample_img_clipped, torchvision.transforms.functional.vflip(sample_img_clipped))\n",
    "\n",
    "    # Denoising loss + aux loss\n",
    "    loss = F.mse_loss(epsilon_pred, epsilon) + 0.01*(steps-i)/steps*horizontal_dissimilarity\n",
    "    \n",
    "    # Update\n",
    "    opt.zero_grad()\n",
    "    loss.backward()\n",
    "    opt.step()\n",
    "    \n",
    "    losses.append(loss.item())\n",
    "    if i % update_every == 0:\n",
    "        bar.set_postfix({'Loss': np.mean(losses)})\n",
    "        losses = []\n",
    "\n",
    "    # Visualize sample\n",
    "    if (i+1) % 100 == 0 or i == 0:\n",
    "        with torch.no_grad():\n",
    "            fig, ax = plt.subplots(1, 1, figsize=(5,5))\n",
    "            ax.imshow(model.encode()[0].detach().cpu().numpy().transpose([1,2,0]), vmin=-1, vmax=1)\n",
    "            ax.set_title('Inferred sample')\n",
    "            plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce03e0f9-5610-4554-9b36-f96c2663b911",
   "metadata": {},
   "source": [
    "## Inference with a learned condition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "421b3f44-9110-4711-9ddd-7605dfa4396e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train classifier\n",
    "class Classifier(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Classifier, self).__init__()\n",
    "        \n",
    "        self.conv1 = nn.Conv2d(1, 32, 3, 2, 1)\n",
    "        self.conv2 = nn.Conv2d(32, 32, 3, 2, 1)\n",
    "        self.conv3 = nn.Conv2d(32, 32, 3, 2, 1)\n",
    "        self.out = nn.Linear(4*4*32, 1)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.conv1(x))\n",
    "        x = F.relu(self.conv2(x))\n",
    "        x = F.relu(self.conv3(x))\n",
    "        x = self.out(x.flatten(1))\n",
    "        \n",
    "        return torch.sigmoid(x)\n",
    "    \n",
    "# Train network\n",
    "class_net = Classifier().to(device)\n",
    "class_opt = torch.optim.Adam(class_net.parameters(), lr=1e-4)\n",
    "\n",
    "target_label = 3 # Digit to distinguish\n",
    "epochs = 5\n",
    "update_every = 100\n",
    "for e in range(epochs):\n",
    "    print(f'Epoch [{e+1}/{epochs}]')\n",
    "    \n",
    "    losses = []\n",
    "    batch_bar = tqdm.tqdm(data_loader)\n",
    "    for i, batch in enumerate(batch_bar):\n",
    "        img, labels = batch\n",
    "        \n",
    "        labels = (labels != target_label).float().to(device)\n",
    "        \n",
    "        # Pass through network\n",
    "        out = class_net(img.float().to(device))\n",
    "        \n",
    "        # Compute loss and backprop\n",
    "        loss = F.binary_cross_entropy(out.squeeze(-1), labels)\n",
    "        \n",
    "        class_opt.zero_grad()\n",
    "        loss.backward()\n",
    "        class_opt.step()\n",
    "        \n",
    "        losses.append(loss.item())\n",
    "        if i % update_every == 0:\n",
    "            batch_bar.set_postfix({'Loss': np.mean(losses)})\n",
    "            losses = []\n",
    "            \n",
    "    batch_bar.set_postfix({'Loss': np.mean(losses)})\n",
    "    losses = []\n",
    "    \n",
    "    plt.figure(figsize=(5,5))\n",
    "    plt.imshow(img.numpy()[0,0,:,:])\n",
    "    plt.title(f'Score {out[0].item():.3f}')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a89ade1f-c4ff-44d4-a7ad-5ed07edcb728",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inference Model\n",
    "class Model(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Model, self).__init__()\n",
    "        self.img = nn.Parameter(torch.randn(1,1,32,32))\n",
    "        self.img.requires_grad = True\n",
    "                \n",
    "    def encode(self):\n",
    "        return self.img\n",
    "    \n",
    "model = Model().to(device)\n",
    "opt = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "\n",
    "diffusion = GaussianDiffusion(T=1000, schedule='linear')\n",
    "net.train()\n",
    "class_net.train()\n",
    "\n",
    "steps = 1000\n",
    "bar = tqdm.tqdm(range(steps))\n",
    "losses = []\n",
    "update_every = 50\n",
    "for i, _ in enumerate(bar):\n",
    "    sample_img = model.encode()\n",
    "   \n",
    "    # Select t\n",
    "    t = ((steps-i) + (steps-i)//4*math.cos(i/50))/steps*diffusion.T # Linearly decreasing + cosine\n",
    "    t = np.array([t]).astype(int)\n",
    "    t = np.clip(t, 1, diffusion.T)\n",
    "    \n",
    "    # Denoise\n",
    "    xt, epsilon = diffusion.sample(sample_img, t)       \n",
    "    t = torch.from_numpy(t).float().view(sample_img.shape[0])\n",
    "    epsilon_pred = net(xt.float(), t.to(device))\n",
    "\n",
    "    # Learned condition\n",
    "    sample_img_clipped = torch.clip(sample_img, -1, 1)\n",
    "    class_loss = class_net(sample_img_clipped).mean()\n",
    "\n",
    "    # Denoising loss + aux loss\n",
    "    loss = F.mse_loss(epsilon_pred, epsilon) + 0.01*(steps-i)/steps*class_loss\n",
    "    \n",
    "    # Update\n",
    "    opt.zero_grad()\n",
    "    loss.backward()\n",
    "    opt.step()\n",
    "    \n",
    "    losses.append(loss.item())\n",
    "    if i % update_every == 0:\n",
    "        bar.set_postfix({'Loss': np.mean(losses)})\n",
    "        losses = []\n",
    "\n",
    "    # Visualize sample\n",
    "    if (i+1) % 100 == 0 or i == 0:\n",
    "        with torch.no_grad():\n",
    "            fig, ax = plt.subplots(1, 1, figsize=(5,5))\n",
    "            ax.imshow(model.encode()[0].detach().cpu().numpy().transpose([1,2,0]), vmin=-1, vmax=1)\n",
    "            ax.set_title('Inferred sample')\n",
    "            plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e53d0ea8-5351-46b0-8972-0fb202b2b128",
   "metadata": {},
   "source": [
    "## Inference with hand-crafted and learned conditions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d95be9a3-fbaf-46ae-9266-1ceb953ca86d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inference Model\n",
    "class Model(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Model, self).__init__()\n",
    "        self.img = nn.Parameter(torch.randn(1,1,32,32))\n",
    "        self.img.requires_grad = True\n",
    "                \n",
    "    def encode(self):\n",
    "        return self.img\n",
    "    \n",
    "model = Model().to(device)\n",
    "opt = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "\n",
    "diffusion = GaussianDiffusion(T=1000, schedule='linear')\n",
    "net.train()\n",
    "class_net.train()\n",
    "\n",
    "steps = 1000\n",
    "bar = tqdm.tqdm(range(steps))\n",
    "losses = []\n",
    "update_every = 50\n",
    "for i, _ in enumerate(bar):\n",
    "    sample_img = model.encode()\n",
    "   \n",
    "    # Select t\n",
    "    t = ((steps-i) + (steps-i)//4*math.cos(i/50))/steps*diffusion.T # Linearly decreasing + cosine\n",
    "    t = np.array([t]).astype(int)\n",
    "    t = np.clip(t, 1, diffusion.T)\n",
    "    \n",
    "    # Denoise\n",
    "    xt, epsilon = diffusion.sample(sample_img, t)       \n",
    "    t = torch.from_numpy(t).float().view(sample_img.shape[0])\n",
    "    epsilon_pred = net(xt.float(), t.to(device))\n",
    "\n",
    "    # Conditions\n",
    "    sample_img_clipped = torch.clip(sample_img, -1, 1)\n",
    "    # Hand-crafted\n",
    "    #vertical_similarity = F.mse_loss(sample_img_clipped, torchvision.transforms.functional.hflip(model.encode()))\n",
    "    horizontal_similarity = F.mse_loss(sample_img_clipped, torchvision.transforms.functional.vflip(model.encode()))\n",
    "    #vertical_dissimilarity = -F.mse_loss(sample_img_clipped, torchvision.transforms.functional.hflip(sample_img_clipped))\n",
    "    #horizontal_dissimilarity = -F.mse_loss(sample_img_clipped, torchvision.transforms.functional.vflip(sample_img_clipped))\n",
    "    # Learned\n",
    "    class_loss = class_net(sample_img_clipped).mean()\n",
    "\n",
    "    # Denoising loss + aux loss\n",
    "    loss = F.mse_loss(epsilon_pred, epsilon) + 0.01*(steps-i)/steps*class_loss + 0.01*(steps-i)/steps*horizontal_similarity\n",
    "    \n",
    "    # Update\n",
    "    opt.zero_grad()\n",
    "    loss.backward()\n",
    "    opt.step()\n",
    "    \n",
    "    losses.append(loss.item())\n",
    "    if i % update_every == 0:\n",
    "        bar.set_postfix({'Loss': np.mean(losses)})\n",
    "        losses = []\n",
    "\n",
    "    # Visualize sample\n",
    "    if (i+1) % 100 == 0 or i == 0:\n",
    "        with torch.no_grad():\n",
    "            fig, ax = plt.subplots(1, 1, figsize=(5,5))\n",
    "            ax.imshow(model.encode()[0].detach().cpu().numpy().transpose([1,2,0]), vmin=-1, vmax=1)\n",
    "            ax.set_title('Inferred sample')\n",
    "            plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fedd502-8134-4833-a50b-7979903ad872",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch_geo_simple",
   "language": "python",
   "name": "conda-env-torch_geo_simple-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
