{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import sys, os\n",
    "\n",
    "import imageio\n",
    "import pickle\n",
    "from geomstats import visualization\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from sklearn import datasets\n",
    "import copy\n",
    "\n",
    "import io\n",
    "from PIL import Image\n",
    "import imageio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [],
   "source": [
    "###############################################################################\n",
    "###############################################################################\n",
    "########################## Geometric Tools in Sphere ##########################\n",
    "###############################################################################\n",
    "###############################################################################\n",
    "\n",
    "class Sphere_manifold:\n",
    "    def __init__(self):\n",
    "        pass\n",
    "    \n",
    "    def to_extrinsic(self, z):\n",
    "        theta = z[:, 0:1]\n",
    "        phi = z[:, 1:2]\n",
    "        return torch.cat([\n",
    "                    torch.sin(theta) * torch.cos(phi),\n",
    "                    torch.sin(theta) * torch.sin(phi),\n",
    "                    torch.cos(theta)\n",
    "                ], dim=1)\n",
    "\n",
    "    def to_intrinsic(self, x):\n",
    "        theta = torch.acos(torch.clip(x[:, 2:3], min=-1, max=1))\n",
    "        phi = torch.atan2(x[:, 1:2], x[:, 0:1])\n",
    "        return torch.cat([theta, phi], dim=1)\n",
    "\n",
    "    def squared_geodesic_distance(self, z1, z2, output_extrinsic=True):\n",
    "        x1 = self.to_extrinsic(z1)\n",
    "        x2 = self.to_extrinsic(z2)\n",
    "        eps = 1.0e-6\n",
    "        if output_extrinsic:\n",
    "            return x1, x2, torch.arccos(torch.clip((x1*x2).sum(dim=1), min=-1 + eps, max=1 - eps))**2\n",
    "        else:\n",
    "            return torch.arccos(torch.clip((x1*x2).sum(dim=1), min=-1 + eps, max=1 - eps))**2 \n",
    "\n",
    "    def Riemannian_metric(self, z):\n",
    "        theta = z[:, 0]\n",
    "        eps = 1.0e-6\n",
    "        sintheta = torch.clip((torch.sin(theta)**2).unsqueeze(1).unsqueeze(1), min=eps, max=1)\n",
    "        return torch.cat([\n",
    "                    torch.cat([torch.ones((len(z), 1, 1)).to(z), torch.zeros((len(z), 1, 1)).to(z)], dim=2),\n",
    "                    torch.cat([torch.zeros((len(z), 1, 1)).to(z), sintheta], dim=2)\n",
    "                ], dim=1)\n",
    "\n",
    "    def get_derivative_of_detG(self, z):\n",
    "        theta = z[:, 0:1]\n",
    "        return torch.cat([(2 * torch.sin(theta) * torch.cos(theta)), torch.zeros((len(z), 1)).to(z)], dim=1)\n",
    "\n",
    "    def get_inv_root_G(self, z):\n",
    "        theta = z[:, 0]\n",
    "        eps = 1.0e-6\n",
    "        sintheta = torch.clip((torch.abs(torch.sin(theta))).unsqueeze(1).unsqueeze(1), min=eps, max=1)\n",
    "        return torch.cat([\n",
    "                    torch.cat([torch.ones((len(z), 1, 1)).to(z), torch.zeros((len(z), 1, 1)).to(z)], dim=2),\n",
    "                    torch.cat([torch.zeros((len(z), 1, 1)).to(z), 1/sintheta], dim=2)\n",
    "                ], dim=1)\n",
    "\n",
    "    def get_inv_G(self, z):\n",
    "        theta = z[:, 0]\n",
    "        eps = 1.0e-6\n",
    "        sintheta = torch.clip((torch.abs(torch.sin(theta))).unsqueeze(1).unsqueeze(1), min=eps, max=1)\n",
    "        return torch.cat([\n",
    "                    torch.cat([torch.ones((len(z), 1, 1)).to(z), torch.zeros((len(z), 1, 1)).to(z)], dim=2),\n",
    "                    torch.cat([torch.zeros((len(z), 1, 1)).to(z), 1/(sintheta**2)], dim=2)\n",
    "                ], dim=1)\n",
    "\n",
    "    def exponential_map(self, x, v, t=1):\n",
    "        term1 = torch.cos(torch.norm(v, dim=1) * t).view(len(x), 1) * x\n",
    "        term2 = torch.sin(torch.norm(v, dim=1) * t).view(len(x), 1) * v/torch.norm(v, dim=1).view(len(v), 1)\n",
    "        return term1 + term2\n",
    "    \n",
    "    def project_to_tangentSpace(self, x, e):\n",
    "        return e - x*torch.sum(x*e,axis=1).view(-1,1)\n",
    "\n",
    "    def logarithm_map(self, x, x_, eps=1.0e-6):\n",
    "        temp = x_ - (x*x_).sum(dim=1, keepdim=True)*x\n",
    "        return torch.sqrt(\n",
    "            torch.acos(torch.clip((x*x_).sum(dim=1, keepdim=True), min=-1 + eps, max=1 - eps))**2\n",
    "            ) * temp/torch.norm(temp, dim=1, keepdim=True)\n",
    "\n",
    "    def Gamma(self, x):\n",
    "        theta = x[:, 0:1]\n",
    "        sin = torch.sin(theta)\n",
    "        cos = torch.cos(theta)\n",
    "        temp = sin**3 * cos\n",
    "        return torch.cat([temp, torch.zeros_like(temp)], dim=1)\n",
    "    \n",
    "def get_manifold(manifold):\n",
    "    if manifold == \"S2\":\n",
    "        return Sphere_manifold()\n",
    "    \n",
    "def jacobian_decoder_jvp_parallel(func, inputs, v=None, create_graph=True):\n",
    "    batch_size, z_dim = inputs.size()\n",
    "    if v is None:\n",
    "        v = torch.eye(z_dim).unsqueeze(0).repeat(batch_size, 1, 1).view(-1, z_dim).to(inputs)\n",
    "    inputs = inputs.repeat(1, z_dim).view(-1, z_dim)\n",
    "    jac = (\n",
    "        torch.autograd.functional.jvp(\n",
    "            func, inputs, v=v, create_graph=create_graph\n",
    "        )[1].view(batch_size, z_dim, -1).permute(0, 2, 1)\n",
    "    )\n",
    "    return jac\n",
    "\n",
    "def Euclidean_langevin_sampler(grad_logp_func, init_points, step_size=0.01, iter=100, ambient=False):\n",
    "    if not ambient:\n",
    "        sampled_points = init_points\n",
    "        for _ in range(iter):\n",
    "            sampled_points += step_size/2 * grad_logp_func(sampled_points, detach=True) \\\n",
    "                + np.sqrt(step_size) * torch.randn_like(sampled_points)\n",
    "    else:\n",
    "        sampled_points = init_points\n",
    "        for _ in range(iter):\n",
    "            sampled_points += step_size/2 * grad_logp_func(sampled_points, detach=True, ambient=ambient) \\\n",
    "                + np.sqrt(step_size) * torch.randn_like(sampled_points)\n",
    "            sampled_points = sampled_points/torch.norm(sampled_points, dim=1, keepdim=True)\n",
    "    return sampled_points\n",
    "\n",
    "def Riemannian_langevin_sampler(grad_logp_g_func, manifold, init_points, step_size=0.01, iter=100, curvature='constant', type=1, ambient=False):\n",
    "    if curvature == 'constant':\n",
    "        if not ambient:\n",
    "            sampled_points = init_points\n",
    "            for _ in range(iter):\n",
    "                G = manifold.Riemannian_metric(sampled_points)\n",
    "                invG = torch.inverse(G)\n",
    "                term1 = torch.einsum(\n",
    "                    'nij, nj -> ni', \n",
    "                    invG, \n",
    "                    grad_logp_g_func(sampled_points, detach=True) + \\\n",
    "                        manifold.get_derivative_of_detG(sampled_points)/(2*torch.det(G).view(len(G), 1)))\n",
    "                inv_root_G = manifold.get_inv_root_G(sampled_points)\n",
    "                term2 = torch.einsum('nij, nj -> ni', inv_root_G, torch.randn_like(sampled_points))\n",
    "                sampled_points += step_size/2 * term1 + np.sqrt(step_size) * term2\n",
    "        else:\n",
    "            sampled_points = init_points\n",
    "            for _ in range(iter):\n",
    "                term1 = grad_logp_g_func(sampled_points, detach=True, ambient=ambient)\n",
    "                term2 = manifold.project_to_tangentSpace(sampled_points, torch.randn_like(sampled_points))\n",
    "                sampled_points = manifold.exponential_map(sampled_points, step_size/2 * term1 + np.sqrt(step_size) * term2)\n",
    "                sampled_points = sampled_points/torch.norm(sampled_points, dim=1, keepdim=True)\n",
    "        return sampled_points    \n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "    \n",
    "class Sphere:\n",
    "    def __init__(self):\n",
    "        pass\n",
    "\n",
    "    def visualize(self, list_data, mode='extrinsic', view_init=(70, 70), axis='off', list_point_draw_kwargs=[{}]):\n",
    "        if mode == 'extrinsic':\n",
    "            fig = plt.figure()\n",
    "            ax = fig.add_subplot(111, projection=\"3d\")\n",
    "            ax_s = 1.2\n",
    "            plt.setp(\n",
    "                ax,\n",
    "                xlim=(-ax_s, ax_s),\n",
    "                ylim=(-ax_s, ax_s),\n",
    "                zlim=(-ax_s, ax_s),\n",
    "                xlabel=\"X\",\n",
    "                ylabel=\"Y\",\n",
    "                zlabel=\"Z\",\n",
    "            )\n",
    "            ax.set_box_aspect([1.0, 1.0, 1.0])\n",
    "            ax.view_init(*view_init)\n",
    "            ax.axis(axis)\n",
    "            for data, point_draw_kwargs in zip(list_data, list_point_draw_kwargs):\n",
    "                visualization.plot(\n",
    "                    data,\n",
    "                    ax=ax,\n",
    "                    space='S2',\n",
    "                    **point_draw_kwargs\n",
    "                )\n",
    "            plt.close()\n",
    "            return fig\n",
    "        elif mode == 'intrinsic-spherical':\n",
    "            fig, ax = plt.subplots()\n",
    "            for data, point_draw_kwargs in zip(list_data, list_point_draw_kwargs):\n",
    "                ax.scatter(data[:,1], data[:,0], **point_draw_kwargs)\n",
    "            plt.xlim(-np.pi, np.pi)\n",
    "            plt.ylim(0, np.pi)\n",
    "            if axis == 'on':\n",
    "                plt.xlabel('phi')\n",
    "                plt.ylabel('theta')\n",
    "            else:\n",
    "                plt.axis('off')\n",
    "            plt.gca().invert_yaxis()\n",
    "            plt.close()\n",
    "            return fig\n",
    "    \n",
    "    def to_intrinsic(self, x, type='spherical'):\n",
    "        if type == 'spherical':\n",
    "            eps = 1.0e-6\n",
    "            theta = np.arccos(np.clip(x[:, 2], a_min=-1+eps, a_max=1-eps)).reshape(-1, 1)\n",
    "            phi = np.arctan2(x[:, 1], x[:, 0])\n",
    "            phi = phi.reshape(-1, 1)\n",
    "            return np.hstack([theta, phi])\n",
    "\n",
    "    def to_extrinsic(self, z, type='spherical'):\n",
    "        if type == 'spherical':\n",
    "            sintheta = np.sin(z[:, 0:1])\n",
    "            costheta = np.cos(z[:, 0:1])\n",
    "            sinphi = np.sin(z[:, 1:2])\n",
    "            cosphi = np.cos(z[:, 1:2])\n",
    "            x = np.hstack([\n",
    "                cosphi * sintheta,\n",
    "                sinphi * sintheta,\n",
    "                costheta\n",
    "            ])\n",
    "            return x/np.linalg.norm(x, axis=1).reshape(len(x), 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [],
   "source": [
    "###############################################################################\n",
    "###############################################################################\n",
    "########################## Autoencoder Models #################################\n",
    "###############################################################################\n",
    "###############################################################################\n",
    "\n",
    "class AE(nn.Module):\n",
    "    def __init__(self, encoder, decoder):\n",
    "        super(AE, self).__init__()\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "\n",
    "    def encode(self, x):\n",
    "        return self.encoder(x)\n",
    "\n",
    "    def decode(self, z):\n",
    "        return self.decoder(z)\n",
    "\n",
    "    def forward(self, x):\n",
    "        z = self.encode(x)\n",
    "        recon = self.decode(z)\n",
    "        return recon\n",
    "\n",
    "    def train_step(self, x, optimizer, **kwargs):\n",
    "        optimizer.zero_grad()\n",
    "        recon = self(x)\n",
    "        loss = ((recon - x) ** 2).view(len(x), -1).mean(dim=1).mean()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        return {\"loss\": loss.item()}\n",
    "    \n",
    "    def validation_step(self, x, **kwargs):\n",
    "        recon = self(x)\n",
    "        loss = ((recon - x) ** 2).view(len(x), -1).mean(dim=1).mean()\n",
    "        return {\"loss\": loss.item()}\n",
    "\n",
    "class GRCAE(AE):\n",
    "    def __init__(\n",
    "        self, encoder, decoder, sigma=0.001, manifold=\"S2\", sampling_type=1\n",
    "    ):\n",
    "        super(GRCAE, self).__init__(encoder, decoder)\n",
    "        self.sigma = sigma\n",
    "        self.manifold = get_manifold(manifold)\n",
    "        self.sampling_type = sampling_type\n",
    "        \n",
    "    def train_step(self, z, optimizer, **kwargs):\n",
    "        optimizer.zero_grad()\n",
    "        recon = self(z)\n",
    "        recon_loss = self.manifold.squared_geodesic_distance(z, recon, output_extrinsic=False)\n",
    "        invG1 = self.manifold.get_inv_G(z)\n",
    "        G2 = self.manifold.Riemannian_metric(recon)\n",
    "        jac = jacobian_decoder_jvp_parallel(self, z)\n",
    "        rc_loss = torch.einsum('nij, nik, nkl, nlj -> n', jac, G2, jac, invG1)\n",
    "        loss = recon_loss.mean() + self.sigma**2 * rc_loss.mean()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        return {\"loss\": loss.item(), \"recon_loss\": recon_loss.mean().item(), \"rc_loss\": rc_loss.mean().item()}\n",
    "    \n",
    "    def validation_step(self, z, **kwargs):\n",
    "        score = self.gradient_log_rho_g(z) # bs x 2\n",
    "        G = self.manifold.Riemannian_metric(z) # bs x 2 x 2\n",
    "        invG = self.manifold.get_inv_G(z) # bs x 2 x 2\n",
    "\n",
    "        Gamma = self.manifold.Gamma(z) # bs x 2\n",
    "        dr_dz = jacobian_decoder_jvp_parallel(self, z) # bs x 2 x 2\n",
    "\n",
    "        term1 = torch.einsum(\n",
    "            'ni, nij, nj -> n',\n",
    "            score,\n",
    "            invG,\n",
    "            score + 2*Gamma\n",
    "            ).mean()\n",
    "        temp_term = (dr_dz - torch.eye(z.size(1)).unsqueeze(0).to(z))\n",
    "        term2 = (temp_term.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)/self.sigma**2).mean()\n",
    "\n",
    "        loss = term1 + 2*term2\n",
    "        return {\"loss\": loss.item()}\n",
    "\n",
    "    def gradient_log_rho_g(self, z, detach=True):\n",
    "        G = self.manifold.Riemannian_metric(z)\n",
    "        return torch.einsum('nij, nj -> ni', G, (self(z).detach() - z)/self.sigma**2)\n",
    "\n",
    "    def sample(self, inits, step_size=0.00001, iter=1000):\n",
    "        return Riemannian_langevin_sampler(\n",
    "            self.gradient_log_rho_g,\n",
    "            self.manifold,\n",
    "            inits,\n",
    "            step_size=step_size,\n",
    "            iter=iter,\n",
    "            curvature='constant',\n",
    "            type=self.sampling_type\n",
    "        )\n",
    "\n",
    "class ambientGRCAE(GRCAE):\n",
    "    def __init__(\n",
    "        self, encoder, decoder, sigma=0.001, manifold=\"S2\", sampling_type=1\n",
    "    ):\n",
    "        super(GRCAE, self).__init__(encoder, decoder)\n",
    "        self.sigma = sigma\n",
    "        self.manifold = get_manifold(manifold)\n",
    "        self.sampling_type = sampling_type\n",
    "        \n",
    "    def train_step(self, x, optimizer, **kwargs):\n",
    "        optimizer.zero_grad()\n",
    "        recon = self(x)\n",
    "        eps = 1.0e-6\n",
    "        recon_loss = torch.acos(torch.clip((x*recon).sum(dim=1), min=-1 + eps, max=1 - eps))**2\n",
    "        jac = jacobian_decoder_jvp_parallel(self, x)\n",
    "        dr_dx_multipy_x = torch.matmul(jac, x.unsqueeze(-1))\n",
    "        rc_loss = (jac**2).sum(-1).sum(-1) - (dr_dx_multipy_x**2).sum(-1).sum(-1)\n",
    "        loss = recon_loss.mean() + self.sigma**2 * rc_loss.mean()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        return {\"loss\": loss.item(), \"recon_loss\": recon_loss.mean().item(), \"rc_loss\": rc_loss.mean().item()}\n",
    "    \n",
    "    def validation_step(self, x, **kwargs):\n",
    "        z = self.manifold.to_intrinsic(x)\n",
    "        score = self.gradient_log_rho_g(z, ambient=False) # bs x 2\n",
    "        G = self.manifold.Riemannian_metric(z) # bs x 2 x 2\n",
    "        invG = self.manifold.get_inv_G(z) # bs x 2 x 2\n",
    "\n",
    "        Gamma = self.manifold.Gamma(z) # bs x 2\n",
    "        def r(z):\n",
    "            x = self.manifold.to_extrinsic(z)\n",
    "            return self.manifold.to_intrinsic(self(x))\n",
    "        dr_dz = jacobian_decoder_jvp_parallel(r, z) # bs x 2 x 2\n",
    "\n",
    "        term1 = torch.einsum(\n",
    "            'ni, nij, nj -> n',\n",
    "            score,\n",
    "            invG,\n",
    "            score + 2*Gamma\n",
    "            ).mean()\n",
    "        temp_term = (dr_dz - torch.eye(z.size(1)).unsqueeze(0).to(x))\n",
    "        term2 = (temp_term.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)/self.sigma**2).mean()\n",
    "\n",
    "        loss = term1 + 2*term2\n",
    "        return {\"loss\": loss.item()}\n",
    "\n",
    "    def gradient_log_rho_g(self, z, detach=True, ambient=False):\n",
    "        if not ambient:\n",
    "            G = self.manifold.Riemannian_metric(z)\n",
    "            x = self.manifold.to_extrinsic(z)\n",
    "            recon_x = self(x)\n",
    "            recon_z = self.manifold.to_intrinsic(recon_x)\n",
    "            return torch.einsum('nij, nj -> ni', G, (recon_z.detach() - z)/self.sigma**2)\n",
    "        else:\n",
    "            return self.manifold.logarithm_map(z, self(z).detach())/self.sigma**2\n",
    "            \n",
    "    def sample(self, inits, step_size=0.00001, iter=1000):\n",
    "        # inits = self.manifold.to_extrinsic(inits)\n",
    "        outputs = Riemannian_langevin_sampler(\n",
    "            self.gradient_log_rho_g,\n",
    "            self.manifold,\n",
    "            inits,\n",
    "            step_size=step_size,\n",
    "            iter=iter,\n",
    "            curvature='constant',\n",
    "            type=self.sampling_type,\n",
    "            ambient=True\n",
    "        )\n",
    "        return self.manifold.to_intrinsic(outputs)\n",
    "        \n",
    "class GDAE(AE):\n",
    "    def __init__(\n",
    "        self, encoder, decoder, sigma=0.001, manifold=\"S2\", sampling_type=1\n",
    "    ):\n",
    "        super(GDAE, self).__init__(encoder, decoder)\n",
    "        self.sigma = sigma\n",
    "        self.manifold = get_manifold(manifold)\n",
    "        self.sampling_type = sampling_type\n",
    "\n",
    "    def train_step(self, z, optimizer, **kwargs):\n",
    "        optimizer.zero_grad()\n",
    "        x1 = self.manifold.to_extrinsic(z)\n",
    "        v1 = self.manifold.project_to_tangentSpace(x1, self.sigma * torch.randn_like(x1))\n",
    "        x1n = self.manifold.exponential_map(x1, v1)\n",
    "        x1n = x1n/torch.norm(x1n, dim=1).view(len(x1n), 1)\n",
    "        z1n = self.manifold.to_intrinsic(x1n)\n",
    "\n",
    "        recon = self(z1n)\n",
    "        loss = self.manifold.squared_geodesic_distance(z, recon, output_extrinsic=False).mean() \n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        return {\"loss\": loss.item()}\n",
    "    \n",
    "    def validation_step(self, z, **kwargs):\n",
    "        score = self.gradient_log_rho_g(z) # bs x 2\n",
    "        G = self.manifold.Riemannian_metric(z) # bs x 2 x 2\n",
    "        invG = self.manifold.get_inv_G(z) # bs x 2 x 2\n",
    "\n",
    "        Gamma = self.manifold.Gamma(z) # bs x 2\n",
    "        dr_dz = jacobian_decoder_jvp_parallel(self, z) # bs x 2 x 2\n",
    "\n",
    "        term1 = torch.einsum(\n",
    "            'ni, nij, nj -> n',\n",
    "            score,\n",
    "            invG,\n",
    "            score + 2*Gamma\n",
    "            ).mean()\n",
    "        temp_term = (dr_dz - torch.eye(z.size(1)).unsqueeze(0).to(z))\n",
    "        term2 = (temp_term.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)/self.sigma**2).mean()\n",
    "\n",
    "        loss = term1 + 2*term2\n",
    "        return {\"loss\": loss.item()}\n",
    "\n",
    "    def gradient_log_rho_g(self, z, detach=True):\n",
    "        G = self.manifold.Riemannian_metric(z)\n",
    "        return torch.einsum('nij, nj -> ni', G, (self(z).detach() - z)/self.sigma**2)\n",
    "\n",
    "    def sample(self, inits, step_size=0.00001, iter=1000):\n",
    "        return Riemannian_langevin_sampler(\n",
    "            self.gradient_log_rho_g,\n",
    "            self.manifold,\n",
    "            inits,\n",
    "            step_size=step_size,\n",
    "            iter=iter,\n",
    "            curvature='constant',\n",
    "            type=self.sampling_type\n",
    "        )\n",
    "\n",
    "class ambientGDAE(GDAE):\n",
    "    def __init__(\n",
    "        self, encoder, decoder, sigma=0.001, manifold=\"S2\", sampling_type=1\n",
    "    ):\n",
    "        super(ambientGDAE, self).__init__(encoder, decoder)\n",
    "        self.sigma = sigma\n",
    "        self.manifold = get_manifold(manifold)\n",
    "        self.sampling_type = sampling_type\n",
    "\n",
    "    def train_step(self, x, optimizer, **kwargs):\n",
    "        optimizer.zero_grad()\n",
    "        x1 = x\n",
    "        v1 = self.manifold.project_to_tangentSpace(x1, self.sigma * torch.randn_like(x1))\n",
    "        x1n = self.manifold.exponential_map(x1, v1)\n",
    "        x1n = x1n/torch.norm(x1n, dim=1).view(len(x1n), 1)\n",
    "\n",
    "        recon = self(x1n)\n",
    "        eps = 1.0e-6\n",
    "        loss = (torch.acos(torch.clip((x1*recon).sum(dim=1), min=-1 + eps, max=1 - eps))**2).mean()\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        return {\"loss\": loss.item()}\n",
    "    \n",
    "    def validation_step(self, x, **kwargs):\n",
    "        z = self.manifold.to_intrinsic(x)\n",
    "        score = self.gradient_log_rho_g(z, ambient=False) # bs x 2\n",
    "        G = self.manifold.Riemannian_metric(z) # bs x 2 x 2\n",
    "        invG = self.manifold.get_inv_G(z) # bs x 2 x 2\n",
    "\n",
    "        Gamma = self.manifold.Gamma(z) # bs x 2\n",
    "        def r(z):\n",
    "            x = self.manifold.to_extrinsic(z)\n",
    "            return self.manifold.to_intrinsic(self(x))\n",
    "        dr_dz = jacobian_decoder_jvp_parallel(r, z) # bs x 2 x 2\n",
    "\n",
    "        term1 = torch.einsum(\n",
    "            'ni, nij, nj -> n',\n",
    "            score,\n",
    "            invG,\n",
    "            score + 2*Gamma\n",
    "            ).mean()\n",
    "        temp_term = (dr_dz - torch.eye(z.size(1)).unsqueeze(0).to(x))\n",
    "        term2 = (temp_term.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)/self.sigma**2).mean()\n",
    "\n",
    "        loss = term1 + 2*term2\n",
    "        return {\"loss\": loss.item()}\n",
    "\n",
    "    def gradient_log_rho_g(self, z, detach=True, ambient=False):\n",
    "        if not ambient:\n",
    "            G = self.manifold.Riemannian_metric(z)\n",
    "            x = self.manifold.to_extrinsic(z)\n",
    "            recon_x = self(x)\n",
    "            recon_z = self.manifold.to_intrinsic(recon_x)\n",
    "            return torch.einsum('nij, nj -> ni', G, (recon_z.detach() - z)/self.sigma**2)\n",
    "        else:\n",
    "            return self.manifold.logarithm_map(z, self(z).detach())/self.sigma**2\n",
    "\n",
    "    def sample(self, inits, step_size=0.00001, iter=1000):\n",
    "        # inits = self.manifold.to_extrinsic(inits)\n",
    "        outputs = Riemannian_langevin_sampler(\n",
    "            self.gradient_log_rho_g,\n",
    "            self.manifold,\n",
    "            inits,\n",
    "            step_size=step_size,\n",
    "            iter=iter,\n",
    "            curvature='constant',\n",
    "            type=self.sampling_type,\n",
    "            ambient=True\n",
    "        )\n",
    "        return self.manifold.to_intrinsic(outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "###############################################################################\n",
    "###############################################################################\n",
    "########################## Torch Data Dataset  ################################\n",
    "###############################################################################\n",
    "###############################################################################\n",
    "\n",
    "class Spherical_Distrib(torch.utils.data.Dataset):\n",
    "    def __init__(\n",
    "        self, split='training', type='two_moons', \n",
    "        n_samples=500, noise=0.1, random_state=0, \n",
    "        split_ratio=(0.6, 0.2), ambient=False, *args, **kwargs):\n",
    "        self.n_samples = n_samples\n",
    "        self.sphere = Sphere()\n",
    "        self.ambient = ambient\n",
    "\n",
    "        if type == 'two_moons':\n",
    "            xy = datasets.make_moons(\n",
    "                n_samples=n_samples, \n",
    "                noise=0, \n",
    "                random_state=random_state)[0]\n",
    "        elif type == 'four_blobs':\n",
    "            xy = datasets.make_blobs(\n",
    "                n_samples=n_samples, \n",
    "                n_features=2, \n",
    "                centers=np.array([[0.5, 0, -0.5, 0], [0, 0.5, 0, -0.5]]).transpose(),\n",
    "                cluster_std=0.0, \n",
    "                center_box=[-1.0, 1.0], \n",
    "                random_state=random_state)[0]\n",
    "        elif type == 'circles':\n",
    "            xy = datasets.make_circles(\n",
    "                n_samples=n_samples, \n",
    "                noise=0,\n",
    "                factor=0.5,\n",
    "                random_state=random_state\n",
    "            )[0]\n",
    "        elif type == 's-curve':\n",
    "            xyz = datasets.make_s_curve(\n",
    "                n_samples=n_samples,\n",
    "                noise=0.0,\n",
    "                random_state=random_state\n",
    "            )[0]\n",
    "            xy = np.hstack([xyz[:, 0:1], xyz[:, 2:3]])\n",
    "        else:\n",
    "            raise ValueError\n",
    "\n",
    "        xy = xy - xy.mean(axis=0).reshape(-1, 2)\n",
    "        xyz = np.hstack([\n",
    "            xy,\n",
    "            np.array([1]*n_samples).reshape(-1, 1),\n",
    "        ])\n",
    "        self.data = xyz/np.linalg.norm(xyz, axis=1).reshape(-1, 1) # extrinsic coordinate\n",
    "                \n",
    "        sphere_torch = Sphere_manifold()\n",
    "        self.x = torch.tensor(self.data,  dtype=torch.float32)\n",
    "        torch.random.manual_seed(random_state)\n",
    "        v = torch.randn_like(self.x)\n",
    "        v = sphere_torch.project_to_tangentSpace(self.x, v)\n",
    "        v = v * noise\n",
    "        self.x = sphere_torch.exponential_map(self.x, v)\n",
    "        self.x = self.x/torch.norm(self.x, dim=1, keepdim=True)\n",
    "        self.z = sphere_torch.to_intrinsic(self.x)\n",
    "        \n",
    "        num_train = int(split_ratio[0]*self.n_samples)\n",
    "        num_valid = int(split_ratio[1]*self.n_samples)\n",
    "        if split == 'training':\n",
    "            self.x = self.x[:num_train]\n",
    "            self.z = self.z[:num_train]\n",
    "        elif split == 'validation':\n",
    "            self.x = self.x[num_train:num_train+num_valid]\n",
    "            self.z = self.z[num_train:num_train+num_valid]\n",
    "        elif split == 'test':\n",
    "            self.x = self.x[num_train+num_valid:]\n",
    "            self.z = self.z[num_train+num_valid:]\n",
    "        elif split == 'all':\n",
    "            pass\n",
    "        print(f'split: {split}, num_data: {len(self.z)}')\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.z)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        if self.ambient:\n",
    "            return self.x[idx]\n",
    "        else:\n",
    "            return self.z[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "###############################################################################\n",
    "###############################################################################\n",
    "####################### Submodules for Models #################################\n",
    "###############################################################################\n",
    "###############################################################################\n",
    "\n",
    "def get_activation(s_act):\n",
    "    if s_act == \"relu\":\n",
    "        return nn.ReLU(inplace=True)\n",
    "    elif s_act == \"sigmoid\":\n",
    "        return nn.Sigmoid()\n",
    "    elif s_act == \"softplus\":\n",
    "        return nn.Softplus()\n",
    "    elif s_act == \"linear\":\n",
    "        return None\n",
    "    elif s_act == \"tanh\":\n",
    "        return nn.Tanh()\n",
    "    elif s_act == \"leakyrelu\":\n",
    "        return nn.LeakyReLU(0.2, inplace=True)\n",
    "    elif s_act == \"softmax\":\n",
    "        return nn.Softmax(dim=1)\n",
    "    elif s_act == \"selu\":\n",
    "        return nn.SELU()\n",
    "    elif s_act == \"elu\":\n",
    "        return nn.ELU()\n",
    "    else:\n",
    "        raise ValueError(f\"Unexpected activation: {s_act}\")\n",
    "\n",
    "class FC_vec(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        in_chan=2,\n",
    "        out_chan=1,\n",
    "        l_hidden=None,\n",
    "        activation=None,\n",
    "        out_activation=None,\n",
    "    ):\n",
    "        super(FC_vec, self).__init__()\n",
    "\n",
    "        self.in_chan = in_chan\n",
    "        self.out_chan = out_chan\n",
    "        l_neurons = l_hidden + [out_chan]\n",
    "        activation = activation + [out_activation]\n",
    "\n",
    "        l_layer = []\n",
    "        prev_dim = in_chan\n",
    "        for [n_hidden, act] in (zip(l_neurons, activation)):\n",
    "            l_layer.append(nn.Linear(prev_dim, n_hidden))\n",
    "            act_fn = get_activation(act)\n",
    "            if act_fn is not None:\n",
    "                l_layer.append(act_fn)\n",
    "            prev_dim = n_hidden\n",
    "\n",
    "        self.net = nn.Sequential(*l_layer)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "    \n",
    "def get_net(in_dim, out_dim, **kwargs):\n",
    "    if kwargs[\"arch\"] == \"fc_vec\":\n",
    "        l_hidden = kwargs[\"l_hidden\"]\n",
    "        activation = kwargs[\"activation\"]\n",
    "        out_activation = kwargs[\"out_activation\"]\n",
    "        net = FC_vec(\n",
    "            in_chan=in_dim,\n",
    "            out_chan=out_dim,\n",
    "            l_hidden=l_hidden,\n",
    "            activation=activation,\n",
    "            out_activation=out_activation,\n",
    "        )\n",
    "    return net\n",
    "\n",
    "class SphericalProjectionLayer(nn.Module):\n",
    "    def __init__(self, net):\n",
    "        super().__init__()\n",
    "        self.net = net\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.net(x)\n",
    "        return x/torch.norm(x, dim=1, keepdim=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "metadata": {},
   "outputs": [],
   "source": [
    "###############################################################################\n",
    "###############################################################################\n",
    "########################## Get Dataset, Loader, AE ############################\n",
    "###############################################################################\n",
    "###############################################################################\n",
    "\n",
    "def get_dataset(data_dict):\n",
    "    name = data_dict[\"dataset\"]\n",
    "    if name == 'Sphere':\n",
    "        dataset = Spherical_Distrib(**data_dict)\n",
    "    return dataset\n",
    "\n",
    "def get_dataloader(data_dict, **kwargs):\n",
    "    dataset = get_dataset(data_dict)\n",
    "    loader = torch.utils.data.DataLoader(\n",
    "        dataset,\n",
    "        batch_size=data_dict[\"batch_size\"],\n",
    "        shuffle=data_dict.get(\"shuffle\", True)\n",
    "    )\n",
    "    return loader\n",
    "\n",
    "def get_ae(**model_cfg):\n",
    "    x_dim = model_cfg['x_dim']\n",
    "    z_dim = model_cfg['z_dim']\n",
    "    arch = model_cfg[\"arch\"]\n",
    "    if arch == \"agrcae\":\n",
    "        sigma = model_cfg.get(\"sigma\", 0.001)\n",
    "        encoder = get_net(in_dim=x_dim, out_dim=z_dim, **model_cfg[\"encoder\"])\n",
    "        decoder = SphericalProjectionLayer(get_net(in_dim=z_dim, out_dim=x_dim, **model_cfg[\"decoder\"]))\n",
    "        model = ambientGRCAE(encoder, decoder, sigma=sigma)\n",
    "    elif arch == \"agdae\":\n",
    "        sigma = model_cfg.get(\"sigma\", 0.001)\n",
    "        encoder = get_net(in_dim=x_dim, out_dim=z_dim, **model_cfg[\"encoder\"])\n",
    "        decoder = SphericalProjectionLayer(get_net(in_dim=z_dim, out_dim=x_dim, **model_cfg[\"decoder\"]))\n",
    "        model = ambientGDAE(encoder, decoder, sigma=sigma)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 166,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "split: all, num_data: 1000\n",
      "split: training, num_data: 800\n",
      "split: validation, num_data: 200\n"
     ]
    }
   ],
   "source": [
    "###########################################################################\n",
    "###########################################################################\n",
    "########################## Main Code Starts Here ##########################\n",
    "###########################################################################\n",
    "###########################################################################\n",
    "\n",
    "sphere = Sphere()\n",
    "# dataset = 's-curve'\n",
    "# dataset = 'four_blobs'\n",
    "dataset = 'two_moons'\n",
    "# dataset = 'circles'\n",
    "\n",
    "noise = 0.01\n",
    "n_samples = 1000\n",
    "view_init = (80, 5)\n",
    "device = f'cuda:{0}'\n",
    "n_test = 1000\n",
    "\n",
    "\n",
    "############################### Warning ################################# \n",
    "## Parameter settings below are not exactly same as done in the paper. ##\n",
    "## In the paper, best parameters are searched using validation scores. ##\n",
    "#########################################################################\n",
    "\n",
    "# # agrcae setting\n",
    "# model_name = 'agrcae'\n",
    "# x_dim = 3\n",
    "# z_dim = 512\n",
    "# hidden = 512\n",
    "# step_size = 1.0e-4\n",
    "# sigma = 0.01\n",
    "# ambient = True\n",
    "\n",
    "# agdae setting\n",
    "model_name = 'agdae'\n",
    "x_dim = 3\n",
    "z_dim = 512\n",
    "hidden = 512\n",
    "step_size = 1.0e-4\n",
    "sigma = 0.01\n",
    "ambient = True\n",
    "\n",
    "#########################################################################\n",
    "#########################################################################\n",
    "\n",
    "data_dict_for_visualization = {\n",
    "    'dataset': 'Sphere',\n",
    "    'split': 'all',\n",
    "    'type': dataset,\n",
    "    'n_samples': n_test,\n",
    "    'noise': noise,\n",
    "    'batch_size': 32,\n",
    "    'random_state': 999,\n",
    "    'split_ratio': (0.8, 0.2)\n",
    "}\n",
    "dl_for_visualization = get_dataloader(\n",
    "    data_dict_for_visualization\n",
    ")\n",
    "\n",
    "train_data_dict = {\n",
    "    'dataset': 'Sphere',\n",
    "    'split': 'training',\n",
    "    'type': dataset,\n",
    "    'n_samples': n_samples,\n",
    "    'noise': noise,\n",
    "    'batch_size': 1000,\n",
    "    'random_state': 3,\n",
    "    'split_ratio': (0.8, 0.2),\n",
    "    'ambient': ambient\n",
    "}\n",
    "train_dl = get_dataloader(\n",
    "    train_data_dict\n",
    ")\n",
    "\n",
    "val_data_dict = {\n",
    "    'dataset': 'Sphere',\n",
    "    'split': 'validation',\n",
    "    'type': dataset,\n",
    "    'n_samples': n_samples,\n",
    "    'noise': noise,\n",
    "    'batch_size': 1000,\n",
    "    'random_state': 3,\n",
    "    'split_ratio': (0.8, 0.2),\n",
    "    'ambient': ambient\n",
    "}\n",
    "val_dl = get_dataloader(\n",
    "    val_data_dict\n",
    ")\n",
    "\n",
    "model_cfg = {\n",
    "    'arch': model_name,\n",
    "    'sigma': sigma,\n",
    "    'x_dim': x_dim,\n",
    "    'z_dim': z_dim,\n",
    "    'encoder': {\n",
    "        'arch': 'fc_vec',\n",
    "        'l_hidden': [hidden, hidden, ],\n",
    "        'activation': ['relu', 'relu', ],\n",
    "        'out_activation': 'linear'\n",
    "    },\n",
    "    'decoder': {\n",
    "        'arch': 'fc_vec',\n",
    "        'l_hidden': [hidden, hidden, ],\n",
    "        'activation': ['relu', 'relu', ],\n",
    "        'out_activation': 'linear'\n",
    "    }\n",
    "}\n",
    "\n",
    "if model_cfg['arch'] in ['gdae', 'grcae']:\n",
    "    model_cfg['sampling_type'] = 1\n",
    "model = get_ae(**model_cfg).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 167,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[epoch: 0]: 229362240.000000 < inf\n",
      "[epoch: 6]: 172017056.000000 < 229362240.000000\n",
      "[epoch: 10]: 167101632.000000 < 172017056.000000\n",
      "[epoch: 15]: 159051792.000000 < 167101632.000000\n",
      "[epoch: 16]: 149114496.000000 < 159051792.000000\n",
      "[epoch: 17]: 148722016.000000 < 149114496.000000\n",
      "[epoch: 18]: 148094528.000000 < 148722016.000000\n",
      "[epoch: 19]: 147878720.000000 < 148094528.000000\n",
      "[epoch: 20]: 146175744.000000 < 147878720.000000\n",
      "[epoch: 21]: 143584464.000000 < 146175744.000000\n",
      "[epoch: 22]: 138356000.000000 < 143584464.000000\n",
      "[epoch: 24]: 102861792.000000 < 138356000.000000\n",
      "[epoch: 25]: 7493075.000000 < 102861792.000000\n",
      "[epoch: 32]: 596940.125000 < 7493075.000000\n",
      "[epoch: 33]: 539916.312500 < 596940.125000\n",
      "[epoch: 46]: 178114.171875 < 539916.312500\n",
      "[epoch: 47]: 151346.406250 < 178114.171875\n",
      "[epoch: 48]: 131923.250000 < 151346.406250\n",
      "[epoch: 49]: 111282.664062 < 131923.250000\n",
      "[epoch: 50]: 94674.968750 < 111282.664062\n",
      "[epoch: 51]: 93980.257812 < 94674.968750\n",
      "[epoch: 57]: 67985.937500 < 93980.257812\n",
      "[epoch: 58]: 47542.542969 < 67985.937500\n",
      "[epoch: 59]: 38363.679688 < 47542.542969\n",
      "[epoch: 66]: 35096.292969 < 38363.679688\n",
      "[epoch: 69]: 30845.929688 < 35096.292969\n",
      "[epoch: 72]: 27194.187500 < 30845.929688\n",
      "[epoch: 73]: 22759.964844 < 27194.187500\n",
      "[epoch: 74]: 20794.183594 < 22759.964844\n",
      "[epoch: 75]: 19146.164062 < 20794.183594\n",
      "[epoch: 76]: 17441.369141 < 19146.164062\n",
      "[epoch: 79]: 15954.652344 < 17441.369141\n",
      "[epoch: 80]: 14478.742188 < 15954.652344\n",
      "[epoch: 81]: 12207.180664 < 14478.742188\n",
      "[epoch: 82]: 11625.037109 < 12207.180664\n",
      "[epoch: 83]: 11553.796875 < 11625.037109\n",
      "[epoch: 87]: 11252.414062 < 11553.796875\n",
      "[epoch: 88]: 10684.320312 < 11252.414062\n",
      "[epoch: 89]: 8677.145508 < 10684.320312\n",
      "[epoch: 90]: 8072.353516 < 8677.145508\n",
      "[epoch: 92]: 6752.696777 < 8072.353516\n",
      "[epoch: 93]: 6708.041504 < 6752.696777\n",
      "[epoch: 97]: 5126.771973 < 6708.041504\n",
      "[epoch: 99]: 4862.187988 < 5126.771973\n",
      "[epoch: 100]: 4763.079102 < 4862.187988\n",
      "[epoch: 103]: 4338.539062 < 4763.079102\n",
      "[epoch: 107]: 3233.111816 < 4338.539062\n",
      "[epoch: 112]: 2084.645996 < 3233.111816\n",
      "[epoch: 117]: 1552.987305 < 2084.645996\n",
      "[epoch: 121]: 1358.925293 < 1552.987305\n",
      "[epoch: 125]: 967.659668 < 1358.925293\n",
      "[epoch: 126]: 793.254395 < 967.659668\n",
      "[epoch: 136]: 432.507812 < 793.254395\n",
      "[epoch: 137]: -463.542236 < 432.507812\n",
      "[epoch: 144]: -565.242432 < -463.542236\n",
      "[epoch: 145]: -667.794189 < -565.242432\n",
      "[epoch: 151]: -966.824951 < -667.794189\n",
      "[epoch: 159]: -1716.655762 < -966.824951\n",
      "[epoch: 187]: -1862.761963 < -1716.655762\n",
      "[epoch: 194]: -2054.413818 < -1862.761963\n",
      "[epoch: 205]: -2448.915527 < -2054.413818\n",
      "[epoch: 212]: -2794.940430 < -2448.915527\n",
      "[epoch: 246]: -3366.865723 < -2794.940430\n",
      "[epoch: 319]: -3392.667236 < -3366.865723\n",
      "[epoch: 322]: -3505.363281 < -3392.667236\n",
      "[epoch: 331]: -3821.933350 < -3505.363281\n",
      "[epoch: 336]: -4000.197998 < -3821.933350\n",
      "[epoch: 369]: -4041.331299 < -4000.197998\n",
      "[epoch: 405]: -4084.207520 < -4041.331299\n",
      "[epoch: 411]: -4234.329102 < -4084.207520\n",
      "[epoch: 480]: -4527.641602 < -4234.329102\n",
      "[epoch: 556]: -4847.901367 < -4527.641602\n",
      "[epoch: 579]: -4868.739746 < -4847.901367\n",
      "[epoch: 761]: -5134.981445 < -4868.739746\n",
      "[epoch: 762]: -5160.802246 < -5134.981445\n",
      "[epoch: 763]: -5324.822754 < -5160.802246\n",
      "[epoch: 857]: -5333.768066 < -5324.822754\n",
      "[epoch: 882]: -5348.653809 < -5333.768066\n",
      "[epoch: 938]: -5398.116211 < -5348.653809\n",
      "[epoch: 952]: -5441.610352 < -5398.116211\n",
      "[epoch: 953]: -5494.146484 < -5441.610352\n",
      "[epoch: 1137]: -5521.209473 < -5494.146484\n",
      "[epoch: 1252]: -5599.901367 < -5521.209473\n",
      "[epoch: 1317]: -5776.736328 < -5599.901367\n",
      "[epoch: 1443]: -5839.322754 < -5776.736328\n",
      "[epoch: 1472]: -5941.106445 < -5839.322754\n",
      "[epoch: 1655]: -5996.120117 < -5941.106445\n",
      "[epoch: 1989]: -6176.425781 < -5996.120117\n",
      "[epoch: 2295]: -6199.855469 < -6176.425781\n",
      "[epoch: 2660]: -6215.451172 < -6199.855469\n",
      "[epoch: 2670]: -6298.730957 < -6215.451172\n",
      "[epoch: 3078]: -6347.077637 < -6298.730957\n",
      "[epoch: 3139]: -6445.459961 < -6347.077637\n",
      "[epoch: 3488]: -6775.252930 < -6445.459961\n",
      "early stopping at 4489\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 167,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "opt = torch.optim.Adam(model.parameters(), lr=1.0e-3, weight_decay=1.0e-12)\n",
    "\n",
    "## Training\n",
    "best_val_loss = np.inf\n",
    "for epoch in range(5000):\n",
    "    for x in train_dl:\n",
    "        train_dict_result = model.train_step(x.to(device), optimizer=opt)\n",
    "    val_loss = []\n",
    "    for x in val_dl:\n",
    "        val_dict_result = model.validation_step(x.to(device))\n",
    "        val_loss.append(val_dict_result['loss'])\n",
    "    val_loss = sum(val_loss)/len(val_loss)\n",
    "    if val_loss < best_val_loss:\n",
    "        best_epoch = epoch\n",
    "        best_model_state_dict = model.state_dict()\n",
    "        print(f'[epoch: {best_epoch}]: {val_loss:.6f} < {best_val_loss:.6f}')\n",
    "        best_val_loss = val_loss\n",
    "    if epoch > best_epoch + 1000:\n",
    "        print(f\"early stopping at {epoch}\")\n",
    "        break\n",
    "best_model = model\n",
    "best_model.load_state_dict(best_model_state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 168,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 done\n",
      "1 done\n",
      "2 done\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_898798/1863531782.py:23: MatplotlibDeprecationWarning: savefig() got unexpected keyword argument \"fomat\" which is no longer supported as of 3.3 and will become an error two minor releases later\n",
      "  temp_fig.savefig(img_buf, fomat='png')\n",
      "/tmp/ipykernel_898798/1863531782.py:54: MatplotlibDeprecationWarning: savefig() got unexpected keyword argument \"fomat\" which is no longer supported as of 3.3 and will become an error two minor releases later\n",
      "  temp_fig.savefig(img_buf, fomat='png')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3 done\n",
      "4 done\n",
      "5 done\n",
      "6 done\n",
      "7 done\n",
      "8 done\n",
      "9 done\n",
      "10 done\n",
      "11 done\n",
      "12 done\n",
      "13 done\n",
      "14 done\n",
      "15 done\n",
      "16 done\n",
      "17 done\n",
      "18 done\n",
      "19 done\n",
      "20 done\n",
      "21 done\n",
      "22 done\n",
      "23 done\n",
      "24 done\n",
      "25 done\n",
      "26 done\n",
      "27 done\n",
      "28 done\n",
      "29 done\n",
      "30 done\n",
      "31 done\n",
      "32 done\n",
      "33 done\n",
      "34 done\n",
      "35 done\n",
      "36 done\n",
      "37 done\n",
      "38 done\n",
      "39 done\n",
      "40 done\n",
      "41 done\n",
      "42 done\n",
      "43 done\n",
      "44 done\n",
      "45 done\n",
      "46 done\n",
      "47 done\n",
      "48 done\n",
      "49 done\n",
      "50 done\n",
      "51 done\n",
      "52 done\n",
      "53 done\n",
      "54 done\n",
      "55 done\n",
      "56 done\n",
      "57 done\n",
      "58 done\n",
      "59 done\n",
      "60 done\n",
      "61 done\n",
      "62 done\n",
      "63 done\n",
      "64 done\n",
      "65 done\n",
      "66 done\n",
      "67 done\n",
      "68 done\n",
      "69 done\n",
      "70 done\n",
      "71 done\n",
      "72 done\n",
      "73 done\n",
      "74 done\n",
      "75 done\n",
      "76 done\n",
      "77 done\n",
      "78 done\n",
      "79 done\n",
      "80 done\n",
      "81 done\n",
      "82 done\n",
      "83 done\n",
      "84 done\n",
      "85 done\n",
      "86 done\n",
      "87 done\n",
      "88 done\n",
      "89 done\n",
      "90 done\n",
      "91 done\n",
      "92 done\n",
      "93 done\n",
      "94 done\n",
      "95 done\n",
      "96 done\n",
      "97 done\n",
      "98 done\n",
      "99 done\n"
     ]
    }
   ],
   "source": [
    "# Sampling Results Save to GIF (at best model)\n",
    "inits = torch.randn(n_test, 3).to(device)\n",
    "inits = inits/torch.norm(inits, dim=-1).view(len(inits), 1)\n",
    "inits = inits.to(device)\n",
    "\n",
    "list_imgs = []\n",
    "\n",
    "temp_fig = sphere.visualize(\n",
    "    [\n",
    "        dl_for_visualization.dataset.x.numpy(), \n",
    "        inits.detach().cpu().numpy()\n",
    "    ],\n",
    "    axis='off', \n",
    "    view_init=view_init, \n",
    "    mode='extrinsic',\n",
    "    list_point_draw_kwargs=[\n",
    "        {'c': 'tab:red', 'alpha': 0.5},\n",
    "        {'c': 'tab:green', 'alpha': 0.5}\n",
    "    ], \n",
    "    )\n",
    "temp_fig.suptitle('Initial Sampled Points', y=0.8)\n",
    "img_buf = io.BytesIO()\n",
    "temp_fig.savefig(img_buf, fomat='png')\n",
    "im = Image.open(img_buf)\n",
    "\n",
    "for _ in range(10):\n",
    "    list_imgs.append(im)\n",
    "\n",
    "# sampled_points = best_model.sample(inits, step_size=step_size, iter=1000)\n",
    "sampled_points = inits\n",
    "for iter_ in range(100):\n",
    "    term1 = best_model.gradient_log_rho_g(sampled_points, detach=True, ambient=ambient)\n",
    "    term2 = best_model.manifold.project_to_tangentSpace(sampled_points, torch.randn_like(sampled_points))\n",
    "    sampled_points = best_model.manifold.exponential_map(sampled_points, step_size/2 * term1 + np.sqrt(step_size) * term2)\n",
    "    sampled_points = sampled_points/torch.norm(sampled_points, dim=1, keepdim=True)\n",
    "\n",
    "    ##\n",
    "    if iter_%1 == 0:\n",
    "        temp_fig = sphere.visualize(\n",
    "        [\n",
    "            # dl_for_visualization.dataset.x.numpy(), \n",
    "            sampled_points.detach().cpu().numpy()\n",
    "        ],\n",
    "        axis='off', \n",
    "        view_init=view_init, \n",
    "        mode='extrinsic',\n",
    "        list_point_draw_kwargs=[\n",
    "            # {'c': 'tab:red', 'alpha': 0.5},\n",
    "            {'c': 'tab:green', 'alpha': 0.5}\n",
    "        ], \n",
    "        )\n",
    "        temp_fig.suptitle(f'{model_name} \\n Riemannian Langevin Sampling \\n ({iter_}-th step)', y=0.8)\n",
    "        img_buf = io.BytesIO()\n",
    "        temp_fig.savefig(img_buf, fomat='png')\n",
    "        im = Image.open(img_buf)\n",
    "\n",
    "        list_imgs.append(im)\n",
    "        print(f'{iter_} done')\n",
    "\n",
    "# SAVE to gif\n",
    "imageio.mimsave(f'{model_name}:{dataset}.gif', list_imgs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "IsoFeature",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "eec87773ff7702733e8166a8c29f5904aadb5156d3bbcbf98cac38c84966ab91"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
