{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7088d12-95dc-4015-a112-e4c011bbdbe5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59897ab1-3a65-4ac7-9012-f5ad86ca4def",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import scanpy as sc\n",
    "import anndata as ad\n",
    "import os\n",
    "import pandas as pd\n",
    "from utils.preprocess import *\n",
    "from datasets.process import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b280876-8632-421f-aa60-c377cd6298e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "sys.path.append(os.path.abspath(\"conditional-flow-matching\"))\n",
    "    \n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import scanpy as sc\n",
    "import torch\n",
    "import torchsde\n",
    "from torchdyn.core import NeuralODE\n",
    "from tqdm import tqdm\n",
    "\n",
    "from torchcfm.conditional_flow_matching import *\n",
    "from torchcfm.models import MLP\n",
    "from torchcfm.utils import plot_trajectories, torch_wrapper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83f6b368-8234-4bcb-819b-8b36c266c977",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import umap\n",
    "from models.modules import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01c2a0ef-2134-4d03-a2ab-ebae8e145949",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1620f58-0462-4b18-b718-3592e168323b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scripts.run_model import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07098bc2-0902-463a-b67b-a53afbfd4a46",
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3276ba58-34fa-48af-aa8e-64f6be36da27",
   "metadata": {},
   "outputs": [],
   "source": [
    "############################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "119c2889-e7f2-4c3d-9422-094b40bca7ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchdyn.datasets import generate_moons\n",
    "\n",
    "from torchcfm.conditional_flow_matching import *\n",
    "from torchcfm.models.models import *\n",
    "from torchcfm.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8372eebd-0a3f-4293-818c-6ffc7d630b67",
   "metadata": {},
   "outputs": [],
   "source": [
    "from simulate.simulate import *\n",
    "from simulate.s import *\n",
    "from simulate.sphere import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2339305d-4f07-4ca8-a682-91501d1e1ee4",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(42)\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2aee349b-2ea6-41f5-9640-87952cdd7829",
   "metadata": {},
   "outputs": [],
   "source": [
    "### SETTINGS ###\n",
    "\n",
    "#TODO: scale up dimension\n",
    "#TODO: add noise back to the sampling procedure\n",
    "#TODO: calculate geodesic disparity\n",
    "\n",
    "dataname = \"sphere\"\n",
    "normalize = False\n",
    "energy_only = False\n",
    "# size = 5000\n",
    "# d = 3 #dummy dimension\n",
    "size = 20000\n",
    "d = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7a8c465-7e1b-4a7d-8a0c-a7b8392c84e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "############################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "042e5fbf-9194-4324-899b-1168a44e7bc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "dic = {\"sphere\": process_sphere_data,\n",
    "       \"cylinder\": process_cylinder_data,\n",
    "       \"ellipse\": process_ellipse_data}\n",
    "adata, values, fs = dic[dataname](size, d, normalize)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90b4a047-5c21-4b9f-9630-e92d38a502d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "conditions, dataset = extract_dataset(adata, values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98514ba1-d0e2-4ad8-aa0a-67fba2c75346",
   "metadata": {},
   "outputs": [],
   "source": [
    "############################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4cf5d82-b07b-44f6-83a2-48253917e337",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_samples = torch.from_numpy(adata.obsm['X_pca_raw']).float()\n",
    "fig, ax = plt.subplots()\n",
    "ax.scatter(raw_samples[:,0], raw_samples[:, 1])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65255ffd-51db-492f-b02e-0082f7943909",
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = torch.from_numpy(adata.obsm['X_pca']).float()\n",
    "fig, ax = plt.subplots()\n",
    "ax.scatter(samples[:,0], samples[:, 1])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a9e2e36-182b-463a-b446-aec794342b78",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = Config(\n",
    "    {        \n",
    "        \"model_class\": \"metricflow\",\n",
    "        \"score_max_epochs\": 2000,\n",
    "        \"energy_max_epochs\": 4000,\n",
    "        \"metric_max_epochs\": 2,\n",
    "        \"embed_max_epochs\": 2 if energy_only else 4000,\n",
    "        \"flow_max_epochs\": 2 if energy_only else 2,\n",
    "        \"lr\": 1e-4,\n",
    "        \"dropout\": 0.0,\n",
    "        \"pc_dim\": d,\n",
    "        \"cond_dim\": d,\n",
    "        \"hidden_dim\": 256,\n",
    "        \"score_batch_size\": 4096,\n",
    "        \"flow_batch_size\": 256,\n",
    "        \"num_freq\": 32,\n",
    "        \"num_layers\": 4,\n",
    "        \"force_cpu\": False,\n",
    "        \"gradient_clip_val\": 0,\n",
    "        \"loader_batch_size\": 6,\n",
    "        \"warmup_steps\": 0,\n",
    "        \"ema_decay\": 1-1e-3,\n",
    "\n",
    "        \"mfm_benchmark\": False,\n",
    "\n",
    "        \"pita_steps\": 1,\n",
    "\n",
    "        \"score_alpha\": 1.0,\n",
    "\n",
    "        \"energy_noise_sigma\": 0.0,\n",
    "        \"metric_scale\": 10,\n",
    "        \"metric_sigma\": 0.05,\n",
    "\n",
    "        \"ot_in_embed\": False, #embed will train over arbitrary paths, not just OT paths\n",
    "        \"fast_ot\": False,\n",
    "\n",
    "        \"num_sigmas\": 20,\n",
    "        \"sigma_min\": 0.005,\n",
    "        \"sigma_max\": 0.2,\n",
    "\n",
    "        \"sigma_dim\": 32,\n",
    "\n",
    "        \"score_beta_min\": 1.0,\n",
    "        \"score_beta_max\": 1.0,\n",
    "        \n",
    "        \"latent_dim\": 100,\n",
    "\n",
    "        \"skip\": False,\n",
    "        \"rescale\": 0.5,\n",
    "\n",
    "        \"pre_low_q\": .05,\n",
    "        \"pre_high_q\": .98,\n",
    "        \"low_q\": .05,\n",
    "        \"high_q\": .95,\n",
    "\n",
    "        \"weight_beta\": 1.0,\n",
    "\n",
    "        \"gamma\": 0.2,\n",
    "\n",
    "        \"sigma\": 0.1,\n",
    "\n",
    "        \"n_neighbors\": 10,\n",
    "        \"resolution\": 0.3,\n",
    "    }\n",
    ")\n",
    "\n",
    "project = \"fm-ipynb\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd69d1a0-283c-4237-8e16-c78af12041bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # MFM BENCHMARK ###\n",
    "# config.score_max_epochs = 2\n",
    "# config.energy_max_epochs = 2\n",
    "# config.metric_max_epochs = 5000\n",
    "# config.mfm_benchmark = True\n",
    "# config.K = 10000\n",
    "# config.kappa = 0.5\n",
    "# config.pita_steps = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfa1b6a9-4a46-40ab-9dc0-d2545ddd4ccc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # MFM BENCHMARK for small d###\n",
    "# config.score_max_epochs = 2\n",
    "# config.energy_max_epochs = 2\n",
    "# config.metric_max_epochs = 2000\n",
    "# config.mfm_benchmark = True\n",
    "# config.K = 1000\n",
    "# config.kappa = 1.0\n",
    "# config.pita_steps = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74b1a806-9f89-4738-8a60-073b25911614",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "pre_score_models, pre_energy_models, score_model, energy_model, metric_model, embed_model, flow_model = run_full_model(config, project, adata, values, conditions, dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d376e4a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize(model, samples):\n",
    "\n",
    "    import matplotlib.pyplot as plt\n",
    "    plt.close('all')\n",
    "    m = torch.max(torch.abs(samples)).item()\n",
    "    m += 1\n",
    "    \n",
    "    xs = torch.linspace(-m, m, steps=30)\n",
    "    ys = torch.linspace(-m, m, steps=30)\n",
    "    x, y = torch.meshgrid(xs, ys, indexing='xy')\n",
    "    print(x.shape)\n",
    "    print(y.shape)\n",
    "    z = torch.stack([x, y], dim = -1).reshape(900, -1)\n",
    "\n",
    "    zeros = torch.zeros(z.shape[0], d-2, device=z.device, dtype=z.dtype)\n",
    "    z = torch.cat([z, zeros], dim=1) \n",
    "\n",
    "    print(z.shape)\n",
    "\n",
    "    ###\n",
    "    sigma = 0.05 * torch.ones(z.shape[0],1,device=z.device)\n",
    "    # sigma = None\n",
    "    ###\n",
    "    v = model(z, sigma).detach()\n",
    "    # h = torch.ones(z.shape[0],1).to(z.device) * min(model.score_sigma)\n",
    "    # v = model(h, z).detach()\n",
    "\n",
    "    fig, ax = plt.subplots()\n",
    "    import numpy as np\n",
    "    from scipy.stats import gaussian_kde\n",
    "    # xy = samples.cpu().numpy().T\n",
    "    # kde = gaussian_kde(xy)\n",
    "    # grid = np.linspace(-m, m, 200)\n",
    "    # xi, yi = np.meshgrid(grid, grid)\n",
    "    # zi = kde(np.vstack([xi.ravel(), yi.ravel()])).reshape(xi.shape)\n",
    "    # ax.contour(xi, yi, zi, levels=10, linewidths=1)\n",
    "    ax.scatter(samples[:,0], samples[:, 1])\n",
    "    ax.quiver(z[:,0], z[:,1], v[:,0], v[:,1])\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28028955-705b-4228-862b-e9be5c1d3c97",
   "metadata": {},
   "outputs": [],
   "source": [
    "#fix the wandb.run.summary bug?\n",
    "def remove_all_forward_hooks(model):\n",
    "    for module in model.modules():\n",
    "        module._forward_hooks.clear()\n",
    "\n",
    "remove_all_forward_hooks(score_model)\n",
    "remove_all_forward_hooks(energy_model)\n",
    "remove_all_forward_hooks(embed_model)\n",
    "remove_all_forward_hooks(flow_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6731a7d7-a9eb-4126-84b6-3bf5d46aa323",
   "metadata": {},
   "outputs": [],
   "source": [
    "#SCORE VALIDATION\n",
    "samples = torch.from_numpy(adata.obsm['X_pca']).float()\n",
    "visualize(score_model, samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4e71353-db39-4396-a785-4f5a91533b4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_energy(energy_model, samples, grid_steps=30):\n",
    "\n",
    "    import torch, numpy as np, matplotlib.pyplot as plt\n",
    "    from scipy.stats import gaussian_kde\n",
    "    plt.close('all')\n",
    "\n",
    "    m = torch.max(torch.abs(samples)).item() + 1.0\n",
    "    xs = torch.linspace(-m,  m, steps=grid_steps)\n",
    "    ys = torch.linspace(-m,  m, steps=grid_steps)\n",
    "    xg, yg = torch.meshgrid(xs, ys, indexing='xy')\n",
    "    z   = torch.stack([xg, yg], dim=-1).reshape(-1, 2)\n",
    "\n",
    "    zeros = torch.zeros(z.shape[0], d-2, device=z.device, dtype=z.dtype)\n",
    "    z = torch.cat([z, zeros], dim=1) \n",
    "\n",
    "    ###\n",
    "    # sigma = 0.05 * torch.ones(z.shape[0],1,device=z.device)\n",
    "    # E = energy_model.forward_energy(z, sigma).detach().reshape(grid_steps, grid_steps)\n",
    "    ###\n",
    "    E = energy_model.get_energy(z).detach().reshape(grid_steps, grid_steps)\n",
    "\n",
    "    fig, ax = plt.subplots()\n",
    "\n",
    "    contour = ax.contourf(xg, yg, -E, levels=100, cmap='inferno')\n",
    "    ax.scatter(samples[:, 0], samples[:, 1], s=2, alpha=0.005)\n",
    "    ax.set_aspect('equal')\n",
    "    fig.colorbar(contour, ax=ax, label=\"negative energy\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eac60792-6738-4a88-aea6-04d6149f0c6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_metric(metric_model, samples, grid_steps=30):\n",
    "\n",
    "    import torch, numpy as np, matplotlib.pyplot as plt\n",
    "    from scipy.stats import gaussian_kde\n",
    "    plt.close('all')\n",
    "\n",
    "    m = torch.max(torch.abs(samples)).item() + 1.0\n",
    "    xs = torch.linspace(-m,  m, steps=grid_steps)\n",
    "    ys = torch.linspace(-m,  m, steps=grid_steps)\n",
    "    xg, yg = torch.meshgrid(xs, ys, indexing='xy')\n",
    "    z   = torch.stack([xg, yg], dim=-1).reshape(-1, 2)\n",
    "\n",
    "    zeros = torch.zeros(z.shape[0], d-2, device=z.device, dtype=z.dtype)\n",
    "    z = torch.cat([z, zeros], dim=1) \n",
    "\n",
    "    if config.mfm_benchmark:\n",
    "        E = metric_model(z.to('cuda')).detach().reshape(grid_steps, grid_steps).cpu()\n",
    "    else:\n",
    "        E = metric_model(z).detach().reshape(grid_steps, grid_steps)\n",
    "\n",
    "    fig, ax = plt.subplots()\n",
    "\n",
    "    contour = ax.contourf(xg, yg, -E, levels=100, cmap='inferno')\n",
    "    ax.scatter(samples[:, 0], samples[:, 1], s=2, alpha=0.005)\n",
    "    ax.set_aspect('equal')\n",
    "    fig.colorbar(contour, ax=ax, label=\"negative metric tensor\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a4d4525-dc4d-4474-bf57-fff43227298a",
   "metadata": {},
   "outputs": [],
   "source": [
    "for sm in pre_score_models:\n",
    "    print(sm.score_beta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd17d250-0b1e-408e-8916-4f177d1e5bfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = torch.from_numpy(adata.obsm['X_pca']).float()\n",
    "visualize_energy(pre_energy_models[0], samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7221fd9f-eed7-4eb7-ab3c-e3d7786c0e32",
   "metadata": {},
   "outputs": [],
   "source": [
    "if len(pre_energy_models) > 1:\n",
    "    samples = torch.from_numpy(adata.obsm['X_pca']).float()\n",
    "    visualize_energy(pre_energy_models[1], samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a50f597b-a47a-4112-8911-fb8e6b8b1050",
   "metadata": {},
   "outputs": [],
   "source": [
    "if len(pre_energy_models) > 2:\n",
    "    samples = torch.from_numpy(adata.obsm['X_pca']).float()\n",
    "    visualize_energy(pre_energy_models[2], samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3174ca21-758f-44d1-9dc3-5456c27cf7c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = torch.from_numpy(adata.obsm['X_pca']).float()\n",
    "visualize_energy(energy_model, samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9e72401-1ca7-46ce-8aec-bb378d650d0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = torch.from_numpy(adata.obsm['X_pca']).float()\n",
    "visualize_metric(metric_model, samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07ffa832-bcd7-4f23-aae2-68f1399ca28f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "X = adata.obsm['X_pca']\n",
    "X = torch.from_numpy(X).float()\n",
    "dt = 0.2\n",
    "sigmas = np.linspace(1.0-dt, 1.0+dt, 10).tolist()\n",
    "fig, axs = plt.subplots(len(sigmas), figsize=(8,8), sharex=True)\n",
    "for i, sigma in enumerate(sigmas):\n",
    "    Z = X * sigma\n",
    "    Z = Z.to(energy_model.get_device())\n",
    "    E = energy_model.get_energy(Z).detach().cpu().numpy()\n",
    "    axs[i].hist(E)\n",
    "    axs[i].set_title(str(sigmas[i]))\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96cb6dca-25ee-451b-ab6a-7774057d3ddc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import TensorDataset, Sampler, DataLoader\n",
    "from datasets.dataset import *\n",
    "\n",
    "_, raw_dataset = extract_dataset(adata, values, use_rep='X_pca_raw')\n",
    "train_dataset = ShufflingDataset(raw_dataset, 500, conditions)\n",
    "train_dataloader = DataLoader(train_dataset, batch_size = config.loader_batch_size, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "071499cb-4221-49be-bb8c-6441d6b30a29",
   "metadata": {},
   "outputs": [],
   "source": [
    "#GEO VALIDATION\n",
    "if not energy_only:\n",
    "    for batch in train_dataloader:\n",
    "\n",
    "        x0, x1, _, _, _ = batch\n",
    "        \n",
    "        paths = embed_model.sample_geodesic(batch, points=100, ot_sample=False)\n",
    "        \n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d1d50bf-21de-404c-b232-0a86aa419cb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "paths.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faf4d5fd-c973-448d-bfad-6088e95dc63e",
   "metadata": {},
   "outputs": [],
   "source": [
    "if d == 2:\n",
    "    fig, ax = plt.subplots(5, figsize=(8,8), sharex=True, sharey=True)\n",
    "    for i in range(5):\n",
    "        traj = paths[:,i]\n",
    "        ax[i].scatter(raw_samples[:,0], raw_samples[:, 1])\n",
    "        ax[i].plot(traj[:, 0], traj[:, 1], marker=\".\", linestyle=\"-\", alpha=0.7, lw=1.5, color='orange')\n",
    "        ax[i].set(adjustable='box', aspect='equal')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "263ff7ba-b4bc-409b-aa7e-c790123905be",
   "metadata": {},
   "outputs": [],
   "source": [
    "def slerp(x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Calculates batched Spherical Linear Interpolation (Slerp) to find points on \n",
    "    the geodesic path on a sphere. This function is identical to the one previously provided.\n",
    "    \"\"\"\n",
    "    # --- 1. Input Validation and Normalization ---\n",
    "    x0 = F.normalize(x0, p=2, dim=-1)\n",
    "    x1 = F.normalize(x1, p=2, dim=-1)\n",
    "\n",
    "    # --- 2. Calculate the Angle (Omega) between Vectors ---\n",
    "    dot_product = torch.sum(x0 * x1, dim=-1)\n",
    "    dot_product = torch.clamp(dot_product, -1.0, 1.0)\n",
    "    omega = torch.acos(dot_product)\n",
    "\n",
    "    # --- 3. Handle Edge Cases and Prepare for Broadcasting ---\n",
    "    sin_omega = torch.sin(omega)\n",
    "    \n",
    "    # Reshape for broadcasting\n",
    "    t_view = t.view(-1, 1)\n",
    "    omega_view = omega.view(-1, 1)\n",
    "    sin_omega_view = sin_omega.view(-1, 1)\n",
    "\n",
    "    # --- 4. Calculate Slerp Coefficients ---\n",
    "    is_small_angle = sin_omega_view.abs() < eps\n",
    "    \n",
    "    c0_linear = 1.0 - t_view\n",
    "    c1_linear = t_view\n",
    "    \n",
    "    c0 = torch.sin((1.0 - t_view) * omega_view) / sin_omega_view\n",
    "    c1 = torch.sin(t_view * omega_view) / sin_omega_view\n",
    "    \n",
    "    c0 = torch.where(is_small_angle, c0_linear, c0)\n",
    "    c1 = torch.where(is_small_angle, c1_linear, c1)\n",
    "    \n",
    "    # --- 5. Compute and Return the Interpolated Vector ---\n",
    "    xt = c0 * x0 + c1 * x1\n",
    "    xt = F.normalize(xt, p=2, dim=-1)\n",
    "    \n",
    "    return xt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53e61e5e-bc3d-4648-bcea-8d74a47a732a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_geodesics(x0, x1, paths, title_prefix = \"\"):\n",
    "    \"\"\"\n",
    "    Creates a 3D matplotlib visualization of the geodesic paths on a sphere.\n",
    "\n",
    "    Args:\n",
    "        x0 (torch.Tensor): The batch of starting points. Shape: (batch_size, 3).\n",
    "        x1 (torch.Tensor): The batch of final points. Shape: (batch_size, 3).\n",
    "        paths (list[torch.Tensor]): A list where each element is a tensor of points\n",
    "                                     representing a geodesic path.\n",
    "    \"\"\"\n",
    "    fig = plt.figure(figsize=(10, 10))\n",
    "    ax = fig.add_subplot(111, projection='3d')\n",
    "\n",
    "    # --- 1. Draw the Sphere ---\n",
    "    # Create the meshgrid for the sphere's surface\n",
    "    u = np.linspace(0, 2 * np.pi, 100)\n",
    "    v = np.linspace(0, np.pi, 100)\n",
    "    x = np.outer(np.cos(u), np.sin(v))\n",
    "    y = np.outer(np.sin(u), np.sin(v))\n",
    "    z = np.outer(np.ones(np.size(u)), np.cos(v))\n",
    "\n",
    "    # Plot the wireframe of the sphere\n",
    "    ax.plot_wireframe(x, y, z, color='gray', alpha=0.2, rstride=5, cstride=5)\n",
    "\n",
    "    # --- 2. Plot the Points and Paths ---\n",
    "    # Convert tensors to numpy for plotting\n",
    "    x0_np = x0.numpy()\n",
    "    x1_np = x1.numpy()\n",
    "\n",
    "    # Plot start points (blue) and end points (red)\n",
    "    ax.scatter(x0_np[:, 0], x0_np[:, 1], x0_np[:, 2], c='blue', s=100, label='Start Points (x0)')\n",
    "    ax.scatter(x1_np[:, 0], x1_np[:, 1], x1_np[:, 2], c='red', s=100, label='End Points (x1)')\n",
    "\n",
    "    # Plot the geodesic paths\n",
    "    for i, path in enumerate(paths):\n",
    "        path_np = path.numpy()\n",
    "        ax.plot(path_np[:, 0], path_np[:, 1], path_np[:, 2], color='green', \n",
    "                label=f'Path {i}' if i == 0 else \"\") # Only label the first path\n",
    "\n",
    "    # --- 3. Set Plot Aesthetics ---\n",
    "    ax.set_xlabel('X Axis')\n",
    "    ax.set_ylabel('Y Axis')\n",
    "    ax.set_zlabel('Z Axis')\n",
    "    ax.set_title(title_prefix + \" \" + 'Geodesic Paths')\n",
    "    \n",
    "    # Set aspect ratio to be equal\n",
    "    ax.set_box_aspect([1, 1, 1]) # IMPORTANT for a correct sphere visualization\n",
    "    \n",
    "    # Set axis limits to keep the sphere centered and properly scaled\n",
    "    ax.set_xlim([-1.1, 1.1])\n",
    "    ax.set_ylim([-1.1, 1.1])\n",
    "    ax.set_zlim([-1.1, 1.1])\n",
    "    \n",
    "    ax.legend()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "346a369a-ab19-4adf-9bb0-7b6e531b720b",
   "metadata": {},
   "outputs": [],
   "source": [
    "path_resolution = paths.shape[0]\n",
    "batch_size = paths.shape[1]\n",
    "\n",
    "x0 = paths[0]\n",
    "x1 = paths[-1]\n",
    "\n",
    "x0 = F.normalize(x0, p=2, dim=-1)\n",
    "x1 = F.normalize(x1, p=2, dim=-1)\n",
    "\n",
    "geodesic_paths = []\n",
    "for i in range(batch_size):\n",
    "    t_path = torch.linspace(0, 1, path_resolution)\n",
    "    \n",
    "    x0_path = x0[i].expand(path_resolution, -1)\n",
    "    x1_path = x1[i].expand(path_resolution, -1)\n",
    "    \n",
    "    # Calculate the path\n",
    "    true_path = slerp(x0_path, x1_path, t_path)\n",
    "    geodesic_paths.append(true_path)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ce4e9f9-d612-4987-b9ce-f2574cefd0b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "paths[:,:5].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5b78759-4dbc-4bdf-a5bc-946323c665e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "if d == 3:\n",
    "    # --- Visualize ---\n",
    "    visualize_geodesics(x0[:5], x1[:5], geodesic_paths[:5], \"True\")\n",
    "    visualize_geodesics(x0[:5], x1[:5], [paths[:,i] for i in range(5)], \"Learned\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2aeffb3f-f9e4-46a1-b5e4-b86ba2bc1b46",
   "metadata": {},
   "outputs": [],
   "source": [
    "paths.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0719a617-d81b-4da4-ac97-467f3f297cde",
   "metadata": {},
   "outputs": [],
   "source": [
    "p = geodesic_paths[0]\n",
    "print(torch.norm(p, dim = 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8586d31-b48c-4b55-957b-caca1baacda0",
   "metadata": {},
   "outputs": [],
   "source": [
    "p = paths[:,3]\n",
    "print(torch.norm(p, dim = 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0196de0-9a12-4f53-8718-deb72e46640d",
   "metadata": {},
   "outputs": [],
   "source": [
    "paths[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15b9ad63-59d9-45f2-8f26-bbe9f53db642",
   "metadata": {},
   "outputs": [],
   "source": [
    "def geodesic_disc(paths, true_paths):\n",
    "    assert paths.shape == true_paths.shape\n",
    "    assert torch.allclose(torch.linalg.norm(paths[0], dim = 1).float(), torch.ones(paths.shape[1]))\n",
    "    assert torch.allclose(torch.linalg.norm(paths[-1], dim = 1).float(), torch.ones(paths.shape[1]))\n",
    "    z = torch.linalg.norm(paths-true_paths, dim = 2).mean(0)\n",
    "    return z.mean(), z.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9c16ca4-025f-4ada-be26-ed06489f3907",
   "metadata": {},
   "outputs": [],
   "source": [
    "geodesic_disc(paths, torch.stack(geodesic_paths, dim=1))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:spatialenv] *",
   "language": "python",
   "name": "conda-env-spatialenv-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
