{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "62d52a00",
   "metadata": {},
   "source": [
    "# Generating MNIST 32x32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48253448",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    " \n",
    "import numpy as np\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import wandb\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import clear_output\n",
    "%matplotlib inline\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from src.utils import Config, optimization_manager, get_scale\n",
    "from src.data import GetDigitMNIST\n",
    "from src.ode import ODESolver \n",
    "from src.utils import compute_field, random_color\n",
    "from src.models import DDPM, ExponentialMovingAverage"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0164e1c",
   "metadata": {},
   "source": [
    "## 1. Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cea56e52",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = Config()\n",
    "\n",
    "config.SCALE = 1.\n",
    "config.L = 20.\n",
    "config.K =  math.pi/20.\n",
    "config.epsilon =0.05\n",
    "config.interpolation  = 'right_side'\n",
    "config.training_steps = 100_000\n",
    "\n",
    "config.M = 291\n",
    "config.sigma_end = 0.01\n",
    "config.tau = 0.03\n",
    "config.device = 'cuda'\n",
    "\n",
    "config.D = math.pi/(2*config.K)\n",
    "config.field_type = \"Shifted\"\n",
    "config.plan_type = \"Independent\"\n",
    "config.field_form = \"exponential\"\n",
    "\n",
    "config.data = Config()\n",
    "config.data.num_channels = 3\n",
    "config.data.centered = True\n",
    "config.data.image_size =32\n",
    "config.data.batch_size=128\n",
    "config.data.colored = True\n",
    "\n",
    "config.name = \"untoy\"\n",
    "config.DIM = config.data.image_size*config.data.image_size*config.data.num_channels + 1\n",
    " \n",
    "config.gamma = 0.\n",
    "\n",
    "config.p = Config()\n",
    "config.p.x_loc = 0.\n",
    "\n",
    "config.q = Config()\n",
    "config.q.x_loc = config.L\n",
    "\n",
    "config.model  = Config()\n",
    "config.model.name = 'ncsnpp'\n",
    "config.model.scale_by_sigma = False\n",
    "config.model.ema_rate = 0.9999\n",
    "config.model.normalization = 'GroupNorm'\n",
    "config.model.nonlinearity = 'swish'\n",
    "config.model.nf = 128\n",
    "config.model.ch_mult = (1, 2, 2, 2)\n",
    "config.model.num_res_blocks = 4\n",
    "config.model.attn_resolutions = (16,)\n",
    "config.model.resamp_with_conv = True\n",
    "config.model.conditional = True\n",
    "config.model.fir = False\n",
    "config.model.fir_kernel = [1, 3, 3, 1]\n",
    "config.model.skip_rescale = True\n",
    "config.model.resblock_type = 'biggan'\n",
    "config.model.progressive = 'none'\n",
    "config.model.progressive_input = 'none'\n",
    "config.model.progressive_combine = 'sum'\n",
    "config.model.attention_type = 'ddpm'\n",
    "config.model.init_scale = 0.\n",
    "config.model.fourier_scale = 16\n",
    "config.model.embedding_type = 'positional'\n",
    "config.model.conv_size = 3\n",
    "config.model.dropout = 0.1\n",
    " \n",
    "\n",
    "config.optim  = Config()\n",
    "config.optim.weight_decay = 0\n",
    "config.optim.optimizer = 'Adam'\n",
    "config.optim.lr = 2e-4\n",
    "config.optim.beta1 = 0.9\n",
    "config.optim.eps = 1e-8\n",
    "config.optim.warmup = 4000\n",
    "config.optim.grad_clip = 15.\n",
    "\n",
    "config.training = Config()\n",
    "config.training.sde = 'poisson'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d924f03",
   "metadata": {},
   "source": [
    "## 2. Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d67e1b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "TRANSFORM = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.Resize(config.data.image_size),\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    random_color,\n",
    "])\n",
    "\n",
    " \n",
    "train_data = torchvision.datasets.MNIST(root='..',\n",
    "                                         train=True, download=True, transform=TRANSFORM)\n",
    "eval_data = torchvision.datasets.MNIST(root='..',\n",
    "                                         train=False, download=True, transform=TRANSFORM)\n",
    " \n",
    "train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.data.batch_size,\n",
    "                                           shuffle=True)\n",
    "eval_loader =  torch.utils.data.DataLoader(eval_data, batch_size=config.data.batch_size, \n",
    "                                           shuffle=True)\n",
    "\n",
    "train_iter = iter(train_loader)\n",
    "eval_iter = iter(eval_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2975831",
   "metadata": {},
   "source": [
    "## 3. Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "611e1e97",
   "metadata": {},
   "outputs": [],
   "source": [
    "net = DDPM(config).to(config.device)\n",
    "params = net.parameters()\n",
    "optimizer = torch.optim.Adam(params,\n",
    "            lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,\n",
    "            weight_decay=config.optim.weight_decay)\n",
    "\n",
    "\"\"\"\n",
    "scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=config.optim.lr,\n",
    "                                    steps_per_epoch=math.ceil(len(train_data['2'])/config.data.batch_size),\n",
    "                                    epochs=config.epochs)\n",
    "\"\"\"\n",
    "\n",
    "ema = ExponentialMovingAverage(net.parameters(), decay=config.model.ema_rate)\n",
    "state = dict(optimizer=optimizer, model=net, ema=ema, step=0)\n",
    "optimize_fn = optimization_manager(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e40889b4",
   "metadata": {},
   "source": [
    "## 4. Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fa5e205",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "wandb.init(project=\"SFMGeneration\",\n",
    "name=f\"FinalLinearLR_clip_{config.optim.grad_clip}_{config.interpolation}_L={config.L }_s_{config.SCALE}_D_{config.D }_eps_{config.epsilon}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c0e89b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LearnedODESolver:\n",
    "\n",
    "    def __init__(self, net, config):\n",
    "        self.config = config\n",
    "        self.net = net\n",
    "\n",
    "    def __call__(self, x_init ):\n",
    "        trajectory = [x_init[:,1:].view(-1,self.config.data.num_channels,\n",
    "                                               self.config.data.image_size,\n",
    "                                               self.config.data.image_size).clone().detach().cpu()]\n",
    "        mask = torch.tensor(x_init.shape[0]*[True]).to(config.device)\n",
    "        \n",
    "        while mask.any():\n",
    " \n",
    "                \n",
    "            field_x, field_z = self.net(x_init[:,1:].view(-1,self.config.data.num_channels,\n",
    "                                                             self.config.data.image_size,\n",
    "                                                             self.config.data.image_size) , x_init[:,0]   )\n",
    "            \n",
    "            field = torch.cat([field_z.view(-1,1),\n",
    "                               field_x.view(-1, self.config.data.num_channels*\\\n",
    "                                                self.config.data.image_size*\\\n",
    "                                                self.config.data.image_size)], dim=1) # [B, 1+C*H*W]\n",
    " \n",
    "            x_init  = x_init  + (0.125/ ( field_z.view(-1,1)  + self.config.gamma ))*field # [B, C*H*W+1]\n",
    "            trajectory.append(x_init[:,1:].view(-1,self.config.data.num_channels,\n",
    "                                               self.config.data.image_size,\n",
    "                                               self.config.data.image_size).clone().detach().cpu())\n",
    "            t = x_init[:,0]\n",
    "            mask = t[0] < self.config.L \n",
    "            #print(t[0])\n",
    "            \n",
    "            \n",
    "            \n",
    "        return x_init, trajectory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe348c98",
   "metadata": {},
   "outputs": [],
   "source": [
    "losses = []\n",
    "for step in tqdm(range(100_000)):\n",
    "    \n",
    "    #################################\n",
    "    batch_x = torch.randn(config.data.batch_size,\n",
    "                          config.data.num_channels,\n",
    "                          config.data.image_size,\n",
    "                          config.data.image_size).to(config.device)\n",
    " \n",
    "    try:\n",
    "        batch_y,_ = next(train_iter)\n",
    "        batch_y = batch_y.to(config.device)\n",
    "    except StopIteration:\n",
    "        print('1')\n",
    "    else:\n",
    "        train_iter =  iter(train_loader)\n",
    "        batch_y,_ = next(train_iter)  \n",
    "        batch_y = batch_y.to(config.device)\n",
    "    #################################\n",
    "    \n",
    "    \n",
    "    \n",
    "    quarks = torch.cat([config.p.x_loc*torch.ones(batch_x.shape[0]).to(config.device)[:,None], \n",
    "            batch_x.view(-1,config.DIM-1)],dim=1)#[B,3*32*32+1]\n",
    "    anti_quarks = torch.cat([config.q.x_loc*torch.ones(batch_x.shape[0]).to(config.device)[:,None],\n",
    "            batch_y.view(-1,config.DIM-1)],dim=1)#[B,3*32*32+1]\n",
    "    \n",
    "    optimizer = state['optimizer']\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    if config.interpolation == 'right_side':\n",
    "        m = torch.rand((batch_x.shape[0],), device=batch_x.device) * config.M\n",
    "        z = torch.randn((len(batch_x), 1, 1, 1)).to( batch_x.device) * config.sigma_end\n",
    "        z = z.abs()\n",
    "        data_dim = config.data.num_channels * config.data.image_size * config.data.image_size\n",
    "        multiplier = (1+config.tau) ** m\n",
    "        perturbed_z = config.L - z.squeeze() * multiplier - config.epsilon\n",
    "        \n",
    "        mask_right = torch.nonzero(perturbed_z > config.q.x_loc - config.epsilon)\n",
    "        perturbed_z[mask_right.view(-1)] = torch.distributions.Uniform(low=config.p.x_loc + config.epsilon,\n",
    "                high=config.q.x_loc - config.epsilon).sample(torch.Size([len(mask_right)])).to(config.device)\n",
    "        \n",
    "        mask_left = torch.nonzero(perturbed_z < config.p.x_loc + config.epsilon)\n",
    "        perturbed_z[mask_left.view(-1)] = torch.distributions.Uniform(low=config.p.x_loc + config.epsilon,\n",
    "                high=config.q.x_loc - config.epsilon).sample(torch.Size([len(mask_left)])).to(config.device)\n",
    "        \n",
    "        \n",
    "        \n",
    "        perturbed_x = batch_y*(perturbed_z[:,None,None,None]/config.L) + (1 - perturbed_z[:,None,None,None]/config.L)*batch_x\n",
    "        perturbed_samples_vec = torch.cat([perturbed_z[:, None],\n",
    "                                           perturbed_x.reshape(len(batch_x), config.DIM-1)], dim=1)\n",
    "    \n",
    "    \n",
    "    ###### Ground-Truth field ######\n",
    "    field = compute_field(quarks , anti_quarks,\n",
    "                          perturbed_samples_vec,\n",
    "                          config)\n",
    "    field = math.sqrt(config.DIM)*field/( torch.norm(field, \n",
    "                                          dim=1, keepdim=True) + 1e-5)\n",
    "    \n",
    "     \n",
    "    ###### Ground-Truth field ######\n",
    "    \n",
    "    \n",
    "    \n",
    "    perturbed_samples_x = perturbed_samples_vec[:, 1:].view_as(batch_x)\n",
    "    perturbed_samples_z = perturbed_samples_vec[:, 0]\n",
    "    net_x, net_z = net(perturbed_samples_x, perturbed_samples_z)\n",
    "    pred = torch.cat([net_z[:, None], net_x.reshape(net_x.shape[0],config.DIM-1)], dim=1)\n",
    "    \n",
    "    \n",
    "    loss = torch.mean(((pred - field) ** 2))\n",
    "    loss.backward()\n",
    "    \n",
    "    optimize_fn( optimizer, net, step=state['step'], config=config)\n",
    "    \n",
    "    state['step'] += 1\n",
    "    state['ema'].update(net.parameters())\n",
    "    wandb.log({\"loss train\":loss.item()},step=step)\n",
    "    \n",
    "    losses.append(loss.item())\n",
    "    clear_output(wait=True)\n",
    "    plt.plot(losses)\n",
    "    plt.show()\n",
    "    \n",
    "    if step % 1_000 == 0:\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            ema = state['ema']\n",
    "            ema.store(net.parameters())\n",
    "            ema.copy_to(net.parameters())\n",
    "            \n",
    "            torch.save(net.cpu().state_dict(),f\"../ckpt/MNISTgeneration_{step}.pth\")\n",
    "            net = net.to(config.device)\n",
    "            \n",
    "            ode = LearnedODESolver(net, config)\n",
    "            stacked_dim = config.p.x_loc*torch.ones(16)[:,None].to(config.device)\n",
    "            x_eval = torch.randn(16,config.data.num_channels,\n",
    "                                 config.data.image_size,\n",
    "                                 config.data.image_size).view(-1,config.DIM-1) \n",
    "            s = torch.cat([stacked_dim, x_eval.to(config.device)], dim=1) \n",
    "            s[:,0] = config.epsilon\n",
    "            mapped,traj = ode(s.clone()) \n",
    "            ema.restore(net.parameters())\n",
    "            \n",
    "            fig,ax = plt.subplots(2,16,figsize=(16,2))\n",
    "            for idx in range(16):\n",
    "                ax[0,idx].imshow(traj[0][idx].permute(1,2,0))\n",
    "                ax[1,idx].imshow(traj[-2][idx].permute(1,2,0))\n",
    "                ax[0,idx].set_yticks([]);ax[0,idx].set_xticks([]);\n",
    "                ax[1,idx].set_yticks([]);ax[1,idx].set_xticks([]);\n",
    "            fig.tight_layout(pad=0.001)\n",
    "\n",
    "            wandb.log({\"Generated Images\":fig},step=step)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b58aa30",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e899370",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ac4254f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
