{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "from tqdm import tqdm\n",
    "from joblib import Parallel, delayed\n",
    "\n",
    "import numpy as np\n",
    "from numpy import fabs, dot, cross\n",
    "from numpy.linalg import norm\n",
    "import torch\n",
    "\n",
    "\n",
    "def fix_seed(seed: int = 42) -> None:\n",
    "    random.seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.manual_seed(seed)\n",
    "        torch.cuda.manual_seed(seed)\n",
    "        torch.cuda.manual_seed_all(seed) \n",
    "        torch.backends.cudnn.benchmark = False\n",
    "        torch.backends.cudnn.deterministic = True\n",
    "\n",
    "fix_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_G(a, b, c):\n",
    "    return (a + b + c) / 4\n",
    "\n",
    "def get_I(a, b, c):\n",
    "    def f0(x, y):\n",
    "        return norm(cross(x, y))\n",
    "    def f1(x, y, z):\n",
    "        return x * f0(y, z)\n",
    "    def f2(x, y, z):\n",
    "        return norm(cross(x, y) + cross(y, z) + cross(z, x))\n",
    "    numerator = f1(a, b, c) + f1(b, c, a) + f1(c, a, b)\n",
    "    denominator = f0(a, b) + f0(b, c) + f0(c, a) + f2(a, b, c) + 1e-6\n",
    "    return numerator / denominator\n",
    "    \n",
    "def get_O(a, b, c):\n",
    "    def f0(x, y, z):\n",
    "        return (norm(x) ** 2) * cross(y, z)\n",
    "    def f1(x, y, z):\n",
    "        return 2 * dot(x, cross(y, z))\n",
    "    numerator = f0(a, b, c) + f0(b, c, a) + f0(c, a, b)\n",
    "    denominator = f1(a, b, c) + 1e-6\n",
    "    return numerator / denominator\n",
    "\n",
    "def get_M(a, b, c):\n",
    "    def f0(a, b, c):\n",
    "        return dot(a, (b + c)) * cross(b, c)\n",
    "    def f1(x, y, z):\n",
    "        return 2 * dot(x, cross(y, z))\n",
    "    numerator = f0(a, b, c) + f0(b, c, a) + f0(c, a, b)\n",
    "    denominator = f1(a, b, c) + 1e-6\n",
    "    return numerator / denominator\n",
    "\n",
    "def get_T(a, b, c):\n",
    "    return (1 / 3) * get_O(a, b, c) + (2 / 3) * get_M(a, b, c)\n",
    "\n",
    "def get_pos():\n",
    "    radius = (np.random.rand(1) + 0.2) * 5\n",
    "    radius = min(6, radius)\n",
    "    pos = np.random.randn(4, 3) \n",
    "    pos = pos / norm(pos, axis=1, keepdims=True)\n",
    "    pos = pos - np.mean(pos, axis=0, keepdims=True)\n",
    "    pos = pos * radius\n",
    "    translation = np.random.randn(1)\n",
    "    pos = pos + translation\n",
    "    return pos\n",
    "    \n",
    "\n",
    "def get_tetrahedron():\n",
    "    pos = get_pos()\n",
    "    O, A, B, C = pos\n",
    "    a, b, c = A - O, B - O, C - O\n",
    "    return {\n",
    "        'pos': pos,\n",
    "        'G': get_G(a, b, c) + O,\n",
    "        'I': get_I(a, b, c) + O,\n",
    "        'O': get_O(a, b, c) + O,\n",
    "        'M': get_M(a, b, c) + O,\n",
    "        'T': get_T(a, b, c) + O,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:00<00:00, 4391.40it/s]\n",
      "100%|██████████| 2000/2000 [00:00<00:00, 6899.74it/s]\n",
      "100%|██████████| 2000/2000 [00:00<00:00, 6281.29it/s]\n"
     ]
    }
   ],
   "source": [
    "def generate_dataset(num_tetrahedron):\n",
    "    results = Parallel(n_jobs=64)(delayed(get_tetrahedron)() for i in tqdm(range(num_tetrahedron)))\n",
    "    return results\n",
    "\n",
    "np.save('./data/train.npy', generate_dataset(500))\n",
    "np.save('./data/valid.npy', generate_dataset(2000))\n",
    "np.save('./data/test.npy', generate_dataset(2000))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "aesc",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
