{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pythae.models import AutoModel\n",
    "\n",
    "vae_128 = AutoModel.load_from_folder(\n",
    "    'my_models_on_cifar/final_model'\n",
    "    )\n",
    "vae_128.to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# lode the cifar 10 data set\n",
    "from torchvision.datasets import CIFAR10\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "transform = transforms.Compose(\n",
    "    [transforms.ToTensor()]\n",
    ")\n",
    "\n",
    "#train_data = CIFAR10(root='data', train=True, download=True, transform=transform)\n",
    "test_data = CIFAR10(root='data', train=False, download=True, transform=transform)\n",
    "\n",
    "#train_loader = DataLoader(train_data, batch_size=32, shuffle=True)\n",
    "test_loader = DataLoader(test_data, batch_size=32, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for img, label in test_loader:\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reconstructed_128 = vae_128.reconstruct(img.to('cuda')).detach().cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from pythae.data.datasets import DatasetOutput\n",
    "\n",
    "dataset = DatasetOutput(data=img.to('cuda'))\n",
    "out = vae_128(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out['z']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_tangent_space(decoder, z, device='cuda'):\n",
    "    assert len(z.shape) == 1, \"compute_tangent_space: batch dimension in z is not supported. z has to be a 1-dimensional vector\"\n",
    "    decoder.to(device)\n",
    "    z = z.to(device)\n",
    "    latent_dim = z.shape[0]\n",
    "    z.requires_grad = True\n",
    "    out = decoder(z.unsqueeze(0))\n",
    "    out = out['reconstruction'].squeeze()      # remove singleton batch dimension\n",
    "    output_shape = out.shape # store original output shape\n",
    "    out = out.reshape(-1)    # and transform the output into a vector\n",
    "    tangent_space = torch.zeros((latent_dim, out.shape[0]))\n",
    "    for i in range(out.shape[0]):\n",
    "        out[i].backward(retain_graph=True)\n",
    "        tangent_space[:, i] = z.grad\n",
    "        z.grad.zero_()\n",
    "    tangent_space = tangent_space.reshape((-1, *output_shape)) # tangent space in model output shape\n",
    "    return tangent_space\n",
    "\n",
    "tangent_space = compute_tangent_space(vae_128.decoder, out['z'].detach()[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import tqdm\n",
    "\n",
    "tangent_space_batch = []\n",
    "for i in tqdm.tqdm(range(32)):\n",
    "    tangent_space_batch.append(compute_tangent_space(vae_128.decoder, out['z'].detach()[i]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_image_grid(images):\n",
    "    fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n",
    "\n",
    "    for i in range(5):\n",
    "        for j in range(5):\n",
    "            axes[i][j].imshow(images[i*5 +j].cpu().squeeze(0).numpy().transpose((1,2,0)), cmap='gray')\n",
    "            axes[i][j].axis('off')\n",
    "    plt.tight_layout(pad=0.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# show the original images\n",
    "plot_image_grid(img)\n",
    " \n",
    "# show the reconstructed images\n",
    "plot_image_grid(reconstructed_128)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Figures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.linalg import orth\n",
    "\n",
    "def project_into_tangent_space(tangent_space, vector):\n",
    "    BATCH_DIM = tangent_space.shape[1]\n",
    "    IMG_DIM = tangent_space.shape[2]\n",
    "    tangent_space_orth = orth(tangent_space.reshape((-1, BATCH_DIM*IMG_DIM*IMG_DIM)).T).T.reshape((-1, BATCH_DIM, IMG_DIM, IMG_DIM))\n",
    "    dim = tangent_space_orth.shape[0]\n",
    "    coeff = np.zeros(dim)\n",
    "    for i in range(dim):\n",
    "        coeff[i] = tangent_space_orth[i, :, :].flatten() @ vector.flatten()\n",
    "    vector_in_tangent_space = (coeff @ tangent_space_orth.reshape((dim, -1))).reshape((BATCH_DIM, IMG_DIM, IMG_DIM))\n",
    "    return vector_in_tangent_space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# project vectors onto the respective tangent spaces\n",
    "noise = torch.randn_like(img)\n",
    "tangent_noise = torch.zeros_like(noise)\n",
    "orthogonal_noise = torch.zeros_like(noise)\n",
    "\n",
    "def direction_to_image(direction):\n",
    "    direction /= direction.abs().max()\n",
    "    direction = (1 + direction) / 2\n",
    "    return direction\n",
    "\n",
    "for i in range(32):\n",
    "    tangent_noise[i] = torch.Tensor(project_into_tangent_space(tangent_space_batch[i].numpy(), noise[i].numpy()))\n",
    "    orthogonal_noise[i] = noise[i] - tangent_noise[i]\n",
    "    \n",
    "    tangent_noise[i] = direction_to_image(tangent_noise[i])\n",
    "    orthogonal_noise[i] = direction_to_image(orthogonal_noise[i])\n",
    "\n",
    "plot_image_grid(tangent_noise)\n",
    "plot_image_grid(orthogonal_noise)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Estimate the tangent space for all images in the test set of cifar10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = DataLoader(test_data, batch_size=1, shuffle=False)\n",
    "\n",
    "test_tangent_spaces = []\n",
    "for idx, (img, label) in tqdm.tqdm(enumerate(loader)):\n",
    "    dataset = DatasetOutput(data=img.to('cuda'))\n",
    "    out = vae_128(dataset)\n",
    "    test_tangent_spaces.append(compute_tangent_space(vae_128.decoder, out['z'].detach()[0]))\n",
    "    if idx % 10 == 0:\n",
    "        torch.save(test_tangent_spaces, f'results/test_tangent_spaces_{idx}.pt')\n",
    "        if idx > 0:\n",
    "            os.remove(f'results/test_tangent_spaces_{idx-10}.pt')\n",
    "    if idx > 999:\n",
    "        break\n",
    "torch.save(test_tangent_spaces, f'results/test_tangent_spaces.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.allclose(test_tangent_spaces[0][0, 0, 0, :], tangent_space_batch[0][0, 0, 0, :])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_tangent_spaces = torch.load(f'results/test_tangent_spaces.pt')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
