{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "47c6a3d8",
   "metadata": {},
   "source": [
    "# Colored MNIST Translation 32x32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28627ce9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    " \n",
    "import numpy as np\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import wandb\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import clear_output\n",
    "%matplotlib inline\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from src.utils import Config, optimization_manager, get_scale\n",
    "from src.data import GetDigitMNIST\n",
    "from src.ode import ODESolver \n",
    "from src.utils import compute_field, random_color\n",
    "from src.models import DDPM, ExponentialMovingAverage"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "705e153c",
   "metadata": {},
   "source": [
    "## 1. Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8818a3c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = Config()\n",
    "\n",
    "config.SCALE = 1.\n",
    "config.L = 10.\n",
    "config.K =  math.pi/10.\n",
    "config.epsilon =0.05\n",
    "config.interpolation  = 'both_side'\n",
    "config.epochs = 1_000\n",
    "config.training_steps = 100_000\n",
    "\n",
    "config.M = 291\n",
    "config.sigma_end = 0.01\n",
    "config.tau = 0.03\n",
    "config.device = 'cuda'\n",
    "\n",
    "config.D = math.pi/(2*config.K)\n",
    "config.field_type = \"Shifted\"\n",
    "config.plan_type = \"Independent\"\n",
    "config.field_form = \"exponential\"\n",
    "\n",
    "config.data = Config()\n",
    "config.data.num_channels = 3\n",
    "config.data.centered = True\n",
    "config.data.image_size =32\n",
    "config.data.classes = (2,3)\n",
    "config.data.batch_size=128\n",
    "config.data.colored = True\n",
    "\n",
    "config.name = \"untoy\"\n",
    "config.DIM = config.data.image_size*config.data.image_size*config.data.num_channels + 1\n",
    " \n",
    "config.gamma = 0.\n",
    "\n",
    "config.p = Config()\n",
    "config.p.x_loc = 0.\n",
    "\n",
    "config.q = Config()\n",
    "config.q.x_loc = config.L\n",
    "\n",
    "config.model  = Config()\n",
    "config.model.name = 'ncsnpp'\n",
    "config.model.scale_by_sigma = False\n",
    "config.model.ema_rate = 0.9999\n",
    "config.model.normalization = 'GroupNorm'\n",
    "config.model.nonlinearity = 'swish'\n",
    "config.model.nf = 128\n",
    "config.model.ch_mult = (1, 2, 2, 2)\n",
    "config.model.num_res_blocks = 4\n",
    "config.model.attn_resolutions = (16,)\n",
    "config.model.resamp_with_conv = True\n",
    "config.model.conditional = True\n",
    "config.model.fir = False\n",
    "config.model.fir_kernel = [1, 3, 3, 1]\n",
    "config.model.skip_rescale = True\n",
    "config.model.resblock_type = 'biggan'\n",
    "config.model.progressive = 'none'\n",
    "config.model.progressive_input = 'none'\n",
    "config.model.progressive_combine = 'sum'\n",
    "config.model.attention_type = 'ddpm'\n",
    "config.model.init_scale = 0.\n",
    "config.model.fourier_scale = 16\n",
    "config.model.embedding_type = 'positional'\n",
    "config.model.conv_size = 3\n",
    "config.model.dropout = 0.1\n",
    " \n",
    "\n",
    "config.optim  = Config()\n",
    "config.optim.weight_decay = 0\n",
    "config.optim.optimizer = 'Adam'\n",
    "config.optim.lr = 2e-4\n",
    "config.optim.beta1 = 0.9\n",
    "config.optim.eps = 1e-8\n",
    "config.optim.warmup = 4000\n",
    "config.optim.grad_clip = 15.\n",
    "\n",
    "config.training = Config()\n",
    "config.training.sde = 'poisson'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3d612b6",
   "metadata": {},
   "source": [
    "## 2. Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db48a14e",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform=torchvision.transforms.Compose([\n",
    "    torchvision.transforms.Resize(config.data.image_size),\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    torchvision.transforms.Normalize([0.5], [0.5])\n",
    "    ])\n",
    "\n",
    "\n",
    "dataset = torchvision.datasets.MNIST('../data/MNIST', train=True, download=True,\n",
    "                    transform=transform)\n",
    "\n",
    "train_data = {}\n",
    "train_loader = {}\n",
    "train_iter = {}\n",
    "for k in config.data.classes:\n",
    "    train_data[str(k)] = GetDigitMNIST(dataset, [k], config)\n",
    "    train_loader[str(k)] = torch.utils.data.DataLoader(train_data[str(k)] ,\n",
    "                                                       batch_size=config.data.batch_size, \n",
    "                                                       shuffle=True)\n",
    "    train_iter[str(k)] = iter(train_loader[str(k)])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7015d5f",
   "metadata": {},
   "source": [
    "## 3. Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abc878cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "net = DDPM(config).to(config.device)\n",
    "params = net.parameters()\n",
    "optimizer = torch.optim.Adam(params,\n",
    "            lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,\n",
    "            weight_decay=config.optim.weight_decay)\n",
    "\n",
    "\"\"\"\n",
    "scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=config.optim.lr,\n",
    "                                    steps_per_epoch=math.ceil(len(train_data['2'])/config.data.batch_size),\n",
    "                                    epochs=config.epochs)\n",
    "\"\"\"\n",
    "\n",
    "ema = ExponentialMovingAverage(net.parameters(), decay=config.model.ema_rate)\n",
    "state = dict(optimizer=optimizer, model=net, ema=ema, step=0)\n",
    "optimize_fn = optimization_manager(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4dafb163",
   "metadata": {},
   "source": [
    "## 4. Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "981b9622",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "wandb.init(project=\"SFMTranslation\",\n",
    "name=f\"FinalLinearLR_clip_{config.optim.grad_clip}_{config.interpolation}_L={config.L }_s_{config.SCALE}_D_{config.D }_eps_{config.epsilon}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5eac679b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LearnedODESolver:\n",
    "\n",
    "    def __init__(self, net, config):\n",
    "        self.config = config\n",
    "        self.net = net\n",
    "\n",
    "    def __call__(self, x_init ):\n",
    "        trajectory = [x_init[:,1:].view(-1,self.config.data.num_channels,\n",
    "                                               self.config.data.image_size,\n",
    "                                               self.config.data.image_size).clone().detach().cpu()]\n",
    "        mask = torch.tensor(x_init.shape[0]*[True]).to(config.device)\n",
    "        \n",
    "        while mask.any():\n",
    " \n",
    "                \n",
    "            field_x, field_z = self.net(x_init[:,1:].view(-1,self.config.data.num_channels,\n",
    "                                                             self.config.data.image_size,\n",
    "                                                             self.config.data.image_size) , x_init[:,0]   )\n",
    "            \n",
    "            field = torch.cat([field_z.view(-1,1),\n",
    "                               field_x.view(-1, self.config.data.num_channels*\\\n",
    "                                                self.config.data.image_size*\\\n",
    "                                                self.config.data.image_size)], dim=1) # [B, 1+C*H*W]\n",
    " \n",
    "            x_init  = x_init  + (0.125/ ( field_z.view(-1,1)  + self.config.gamma ))*field # [B, C*H*W+1]\n",
    "            trajectory.append(x_init[:,1:].view(-1,self.config.data.num_channels,\n",
    "                                               self.config.data.image_size,\n",
    "                                               self.config.data.image_size).clone().detach().cpu())\n",
    "            t = x_init[:,0]\n",
    "            mask = t[0] < self.config.L \n",
    "            #print(t[0])\n",
    "            \n",
    "            \n",
    "            \n",
    "        return x_init, trajectory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9892aec0",
   "metadata": {},
   "outputs": [],
   "source": [
    "losses = []\n",
    "for step in tqdm(range(47_000)):\n",
    "    \n",
    "    #################################\n",
    "    try:\n",
    "        batch_x = next(train_iter['2'])\n",
    "        batch_x = batch_x.to(config.device)\n",
    "    except StopIteration:\n",
    "        print('1')\n",
    "    else:\n",
    "        train_iter['2'] =  iter(train_loader['2'])\n",
    "        batch_x = next(train_iter['2'])  \n",
    "        batch_x = batch_x.to(config.device)\n",
    "    \n",
    " \n",
    "    try:\n",
    "        batch_y = next(train_iter['3'])\n",
    "        batch_y = batch_y.to(config.device)\n",
    "    except StopIteration:\n",
    "        print('1')\n",
    "    else:\n",
    "        train_iter['3'] =  iter(train_loader['3'])\n",
    "        batch_y = next(train_iter['3'])  \n",
    "        batch_y = batch_y.to(config.device)\n",
    "    #################################\n",
    "    \n",
    "    quarks = torch.cat([config.p.x_loc*torch.ones(batch_x.shape[0]).to(config.device)[:,None], \n",
    "            batch_x.view(-1,config.DIM-1)],dim=1)#[B,3*32*32+1]\n",
    "    anti_quarks = torch.cat([config.q.x_loc*torch.ones(batch_x.shape[0]).to(config.device)[:,None],\n",
    "            batch_y.view(-1,config.DIM-1)],dim=1)#[B,3*32*32+1]\n",
    "    \n",
    "    optimizer = state['optimizer']\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    \n",
    "    \n",
    "     \n",
    "    \n",
    "    if config.interpolation == 'left_side':\n",
    "        m = torch.rand((batch_x.shape[0],), device=batch_x.device) * config.M\n",
    "        z = torch.randn((len(batch_x), 1, 1, 1)).to( batch_x.device) * config.sigma_end\n",
    "        z = z.abs()\n",
    "        data_dim = config.data.num_channels * config.data.image_size * config.data.image_size\n",
    "        multiplier = (1+config.tau) ** m\n",
    "        perturbed_z = config.epsilon + z.squeeze() * multiplier\n",
    "        \n",
    "        mask_right = torch.nonzero(perturbed_z > config.q.x_loc - config.epsilon)\n",
    "        perturbed_z[mask_right.view(-1)] = torch.distributions.Uniform(low=config.p.x_loc + config.epsilon,\n",
    "                high=config.q.x_loc - config.epsilon).sample(torch.Size([len(mask_right)])).to(config.device)\n",
    "        \n",
    "        mask_left = torch.nonzero(perturbed_z < config.p.x_loc + config.epsilon)\n",
    "        perturbed_z[mask_left.view(-1)] = torch.distributions.Uniform(low=config.p.x_loc + config.epsilon,\n",
    "                high=config.q.x_loc - config.epsilon).sample(torch.Size([len(mask_left)])).to(config.device)\n",
    "        \n",
    "        \n",
    "        \n",
    "        perturbed_x = batch_y*(perturbed_z[:,None,None,None]/config.L) + (1 - perturbed_z[:,None,None,None]/config.L)*batch_x\n",
    "        perturbed_samples_vec = torch.cat([perturbed_z[:, None],\n",
    "                                           perturbed_x.reshape(len(batch_x), config.DIM-1)], dim=1)\n",
    "         \n",
    "        \n",
    "    elif config.interpolation == 'both_side':\n",
    "        m = torch.rand((batch_x.shape[0]//2,), device=batch_x.device) * config.M\n",
    "        left_z = torch.randn((len(batch_x)//2, 1, 1, 1)).to( batch_x.device) * config.sigma_end\n",
    "        left_z = left_z.abs()  \n",
    "        data_dim = config.data.num_channels * config.data.image_size * config.data.image_size\n",
    "        multiplier = (1+config.tau) ** m\n",
    "        perturbed_left_z = config.epsilon + left_z.squeeze() * multiplier\n",
    "        \n",
    "        m = torch.rand((batch_x.shape[0]//2,), device=batch_x.device) * config.M\n",
    "        right_z = torch.randn((len(batch_x)//2, 1, 1, 1)).to( batch_x.device) * config.sigma_end\n",
    "        right_z = right_z.abs()  \n",
    "        data_dim = config.data.num_channels * config.data.image_size * config.data.image_size\n",
    "        multiplier = (1+config.tau) ** m\n",
    "        perturbed_right_z = config.L - right_z.squeeze() * multiplier - config.epsilon\n",
    "\n",
    "        perturbed_z = torch.cat([perturbed_left_z,perturbed_right_z],dim=0)\n",
    "         \n",
    "        mask_right = torch.nonzero(perturbed_z > config.q.x_loc - config.epsilon)\n",
    "        perturbed_z[mask_right.view(-1)] = torch.distributions.Uniform(low=config.p.x_loc + config.epsilon,\n",
    "                high=config.q.x_loc - config.epsilon).sample(torch.Size([len(mask_right)])).to(config.device)\n",
    "        \n",
    "        mask_left = torch.nonzero(perturbed_z < config.p.x_loc + config.epsilon)\n",
    "        perturbed_z[mask_left.view(-1)] = torch.distributions.Uniform(low=config.p.x_loc + config.epsilon,\n",
    "                high=config.q.x_loc - config.epsilon).sample(torch.Size([len(mask_left)])).to(config.device)\n",
    "        \n",
    "         \n",
    "        perturbed_x = batch_y*(perturbed_z[:,None,None,None]/config.L) + (1 - perturbed_z[:,None,None,None]/config.L)*batch_x\n",
    "        perturbed_samples_vec = torch.cat([perturbed_z[:, None],\n",
    "                                           perturbed_x.reshape(len(batch_x), config.DIM-1)], dim=1)\n",
    "        \n",
    "    \n",
    "    elif config.interpolation == 'ode':\n",
    "        \n",
    "        points = quarks.clone()\n",
    "        points[:,0] = config.epsilon + config.p.x_loc\n",
    "        field = compute_field(quarks, anti_quarks,  points,\n",
    "                              config) \n",
    "\n",
    "        t = torch.distributions.Uniform(low= config.epsilon + config.p.x_loc,\n",
    "            high= config.q.x_loc - config.epsilon).sample(torch.Size([field.shape[0]])).to(config.device)[:,None]#[B,1]\n",
    "        perturbed_samples_vec  = points  + ( (t - config.epsilon) / (field[:,0][:,None] +  config.gamma ))*field\n",
    "    \n",
    "    \n",
    "    elif  config.interpolation == 'uniform':\n",
    "        \n",
    "        perturbed_z = torch.distributions.Uniform(low= config.epsilon + config.p.x_loc,\n",
    "        high= config.q.x_loc - config.epsilon).sample(torch.Size([batch_x.shape[0]])).to(config.device) \n",
    "        perturbed_x = (perturbed_z[:,None]/(config.L- 2*config.epsilon) )*batch_y.view(-1,\n",
    "                  config.DIM-1) + (1-perturbed_z[:,None]/(config.L- 2*config.epsilon))*batch_x.view(-1,config.DIM-1)\n",
    "        perturbed_samples_vec = torch.cat([perturbed_z[:, None],\n",
    "                                           perturbed_x.reshape(len(batch_x), config.DIM-1)], dim=1)\n",
    "        \n",
    "        \n",
    "        \n",
    "    elif config.interpolation == 'gauss':\n",
    "        \n",
    "    \n",
    "        gauss = torch.randn(batch_x.shape[0]//2).abs().to(config.device)[:,None]*10.\n",
    "        gauss_left = torch.clamp(gauss, min=config.model.epsilon,max=config.L - config.model.epsilon)\n",
    "\n",
    "        gauss = config.q.x_loc  - config.model.epsilon - torch.randn(batch_x.shape[0]//2).abs().to(config.device)[:,None]*10.\n",
    "        gauss_right =  torch.clamp(gauss, min=config.model.epsilon,max=config.L - config.model.epsilon)\n",
    "\n",
    "        perturbed_z =  torch.cat([gauss_left,gauss_right],dim=0)\n",
    "\n",
    "\n",
    "        noise = get_scale(perturbed_z,config.L - 2*config.model.epsilon,config) #[B,1]\n",
    "\n",
    "\n",
    "        perturbed_x = (perturbed_z/(config.L- 2*config.model.epsilon) )*batch_y.view(-1,3*32*32) + (1-perturbed_z/(config.L- 2*config.model.epsilon))*batch_x.view(-1,3*32*32)\n",
    "\n",
    "\n",
    "        gaussian = torch.randn( batch_x.shape[0], config.DIM - 1).to(config.device) # torch.Size([B, c*h*w])\n",
    "        unit_gaussian = gaussian / torch.norm(gaussian, p=2, dim=1, keepdim=True) \n",
    "\n",
    "\n",
    "        perturbed_x = perturbed_x #+ unit_gaussian*noise\n",
    "\n",
    "\n",
    "        perturbed_samples_vec = torch.cat([perturbed_z.view(-1,1),\n",
    "                                        perturbed_x.view(-1,3*32*32)],dim=1)\n",
    "        \n",
    "    else:\n",
    "        raise ValueError(f\"There is no such interpolation as {config.interpolation}\")\n",
    "    \n",
    "    cc    \n",
    "    ###### Ground-Truth field ######\n",
    "    field = compute_field(quarks , anti_quarks,\n",
    "                          perturbed_samples_vec,\n",
    "                          config)\n",
    "    field = math.sqrt(config.DIM)*field/( torch.norm(field, \n",
    "                                          dim=1, keepdim=True) + 1e-5)\n",
    "    \n",
    "     \n",
    "    ###### Ground-Truth field ######\n",
    "    \n",
    "    \n",
    "    \n",
    "    perturbed_samples_x = perturbed_samples_vec[:, 1:].view_as(batch_x)\n",
    "    perturbed_samples_z = perturbed_samples_vec[:, 0]\n",
    "    net_x, net_z = net(perturbed_samples_x, perturbed_samples_z)\n",
    "    pred = torch.cat([net_z[:, None], net_x.reshape(net_x.shape[0],config.DIM-1)], dim=1)\n",
    "    \n",
    "    \n",
    "    loss = torch.mean(((pred - field) ** 2))\n",
    "    loss.backward()\n",
    "    \n",
    "    optimize_fn( optimizer, net, step=state['step'], config=config)\n",
    "    \n",
    "    state['step'] += 1\n",
    "    state['ema'].update(net.parameters())\n",
    "    wandb.log({\"loss train\":loss.item()},step=step)\n",
    "    \n",
    "    losses.append(loss.item())\n",
    "    clear_output(wait=True)\n",
    "    plt.plot(losses)\n",
    "    plt.show()\n",
    "    \n",
    "    if step % 1_000 == 0:\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            ema = state['ema']\n",
    "            ema.store(net.parameters())\n",
    "            ema.copy_to(net.parameters())\n",
    "            \n",
    "            torch.save(net.cpu().state_dict(),f\"../ckpt/Trans/MNISTTranslation_{step}.pth\")\n",
    "            net = net.to(config.device)\n",
    "            \n",
    "            ode = LearnedODESolver(net, config)\n",
    "            stacked_dim = config.p.x_loc*torch.ones(16)[:,None].to(config.device)\n",
    "            x_eval = next(train_iter['2'])[:16].view(-1,config.DIM-1) \n",
    "            s = torch.cat([stacked_dim, x_eval.to(config.device)], dim=1) \n",
    "            s[:,0] = config.epsilon\n",
    "            mapped,traj = ode(s.clone()) \n",
    "            ema.restore(net.parameters())\n",
    "            \n",
    "            fig,ax = plt.subplots(2,16,figsize=(16,2))\n",
    "            for idx in range(16):\n",
    "                ax[0,idx].imshow(traj[0][idx].permute(1,2,0))\n",
    "                ax[1,idx].imshow(traj[-2][idx].permute(1,2,0))\n",
    "                ax[0,idx].set_yticks([]);ax[0,idx].set_xticks([]);\n",
    "                ax[1,idx].set_yticks([]);ax[1,idx].set_xticks([]);\n",
    "            fig.tight_layout(pad=0.001)\n",
    "\n",
    "            wandb.log({\"Generated Images\":fig},step=step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83dae3f3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
