{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "from torch.autograd import Function\n",
    "from torch.nn.parameter import Parameter\n",
    "from torch import optim\n",
    "from IPython import display\n",
    "import time\n",
    "import pickle \n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import os\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device('cuda')\n",
    "else:\n",
    "    device = torch.device('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def load_point_cloud_from_obj(filepath):\n",
    "    vertices = []\n",
    "\n",
    "    with open(filepath, 'r') as file:\n",
    "        for line in file:\n",
    "            if line.startswith('v '): \n",
    "                parts = line.strip().split()\n",
    "                vertex = list(map(float, parts[1:4]))\n",
    "                vertices.append(vertex)\n",
    "\n",
    "    return np.array(vertices, dtype=np.float32)\n",
    "\n",
    "# Load the provided .obj file\n",
    "file_path = \"elephant-reference.obj\"\n",
    "point_cloud1 = load_point_cloud_from_obj(file_path)\n",
    "point_cloud1.shape\n",
    "indices = np.random.choice(point_cloud1.shape[0], 2000, replace=False)\n",
    "elepant = point_cloud1[indices]\n",
    "\n",
    "point_cloud2 = load_point_cloud_from_obj('horse-01.obj')\n",
    "point_cloud2.shape\n",
    "indices = np.random.choice(point_cloud2.shape[0], 2000, replace=False)\n",
    "horse = point_cloud2[indices]\n",
    "\n",
    "point_cloud3 = load_point_cloud_from_obj('flam-reference.obj')\n",
    "point_cloud3.shape\n",
    "indices = np.random.choice(point_cloud3.shape[0], 2000, replace=False)\n",
    "flam = point_cloud3[indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "figure = [elepant,horse,flam]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "import sys\n",
    "from SEINT.SEINT_numpy import SEINT\n",
    "step_indices = [0, 20, 40, 60, 80, 100]\n",
    "colors = ['#2C3E50', '#A2B9C8', '#36597D', '#D6E4F0',\n",
    "'#E37C59', '#F5BCA9', '#C95D39', '#FFDAC6',\n",
    "'#4A7C59', '#B4CFB0', '#6F9E84', '#D5E7D2']\n",
    "fig = plt.figure(figsize=(24, 10))\n",
    "gs_top = fig.add_gridspec(nrows=3, ncols=6, top=0.95, bottom=0.45, hspace=0.05, wspace=0.05)\n",
    "gs_bottom = fig.add_gridspec(nrows=2, ncols=6, top=0.4, bottom=0.11, hspace=0.4, wspace=0.5)\n",
    "\n",
    "for t in range(3):\n",
    "        X = figure[t]\n",
    "\n",
    "        loss3 = []\n",
    "        loss4 = []\n",
    "        name = ['Elepant','Horse','Flam']\n",
    "        angle = [50,60,70]\n",
    "        Z = [X]\n",
    "        steps = np.linspace(0, 0.51, 101)  \n",
    "\n",
    "        noise_std = 0.2  \n",
    "        noise = np.random.normal(scale=noise_std, size=X.shape)  \n",
    "\n",
    "        for idx, alpha in enumerate(steps):\n",
    "                zi = (1-alpha)* X + alpha * noise  \n",
    "                Z.append(zi)\n",
    "                loss3.append(SEINT(zi,X,maxed=True, rep = 50, set_seed = 42,rd_rad = 3))\n",
    "                loss4.append(SEINT(zi,X,maxed=False, rep = 50, set_seed = 42,rd_rad = 3))\n",
    "\n",
    "\n",
    "\n",
    "        for j, idx0 in enumerate(step_indices):\n",
    "                if (j==0):\n",
    "                        ax = fig.add_subplot(gs_top[t, j], projection='3d')\n",
    "                        zi = Z[idx0]  \n",
    "                        ax.scatter(zi[:, 2], zi[:, 0], zi[:, 1], c=colors[4*t+1], s=0.5,alpha=0.8)\n",
    "                        ax.set_title(f'Source shape : {name[t]}', fontsize=22, fontname='Times New Roman',fontweight='bold')\n",
    "                        ax.view_init(elev=0, azim=angle[t])\n",
    "                        ax.set_axis_off()\n",
    "                else:\n",
    "                        ax = fig.add_subplot(gs_top[t, j], projection='3d')\n",
    "                        zi = Z[idx0]  \n",
    "                        ax.scatter(zi[:, 2], zi[:, 0], zi[:, 1], c=colors[4*t], s=0.5,alpha=0.8)\n",
    "                        ax.set_title(f'Step = {idx0}', fontsize=22, fontname='Times New Roman')\n",
    "                        ax.view_init(elev=0, azim=angle[t])\n",
    "                        ax.set_axis_off()\n",
    "\n",
    "        ax = fig.add_subplot(gs_bottom[0:, 2*t:2*t+2])\n",
    "        ax.plot(loss3, \n",
    "                color=colors[4*t+2], \n",
    "                linestyle='--', \n",
    "                linewidth=1, \n",
    "                marker='o', \n",
    "                markersize=3, \n",
    "                label='SEINT')\n",
    "\n",
    "        ax.plot(loss4, \n",
    "                color=colors[4*t+3], \n",
    "                linestyle='-', \n",
    "                linewidth=1, \n",
    "                marker='x', \n",
    "                markersize=3, \n",
    "                label='ISEINT')\n",
    "        ax.set_title(f'Loss Curve for {name[t]}',fontsize=22,fontname='Times New Roman')\n",
    "        ax.set_xlabel('Step',fontsize=22,fontname='Times New Roman')\n",
    "        ax.set_ylabel('Loss',fontsize=22,fontname='Times New Roman')\n",
    "        ax.tick_params(axis='both', labelsize=20)\n",
    "        ax.legend(fontsize=15, frameon=False, loc='lower right')\n",
    "plt.tight_layout()\n",
    "#fig.subplots_adjust(wspace=0.15, hspace=0.25)\n",
    "plt.savefig(\"noise.png\", dpi=300)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "HW",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
