{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sVZTMxtb8qdf"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "\n",
        "from torch.utils.data import DataLoader, TensorDataset\n",
        "\n",
        "import numpy as np\n",
        "\n",
        "import plotly.graph_objects as go\n",
        "import plotly.figure_factory as ff\n",
        "\n",
        "from skimage.measure import marching_cubes\n",
        "from scipy.interpolate import griddata\n",
        "from IPython.core.display import clear_output"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "n-Vttl5kIYSB"
      },
      "outputs": [],
      "source": [
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bxq7aV8qCkTO"
      },
      "source": [
        "# Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xPCXAnYrrxXG"
      },
      "source": [
        "## Utils\n",
        "\n",
        "To be used with the shapes from the **Shapes** section."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y4MN8oXxIYlU"
      },
      "outputs": [],
      "source": [
        "def dist(coords, shape_samples):\n",
        "  coords_expanded = coords.unsqueeze(1)\n",
        "  shape_expanded = shape_samples.unsqueeze(0)\n",
        "\n",
        "  distance_squared = torch.sum((coords_expanded - shape_expanded) ** 2, dim=-1)\n",
        "  min_distance_squared = torch.min(distance_squared, dim=1).values\n",
        "\n",
        "  return torch.sqrt(min_distance_squared)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_lH7pI4LIsL1"
      },
      "outputs": [],
      "source": [
        "def dataset(shape, inside, num_points=10000):\n",
        "  coords = torch.rand([num_points, 3]) * 2 - 1\n",
        "  shape_samples = shape()\n",
        "  distance = dist(coords, shape_samples)\n",
        "\n",
        "  mask = inside(coords)\n",
        "  distance[mask] = -1 * distance[mask]\n",
        "\n",
        "  return coords, distance"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7KQuXEYgIwgl"
      },
      "source": [
        "## Shapes"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-N97zwj5Ivol"
      },
      "outputs": [],
      "source": [
        "def sphere(resolution=100):\n",
        "    _s, _t = np.linspace(0, 1, resolution), np.linspace(0, 1, resolution)\n",
        "    s, t = np.meshgrid(_s, _t)\n",
        "\n",
        "    x = np.sin(np.pi * s) * np.cos(2 * np.pi * t)\n",
        "    y = np.sin(np.pi * s) * np.sin(2 * np.pi * t)\n",
        "    z = np.cos(np.pi * s)\n",
        "\n",
        "    x, y, z = x.ravel(), y.ravel(), z.ravel()\n",
        "\n",
        "    return torch.from_numpy(np.stack([x, y, z], axis=1)).float()\n",
        "\n",
        "def inside_sphere(coords):\n",
        "  return torch.norm(coords, dim=1) < 1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EwlQKT63PlEY"
      },
      "source": [
        "## Activations"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vxPgp-VrPqKt"
      },
      "outputs": [],
      "source": [
        "class SIREN(nn.Module):\n",
        "    def __init__(self, frequency=1.0):\n",
        "        super(SIREN, self).__init__()\n",
        "        self.frequency=frequency\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.frequency * torch.sin(x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "syWu5JakPwBy"
      },
      "outputs": [],
      "source": [
        "class HOSC(nn.Module):\n",
        "    def __init__(self, frequency=1.0, sharpness=1.0):\n",
        "        super(HOSC, self).__init__()\n",
        "        self.frequency=frequency\n",
        "        self.sharpness=sharpness\n",
        "\n",
        "    def forward(self, x):\n",
        "        return torch.tanh(self.sharpness * torch.sin(self.frequency * x))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BuKUBiupJUoA"
      },
      "source": [
        "## Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ccaIhJBxJVSF"
      },
      "outputs": [],
      "source": [
        "class MLP(nn.Module):\n",
        "    def __init__(self, activation=nn.ReLU()):\n",
        "        super(MLP, self).__init__()\n",
        "\n",
        "        self.layers = nn.Sequential(\n",
        "            nn.Linear(3, 256), activation,\n",
        "            nn.Linear(256, 256),  activation,\n",
        "            nn.Linear(256, 256), activation,\n",
        "            nn.Linear(256, 1)\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.layers(x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jM0widu3JV2H"
      },
      "source": [
        "## Training loop"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gKF9MTVGJauG"
      },
      "outputs": [],
      "source": [
        "def train(\n",
        "    model,\n",
        "    coords,\n",
        "    distance,\n",
        "    checkpoint_path,\n",
        "    lr=0.001,\n",
        "    num_epochs=250,\n",
        "    batch_size=128,\n",
        "    from_epoch=0,\n",
        "    optimizer_state=None,\n",
        "    lossf=None\n",
        "  ):\n",
        "\n",
        "  optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
        "\n",
        "  if optimizer_state is not None:\n",
        "    optimizer.load_state_dict(optimizer_state)\n",
        "\n",
        "  if lossf is None:\n",
        "    lossf = nn.MSELoss()\n",
        "\n",
        "  dataset = TensorDataset(coords, distance)\n",
        "  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
        "\n",
        "  model.train()\n",
        "  epoch_mean_loss = []\n",
        "\n",
        "  for epoch in range(from_epoch, num_epochs):\n",
        "    epoch_loss = 0.0\n",
        "    num_batches = 0\n",
        "    for batch_x, batch_y in dataloader:\n",
        "      optimizer.zero_grad()\n",
        "      output = model(batch_x)\n",
        "      loss = lossf(output.flatten(), batch_y)\n",
        "      loss.backward()\n",
        "      optimizer.step()\n",
        "\n",
        "      epoch_loss = epoch_loss + loss.item()\n",
        "      num_batches = num_batches + 1\n",
        "\n",
        "    epoch_mean_loss.append(epoch_loss / num_batches)\n",
        "    print(f\"Epoch: [{epoch + 1} / {num_epochs}], MSE: {epoch_mean_loss[-1]}\")\n",
        "\n",
        "    torch.save({\n",
        "        \"epoch\": epoch,\n",
        "        \"model_state\": model.state_dict(),\n",
        "        \"optimizer_state\": optimizer.state_dict(),\n",
        "        \"epoch_mean_loss\": epoch_mean_loss,\n",
        "        \"learning_rate\": learning_rate,\n",
        "        \"batch_size\": batch_size\n",
        "    }, checkpoint_path)\n",
        "\n",
        "  return epoch_mean_loss"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tkT30KBEJbRb"
      },
      "source": [
        "## Evaluation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZoT14xldJeWH"
      },
      "outputs": [],
      "source": [
        "def evaluate_model(model, meshgrid):\n",
        "    X, Y, Z = meshgrid\n",
        "\n",
        "    stack = np.c_[X.ravel(), Y.ravel(), Z.ravel()]\n",
        "    points = torch.FloatTensor(stack).to(device)\n",
        "\n",
        "    with torch.no_grad():\n",
        "        predictions = model(points)\n",
        "    return predictions.cpu()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1cPbMFY3EvL0"
      },
      "source": [
        "## Visualization"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gsUx2jp2KDMH"
      },
      "outputs": [],
      "source": [
        "def plot_inr(meshgrid, predictions, precision=0.05):\n",
        "  X, Y, Z = meshgrid\n",
        "\n",
        "  # flatten\n",
        "  _X = X.ravel()\n",
        "  _Y = Y.ravel()\n",
        "  _Z = Z.ravel()\n",
        "  _predictions = predictions.ravel()\n",
        "\n",
        "  mask = np.abs(_predictions) < precision\n",
        "\n",
        "  surface = go.Scatter3d(x=_X[mask], y=_Y[mask], z=_Z[mask])\n",
        "  fig = go.Figure(data=surface)\n",
        "  fig.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "n9a4bwd5XPhW"
      },
      "outputs": [],
      "source": [
        "def plot_surface(predictions, resolution, title=\"\"):\n",
        "  data = predictions.reshape(3 * [resolution]).numpy()\n",
        "  verts, faces, _, _ = marching_cubes(data)\n",
        "\n",
        "  fig = ff.create_trisurf(\n",
        "    x=verts[:, 0],\n",
        "    y=verts[:, 1],\n",
        "    z=verts[:, 2],\n",
        "    simplices =faces,\n",
        "    title=title\n",
        "  )\n",
        "\n",
        "  fig.update_layout(scene=dict(\n",
        "    xaxis=dict(visible=False),\n",
        "    yaxis=dict(visible=False),\n",
        "    zaxis=dict(visible=False)),\n",
        "    coloraxis_showscale=False\n",
        "  )\n",
        "\n",
        "  fig.show()"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def plot_gt(data_path):\n",
        "  coords, distances = get_sdf_data(data_path)\n",
        "  coords, distances = coords.cpu().numpy(), distances.cpu().numpy()\n",
        "\n",
        "  res = 64 # resolution of the grid\n",
        "  range = 1\n",
        "  x = np.linspace(-range, range, res)\n",
        "  y = np.linspace(-range, range, res)\n",
        "  z = np.linspace(-range, range, res)\n",
        "  meshgrid = np.meshgrid(x, y, z)\n",
        "\n",
        "  _X, _Y, _Z = meshgrid\n",
        "  xi = np.c_[_X.ravel(), _Y.ravel(), _Z.ravel()]\n",
        "\n",
        "  grid = griddata(coords, distances.ravel(), xi, method=\"linear\", fill_value=1)\n",
        "  grid = grid.reshape(3 * [res])\n",
        "\n",
        "  verts, faces, _, _ = marching_cubes(grid)\n",
        "\n",
        "  fig = ff.create_trisurf(\n",
        "    x=verts[:, 0],\n",
        "    y=verts[:, 1],\n",
        "    z=verts[:, 2],\n",
        "    simplices =faces,\n",
        "    title=\"\"\n",
        "  )\n",
        "\n",
        "  fig.update_layout(scene=dict(\n",
        "    xaxis=dict(visible=False),\n",
        "    yaxis=dict(visible=False),\n",
        "    zaxis=dict(visible=False)),\n",
        "    coloraxis_showscale=False\n",
        "  )\n",
        "\n",
        "  fig.show()"
      ],
      "metadata": {
        "id": "VGlBpWmuMJZw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PJjjUEjiEn8z"
      },
      "source": [
        "## Prep data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yqfQFtxuHlYV"
      },
      "outputs": [],
      "source": [
        "def sample_sdf_data(coords, distance, fr=0.1):\n",
        "  assert len(coords) == len(distance)\n",
        "\n",
        "  num_samples = int(fr * len(coords))\n",
        "  idx = torch.randint(0, len(coords), (num_samples, ))\n",
        "\n",
        "  return coords[idx], distance[idx]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4-nnmgo6aG2v"
      },
      "outputs": [],
      "source": [
        "def get_sdf_data(data_path):\n",
        "  # load data\n",
        "  shape = np.load(data_path)\n",
        "\n",
        "  coords = shape[\"position\"]\n",
        "  distance = shape[\"distance\"]\n",
        "\n",
        "  coords = coords[~np.isnan(coords).any(axis=1)]\n",
        "  distance = distance[~np.isnan(distance)]\n",
        "\n",
        "  coords = torch.from_numpy(coords).to(device)\n",
        "  distance = torch.from_numpy(distance).to(device)\n",
        "\n",
        "  return coords, distance"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hqbIVWcPOhNN"
      },
      "outputs": [],
      "source": [
        "def train_3d_sdf(\n",
        "    activation,\n",
        "    data_path,\n",
        "    num_epochs,\n",
        "    learning_rate,\n",
        "    batch_size,\n",
        "    data_sample_fr,\n",
        "    checkpoint_path\n",
        "  ):\n",
        "\n",
        "  # model\n",
        "  model = MLP(activation).to(device)\n",
        "\n",
        "  coords, distance = get_sdf_data(data_path)\n",
        "  coords, distance = sample_sdf_data(coords, distance, fr=data_sample_fr)\n",
        "\n",
        "  # train\n",
        "  epoch_mean_loss = train(model, coords, distance, checkpoint_path=checkpoint_path, num_epochs=num_epochs, lr=learning_rate, batch_size=batch_size)\n",
        "\n",
        "  return model, epoch_mean_loss"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "j1geWIWFXUW_"
      },
      "outputs": [],
      "source": [
        "def train_3d_sdf_from_checkpoint(\n",
        "    activation,\n",
        "    data_path,\n",
        "    num_epochs, # overall\n",
        "    data_sample_fr,\n",
        "    checkpoint_path\n",
        "  ):\n",
        "\n",
        "  # load checkpoint\n",
        "  checkpoint = torch.load(checkpoint_path)\n",
        "\n",
        "  # set model and learning hyperparams\n",
        "  model = MLP(activation).to(device)\n",
        "  model.load_state_dict(checkpoint[\"model_state\"])\n",
        "\n",
        "  optimizer_state = checkpoint[\"optimizer_state\"]\n",
        "\n",
        "  learning_rate = checkpoint[\"learning_rate\"]\n",
        "  batch_size = checkpoint[\"batch_size\"]\n",
        "\n",
        "  from_epoch = checkpoint[\"epoch\"]\n",
        "\n",
        "  # data\n",
        "  coords, distance = get_sdf_data(data_path)\n",
        "  coords, distance = sample_sdf_data(coords, distance, fr=data_sample_fr)\n",
        "\n",
        "  epoch_mean_loss = train(\n",
        "      model=model,\n",
        "      coords=coords,\n",
        "      distance=distance,\n",
        "      checkpoint_path=checkpoint_path,\n",
        "      num_epochs=num_epochs,\n",
        "      lr=learning_rate,\n",
        "      batch_size=batch_size,\n",
        "      from_epoch=from_epoch,\n",
        "      optimizer_state=optimizer_state\n",
        "    )\n",
        "\n",
        "  return model, epoch_mean_loss"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AfRpC_8zCgpR"
      },
      "source": [
        "# Run Experiments"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0eIVE9hePQ_X"
      },
      "source": [
        "## Cylinder Experiment"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m7sAjRTlB-D0"
      },
      "outputs": [],
      "source": [
        "data_path = \"/content/drive/MyDrive/HOSC/3d-sdf-experiments/Cylinder/Cylinder.npz\" # path to shape\n",
        "\n",
        "num_epochs = 200\n",
        "learning_rate = 0.001\n",
        "batch_size = 256\n",
        "data_sample_fr = 1\n",
        "\n",
        "checkpoint_path = \"/content/drive/MyDrive/HOSC/3d-sdf-experiments/Cylinder/checkpoints/\"\n",
        "\n",
        "def run_cylinder_experiment(activation, checkpoint_path):\n",
        "  return train_3d_sdf(\n",
        "    activation=activation,\n",
        "    data_path=data_path,\n",
        "    num_epochs=num_epochs,\n",
        "    learning_rate=learning_rate,\n",
        "    batch_size=batch_size,\n",
        "    data_sample_fr=data_sample_fr,\n",
        "    checkpoint_path=checkpoint_path)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "twdsYQHj-Gl2"
      },
      "outputs": [],
      "source": [
        "# marching cubes\n",
        "res = 64 # resolution of the grid\n",
        "range = 1\n",
        "x = np.linspace(-range, range, res)\n",
        "y = np.linspace(-range, range, res)\n",
        "z = np.linspace(-range, range, res)\n",
        "meshgrid = np.meshgrid(x, y, z)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4Y48bjKYBOpR"
      },
      "source": [
        "### ReLU"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b1S1vMNtBRTA"
      },
      "outputs": [],
      "source": [
        "model, epoch_mean_loss = run_cylinder_experiment(nn.ReLU(), checkpoint_path + \"relu_model.pth\") # training\n",
        "clear_output()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9hhcZ3T8-7ZQ"
      },
      "outputs": [],
      "source": [
        "predictions = evaluate_model(model, meshgrid) # results\n",
        "plot_surface(predictions, res)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pAtZSg1Jm_5b"
      },
      "source": [
        "### SIREN"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OjGbqDVHnBKt"
      },
      "outputs": [],
      "source": [
        "model, epoch_mean_loss = run_cylinder_experiment(SIREN(), checkpoint_path + \"siren_model.pth\") # training\n",
        "clear_output()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zlKadLTa_Hoj"
      },
      "outputs": [],
      "source": [
        "predictions = evaluate_model(model, meshgrid) # results\n",
        "plot_surface(predictions, res)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CIreIM3Gm9s4"
      },
      "source": [
        "### HOSC"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jGrqnImo_KPe"
      },
      "outputs": [],
      "source": [
        "model, epoch_mean_loss = run_cylinder_experiment(HOSC(), checkpoint_path + \"hosc_model.pth\") # training\n",
        "clear_output()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9L_5xpEyeFPb"
      },
      "outputs": [],
      "source": [
        "predictions = evaluate_model(model, meshgrid) # results\n",
        "plot_surface(predictions, res)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AnzljY60CQKg"
      },
      "source": [
        "## Icosahedron Experiment"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TI10WV7qCWIY"
      },
      "outputs": [],
      "source": [
        "data_path = \"/content/drive/MyDrive/HOSC/3d-sdf-experiments/Icosahedron/Icosahedron.npz\" # path to shape\n",
        "\n",
        "num_epochs = 2000\n",
        "learning_rate = 0.001\n",
        "batch_size = 128\n",
        "data_sample_fr = 1\n",
        "\n",
        "checkpoint_path = \"/content/drive/MyDrive/HOSC/3d-sdf-experiments/Icosahedron/checkpoints/\"\n",
        "\n",
        "def run_icosahedron_experiment(activation, checkpoint_path):\n",
        "  return train_3d_sdf(\n",
        "    activation=activation,\n",
        "    data_path=data_path,\n",
        "    num_epochs=num_epochs,\n",
        "    learning_rate=learning_rate,\n",
        "    batch_size=batch_size,\n",
        "    data_sample_fr=data_sample_fr,\n",
        "    checkpoint_path=checkpoint_path)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NWyk_F1YP73G"
      },
      "outputs": [],
      "source": [
        "# marching cubes\n",
        "res = 64 # resolution of the grid\n",
        "range = 1\n",
        "x = np.linspace(-range, range, res)\n",
        "y = np.linspace(-range, range, res)\n",
        "z = np.linspace(-range, range, res)\n",
        "meshgrid = np.meshgrid(x, y, z)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# ground truth\n",
        "from scipy.interpolate import griddata\n",
        "\n",
        "coords, distances = get_sdf_data(data_path)\n",
        "coords, distances = coords.cpu().numpy(), distances.cpu().numpy()\n",
        "\n",
        "_X, _Y, _Z = meshgrid\n",
        "xi = np.c_[_X.ravel(), _Y.ravel(), _Z.ravel()]\n",
        "\n",
        "grid = griddata(coords, distances.ravel(), xi, method=\"linear\", fill_value=1)\n",
        "grid = grid.reshape(3 * [res])\n",
        "\n",
        "verts, faces, _, _ = marching_cubes(grid)\n",
        "\n",
        "fig = ff.create_trisurf(\n",
        "  x=verts[:, 0],\n",
        "  y=verts[:, 1],\n",
        "  z=verts[:, 2],\n",
        "  simplices =faces,\n",
        "  title=\"\"\n",
        ")\n",
        "\n",
        "fig.update_layout(scene=dict(\n",
        "  xaxis=dict(visible=False),\n",
        "  yaxis=dict(visible=False),\n",
        "  zaxis=dict(visible=False)),\n",
        "  coloraxis_showscale=False\n",
        ")\n",
        "\n",
        "fig.show()"
      ],
      "metadata": {
        "id": "TBdJj7HfwgHJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### SIREN"
      ],
      "metadata": {
        "id": "LVdkDOrVphfO"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "model, epoch_mean_loss = run_icosahedron_experiment(SIREN(), checkpoint_path + \"siren_model.pth\") # training\n",
        "clear_output()"
      ],
      "metadata": {
        "id": "FgLrA_u4pk-y"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "model, epoch_mean_loss = train_3d_sdf_from_checkpoint(\n",
        "    activation=SIREN(),\n",
        "    data_path=data_path,\n",
        "    num_epochs=num_epochs,\n",
        "    data_sample_fr=data_sample_fr,\n",
        "    checkpoint_path=checkpoint_path + \"siren_model.pth\"\n",
        "  )\n",
        "clear_output()"
      ],
      "metadata": {
        "id": "J__HiWLnpp5_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "predictions = evaluate_model(model, meshgrid) # results\n",
        "plot_surface(predictions, res)"
      ],
      "metadata": {
        "id": "Oo59WW4_pq8w"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mpdEk4KUAB_g"
      },
      "source": [
        "### HOSC"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_i1x-XN9CZFH"
      },
      "outputs": [],
      "source": [
        "model, epoch_mean_loss = run_icosahedron_experiment(HOSC(), checkpoint_path + \"hosc_model.pth\") # training\n",
        "clear_output()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1UxBr3eYWGGy"
      },
      "outputs": [],
      "source": [
        "model, epoch_mean_loss = train_3d_sdf_from_checkpoint(\n",
        "    activation=HOSC(),\n",
        "    data_path=data_path,\n",
        "    num_epochs=num_epochs,\n",
        "    data_sample_fr=data_sample_fr,\n",
        "    checkpoint_path=checkpoint_path + \"hosc_model.pth\"\n",
        "  )\n",
        "clear_output()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QfQ1bchROOwP"
      },
      "outputs": [],
      "source": [
        "predictions = evaluate_model(model, meshgrid) # results\n",
        "plot_surface(predictions, res)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}