{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "import random\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "import matplotlib.pyplot as plt\n",
    "from functorch import vmap, jacfwd\n",
    "\n",
    "from morphomnist import io\n",
    "from vae import VAE\n",
    "from geodesic import GeodesicMaker, geodesic\n",
    "from dataset import MorphoMNISTDataset, r2r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
    "random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "torch.cuda.manual_seed(42)\n",
    "torch.cuda.manual_seed_all(42)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = True\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set = MorphoMNISTDataset(train=True)\n",
    "test_set = MorphoMNISTDataset(train=False)\n",
    "concat_set = MorphoMNISTDataset(train_set=train_set, test_set=test_set)\n",
    "train_set, test_set = r2r(concat_set)\n",
    "train_loader = DataLoader(dataset = train_set, batch_size=100, shuffle=True)\n",
    "test_loader = DataLoader(dataset = test_set, batch_size=100, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vae_model = VAE()\n",
    "vae_model = vae_model.cuda()\n",
    "geodesic_maker = GeodesicMaker().cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.exists('vae_model.pth'):\n",
    "    vae_model.load_state_dict(torch.load('vae_model.pth'))\n",
    "    vae_model.eval()\n",
    "    print(\"Pretrained model loaded successfully.\")\n",
    "    loaded = True\n",
    "else:\n",
    "    print(\"Pretrained model not found.\")\n",
    "    loaded = False\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(vae_model.parameters(), lr=1e-3)\n",
    "geo_optim = optim.Adam(geodesic_maker.parameters(), lr=1e-3)\n",
    "def loss_function(output, x, mu, logvar):\n",
    "    bce = nn.functional.binary_cross_entropy(output, x, reduction=\"sum\")\n",
    "    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n",
    "    return bce + kld\n",
    "\n",
    "def geodesic_loss_function(z):\n",
    "    batch_size = z.size(0)\n",
    "    w1, w2, b1, b2 = geodesic_maker(z)\n",
    "    \n",
    "    zero = torch.zeros((z.size(0), 1, 1)).cuda()\n",
    "    one = torch.ones((z.size(0), 1, 1)).cuda()\n",
    "    \n",
    "    energy = 0\n",
    "    \n",
    "    \n",
    "    gamma_zero, mu_zero, logvar_zero = geodesic(zero, w1, w2, b1, b2, vae_model.encoder)\n",
    "    gamma_one, mu_one, logvar_one = geodesic(one, w1, w2, b1, b2, vae_model.encoder)\n",
    "\n",
    "    gamma_old = gamma_zero\n",
    "    mu_old = mu_zero\n",
    "    logvar_old = logvar_zero\n",
    "\n",
    "    for i in range(1, 11):\n",
    "        p = one * (i / 10)\n",
    "        gamma_p, mu_p, logvar_p = geodesic(p, w1, w2, b1, b2, vae_model.encoder)\n",
    "        energy += torch.pow(mu_old - mu_p, 2) + (logvar_old.exp() + logvar_p.exp())\n",
    "        gamma_old, mu_old, logvar_old = gamma_p, mu_p, logvar_p\n",
    "        \n",
    "    geodesic_loss = (\n",
    "        nn.functional.l1_loss(gamma_zero, geodesic_maker.anchor.unsqueeze(0).repeat(batch_size, 1), reduction=\"sum\")\n",
    "        + nn.functional.l1_loss(gamma_one, z, reduction=\"sum\")\n",
    "        + energy.sum() / 2\n",
    "    )\n",
    "        \n",
    "    return geodesic_loss / batch_size\n",
    "\n",
    "def train(epoch):\n",
    "    vae_model.train()\n",
    "    train_loss = 0\n",
    "    train_geodesic_loss = 0\n",
    "    for batch_idx, (data, _, _) in enumerate(train_loader):\n",
    "        data = data.cuda()\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        output, z, mu, logvar = vae_model(data)\n",
    "        loss = loss_function(output, data, mu, logvar)\n",
    "        \n",
    "        loss.backward()\n",
    "        train_loss += loss.item()\n",
    "        optimizer.step()\n",
    "        \n",
    "        \"\"\"geodesic_loss = geodesic_loss_function(z.detach().clone())\n",
    "        \n",
    "        geodesic_loss.backward()\n",
    "        train_geodesic_loss += geodesic_loss.item()\n",
    "        geo_optim.step()\"\"\"\n",
    "        \n",
    "        \n",
    "    train_loss /= len(train_loader) * 100\n",
    "    print(\"Train epoch: {},\\tLoss: {:.6f}\".format(epoch, train_loss))\n",
    "\n",
    "def test():\n",
    "    vae_model.eval()\n",
    "    test_loss = 0\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, (data, _, _) in enumerate(test_loader):\n",
    "            data = data.cuda()\n",
    "            output, z, mu, logvar = vae_model(data)\n",
    "            test_loss += loss_function(output, data, mu, logvar).item()\n",
    "            \n",
    "        test_loss /= len(test_loader) * 100\n",
    "    print(\"test loss\", test_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(50):\n",
    "    if loaded:\n",
    "        print(\"skip training\")\n",
    "        break\n",
    "    train(epoch)\n",
    "    test()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if loaded:\n",
    "    pass\n",
    "else:\n",
    "    torch.save(vae_model.state_dict(), 'vae_model.pth')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "latent_x = np.linspace(-5, 5, 200)\n",
    "latent_y = np.linspace(-5, 5, 200)\n",
    "xv, yv = np.meshgrid(latent_x, latent_y)\n",
    "xtensor = torch.tensor(xv.reshape(-1))\n",
    "ytensor = torch.tensor(yv.reshape(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ztensor = torch.stack((xtensor, ytensor), dim=1)\n",
    "ztensor = torch.tensor(ztensor, dtype=torch.float32)\n",
    "ztensor.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vae_model = vae_model.cpu()\n",
    "jacobian = vmap(jacfwd(vae_model.decoder))(ztensor)\n",
    "vae_model = vae_model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jacobian.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jacobian = jacobian.view(40000, 28*28, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric = torch.bmm(jacobian.permute(0, 2, 1), jacobian)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indi = torch.sqrt(torch.linalg.det(metric).clamp_(min=1e-5)).detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indi.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indi = indi.reshape(200, 200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "h = plt.contourf(latent_x, latent_y, np.log(indi), extend=\"both\")\n",
    "anchor = geodesic_maker.anchor.detach().clone().cpu().numpy()\n",
    "plt.scatter(anchor[0], anchor[1],c=\"r\")\n",
    "plt.colorbar()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "anchor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_list = list()\n",
    "vae_model.eval()\n",
    "with torch.no_grad():\n",
    "    for batch_idx, (data, _, _) in enumerate(test_loader):\n",
    "        data = data.cuda()\n",
    "        output, z, mu, logvar = vae_model(data)\n",
    "        z_list.append(z.detach().cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_z_list = list()\n",
    "vae_model.eval()\n",
    "with torch.no_grad():\n",
    "    for batch_idx, (data, _, _) in enumerate(train_loader):\n",
    "        data = data.cuda()\n",
    "        output, z, mu, logvar = vae_model(data)\n",
    "        train_z_list.append(z.detach().cpu().numpy())\n",
    "z_train = np.concatenate(train_z_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "zs = np.concatenate(z_list, axis=0)\n",
    "zs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_z_test = np.random.permutation(zs)[:2000]\n",
    "random_z_train = np.random.permutation(z_train)[:18000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,6))\n",
    "plt.contourf(latent_x, latent_y, np.log(indi), extend=\"both\")\n",
    "anchor = geodesic_maker.anchor.detach().clone().cpu().numpy()\n",
    "plt.scatter(anchor[0], anchor[1],c=\"r\")\n",
    "#plt.scatter(zs[:,0], zs[:,1], c='g', alpha=0.5, s=0.5)\n",
    "#plt.scatter()\n",
    "plt.scatter(random_z_test[:,0], random_z_test[:,1], c='b', alpha=0.5, s=0.5)\n",
    "plt.scatter(random_z_train[:,0], random_z_train[:,1], c='g', alpha=0.5, s=0.5)\n",
    "plt.colorbar()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geo_loaded = False\n",
    "\n",
    "if os.path.exists('geodesic.pth'):\n",
    "    geodesic_maker.load_state_dict(torch.load('geodesic.pth'))\n",
    "    geodesic_maker.eval()\n",
    "    print(\"Pretrained geodesic model loaded successfully.\")\n",
    "    geo_loaded = True\n",
    "else:\n",
    "    print(\"Pretrained geodesic model not found.\")\n",
    "    geo_loaded = False\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def geodesic_training():\n",
    "    vae_model.eval()\n",
    "    train_geodesic_loss = 0\n",
    "    for batch_idx, (data, _, _) in enumerate(train_loader):\n",
    "        data = data.cuda()\n",
    "        geo_optim.zero_grad()\n",
    "        \n",
    "        output, z, mu, logvar = vae_model(data)\n",
    "        \n",
    "        geodesic_loss = geodesic_loss_function(z.detach().clone())\n",
    "        \n",
    "        geodesic_loss.backward()\n",
    "        train_geodesic_loss += geodesic_loss.item()\n",
    "        geo_optim.step()\n",
    "        \n",
    "        \n",
    "    print(\"Geodesic loss {:.6f}\".format(train_geodesic_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w1, w2, b1, b2 = geodesic_maker(torch.randn(100,2).cuda())\n",
    "\n",
    "zero = torch.zeros((100, 1, 1)).cuda()\n",
    "one = torch.ones((100, 1, 1)).cuda()\n",
    "\n",
    "energy = 0\n",
    "\n",
    "gamma_zero, mu_zero, logvar_zero = geodesic(zero, w1, w2, b1, b2, vae_model.encoder)\n",
    "gamma_one, mu_one, logvar_one = geodesic(one, w1, w2, b1, b2, vae_model.encoder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(100):\n",
    "    if geo_loaded:\n",
    "        break\n",
    "    geodesic_training()\n",
    "    \n",
    "if geo_loaded:\n",
    "    pass\n",
    "else:\n",
    "    torch.save(geodesic_maker.state_dict(), 'geodesic.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x0 = train_set[80][0].cuda()\n",
    "_, z0, _, _ = vae_model(x0.unsqueeze(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geo_list = list()\n",
    "one = torch.ones((1)).cuda()\n",
    "\n",
    "w1, w2, b1, b2 = geodesic_maker(z0)\n",
    "\n",
    "for i in range(100):\n",
    "    p = one * (50 - i) / 50\n",
    "    z, _, _ = geodesic(p, w1, w2, b1, b2, vae_model.encoder)\n",
    "    geo_list.append(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geo_tensor = torch.stack(geo_list, axis=0)\n",
    "geo_np = geo_tensor.detach().cpu().numpy()\n",
    "geo_np.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geo_np[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.log(indi).max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(15,13))\n",
    "plt.contour(latent_x, latent_y, np.log(indi), levels=[0, 3], colors='k')\n",
    "#plt.contourf(latent_x, latent_y, np.log(indi), levels=5, extend=\"both\")\n",
    "cont = plt.contourf(latent_x, latent_y, np.log(indi), extend=\"both\", cmap=\"gist_gray\")\n",
    "anchor = geodesic_maker.anchor.detach().clone().cpu().numpy()\n",
    "plt.scatter(anchor[0], anchor[1],c=\"r\")\n",
    "#plt.scatter(zs[:,0], zs[:,1], c='b', alpha=0.5, s=0.7)\n",
    "plt.scatter(random_z_train[:,0], random_z_train[:,1], c='g', alpha=0.5, s=0.5)\n",
    "plt.scatter(random_z_test[:,0], random_z_test[:,1], c='b', alpha=0.5, s=0.5)\n",
    "plt.plot(geo_np[:, 0], geo_np[:,1], 'm')\n",
    "plt.plot([geo_np[1,0], 2*anchor[0]-geo_np[1,0]], [geo_np[1,1],2*anchor[1]-geo_np[1,1]], 'k-.')\n",
    "plt.colorbar(cont)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(15,10))\n",
    "plt.contour(latent_x, latent_y, np.log(indi), levels=[0, 4], colors='k')\n",
    "#plt.contourf(latent_x, latent_y, np.log(indi), levels=5, extend=\"both\")\n",
    "cont = plt.contourf(latent_x, latent_y, np.log(indi), extend=\"both\", cmap='gist_gray')\n",
    "anchor = geodesic_maker.anchor.detach().clone().cpu().numpy()\n",
    "plt.scatter(anchor[0], anchor[1],c=\"r\")\n",
    "#plt.scatter(zs[:,0], zs[:,1], c='b', alpha=0.5, s=0.7)\n",
    "plt.scatter(random_z_train[:, 0], random_z_train[:, 1], c='g', alpha=0.5, s=0.5)\n",
    "plt.scatter(random_z_test[:, 0], random_z_test[:, 1], c='b', alpha=0.5, s=0.5)\n",
    "plt.plot(geo_np[:, 0], geo_np[:,1], 'm')\n",
    "plt.plot([geo_np[1,0], 2*anchor[0]-geo_np[1,0]], [geo_np[1,1],2*anchor[1]-geo_np[1,1]], 'k-.')\n",
    "plt.colorbar(cont)\n",
    "plt.axis([-2.5, 2.5, -2.5, 2.5])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gsvae import GSVAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gsvae = GSVAE()\n",
    "gsoptim = optim.Adam(gsvae.parameters(), lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gs_loaded = False\n",
    "\n",
    "if os.path.exists('gsvae.pth'):\n",
    "    gsvae.load_state_dict(torch.load('gsvae.pth'))\n",
    "    gsvae.eval()\n",
    "    print(\"Pretrained gsvae model loaded successfully.\")\n",
    "    gs_loaded = True\n",
    "else:\n",
    "    print(\"Pretrained gsvae model not found.\")\n",
    "    gs_loaded = False\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gstrain(epoch):\n",
    "    gsvae.train()\n",
    "    train_loss = 0\n",
    "    for batch_idx, (data, _, _) in enumerate(train_loader):\n",
    "        data = data.cuda()\n",
    "        gsoptim.zero_grad()\n",
    "        \n",
    "        (rec, kld, gs), output, z, mu, logvar = gsvae(data)\n",
    "        loss = (rec + kld + 5 * gs)\n",
    "        loss.backward()\n",
    "        train_loss += loss.item()\n",
    "        \n",
    "        gsoptim.step()\n",
    "        \n",
    "    train_loss /= len(train_loader) * 100\n",
    "    print(\"Train epoch: {},\\tLoss: {:.6f}\".format(epoch, train_loss))\n",
    "    #print(\"Geodesic loss {:.6f}\".format(train_geodesic_loss))\n",
    "\n",
    "def gstest():\n",
    "    gsvae.eval()\n",
    "    test_loss = 0\n",
    "    rec_loss = 0\n",
    "    kld_loss = 0\n",
    "    gs_loss = 0\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, (data, _, _) in enumerate(test_loader):\n",
    "            data = data.cuda()\n",
    "            (rec, kld, gs), output, z, mu, logvar = gsvae(data)\n",
    "            test_loss += (rec + kld + 5 * gs).item()\n",
    "            rec_loss += rec.item()\n",
    "            kld_loss += kld.item()\n",
    "            gs_loss += gs.item()\n",
    "            \n",
    "            \n",
    "    test_loss /= len(test_loader) * 100\n",
    "    rec_loss /= len(test_loader) * 100\n",
    "    kld_loss /= len(test_loader) * 100\n",
    "    gs_loss /= len(test_loader) * 100\n",
    "    \n",
    "    print(\"test loss\", test_loss)\n",
    "    print(\"rec loss\", rec_loss)\n",
    "    print(\"kld loss\", kld_loss)\n",
    "    print(\"gs loss\", gs_loss )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gsvae = gsvae.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(50):\n",
    "    if gs_loaded:\n",
    "        print(\"skip training\")\n",
    "        break\n",
    "    gstrain(i)\n",
    "    gstest()\n",
    "    \n",
    "if gs_loaded:\n",
    "    pass\n",
    "else:\n",
    "    torch.save(gsvae.state_dict(), 'gsvae.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gsvae = gsvae.cpu()\n",
    "gsjacobian = vmap(jacfwd(gsvae.vae.decoder))(ztensor)\n",
    "gsjacobian = gsjacobian.view(40000, 28*28, 2)\n",
    "gsmetric = torch.bmm(gsjacobian.permute(0, 2, 1), gsjacobian)\n",
    "gsindi = torch.sqrt(torch.linalg.det(gsmetric).clamp_(min=1e-5)).detach().cpu().numpy()\n",
    "gsindi = gsindi.reshape(200, 200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gsvae = gsvae.cuda()\n",
    "geo_list0 = list()\n",
    "x0 = train_set[80][0].cuda()\n",
    "_, _, z0, _, _ = gsvae(x0)\n",
    "one = torch.ones((1)).cuda()\n",
    "\n",
    "w1, w2, b1, b2 = gsvae.gamma(z0)\n",
    "\n",
    "for i in range(100):\n",
    "    p = one * (50 - i) / 50\n",
    "    z, _, _ = geodesic(p, w1, w2, b1, b2, gsvae.vae.encoder)\n",
    "    geo_list0.append(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geo_list1 = list()\n",
    "x1 = test_set[10][0].cuda()\n",
    "_, _, z1, _, _ = gsvae(x1)\n",
    "one = torch.ones((1)).cuda()\n",
    "\n",
    "w1, w2, b1, b2 = gsvae.gamma(z1)\n",
    "\n",
    "for i in range(100):\n",
    "    p = one * (50 - i) / 50\n",
    "    z, _, _ = geodesic(p, w1, w2, b1, b2, gsvae.vae.encoder)\n",
    "    geo_list1.append(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geo_tensor0 = torch.stack(geo_list0, axis=0)\n",
    "geo_np0 = geo_tensor0.detach().cpu().numpy()\n",
    "geo_np0.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geo_tensor1 = torch.stack(geo_list1, axis=0)\n",
    "geo_np1 = geo_tensor1.detach().cpu().numpy()\n",
    "geo_np1.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_list = list()\n",
    "gsvae.eval()\n",
    "with torch.no_grad():\n",
    "    for batch_idx, (data, _, _) in enumerate(test_loader):\n",
    "        data = data.cuda()\n",
    "        loss, output, z, mu, logvar = gsvae(data)\n",
    "        z_list.append(z.detach().cpu().numpy())\n",
    "gszs = np.concatenate(z_list, axis=0)\n",
    "\n",
    "train_z_list = list()\n",
    "gsvae.eval()\n",
    "with torch.no_grad():\n",
    "    for batch_idx, (data, _, _) in enumerate(train_loader):\n",
    "        data = data.cuda()\n",
    "        loss, output, z, mu, logvar = gsvae(data)\n",
    "        train_z_list.append(z.detach().cpu().numpy())\n",
    "gsz_train = np.concatenate(train_z_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_gsz_test = np.random.permutation(gszs)[:2000]\n",
    "random_gsz_train = np.random.permutation(gsz_train)[:18000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(15,13))\n",
    "plt.contour(latent_x, latent_y, np.log(gsindi), levels=[0, 4], colors='k')\n",
    "#plt.contourf(latent_x, latent_y, np.log(indi), levels=3, extend=\"both\")\n",
    "cont = plt.contourf(latent_x, latent_y, np.log(gsindi), extend=\"both\", cmap=\"gist_gray\")\n",
    "#plt.contour(latent_x, latent_y, np.log(indi), extend=\"both\", colors='k')\n",
    "anchor = gsvae.gamma.anchor.detach().clone().cpu().numpy()\n",
    "plt.scatter(anchor[0], anchor[1],c=\"r\")\n",
    "#plt.scatter(gszs[:,0], gszs[:,1], c='b', alpha=0.5, s=0.7)\n",
    "plt.scatter(random_gsz_train[:,0], random_gsz_train[:,1], c='g', alpha=0.5, s=0.5)\n",
    "plt.scatter(random_gsz_test[:,0], random_gsz_test[:,1], c='b', alpha=0.5, s=0.5)\n",
    "plt.plot(geo_np0[:, 0], geo_np0[:,1], 'm-')\n",
    "#plt.plot(geo_np)\n",
    "plt.plot([geo_np0[0,0], 2*anchor[0]-geo_np0[0,0]], [geo_np0[0,1], 2*anchor[1]-geo_np0[0,1]], 'k-.')\n",
    "#plt.plot(geo_np1[:, 0], geo_np1[:,1], 'y')\n",
    "#plt.plot([geo_np1[0,0], 2*anchor[0]-geo_np1[0,0]], [geo_np1[0,1], 2*anchor[1]-geo_np1[0,1]], 'c')\n",
    "plt.colorbar(cont)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(15,10))\n",
    "plt.contour(latent_x, latent_y, np.log(gsindi), levels=[0, 4], colors='k')\n",
    "#plt.contourf(latent_x, latent_y, np.log(indi), levels=3, extend=\"both\")\n",
    "cont = plt.contourf(latent_x, latent_y, np.log(gsindi), extend=\"both\", cmap=\"gist_gray\")\n",
    "#plt.contour(latent_x, latent_y, np.log(indi), extend=\"both\", colors='k')\n",
    "anchor = gsvae.gamma.anchor.detach().clone().cpu().numpy()\n",
    "plt.scatter(anchor[0], anchor[1],c=\"r\")\n",
    "plt.scatter(random_gsz_train[:,0], random_gsz_train[:,1], c='g', alpha=0.5, s=0.5)\n",
    "plt.scatter(random_gsz_test[:,0], random_gsz_test[:,1], c='b', alpha=0.5, s=0.5)\n",
    "plt.plot(geo_np0[:, 0], geo_np0[:,1], 'm-')\n",
    "#plt.plot(geo_np)\n",
    "plt.plot([geo_np0[0,0], 2*anchor[0]-geo_np0[0,0]], [geo_np0[0,1], 2*anchor[1]-geo_np0[0,1]], 'k-.')\n",
    "plt.plot(geo_np1[:, 0], geo_np1[:,1], 'y')\n",
    "plt.plot([geo_np1[0,0], 2*anchor[0]-geo_np1[0,0]], [geo_np1[0,1], 2*anchor[1]-geo_np1[0,1]], 'c')\n",
    "plt.colorbar(cont)\n",
    "plt.axis([-2.5, 2.5, -2.5, 2.5])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vanilla_geodesic_list = list()\n",
    "vae_model.eval()\n",
    "z0 = torch.tensor(random_z_test[50]).unsqueeze(0)\n",
    "w1, w2, b1, b2 = geodesic_maker(torch.tensor(z0).cuda())\n",
    "one = torch.ones((1)).cuda()\n",
    "\n",
    "for i in range(100):\n",
    "    p = one * i / 100\n",
    "    z, _, _ = geodesic(p, w1, w2, b1, b2, vae_model.encoder)\n",
    "    vanilla_geodesic_list.append(z)\n",
    "    \n",
    "vanilla_geodesic_tensor = torch.stack(vanilla_geodesic_list, axis=0)\n",
    "vanilla_geodesic_np = vanilla_geodesic_tensor.detach().cpu().numpy()\n",
    "\n",
    "gs_geodesic_list = list()\n",
    "gsvae.eval()\n",
    "z0 = torch.tensor(random_gsz_test[17]).unsqueeze(0)\n",
    "w1, w2, b1, b2 = gsvae.gamma(torch.tensor(z0).cuda())\n",
    "one = torch.ones((1)).cuda()\n",
    "\n",
    "for i in range(100):\n",
    "    p = one * (50 - i) / 50\n",
    "    z, _, _ = geodesic(p, w1, w2, b1, b2, gsvae.vae.encoder)\n",
    "    gs_geodesic_list.append(z)\n",
    "    \n",
    "gs_geodesic_tensor = torch.stack(gs_geodesic_list, axis=0)\n",
    "gs_geodesic_np = gs_geodesic_tensor.detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 2, figsize=(16, 8))\n",
    "\n",
    "# Plot 1\n",
    "axs[0].contour(latent_x, latent_y, np.log(indi), levels=[0, 3], colors='k')\n",
    "axs[0].contourf(latent_x, latent_y, np.log(indi), extend=\"both\", levels=np.arange(0, 6, 0.2), cmap='gist_gray')\n",
    "axs[0].scatter(anchor[0], anchor[1], c=\"r\")\n",
    "#axs[0].scatter(zs[:, 0], zs[:, 1], c='b', alpha=0.5, s=0.7)\n",
    "axs[0].scatter(random_z_train[:, 0], random_z_train[:, 1], c='g', alpha=0.5, s=0.5)\n",
    "axs[0].scatter(random_z_test[:, 0], random_z_test[:, 1], c='b', alpha=0.5, s=0.5)\n",
    "axs[0].plot(vanilla_geodesic_np[:, 0], vanilla_geodesic_np[:, 1], 'm')\n",
    "#axs[0].plot(geo_np[:, 0], geo_np[:, 1], 'm')\n",
    "#axs[0].plot([geo_np[1, 0], 2 * anchor[0] - geo_np[1, 0]], [geo_np[1, 1], 2 * anchor[1] - geo_np[1, 1]], 'k-.')\n",
    "axs[0].axis([-4, 4, -4, 4])\n",
    "axs[0].set_title('Vanilla VAE')\n",
    "\n",
    "# Plot 2\n",
    "plt.contour(latent_x, latent_y, np.log(gsindi), levels=[0, 3], colors='k')\n",
    "plt.contourf(latent_x, latent_y, np.log(gsindi), extend=\"both\", levels=np.arange(0, 6, 0.2), cmap=\"gist_gray\")\n",
    "axs[1].scatter(anchor[0], anchor[1], c=\"r\")\n",
    "axs[1].scatter(random_gsz_train[:,0], random_gsz_train[:,1], c='g', alpha=0.5, s=0.5)\n",
    "axs[1].scatter(random_gsz_test[:,0], random_gsz_test[:,1], c='b', alpha=0.5, s=0.5)\n",
    "axs[1].plot(gs_geodesic_np[:, 0], gs_geodesic_np[:, 1], 'm')\n",
    "#axs[1].plot(geo_np0[:, 0], geo_np0[:, 1], 'm-')\n",
    "#axs[1].plot([geo_np0[0, 0], 2 * anchor[0] - geo_np0[0, 0]], [geo_np0[0, 1], 2 * anchor[1] - geo_np0[0, 1]], 'k-.')\n",
    "#axs[1].plot(geo_np1[:, 0], geo_np1[:, 1], 'y')\n",
    "#axs[1].plot([geo_np1[0, 0], 2 * anchor[0] - geo_np1[0, 0]], [geo_np1[0, 1], 2 * anchor[1] - geo_np1[0, 1]], 'c')\n",
    "axs[1].axis([-4, 4, -4, 4])\n",
    "axs[1].set_title('VAE + geodesic symmetry')\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 2, figsize=(16, 8))\n",
    "\n",
    "# Plot 1\n",
    "axs[0].contour(latent_x, latent_y, np.log(indi), levels=np.arange(0, 5, 0.5), colors='k')\n",
    "axs[0].contourf(latent_x, latent_y, np.log(indi), extend=\"both\", levels=np.arange(0, 6, 0.2), cmap='gist_gray')\n",
    "axs[0].scatter(anchor[0], anchor[1], c=\"r\")\n",
    "#axs[0].scatter(zs[:, 0], zs[:, 1], c='b', alpha=0.5, s=0.7)\n",
    "axs[0].scatter(random_z_train[:, 0], random_z_train[:, 1], c='g', alpha=0.5, s=0.5)\n",
    "axs[0].scatter(random_z_test[:, 0], random_z_test[:, 1], c='b', alpha=0.5, s=0.5)\n",
    "axs[0].plot(vanilla_geodesic_np[:, 0], vanilla_geodesic_np[:, 1], 'm')\n",
    "#axs[0].plot(geo_np[:, 0], geo_np[:, 1], 'm')\n",
    "#axs[0].plot([geo_np[1, 0], 2 * anchor[0] - geo_np[1, 0]], [geo_np[1, 1], 2 * anchor[1] - geo_np[1, 1]], 'k-.')\n",
    "axs[0].axis([-4, 4, -4, 4])\n",
    "axs[0].set_title('Vanilla VAE')\n",
    "\n",
    "# Plot 2\n",
    "axs[1].contour(latent_x, latent_y, np.log(gsindi), levels=np.arange(0, 5, 0.5), colors='k')\n",
    "axs[1].contourf(latent_x, latent_y, np.log(gsindi), extend=\"both\", levels=np.arange(0, 6, 0.2), cmap=\"gist_gray\")\n",
    "axs[1].scatter(anchor[0], anchor[1], c=\"r\")\n",
    "axs[1].scatter(random_gsz_train[:,0], random_gsz_train[:,1], c='g', alpha=0.5, s=0.5)\n",
    "axs[1].scatter(random_gsz_test[:,0], random_gsz_test[:,1], c='b', alpha=0.5, s=0.5)\n",
    "axs[1].plot(gs_geodesic_np[:, 0], gs_geodesic_np[:, 1], 'm')\n",
    "#axs[1].plot(geo_np0[:, 0], geo_np0[:, 1], 'm-')\n",
    "#axs[1].plot([geo_np0[0, 0], 2 * anchor[0] - geo_np0[0, 0]], [geo_np0[0, 1], 2 * anchor[1] - geo_np0[0, 1]], 'k-.')\n",
    "#axs[1].plot(geo_np1[:, 0], geo_np1[:, 1], 'y')\n",
    "#axs[1].plot([geo_np1[0, 0], 2 * anchor[0] - geo_np1[0, 0]], [geo_np1[0, 1], 2 * anchor[1] - geo_np1[0, 1]], 'c')\n",
    "axs[1].axis([-4, 4, -4, 4])\n",
    "axs[1].set_title('VAE + geodesic symmetry')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assuming random_gsz_test is a numpy array\n",
    "distances = np.linalg.norm(random_gsz_test - random_gsz_test[17], axis=1)\n",
    "\n",
    "# Create a dictionary with key-value pairs sorted by distance values\n",
    "dist_dict = {i: distance for i, distance in enumerate(distances)}\n",
    "sorted_dist_dict = sorted(dist_dict.items(), key=lambda x: x[1])\n",
    "\n",
    "sorted_dist_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "anchor = gsvae.gamma.anchor.detach().clone().cpu().numpy()\n",
    "\n",
    "point1 = random_gsz_test[sorted_dist_dict[100][0]]\n",
    "point2 = random_gsz_test[sorted_dist_dict[200][0]]\n",
    "point3 = random_gsz_test[17]\n",
    "\n",
    "point1_ref = 2 * anchor - point1\n",
    "point2_ref = 2 * anchor - point2\n",
    "point3_ref = 2 * anchor - point3\n",
    "\n",
    "geo_list1 = list()\n",
    "geo_list2 = list()\n",
    "geo_list3 = list()\n",
    "\n",
    "w11, w21, b11, b21 = gsvae.gamma(torch.tensor(point1).cuda().unsqueeze(0))\n",
    "w12, w22, b12, b22 = gsvae.gamma(torch.tensor(point2).cuda().unsqueeze(0))\n",
    "w13, w23, b13, b23 = gsvae.gamma(torch.tensor(point3).cuda().unsqueeze(0))\n",
    "\n",
    "\n",
    "gsvae.eval()\n",
    "one = torch.ones((1)).unsqueeze(0).cuda()\n",
    "\n",
    "for i in range(100):\n",
    "    p = one * (50 - i) / 50\n",
    "    z1, _, _ = geodesic(p, w11, w21, b11, b21, gsvae.vae.encoder)\n",
    "    z2, _, _ = geodesic(p, w12, w22, b12, b22, gsvae.vae.encoder)\n",
    "    z3, _, _ = geodesic(p, w13, w23, b13, b23, gsvae.vae.encoder)\n",
    "    \n",
    "    geo_list1.append(z1)\n",
    "    geo_list2.append(z2)\n",
    "    geo_list3.append(z3)\n",
    "    \n",
    "geo_tensor1 = torch.stack(geo_list1, axis=0)\n",
    "geo_np1 = geo_tensor1.detach().cpu().numpy()\n",
    "\n",
    "geo_tensor2 = torch.stack(geo_list2, axis=0)\n",
    "geo_np2 = geo_tensor2.detach().cpu().numpy()\n",
    "\n",
    "geo_tensor3 = torch.stack(geo_list3, axis=0)\n",
    "geo_np3 = geo_tensor3.detach().cpu().numpy()\n",
    "\n",
    "point1_geo_symm = geo_np1[-1]\n",
    "point2_geo_symm = geo_np2[-1]\n",
    "point3_geo_symm = geo_np3[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image1 = test_set[sorted_dist_dict[100][0]][0].numpy().reshape(28, 28)\n",
    "feature1 = test_set[sorted_dist_dict[100][0]][1]\n",
    "image2 = test_set[sorted_dist_dict[200][0]][0].numpy().reshape(28, 28)\n",
    "feature2 = test_set[sorted_dist_dict[200][0]][1]\n",
    "image3 = test_set[17][0].numpy().reshape(28, 28)\n",
    "feature3 = test_set[17][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dist_from_point1_ref = np.linalg.norm(gsz_train - point1_ref, axis=1)\n",
    "nearest_point1_ref = np.argmin(dist_from_point1_ref)\n",
    "image1_ref = train_set[nearest_point1_ref][0].numpy().reshape(28, 28)\n",
    "feature_ref1 = train_set[nearest_point1_ref][1]\n",
    "exact_point1_ref = gsz_train[nearest_point1_ref]\n",
    "\n",
    "dist_from_point2_ref = np.linalg.norm(gsz_train - point2_ref, axis=1)\n",
    "nearest_point2_ref = np.argmin(dist_from_point2_ref)\n",
    "image2_ref = train_set[nearest_point2_ref][0].numpy().reshape(28, 28)\n",
    "feature_ref2 = train_set[nearest_point2_ref][1]\n",
    "exact_point2_ref = gsz_train[nearest_point2_ref]\n",
    "\n",
    "dist_from_point3_ref = np.linalg.norm(gsz_train - point3_ref, axis=1)\n",
    "nearest_point3_ref = np.argmin(dist_from_point3_ref)\n",
    "image3_ref = train_set[nearest_point3_ref][0].numpy().reshape(28, 28)\n",
    "feature_ref3 = train_set[nearest_point3_ref][1]\n",
    "exact_point3_ref = gsz_train[nearest_point3_ref]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas\n",
    "from morphomnist.measure import measure_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "area1, length1, thickness1, slant1, width1, height1 = measure_image(image1)\n",
    "area2, length2, thickness2, slant2, width2, height2 = measure_image(image2)\n",
    "area3, length3, thickness3, slant3, width3, height3 = measure_image(image3)\n",
    "\n",
    "area_ref1, length_ref1, thickness_ref1, slant_ref1, width_ref1, height_ref1 = measure_image(image1_ref)\n",
    "area_ref2, length_ref2, thickness_ref2, slant_ref2, width_ref2, height_ref2 = measure_image(image2_ref)\n",
    "area_ref3, length_ref3, thickness_ref3, slant_ref3, width_ref3, height_ref3 = measure_image(image3_ref)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.figure import Figure\n",
    "from matplotlib.offsetbox import OffsetImage, AnnotationBbox, VPacker, TextArea"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vanilla_anchor = geodesic_maker.anchor.clone().detach().cpu().numpy()\n",
    "gs_anchor = gsvae.gamma.anchor.clone().detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(20,20))\n",
    "#ax.contour(latent_x, latent_y, np.log(gsindi), extend=\"both\", levels=np.arange(0, 6, 1), colors='k')\n",
    "ax.contourf(latent_x, latent_y, np.log(gsindi), extend=\"both\", levels=np.arange(0, 6, 0.2), cmap='gist_gray')\n",
    "ax.scatter(random_gsz_train[:, 0], random_gsz_train[:, 1], c='g', alpha=0.3, s=10)\n",
    "ax.scatter(random_gsz_test[:, 0], random_gsz_test[:, 1], c='b', alpha=0.3, s=10)\n",
    "ax.plot(geo_np1[:, 0], geo_np1[:, 1], 'm-.', linewidth=2)\n",
    "ax.plot(geo_np2[:, 0], geo_np2[:, 1], 'm-', linewidth=2)\n",
    "ax.plot(geo_np3[:, 0], geo_np3[:, 1], 'm', linewidth=2)\n",
    "ax.scatter(gs_anchor[0], gs_anchor[1], c='r', s=100)\n",
    "ax.axis([-0.3, 0.3, -0.1, 0.5])\n",
    "\n",
    "ax.plot([point1[0], point1_ref[0]], [point1[1], point1_ref[1]], 'k-.', linewidth=2)\n",
    "ax.plot([point2[0], point2_ref[0]], [point2[1], point2_ref[1]], 'k-', linewidth=2)\n",
    "ax.plot([point3[0], point3_ref[0]], [point3[1], point3_ref[1]], 'k', linewidth=2)\n",
    "\n",
    "ax.scatter(point1[0], point1[1], c='b', s=100, marker='X')\n",
    "ax.scatter(point2[0], point2[1], c='b', s=100, marker='X')\n",
    "ax.scatter(point3[0], point3[1], c='b', s=100, marker='X')\n",
    "\n",
    "ax.scatter(point1_ref[0], point1_ref[1], c='g', s=100, marker='s')\n",
    "ax.scatter(point2_ref[0], point2_ref[1], c='g', s=100, marker='s')\n",
    "ax.scatter(point3_ref[0], point3_ref[1], c='g', s=100, marker='s')\n",
    "\n",
    "ax.axis('off')\n",
    "\n",
    "fig.savefig('gs_vae_geodesic.png', dpi=100, bbox_inches='tight', pad_inches=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(20,20))\n",
    "ax.contour(latent_x, latent_y, np.log(gsindi), extend=\"both\", levels=[0,2.5], colors='k')\n",
    "ax.contourf(latent_x, latent_y, np.log(gsindi), extend=\"both\", levels=np.arange(0, 6, 0.2), cmap='gist_gray')\n",
    "ax.scatter(random_gsz_train[:, 0], random_gsz_train[:, 1], c='g', alpha=0.3, s=10)\n",
    "ax.scatter(random_gsz_test[:, 0], random_gsz_test[:, 1], c='b', alpha=0.3, s=10)\n",
    "#ax.plot(geo_np1[:, 0], geo_np1[:, 1], 'm-.', linewidth=2)\n",
    "#ax.plot(geo_np2[:, 0], geo_np2[:, 1], 'm-', linewidth=2)\n",
    "#ax.plot(geo_np3[:, 0], geo_np3[:, 1], 'm', linewidth=2)\n",
    "ax.scatter(gs_anchor[0], gs_anchor[1], c='r', s=100)\n",
    "ax.axis([-0.4, 0.6, -0.3, 0.6])\n",
    "\n",
    "ax.plot([point1[0], point1_ref[0]], [point1[1], point1_ref[1]], 'k-.', linewidth=2)\n",
    "ax.plot([point2[0], point2_ref[0]], [point2[1], point2_ref[1]], 'k-', linewidth=2)\n",
    "ax.plot([point3[0], point3_ref[0]], [point3[1], point3_ref[1]], 'k', linewidth=2)\n",
    "\n",
    "ax.scatter(point1[0], point1[1], c='b', s=100, marker='X')\n",
    "ax.scatter(point2[0], point2[1], c='b', s=100, marker='X')\n",
    "ax.scatter(point3[0], point3[1], c='b', s=100, marker='X')\n",
    "\n",
    "ax.scatter(point1_ref[0], point1_ref[1], c='g', s=100, marker='s')\n",
    "ax.scatter(point2_ref[0], point2_ref[1], c='g', s=100, marker='s')\n",
    "ax.scatter(point3_ref[0], point3_ref[1], c='g', s=100, marker='s')\n",
    "\n",
    "offset1 = OffsetImage(image1, zoom=4)\n",
    "offset2 = OffsetImage(image2, zoom=4)\n",
    "offset3 = OffsetImage(image3, zoom=4)\n",
    "\n",
    "offset1_ref = OffsetImage(image1_ref, zoom=4)\n",
    "offset2_ref = OffsetImage(image2_ref, zoom=4)\n",
    "offset3_ref = OffsetImage(image3_ref, zoom=4)\n",
    "\n",
    "\"\"\" ab1 = AnnotationBbox(offset1, point1, xybox=[0.5, 0.15], arrowprops=dict(arrowstyle=\"-|>\", lw=0.5), frameon=False)\n",
    "ab2 = AnnotationBbox(offset2, point2, xybox=[0.5, -0.1], arrowprops=dict(arrowstyle=\"-|>\", lw=0.5), frameon=False)\n",
    "ab3 = AnnotationBbox(offset3, point3, xybox=[0.5, 0.4], arrowprops=dict(arrowstyle=\"-|>\", lw=0.5), frameon=False)\n",
    "ab1_ref = AnnotationBbox(offset1_ref, exact_point1_ref, xybox=[-0.3, 0.15], arrowprops=dict(arrowstyle=\"-|>\", lw=0.5), frameon=False)\n",
    "ab2_ref = AnnotationBbox(offset2_ref, exact_point2_ref, xybox=[-0.3, 0.4], arrowprops=dict(arrowstyle=\"-|>\", lw=0.5), frameon=False)\n",
    "ab3_ref = AnnotationBbox(offset3_ref, exact_point3_ref, xybox=[-0.3, -0.1], arrowprops=dict(arrowstyle=\"-|>\", lw=0.5), frameon=False) \"\"\"\n",
    "\n",
    "text1 = TextArea(f\"B\\'\\nThickness: {thickness1:.3f}\\nSlant: {slant1:.3f}\", textprops=dict(color='black', fontsize=20))\n",
    "text2 = TextArea(f\"A\\'\\nThickness: {thickness2:.3f}\\nSlant: {slant2:.3f}\", textprops=dict(color='black', fontsize=20))\n",
    "text3 = TextArea(f\"C\\'\\nThickness: {thickness3:.3f}\\nSlant: {slant3:.3f}\", textprops=dict(color='black', fontsize=20))\n",
    "\n",
    "text1_ref = TextArea(f\"B\\nThickness: {thickness_ref1:.3f}\\nSlant: {slant_ref1:.3f}\", textprops=dict(color='black', fontsize=20))\n",
    "text2_ref = TextArea(f\"A\\nThickness: {thickness_ref2:.3f}\\nSlant: {slant_ref2:.3f}\", textprops=dict(color='black', fontsize=20))\n",
    "text3_ref = TextArea(f\"C\\nThickness: {thickness_ref3:.3f}\\nSlant: {slant_ref3:.3f}\", textprops=dict(color='black', fontsize=20))\n",
    "\n",
    "thickness_box = TextArea(\"Thickness\\nA ⇔ B ~ C → A\\' ⇔ B\\' ~ C\\'\", textprops=dict(color='black', fontsize=20))\n",
    "slant_box = TextArea(\"Slant\\nA ~ B ⇔ C → A\\' ~ B\\' ⇔ C\\'\", textprops=dict(color='black', fontsize=20))\n",
    "\n",
    "vp1 = VPacker(children=[offset1, text1])\n",
    "vp2 = VPacker(children=[offset2, text2])\n",
    "vp3 = VPacker(children=[offset3, text3])\n",
    "vp1_ref = VPacker(children=[offset1_ref, text1_ref])\n",
    "vp2_ref = VPacker(children=[offset2_ref, text2_ref])\n",
    "vp3_ref = VPacker(children=[offset3_ref, text3_ref])\n",
    "\n",
    "vp_textbox = VPacker(children=[thickness_box, slant_box])\n",
    "\n",
    "ab1 = AnnotationBbox(vp1, point1, xybox=[0.5, 0.15], arrowprops=dict(arrowstyle=\"-|>\", lw=0.5), frameon=False)\n",
    "ab2 = AnnotationBbox(vp2, point2, xybox=[0.5, -0.1], arrowprops=dict(arrowstyle=\"-|>\", lw=0.5), frameon=False)\n",
    "ab3 = AnnotationBbox(vp3, point3, xybox=[0.5, 0.4], arrowprops=dict(arrowstyle=\"-|>\", lw=0.5), frameon=False)\n",
    "\n",
    "ab1_ref = AnnotationBbox(vp1_ref, exact_point1_ref, xybox=[-0.3, 0.15], arrowprops=dict(arrowstyle=\"-|>\", lw=0.5), frameon=False)\n",
    "ab2_ref = AnnotationBbox(vp2_ref, exact_point2_ref, xybox=[-0.3, 0.4], arrowprops=dict(arrowstyle=\"-|>\", lw=0.5), frameon=False)\n",
    "ab3_ref = AnnotationBbox(vp3_ref, exact_point3_ref, xybox=[-0.3, -0.1], arrowprops=dict(arrowstyle=\"-|>\", lw=0.5), frameon=False)\n",
    "\n",
    "ab_tb = AnnotationBbox(vp_textbox, [0.1, 0.55], frameon=True, bboxprops=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.5'))\n",
    "\n",
    "ax.add_artist(ab1)\n",
    "ax.add_artist(ab2)\n",
    "ax.add_artist(ab3)\n",
    "ax.add_artist(ab1_ref)\n",
    "ax.add_artist(ab2_ref)\n",
    "ax.add_artist(ab3_ref)\n",
    "ax.add_artist(ab_tb)\n",
    "\n",
    "ax.axis('off')\n",
    "\n",
    "fig.savefig('gs_vae_symm.png', dpi=100, bbox_inches='tight', pad_inches=0)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "disentanglement_cuda_10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
