{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import statements\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "import samplers\n",
    "import numpy as np\n",
    "import importlib\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "from network_unet import UNet\n",
    "from randomized_svd_jacobian import randomized_svd as rsvd\n",
    "\n",
    "import torch_dct as dct\n",
    "\n",
    "importlib.reload(samplers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load CelebA\n",
    "\n",
    "to_tensor = torchvision.transforms.ToTensor()\n",
    "downsize = torchvision.transforms.Resize((256, 256))\n",
    "composed_transform = torchvision.transforms.Compose([downsize, to_tensor])\n",
    "root = \"\" # path to CelebA dataset\n",
    "trainset = torchvision.datasets.CelebA(root=root, split='train', download=True, transform=composed_transform)\n",
    "trainset_abridged = torch.utils.data.Subset(trainset, range(2000)) # 2000 images\n",
    "batch_size = 16\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim, hidden_dims):\n",
    "        super(MLP, self).__init__()\n",
    "        self.input_dim = input_dim\n",
    "        self.output_dim = output_dim\n",
    "        self.hidden_dims = hidden_dims\n",
    "        self.layers = nn.ModuleList()\n",
    "        self.layers.append(nn.Linear(input_dim, hidden_dims[0]))\n",
    "        for i in range(len(hidden_dims)-1):\n",
    "            self.layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))\n",
    "        self.layers.append(nn.Linear(hidden_dims[-1], output_dim))\n",
    "\n",
    "    def forward(self, x):\n",
    "        for i in range(len(self.layers)-1):\n",
    "            x = F.elu(self.layers[i](x))\n",
    "        return self.layers[-1](x).reshape(-1, 3, 80, 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Class Enc1 is just an input layer plus Elu\n",
    "\n",
    "class Enc1(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim):\n",
    "        super(Enc1, self).__init__()\n",
    "        self.input_dim = input_dim\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.linear = nn.Linear(input_dim, hidden_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return F.elu(self.linear(x))\n",
    "\n",
    "# Class enc2 is just an output layer\n",
    "\n",
    "class Enc2(nn.Module):\n",
    "    def __init__(self, hidden_dim, output_dim):\n",
    "        super(Enc2, self).__init__()\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.output_dim = output_dim\n",
    "        self.linear = nn.Linear(hidden_dim, output_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.linear(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TruncatedDCT(nn.Module):\n",
    "    def __init__(self, num_coeffs):\n",
    "        super(TruncatedDCT, self).__init__()\n",
    "        self.num_coeffs = num_coeffs\n",
    "\n",
    "    def forward(self, x):\n",
    "        x_dct = dct.dct_2d(x)\n",
    "        x_dct_trunc = x_dct[:,:,:self.num_coeffs,:self.num_coeffs]\n",
    "        return x_dct_trunc\n",
    "\n",
    "class TruncatedIDCT(nn.Module):\n",
    "    def __init__(self, num_coeffs, original_size):\n",
    "        super(TruncatedIDCT, self).__init__()\n",
    "        self.num_coeffs = num_coeffs\n",
    "        self.original_size = original_size\n",
    "\n",
    "    def forward(self, X_dct_trunc):\n",
    "        X_dct_trunc = torch.cat([X_dct_trunc, torch.zeros(X_dct_trunc.shape[0], X_dct_trunc.shape[1], self.original_size - self.num_coeffs, self.num_coeffs, device=device)], dim=2)\n",
    "        X_dct_trunc = torch.cat([X_dct_trunc, torch.zeros(X_dct_trunc.shape[0], X_dct_trunc.shape[1], self.original_size, self.original_size - self.num_coeffs, device=device)], dim=3)\n",
    "        x_idct = dct.idct_2d(X_dct_trunc)\n",
    "        return x_idct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load enc1, enc2, dec from checkpoint\n",
    "\n",
    "checkpt_fname = \"\" # fill in with checkpoint filename\n",
    "checkpt = torch.load(checkpt_fname, map_location=device)\n",
    "\n",
    "D = 3*80*80\n",
    "input_dim = D\n",
    "latent_dim = 700\n",
    "hidden_dim = 10000\n",
    "\n",
    "num_coeffs = 80\n",
    "\n",
    "trunc_dct = TruncatedDCT(num_coeffs=num_coeffs)\n",
    "trunc_idct = TruncatedIDCT(num_coeffs=num_coeffs, original_size=256)\n",
    "\n",
    "enc1 = Enc1(input_dim, hidden_dim).to(device)\n",
    "enc2 = Enc2(hidden_dim, latent_dim).to(device)\n",
    "output_unet = UNet(in_nc=3, out_nc=3).to(device)\n",
    "dec = nn.Sequential(MLP(latent_dim,input_dim,[hidden_dim]), trunc_idct, output_unet).to(device)\n",
    "\n",
    "enc1.load_state_dict(checkpt['enc1_state_dict'])\n",
    "enc2.load_state_dict(checkpt['enc2_state_dict'])\n",
    "dec.load_state_dict(checkpt['dec_state_dict'])\n",
    "\n",
    "losses = checkpt['losses']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "del checkpt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(torch.log(torch.tensor(losses)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Concatenate enc1 and enc2 to get encoder\n",
    "\n",
    "enc = nn.Sequential(enc1, enc2)\n",
    "\n",
    "# Concatenate enc and dec to get autoencoder\n",
    "\n",
    "autoencoder = nn.Sequential(enc, dec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw N random points from training set\n",
    "\n",
    "X, _ = next(iter(trainloader))\n",
    "\n",
    "X = X.to(device)\n",
    "# Take truncated DCT of X\n",
    "X_dct = trunc_dct(X) # X.shape = (batch_size, 3, 128, 128)\n",
    "X_dct = X_dct.reshape(X_dct.shape[0], D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Singular values of encoder Jacobian at training point\n",
    "\n",
    "idx = 0\n",
    "x = X_dct[idx]\n",
    "rank = 700\n",
    "start = time.time()\n",
    "U, S, V = rsvd(enc, x, rank, oversampling_factor=10)\n",
    "end = time.time()\n",
    "print('Time for rsvd: %.4f' % (end - start))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(S.detach().cpu())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show(img):\n",
    "    npimg = img.numpy()\n",
    "    plt.figure(figsize=(20,5))\n",
    "    # no ticks\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize first 10 elements of X on subplots\n",
    "\n",
    "X_viz = X[:8].cpu()\n",
    "\n",
    "# Visualize X_viz on torch grid\n",
    "\n",
    "grid = torchvision.utils.make_grid(X_viz.reshape(-1,3,256,256), nrow=10)\n",
    "show(grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize outputs of decoder MLP\n",
    "\n",
    "dec_mlp_out = dec[0](enc(X_dct[:5]))\n",
    "dec_mlp_out = trunc_idct(dec_mlp_out)\n",
    "dec_mlp_out = dec_mlp_out.detach().cpu()\n",
    "\n",
    "# Visualize dec_mlp_out on torch grid\n",
    "\n",
    "grid = torchvision.utils.make_grid(dec_mlp_out.reshape(-1,3,256,256), nrow=5)\n",
    "show(grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize reconstructions\n",
    "\n",
    "X_rec = dec(enc(X_dct[:8]))\n",
    "X_rec = X_rec.detach().cpu()\n",
    "\n",
    "# Visualize X_rec on torch grid\n",
    "\n",
    "grid = torchvision.utils.make_grid(X_rec.reshape(-1,3,256,256), nrow=10)\n",
    "show(grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Interpolate between latents and visualize reconstructions\n",
    "\n",
    "Z = enc(X_dct)\n",
    "Z_interp = torch.zeros(10, latent_dim).to(device)\n",
    "Z_interp[0] = Z[0]\n",
    "Z_interp[-1] = Z[3]\n",
    "for i in range(1, 9):\n",
    "    Z_interp[i] = (i/9) * Z[3] + ((9-i)/9) * Z[0]\n",
    "X_interp = dec(Z_interp)\n",
    "X_interp = X_interp.detach().cpu()\n",
    "\n",
    "# Visualize X_interp on torch grid\n",
    "\n",
    "grid = torchvision.utils.make_grid(X_interp.reshape(-1,3,256,256), nrow=10)\n",
    "show(grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Take dominant left-singular vector of enc at training point\n",
    "\n",
    "idx = 0\n",
    "x = X_dct[idx]\n",
    "rank = 700\n",
    "U, S, V = rsvd(enc, x, rank, oversampling_factor=10)\n",
    "u = U[:,:5]\n",
    "\n",
    "# Move along each of the top 5 left-singular vectors of enc at training point\n",
    "# Visualize reconstructions along rows\n",
    "\n",
    "Z = enc(X_dct[:10])\n",
    "Z_interp = torch.zeros(u.shape[1]*8, latent_dim).to(device)\n",
    "t = torch.linspace(-1, 1, 8).to(device)\n",
    "scale = 20000 # 2k for sigma=0.5, 20k for sigma=0\n",
    "for i in range(u.shape[1]):\n",
    "    for j in range(8):\n",
    "        Z_interp[i*8+j] = Z[idx] + t[j] * scale * u[:,i]\n",
    "X_interp = dec(Z_interp)\n",
    "X_interp = X_interp.detach().cpu()\n",
    "\n",
    "# Visualize X_interp on torch grid\n",
    "\n",
    "grid = torchvision.utils.make_grid(X_interp.reshape(-1,3,256,256), nrow=8)\n",
    "show(grid)\n",
    "\n",
    "# Save the grid\n",
    "\n",
    "plt.savefig('results/sigma0_top5_left_singular_vecs_training_0.png', bbox_inches='tight', dpi=300)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ldm4",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
