{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 429,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "from e3nn_jax import SO3Signal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def compute_Z(lambd: float, F: jnp.ndarray) -> float:\n",
    "    \"\"\"Computes the partition function Z for the given parameters.\"\"\"\n",
    "    sig = SO3Signal.from_function(\n",
    "        lambda R: jnp.exp(jnp.trace(lambd * F @ R)),\n",
    "        res_beta=100,\n",
    "        res_alpha=100,\n",
    "        res_theta=100,\n",
    "        quadrature=\"gausslegendre\",\n",
    "    )\n",
    "    Z = sig.integrate()\n",
    "    return Z\n",
    "\n",
    "\n",
    "def compute_RZ(lambd: float, F: jnp.ndarray) -> float:\n",
    "    \"\"\"Computes the partition function Z for the given parameters.\"\"\"\n",
    "    sig = SO3Signal.from_function(\n",
    "        lambda R: jnp.exp(jnp.trace(lambd * F @ R)) * R,\n",
    "        res_beta=100,\n",
    "        res_alpha=100,\n",
    "        res_theta=100,\n",
    "        quadrature=\"gausslegendre\",\n",
    "    )\n",
    "    Z = sig.integrate()\n",
    "    return Z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def first_order_correction_diag(S: jnp.ndarray, lambd: float = 1):\n",
    "    s1, s2, s3 = S * lambd\n",
    "    c1 = 1 / (s1 + s2) + 1 / (s1 + s3)\n",
    "    c2 = 1 / (s2 + s1) + 1 / (s2 + s3)\n",
    "    c3 = 1 / (s3 + s1) + 1 / (s3 + s2)\n",
    "    return -jnp.array([c1, c2, c3]) / 2\n",
    "\n",
    "\n",
    "def first_order_matrix(S: jnp.ndarray, lambd: float = 1):\n",
    "    first_correction = first_order_correction_diag(S, lambd)\n",
    "    return jnp.identity(3) + jnp.diag(first_correction)\n",
    "\n",
    "\n",
    "def second_order_correction_diag(S: jnp.ndarray, lambd: float = 1):\n",
    "    s1, s2, s3 = S * lambd\n",
    "    c1 = 1 / (s1 + s2) ** 2 + 1 / (s1 + s3) ** 2\n",
    "    c2 = 1 / (s2 + s1) ** 2 + 1 / (s2 + s3) ** 2\n",
    "    c3 = 1 / (s3 + s1) ** 2 + 1 / (s3 + s2) ** 2\n",
    "    return -jnp.array([c1, c2, c3]) / 8\n",
    "\n",
    "\n",
    "def second_order_matrix(S: jnp.ndarray, lambd: float = 1):\n",
    "    second_correction = second_order_correction_diag(S, lambd)\n",
    "    return first_order_matrix(S, lambd) + jnp.diag(second_correction)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 432,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Array([[-0.79156095,  0.00976491,  0.61101246],\n",
       "        [-0.29899886, -0.8782002 , -0.37331507],\n",
       "        [ 0.53294563, -0.4781936 ,  0.69806856]], dtype=float32),\n",
       " Array([1.8923466 , 1.5216088 , 0.47148645], dtype=float32),\n",
       " Array([[-0.16220267, -0.7547595 ,  0.6356323 ],\n",
       "        [-0.87639874,  0.40620592,  0.25869286],\n",
       "        [ 0.45344856,  0.5151066 ,  0.7273579 ]], dtype=float32),\n",
       " Array(1.0000001, dtype=float32),\n",
       " Array(-1.0000001, dtype=float32))"
      ]
     },
     "execution_count": 432,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lambd = 1\n",
    "F = jax.random.normal(jax.random.PRNGKey(2), shape=(3, 3))\n",
    "U, S, W = jnp.linalg.svd(F)\n",
    "U, S, W, jnp.linalg.det(U), jnp.linalg.det(W)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.2300029\n",
      "2.230003\n",
      "2.2300022\n",
      "2.2300026\n",
      "2.2300014\n",
      "2.2300014\n",
      "2.230002\n",
      "2.2300024\n",
      "2.2300029\n",
      "2.2300036\n",
      "[[ 0.01545128  0.79341394  0.14040026]\n",
      " [ 0.5156793   0.01212729 -0.752028  ]\n",
      " [-0.72449064 -0.289583    0.04557032]]\n",
      "[[ 0.01545143  0.7934138   0.1404004 ]\n",
      " [ 0.5156792   0.01212744 -0.7520278 ]\n",
      " [-0.7244905  -0.28958294  0.04557034]]\n",
      "[[ 0.01545135  0.7934146   0.1404005 ]\n",
      " [ 0.51567966  0.0121274  -0.7520285 ]\n",
      " [-0.72449136 -0.2895829   0.04557026]]\n",
      "[[ 0.01545119  0.7934147   0.14040056]\n",
      " [ 0.5156792   0.01212743 -0.7520285 ]\n",
      " [-0.724491   -0.28958303  0.04557035]]\n",
      "[[ 0.01545116  0.79341525  0.14040065]\n",
      " [ 0.5156797   0.01212742 -0.7520292 ]\n",
      " [-0.7244919  -0.28958336  0.04557042]]\n",
      "[[ 0.0154511   0.79341435  0.14040054]\n",
      " [ 0.5156793   0.01212731 -0.7520284 ]\n",
      " [-0.72449094 -0.28958297  0.04557046]]\n",
      "[[ 0.01545097  0.7934151   0.14040042]\n",
      " [ 0.5156797   0.01212737 -0.75202894]\n",
      " [-0.7244915  -0.2895833   0.04557047]]\n",
      "[[ 0.0154513   0.7934138   0.14040026]\n",
      " [ 0.5156788   0.01212736 -0.75202763]\n",
      " [-0.7244902  -0.2895829   0.04557034]]\n",
      "[[ 0.01545115  0.7934146   0.14040045]\n",
      " [ 0.5156796   0.01212736 -0.7520287 ]\n",
      " [-0.72449124 -0.28958306  0.04557042]]\n",
      "[[ 0.01545127  0.7934142   0.14040051]\n",
      " [ 0.51567936  0.01212735 -0.7520283 ]\n",
      " [-0.7244908  -0.28958303  0.04557028]]\n"
     ]
    }
   ],
   "source": [
    "# Check equivariance\n",
    "for i in range(10):\n",
    "    R_dash = jax.random.orthogonal(jax.random.PRNGKey(i * 10), 3)\n",
    "    R_dash /= jnp.linalg.det(R_dash)\n",
    "    Z_dash = compute_Z(lambd, F @ R_dash)\n",
    "    print(Z_dash)\n",
    "\n",
    "for i in range(10):\n",
    "    R_dash1 = jax.random.orthogonal(jax.random.PRNGKey(i * 2), 3)\n",
    "    R_dash1 /= jnp.linalg.det(R_dash1)\n",
    "    R_dash2 = jax.random.orthogonal(jax.random.PRNGKey(i * 2) + 1, 3)\n",
    "    R_dash2 /= jnp.linalg.det(R_dash2)\n",
    "    Z_dash = R_dash2 @ compute_RZ(lambd, R_dash1 @ F @ R_dash2) @ R_dash1\n",
    "    print(Z_dash)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 438,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Array([[-0.79156095,  0.00976491,  0.61101246],\n",
       "        [-0.29899886, -0.8782002 , -0.37331507],\n",
       "        [ 0.53294563, -0.4781936 ,  0.69806856]], dtype=float32),\n",
       " Array([1.8923466 , 1.5216088 , 0.47148645], dtype=float32),\n",
       " Array([[-0.16220267, -0.7547595 ,  0.6356323 ],\n",
       "        [-0.87639874,  0.40620592,  0.25869286],\n",
       "        [ 0.45344856,  0.5151066 ,  0.7273579 ]], dtype=float32),\n",
       " Array(1.0000001, dtype=float32),\n",
       " Array(-1.0000001, dtype=float32))"
      ]
     },
     "execution_count": 438,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lambd = 1\n",
    "F = jax.random.normal(jax.random.PRNGKey(2), shape=(3, 3))\n",
    "# F = F*lambd\n",
    "# lambd = 1\n",
    "U, S, W = jnp.linalg.svd(F)\n",
    "U, S, W, jnp.linalg.det(U), jnp.linalg.det(W)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Array([[-0.1572274 ,  0.98743117,  0.01610483],\n",
       "        [ 0.28666812,  0.06123919, -0.95607066],\n",
       "        [-0.9450404 , -0.14570373, -0.29269356]], dtype=float32),\n",
       " Array([[ 0.01353253,  0.34391186,  0.06036871],\n",
       "        [ 0.24707372,  0.01164275, -0.3369285 ],\n",
       "        [-0.3278697 , -0.13438539,  0.03593323]], dtype=float32),\n",
       " Array([[ 0.05382817,  0.21522303,  0.07012825],\n",
       "        [ 0.25834626,  0.00580794, -0.22058417],\n",
       "        [-0.21374026, -0.13998269,  0.11566643]], dtype=float32),\n",
       " Array([[ 0.0069288 ,  0.355791  ,  0.06295981],\n",
       "        [ 0.23124614,  0.00543828, -0.33723226],\n",
       "        [-0.32488358, -0.12985776,  0.02043513]], dtype=float32))"
      ]
     },
     "execution_count": 439,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "signed_I = jnp.diag(jnp.array([1, 1, jnp.linalg.det(U @ W)]))\n",
    "vanilla_R = W.T @ signed_I @ U.T\n",
    "first_order = first_order_matrix(S @ signed_I, lambd)\n",
    "second_order = second_order_matrix(S @ signed_I, lambd)\n",
    "spicy1_R = W.T @ first_order @ signed_I @ U.T\n",
    "spicy2_R = W.T @ second_order @ signed_I @ U.T\n",
    "\n",
    "Z = compute_Z(lambd, F)\n",
    "RZ = compute_RZ(lambd, F)\n",
    "numeric_expected_R = RZ / Z\n",
    "\n",
    "vanilla_R, spicy1_R, spicy2_R, numeric_expected_R"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 440,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Array([[ 1.00000012e+00,  8.63127880e-08, -1.89471848e-07],\n",
       "        [ 1.00896045e-07,  1.00000012e+00,  4.79464646e-09],\n",
       "        [ 2.31079980e-07, -4.35952288e-08, -1.00000012e+00]],      dtype=float32),\n",
       " Array([[ 5.0164282e-01,  3.1422992e-08, -1.0486684e-07],\n",
       "        [ 1.7235038e-08,  3.7740722e-01, -4.7247202e-08],\n",
       "        [ 3.2465948e-08, -9.5322443e-09, -1.7196546e-01]], dtype=float32),\n",
       " Array([[ 4.2900121e-01,  4.3695447e-08, -7.2269444e-08],\n",
       "        [ 9.8090283e-09,  2.5333005e-01, -2.3294056e-08],\n",
       "        [-1.7784902e-08, -8.9003498e-09,  3.3035157e-03]], dtype=float32),\n",
       " Array([[ 4.8280004e-01,  6.4708168e-09, -5.3772574e-08],\n",
       "        [ 3.6206586e-08,  3.9079815e-01, -7.7078312e-08],\n",
       "        [-3.4654221e-08, -5.1178727e-08, -1.8665485e-01]], dtype=float32))"
      ]
     },
     "execution_count": 440,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W @ vanilla_R @ U, W @ spicy1_R @ U, W @ spicy2_R @ U, W @ numeric_expected_R @ U"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 441,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Array(1.1402427, dtype=float32),\n",
       " Array(0.02738877, dtype=float32),\n",
       " Array(0.24057426, dtype=float32))"
      ]
     },
     "execution_count": 441,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vanilla_diff = jnp.linalg.norm(vanilla_R - numeric_expected_R)\n",
    "spicy1_diff = jnp.linalg.norm(spicy1_R - numeric_expected_R)\n",
    "spicy2_diff = jnp.linalg.norm(spicy2_R - numeric_expected_R)\n",
    "\n",
    "vanilla_diff, spicy1_diff, spicy2_diff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "e3nn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
