{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ddd7cee5-4a64-49ad-80b0-10480e2209e2",
   "metadata": {},
   "source": [
    "### A Demo About Fast Riemannian Attention\n",
    "In this demo, we will show that\n",
    "(1) Fast Riemannian Attention closely matches the results of conventional Riemannian Attention, while achieving a substantial reduction in runtime.\n",
    "(2) When curvature = 0, our method reduces to Euclidean attention.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5e4d7e47-5965-46d6-b0c2-45f8c9b0cd0e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Curvature = -1\n",
      "Riemannian_attention time:0.703925609588623\n",
      "Fast_Riemannian_Attention time:0.12346529960632324\n",
      "Maximum error = 0.0007220767438411713\n",
      "MAE = 8.976956451078877e-05\n",
      "Curvature = 1\n",
      "Riemannian_attention time:0.6795008182525635\n",
      "Fast_Riemannian_Attention time:0.12851357460021973\n",
      "Maximum error = 4.6566128730773926e-07\n",
      "MAE = 1.8362371623226181e-09\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import geoopt\n",
    "from geoopt.manifolds.stereographic.math import project\n",
    "from geoopt.manifolds.stereographic import StereographicExact\n",
    "from geoopt import ManifoldTensor\n",
    "from geoopt import ManifoldParameter\n",
    "import torch.nn.functional as F\n",
    "import time\n",
    "\n",
    "\n",
    "\n",
    "seed = 0\n",
    "N=500\n",
    "L=32\n",
    "d=200\n",
    "\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "K = torch.randn(N, L, d)\n",
    "Q = torch.randn(N, L, d)\n",
    "V = torch.randn(N, L, d)\n",
    "\n",
    "\n",
    "    \n",
    "\n",
    "\n",
    "def Riemannian_attention(K,Q,V,kappa):\n",
    "    manifold = geoopt.Stereographic(k=kappa, learnable=False)\n",
    "    k = manifold.proju(manifold.origin(K.shape), K)\n",
    "    k = manifold.expmap0(k, project=True)\n",
    "    q = manifold.proju(manifold.origin(Q.shape), Q)\n",
    "    q = manifold.expmap0(q, project=True)\n",
    "    v = manifold.proju(manifold.origin(V.shape), V)\n",
    "    v = manifold.expmap0(v, project=True)\n",
    "    \n",
    "    distance = manifold.dist(q.unsqueeze(2), k.unsqueeze(1))\n",
    "    attn = F.softmax(-distance, dim=-1)\n",
    "    attn = attn.unsqueeze(-1) \n",
    "    v_expanded = v.unsqueeze(1)\n",
    "    weighted_v = manifold.mobius_scalar_mul(attn, v_expanded)\n",
    "    \n",
    "    weighted_v = manifold.logmap0(weighted_v)\n",
    "    \n",
    "    \n",
    "            \n",
    "    output = torch.sum(weighted_v,dim=2)\n",
    "            \n",
    "    \n",
    "    return output\n",
    "\n",
    "def rie_polar_decompose(x, kappa, eps=1e-15):\n",
    "    # r = ||x|| over last dimension\n",
    "    r = torch.norm(x, dim=-1).clamp_min(1e-15)  # [N, L]\n",
    "    # u = x / ||x||\n",
    "    u = x / r.unsqueeze(-1)   # [N, L, d]\n",
    "\n",
    "\n",
    "    kappa = torch.tensor(kappa, dtype=x.dtype, device=x.device)\n",
    "    \n",
    "    if kappa<0:\n",
    "        sqrt_neg_k = torch.sqrt(-kappa)\n",
    "        z = (sqrt_neg_k * r).clamp(-1.0 + eps, 1.0 - eps)\n",
    "        rho = 2.0 / sqrt_neg_k * torch.atanh(z)  # [N, L]\n",
    "    elif kappa>0:\n",
    "        sqrt_k = torch.sqrt(kappa)\n",
    "        rho = 2.0 / sqrt_k * torch.atan(sqrt_k * r)\n",
    "    else:  # kappa == 0\n",
    "        rho = 2.0 * r\n",
    "\n",
    "    return r, u, rho\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def Fast_Riemannian_Attention(K,Q,V,kappa):\n",
    "    manifold = geoopt.Stereographic(k=kappa, learnable=False)\n",
    "    k = manifold.proju(manifold.origin(K.shape), K)\n",
    "    k = manifold.expmap0(k, project=True)\n",
    "    k_r, k_u, k_rho = rie_polar_decompose(k,kappa)\n",
    "    q = manifold.proju(manifold.origin(Q.shape), Q)\n",
    "    q = manifold.expmap0(q, project=True)\n",
    "    q_r, q_u, q_rho = rie_polar_decompose(q,kappa)\n",
    "    v = manifold.proju(manifold.origin(V.shape), V)\n",
    "    v = manifold.expmap0(v, project=True)\n",
    "    v_r, v_u, v_rho = rie_polar_decompose(v,kappa)\n",
    "    \n",
    "    kappa_float = kappa\n",
    "    if not torch.is_tensor(kappa):\n",
    "        kappa = torch.tensor(kappa, dtype=q.dtype, device=q.device)\n",
    "    # ----------------------------------------------------------\n",
    "    # 1. 计算角度 cosθ_ij = u_Q · u_K\n",
    "    # ----------------------------------------------------------\n",
    "    # u_Q: [N, Lq, d] → [N, Lq, 1, d]\n",
    "    # u_K: [N, Lk, d] → [N, 1, Lk, d]\n",
    "    dot = (q_u.unsqueeze(2) * k_u.unsqueeze(1)).sum(dim=-1)  # [N, Lq, Lk]\n",
    "    dot = dot.clamp(-1.0 + 1e-6, 1.0 - 1e-6)\n",
    "    # ----------------------------------------------------------\n",
    "    # 2. 计算 radial 部分\n",
    "    # ----------------------------------------------------------\n",
    "    # expand ρ 以做 broadcast\n",
    "    rho_Q_exp = q_rho.unsqueeze(-1)  # [N, Lq, 1]\n",
    "    rho_K_exp = k_rho.unsqueeze(1)   # [N, 1, Lk]\n",
    "    if kappa<0:\n",
    "        sqrt_neg_k = torch.sqrt(-kappa)\n",
    "        arg_i = (sqrt_neg_k * rho_Q_exp).clamp_max(20.0)\n",
    "        arg_j = (sqrt_neg_k * rho_K_exp).clamp_max(20.0)\n",
    "        cosh_i = torch.cosh(arg_i)\n",
    "        cosh_j = torch.cosh(arg_j)\n",
    "        sinh_i = torch.sinh(arg_i)\n",
    "        sinh_j = torch.sinh(arg_j)\n",
    "\n",
    "        # ----------------------------------------------------------\n",
    "        # 3. 超曲余弦定理\n",
    "        # ----------------------------------------------------------\n",
    "        cosh_d = cosh_i * cosh_j - sinh_i * sinh_j * dot  # [N, Lq, Lk]\n",
    "        cosh_d = cosh_d.clamp_min(1.0 + 1e-6)\n",
    "        distance = torch.acosh(cosh_d) / sqrt_neg_k\n",
    "    \n",
    "    elif kappa == 0:\n",
    "        d2 = rho_Q_exp**2 + rho_K_exp**2 - 2 * rho_Q_exp * rho_K_exp * dot\n",
    "        distance = torch.sqrt(d2.clamp_min(0.0) + 1e-12)\n",
    "    \n",
    "    else:  # kappa > 0\n",
    "        sqrt_k = torch.sqrt(kappa)\n",
    "\n",
    "        cos_i = torch.cos(sqrt_k * rho_Q_exp)\n",
    "        cos_j = torch.cos(sqrt_k * rho_K_exp)\n",
    "        sin_i = torch.sin(sqrt_k * rho_Q_exp)\n",
    "        sin_j = torch.sin(sqrt_k * rho_K_exp)\n",
    "\n",
    "        cos_d = cos_i * cos_j + sin_i * sin_j * dot\n",
    "        cos_d = torch.clamp(cos_d, -1.0 + 1e-6, 1.0 - 1e-6)  # 避免数值越界\n",
    "        distance = torch.acos(cos_d) / sqrt_k\n",
    "\n",
    "    distance = torch.nan_to_num(distance, nan=1e6, posinf=1e6, neginf=1e6)\n",
    "    attn = F.softmax(-distance, dim=-1) #[B,L,L]\n",
    "    \n",
    "\n",
    "    rho = v_rho.unsqueeze(1).unsqueeze(-1) \n",
    "    u = v_u.unsqueeze(1)\n",
    "    w = attn.unsqueeze(-1) \n",
    "    output = 0.5*w * rho *u\n",
    "    output = torch.sum(output,dim=2)\n",
    "    \n",
    "\n",
    "    \n",
    "    return output\n",
    "    \n",
    "    \n",
    "kappa_ori=-1\n",
    "time1 = time.time()\n",
    "scores1 = Riemannian_attention(K,Q,V,kappa = kappa_ori)\n",
    "time2 = time.time()\n",
    "scores2 = Fast_Riemannian_Attention(K,Q,V,kappa = kappa_ori)\n",
    "time3 = time.time()\n",
    "print('Curvature = -1')\n",
    "print('Riemannian_attention time:{}'.format(time2-time1))\n",
    "print('Fast_Riemannian_Attention time:{}'.format(time3-time2))\n",
    "print(\"Maximum error =\", (scores1 - scores2).abs().max().item())\n",
    "mae = (scores1 - scores2).abs().mean().item()\n",
    "print(\"MAE =\", mae)\n",
    "\n",
    "\n",
    "kappa_ori=1\n",
    "time1 = time.time()\n",
    "scores1 = Riemannian_attention(K,Q,V,kappa = kappa_ori)\n",
    "time2 = time.time()\n",
    "scores2 = Fast_Riemannian_Attention(K,Q,V,kappa = kappa_ori)\n",
    "time3 = time.time()\n",
    "print('Curvature = 1')\n",
    "print('Riemannian_attention time:{}'.format(time2-time1))\n",
    "print('Fast_Riemannian_Attention time:{}'.format(time3-time2))\n",
    "print(\"Maximum error =\", (scores1 - scores2).abs().max().item())\n",
    "mae = (scores1 - scores2).abs().mean().item()\n",
    "print(\"MAE =\", mae)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7ff262c0-23e6-4a4a-8388-9977e76d26d8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximum error = 7.867813110351562e-06\n",
      "MAE = 4.518415721577185e-07\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import geoopt\n",
    "from geoopt.manifolds.stereographic.math import project\n",
    "from geoopt.manifolds.stereographic import StereographicExact\n",
    "from geoopt import ManifoldTensor\n",
    "from geoopt import ManifoldParameter\n",
    "import torch.nn.functional as F\n",
    "import time\n",
    "\n",
    "\n",
    "\n",
    "seed = 0\n",
    "N=100\n",
    "L=32\n",
    "d=200\n",
    "\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "K = torch.randn(N, L, d)\n",
    "Q = torch.randn(N, L, d)\n",
    "V = torch.randn(N, L, d)\n",
    "\n",
    "\n",
    "def Euclidean_method(K, Q, V):\n",
    "    # κ = 0 的 stereographic model 给出的距离：2 * ||q - k||\n",
    "    diff = Q.unsqueeze(2) - K.unsqueeze(1)   # [N, L, L, d]\n",
    "    dist = 2 * torch.norm(diff, dim=-1)      # [N, L, L]\n",
    "\n",
    "    attn = F.softmax(-dist, dim=-1)          # distance-based attention\n",
    "\n",
    "    output = torch.bmm(attn, V)              # [N, L, d]\n",
    "    return output\n",
    "    \n",
    "    \n",
    "def Fast_Riemannian_Attention(K,Q,V,kappa):\n",
    "    manifold = geoopt.Stereographic(k=kappa, learnable=False)\n",
    "    k = manifold.proju(manifold.origin(K.shape), K)\n",
    "    k = manifold.expmap0(k, project=True)\n",
    "    k_r, k_u, k_rho = rie_polar_decompose(k,kappa)\n",
    "    q = manifold.proju(manifold.origin(Q.shape), Q)\n",
    "    q = manifold.expmap0(q, project=True)\n",
    "    q_r, q_u, q_rho = rie_polar_decompose(q,kappa)\n",
    "    v = manifold.proju(manifold.origin(V.shape), V)\n",
    "    v = manifold.expmap0(v, project=True)\n",
    "    v_r, v_u, v_rho = rie_polar_decompose(v,kappa)\n",
    "    \n",
    "    kappa_float = kappa\n",
    "    if not torch.is_tensor(kappa):\n",
    "        kappa = torch.tensor(kappa, dtype=q.dtype, device=q.device)\n",
    "    # ----------------------------------------------------------\n",
    "    # 1. 计算角度 cosθ_ij = u_Q · u_K\n",
    "    # ----------------------------------------------------------\n",
    "    # u_Q: [N, Lq, d] → [N, Lq, 1, d]\n",
    "    # u_K: [N, Lk, d] → [N, 1, Lk, d]\n",
    "    dot = (q_u.unsqueeze(2) * k_u.unsqueeze(1)).sum(dim=-1)  # [N, Lq, Lk]\n",
    "    dot = dot.clamp(-1.0 + 1e-6, 1.0 - 1e-6)\n",
    "    # ----------------------------------------------------------\n",
    "    # 2. 计算 radial 部分\n",
    "    # ----------------------------------------------------------\n",
    "    # expand ρ 以做 broadcast\n",
    "    rho_Q_exp = q_rho.unsqueeze(-1)  # [N, Lq, 1]\n",
    "    rho_K_exp = k_rho.unsqueeze(1)   # [N, 1, Lk]\n",
    "    if kappa<0:\n",
    "        sqrt_neg_k = torch.sqrt(-kappa)\n",
    "        arg_i = (sqrt_neg_k * rho_Q_exp).clamp_max(20.0)\n",
    "        arg_j = (sqrt_neg_k * rho_K_exp).clamp_max(20.0)\n",
    "        cosh_i = torch.cosh(arg_i)\n",
    "        cosh_j = torch.cosh(arg_j)\n",
    "        sinh_i = torch.sinh(arg_i)\n",
    "        sinh_j = torch.sinh(arg_j)\n",
    "\n",
    "        # ----------------------------------------------------------\n",
    "        # 3. 超曲余弦定理\n",
    "        # ----------------------------------------------------------\n",
    "        cosh_d = cosh_i * cosh_j - sinh_i * sinh_j * dot  # [N, Lq, Lk]\n",
    "        cosh_d = cosh_d.clamp_min(1.0 + 1e-6)\n",
    "        distance = torch.acosh(cosh_d) / sqrt_neg_k\n",
    "    \n",
    "    elif kappa == 0:\n",
    "        d2 = rho_Q_exp**2 + rho_K_exp**2 - 2 * rho_Q_exp * rho_K_exp * dot\n",
    "        distance = torch.sqrt(d2.clamp_min(0.0) + 1e-12)\n",
    "    \n",
    "    else:  # kappa > 0\n",
    "        sqrt_k = torch.sqrt(kappa)\n",
    "\n",
    "        cos_i = torch.cos(sqrt_k * rho_Q_exp)\n",
    "        cos_j = torch.cos(sqrt_k * rho_K_exp)\n",
    "        sin_i = torch.sin(sqrt_k * rho_Q_exp)\n",
    "        sin_j = torch.sin(sqrt_k * rho_K_exp)\n",
    "\n",
    "        cos_d = cos_i * cos_j + sin_i * sin_j * dot\n",
    "        cos_d = torch.clamp(cos_d, -1.0 + 1e-6, 1.0 - 1e-6)  # 避免数值越界\n",
    "        distance = torch.acos(cos_d) / sqrt_k\n",
    "\n",
    "    distance = torch.nan_to_num(distance, nan=1e6, posinf=1e6, neginf=1e6)\n",
    "    attn = F.softmax(-distance, dim=-1) #[B,L,L]\n",
    "    \n",
    "\n",
    "    rho = v_rho.unsqueeze(1).unsqueeze(-1) \n",
    "    u = v_u.unsqueeze(1)\n",
    "    w = attn.unsqueeze(-1) \n",
    "    output = 0.5*w * rho *u\n",
    "    output = torch.sum(output,dim=2)\n",
    "    \n",
    "\n",
    "    \n",
    "    return output\n",
    "    \n",
    "kappa_ori=0\n",
    "\n",
    "scores1 = Euclidean_method(K,Q,V)\n",
    "scores2 = Fast_Riemannian_Attention(K,Q,V,kappa=kappa_ori)\n",
    "\n",
    "\n",
    "print(\"Maximum error =\", (scores1 - scores2).abs().max().item())\n",
    "mae = (scores1 - scores2).abs().mean().item()\n",
    "print(\"MAE =\", mae)\n",
    "\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "transformers",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
