{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import trimesh\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from RASF import RASF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 187,
   "metadata": {},
   "outputs": [],
   "source": [
    "model=RASF()\n",
    "model=model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 188,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1000, 35]) torch.Size([1000, 35])\n"
     ]
    }
   ],
   "source": [
    "# testing for a point cloud\n",
    "pc=torch.rand(1000,3)\n",
    "pcnp=np.random.randn(1000,3)\n",
    "pceb=model.point_clouds_inference(pc)\n",
    "pcnpeb=model.point_clouds_inference(pcnp)\n",
    "print(pceb.shape,pcnpeb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 189,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 1000, 35]) torch.Size([16, 1000, 35])\n"
     ]
    }
   ],
   "source": [
    "# testing for a batch of point cloud\n",
    "bpc=torch.rand(16,1000,3)\n",
    "bpceb=model.point_clouds_batch_inference(bpc)\n",
    "\n",
    "bpcnp=np.random.randn(16,1000,3)\n",
    "bpcnpeb=model.point_clouds_batch_inference(bpcnp)\n",
    "print(bpceb.shape,bpcnpeb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 190,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'V': tensor([[ 0.3258,  0.4581, -0.3700,  ...,  0.8124,  0.7223,  0.7916],\n",
      "        [ 0.3848, -0.2788, -0.0439,  ...,  0.8615,  0.7176,  0.8401],\n",
      "        [-0.2328,  0.0103,  0.0692,  ...,  0.8292,  0.7213,  0.8825],\n",
      "        ...,\n",
      "        [ 0.3016, -0.3034, -0.0363,  ...,  0.8629,  0.7688,  0.8119],\n",
      "        [ 0.3218,  0.2205, -0.1130,  ...,  0.7850,  0.8980,  0.7525],\n",
      "        [ 0.1704, -0.4005, -0.1550,  ...,  0.8277,  0.6894,  0.7691]],\n",
      "       device='cuda:0', grad_fn=<CatBackward>), 'E': tensor([[235., 213.],\n",
      "        [213., 240.],\n",
      "        [240., 235.],\n",
      "        ...,\n",
      "        [146., 208.],\n",
      "        [208.,  89.],\n",
      "        [ 89., 146.]], device='cuda:0'), 'F': tensor([[235., 213., 240.],\n",
      "        [239., 151., 188.],\n",
      "        [236.,  47.,  80.],\n",
      "        ...,\n",
      "        [ 15., 146.,   2.],\n",
      "        [  2., 146.,  89.],\n",
      "        [146., 208.,  89.]], device='cuda:0')} {'V': tensor([[ 0.3258,  0.4581, -0.3700,  ...,  0.8710,  0.7948,  0.7212],\n",
      "        [ 0.3848, -0.2788, -0.0439,  ...,  0.8201,  0.6774,  0.8073],\n",
      "        [-0.2328,  0.0103,  0.0692,  ...,  0.7390,  0.7431,  0.7116],\n",
      "        ...,\n",
      "        [ 0.3016, -0.3034, -0.0363,  ...,  0.7902,  0.7305,  0.7650],\n",
      "        [ 0.3218,  0.2205, -0.1130,  ...,  0.8330,  0.8592,  0.8453],\n",
      "        [ 0.1704, -0.4005, -0.1550,  ...,  0.7430,  0.7081,  0.8235]],\n",
      "       device='cuda:0', grad_fn=<CatBackward>), 'E': tensor([[235., 213.],\n",
      "        [213., 240.],\n",
      "        [240., 235.],\n",
      "        ...,\n",
      "        [146., 208.],\n",
      "        [208.,  89.],\n",
      "        [ 89., 146.]], device='cuda:0'), 'F': tensor([[235., 213., 240.],\n",
      "        [239., 151., 188.],\n",
      "        [236.,  47.,  80.],\n",
      "        ...,\n",
      "        [ 15., 146.,   2.],\n",
      "        [  2., 146.,  89.],\n",
      "        [146., 208.,  89.]], device='cuda:0')}\n"
     ]
    }
   ],
   "source": [
    "# testing for a mesh\n",
    "mesh = trimesh.load_mesh('T1.obj')\n",
    "\n",
    "meshobj={'V':mesh.vertices,'E':mesh.edges,'F':mesh.faces}\n",
    "meshout=model.mesh_inference(meshobj)\n",
    "\n",
    "meshobjtorch={'V':torch.Tensor(mesh.vertices),'E':torch.Tensor(mesh.edges),'F':torch.Tensor(mesh.faces)}\n",
    "meshouttorch=model.mesh_inference(meshobjtorch)\n",
    "print(meshout, meshouttorch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 191,
   "metadata": {},
   "outputs": [],
   "source": [
    "# testing for a batch of mesh\n",
    "\n",
    "meshbatch=[meshobj,meshobj]\n",
    "meshbatchout=model.mesh_batch_inference(meshbatch)\n",
    "\n",
    "meshbatchtorch=[meshobjtorch,meshobjtorch]\n",
    "meshbatchouttorch=model.mesh_batch_inference(meshbatchtorch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 195,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([33, 8, 8, 8])\n"
     ]
    }
   ],
   "source": [
    "# testing for a voxel\n",
    "voxel=(torch.randn(8,8,8)>0.5).float()\n",
    "f=model.voxels_inference(voxel)\n",
    "print(f.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 196,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 33, 8, 8, 8]) torch.Size([16, 33, 8, 8, 8])\n"
     ]
    }
   ],
   "source": [
    "# testing for a batch of voxel\n",
    "batch_voxel=(torch.randn(16,8,8,8)>0.5).float()\n",
    "f=model.voxels_batch_inference(batch_voxel)\n",
    "batch_voxelnp=batch_voxel.numpy()\n",
    "fnp=model.voxels_batch_inference(batch_voxelnp)\n",
    "print(f.shape,fnp.shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7-final"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}