{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Jupyter environment detected. Enabling Open3D WebVisualizer.\n",
      "[Open3D INFO] WebRTC GUI backend enabled.\n",
      "[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/csuser/mambaforge/envs/tim/lib/python3.8/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n",
      "################################################################################\n",
      "WARNING!\n",
      "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n",
      "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n",
      "to learn more and leave feedback.\n",
      "################################################################################\n",
      "\n",
      "  deprecation_warning()\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "import torch_scatter\n",
    "import torch\n",
    "import torch_geometric as pyg\n",
    "import torch.nn.functional as F\n",
    "import os\n",
    "import numpy as np\n",
    "import json\n",
    "from neuralop.models import FNO\n",
    "from torch import nn\n",
    "from models.neuraloperator.neuralop.layers.mlp import MLP as NeuralOpMLP\n",
    "from models.neuraloperator.neuralop.layers.embeddings import PositionalEmbedding\n",
    "from models.neuraloperator.neuralop.layers.integral_transform import IntegralTransform\n",
    "from models.neuraloperator.neuralop.layers.neighbor_search import NeighborSearch\n",
    "import random\n",
    "from random import randint\n",
    "from dataloader import preprocess\n",
    "#from models.giorom2d import PhysicsEngine as LearnedSimulator\n",
    "#from Baselines.mmgpt import PhysicsEngine as LearnedSimulator\n",
    "from models.giorom2d_T import PhysicsEngine as LearnedSimulator\n",
    "#from GAT import preprocess, LearnedSimulator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using Linear Attention\n",
      "Size mismatch while loading! torch.Size([128, 37]) != torch.Size([128, 30]) Skipping node_in.layers.0.weight...\n",
      "Size mismatch while loading! torch.Size([128, 4]) != torch.Size([128, 3]) Skipping edge_in.layers.0.weight...\n",
      "Size mismatch while loading! torch.Size([3, 128]) != torch.Size([2, 128]) Skipping node_out.layers.4.weight...\n",
      "Size mismatch while loading! torch.Size([3]) != torch.Size([2]) Skipping node_out.layers.4.bias...\n",
      "Size mismatch while loading! torch.Size([32, 64, 1]) != torch.Size([32, 32, 1]) Skipping fno_mapper.lifting.fcs.0.weight...\n",
      "Size mismatch while loading! torch.Size([256, 64]) != torch.Size([256, 32]) Skipping gnot_layer.trunk_mlp.linear_pre.weight...\n",
      "Size mismatch while loading! torch.Size([256, 3]) != torch.Size([256, 2]) Skipping gnot_layer.branch_mlps.0.linear_pre.weight...\n",
      "Size mismatch while loading! torch.Size([256, 64]) != torch.Size([256, 32]) Skipping gnot_layer.blocks.0.gatenet.0.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.ln1.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.ln1.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.ln2_branch.0.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.ln2_branch.0.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.ln3.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.ln3.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.ln4.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.ln4.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.ln5.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.ln5.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256, 256]) Skipping gnot_layer.blocks.1.selfattn.key.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.selfattn.key.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256, 256]) Skipping gnot_layer.blocks.1.selfattn.query.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.selfattn.query.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256, 256]) Skipping gnot_layer.blocks.1.selfattn.value.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.selfattn.value.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256, 256]) Skipping gnot_layer.blocks.1.selfattn.proj.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.selfattn.proj.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256, 256]) Skipping gnot_layer.blocks.1.crossattn.query.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.crossattn.query.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256, 256]) Skipping gnot_layer.blocks.1.crossattn.keys.0.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.crossattn.keys.0.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256, 256]) Skipping gnot_layer.blocks.1.crossattn.values.0.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.crossattn.values.0.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256, 256]) Skipping gnot_layer.blocks.1.crossattn.proj.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.crossattn.proj.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256, 256]) Skipping gnot_layer.blocks.1.moe_mlp1.0.0.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.moe_mlp1.0.0.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256, 256]) Skipping gnot_layer.blocks.1.moe_mlp1.0.2.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.moe_mlp1.0.2.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256, 256]) Skipping gnot_layer.blocks.1.moe_mlp1.1.0.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.moe_mlp1.1.0.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256, 256]) Skipping gnot_layer.blocks.1.moe_mlp1.1.2.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.moe_mlp1.1.2.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256, 256]) Skipping gnot_layer.blocks.1.moe_mlp2.0.0.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.moe_mlp2.0.0.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256, 256]) Skipping gnot_layer.blocks.1.moe_mlp2.0.2.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.moe_mlp2.0.2.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256, 256]) Skipping gnot_layer.blocks.1.moe_mlp2.1.0.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.moe_mlp2.1.0.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256, 256]) Skipping gnot_layer.blocks.1.moe_mlp2.1.2.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.moe_mlp2.1.2.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256, 32]) Skipping gnot_layer.blocks.1.gatenet.0.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.gatenet.0.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256, 256]) Skipping gnot_layer.blocks.1.gatenet.2.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([256]) Skipping gnot_layer.blocks.1.gatenet.2.bias...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([2, 256]) Skipping gnot_layer.blocks.1.gatenet.4.weight...\n",
      "Model Dict not in Saved Dict! 2 != torch.Size([2]) Skipping gnot_layer.blocks.1.gatenet.4.bias...\n",
      "Size mismatch while loading! torch.Size([32, 131]) != torch.Size([32, 4]) Skipping gno_in.mlp.fcs.0.weight...\n",
      "Size mismatch while loading! torch.Size([64, 64]) != torch.Size([32, 64]) Skipping gno_in.mlp.fcs.2.weight...\n",
      "Size mismatch while loading! torch.Size([64]) != torch.Size([32]) Skipping gno_in.mlp.fcs.2.bias...\n",
      "Size mismatch while loading! torch.Size([32, 3]) != torch.Size([32, 2]) Skipping gno_out.mlp.fcs.0.weight...\n"
     ]
    }
   ],
   "source": [
    "params = {\n",
    "    \"epoch\": 1000,\n",
    "    \"batch_size\": 4,\n",
    "    \"lr\": 1e-4,\n",
    "    \"noise\": 0,\n",
    "    \"save_interval\": 1000,\n",
    "    \"eval_interval\": 10000,\n",
    "    \"rollout_interval\": 20000,\n",
    "    \"connectivity_radius\":0.032,\n",
    "    \"sampling\": False,\n",
    "    \"mesh_size\":20,\n",
    "    \"sampling_strategy\": \"fps\",\n",
    "    \"graph_type\": \"radius\"\n",
    "}\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "simulator = LearnedSimulator(device)\n",
    "ckpt = torch.load('/home/csuser/Documents/Neural Operator/saved_models/giorom2d_T_Sand.pt')\n",
    "weights = torch.load('/home/csuser/Documents/Neural Operator/saved_models/giorom2d_T_Sand.pt')['model']\n",
    "\n",
    "model_dict = simulator.state_dict()\n",
    "ckpt_dict = {}\n",
    "\n",
    "#print(simulator.keys())\n",
    "model_dict = dict(model_dict)\n",
    "\n",
    "for k, v in weights.items():\n",
    "    k2 = k[0:]\n",
    "    #print(k2)\n",
    "    if k2 in model_dict:\n",
    "        #print(k2)\n",
    "        if model_dict[k2].size() == v.size():\n",
    "            ckpt_dict[k2] = v\n",
    "        else:\n",
    "            print(\"Size mismatch while loading! %s != %s Skipping %s...\"%(str(model_dict[k2].size()), str(v.size()), k2))\n",
    "            mismatch = True\n",
    "    else:\n",
    "        print(\"Model Dict not in Saved Dict! %s != %s Skipping %s...\"%(2, str(v.size()), k2))\n",
    "        mismatch = True\n",
    "if len(simulator.state_dict().keys()) > len(ckpt_dict.keys()):\n",
    "    mismatch = True\n",
    "model_dict.update(ckpt_dict)\n",
    "simulator.load_state_dict(model_dict)\n",
    "\n",
    "\n",
    "#simulator.load_state_dict(weights['model'])\n",
    "simulator = simulator.cuda()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rollout(model, data, metadata, noise_std):\n",
    "    device = next(model.parameters()).device\n",
    "    model.eval()\n",
    "    window_size = model.window_size + 1\n",
    "    total_time = data[\"position\"].size(0)\n",
    "    #print(\"Total Time = \", total_time)\n",
    "    \n",
    "    traj = data[\"position\"][:window_size]\n",
    "    #print(\"TRAJ SHAPE = \", traj.shape)\n",
    "    traj = traj.permute(1, 0, 2)\n",
    "    particle_type = data[\"particle_type\"]\n",
    "\n",
    "\n",
    "    for time in range(total_time - window_size):\n",
    "        with torch.no_grad():\n",
    "            #print(\"PARTICLE TYPE = \", particle_type.shape)\n",
    "            #print(\"TRAJECTORY = \", traj.shape)\n",
    "            graph = preprocess(particle_type, traj[:, -window_size:], None, metadata, 0.0, radius=params['connectivity_radius'], graph_type=params['graph_type'])\n",
    "            graph = graph.to(device)\n",
    "            acceleration = model(graph).cpu()\n",
    "            acceleration = acceleration * torch.sqrt(torch.tensor(metadata[\"acc_std\"]) ** 2 + noise_std ** 2) + torch.tensor(metadata[\"acc_mean\"])\n",
    "\n",
    "            recent_position = traj[:, -1]\n",
    "            recent_velocity = recent_position - traj[:, -2]\n",
    "            new_velocity = recent_velocity + acceleration\n",
    "            new_position = recent_position + new_velocity\n",
    "            traj = torch.cat((traj, new_position.unsqueeze(1)), dim=1)\n",
    "    return traj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "def avg_velocity(rollout):\n",
    "    velocity_seq = rollout[:,1:] - rollout[:, :-1]\n",
    "    print(velocity_seq.shape)\n",
    "    return velocity_seq.numpy().max()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "def avg_spacing(rollout):\n",
    "    spacing_vectors = []\n",
    "    for i in range(1000):\n",
    "        distances = torch.cdist(rollout[i], rollout[i]).flatten()\n",
    "        spacing_vectors.append(distances[distances.nonzero()].min())\n",
    "    return np.asarray(spacing_vectors).mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "metadata": {}
   },
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "mat1 and mat2 shapes cannot be multiplied (1791x30 and 37x128)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[6], line 89\u001b[0m\n\u001b[1;32m     86\u001b[0m rollout_data \u001b[38;5;241m=\u001b[39m rollout_dataset[sim_id]\n\u001b[1;32m     87\u001b[0m temp \u001b[38;5;241m=\u001b[39m rollout_data[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mposition\u001b[39m\u001b[38;5;124m'\u001b[39m][\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m---> 89\u001b[0m rollout_out \u001b[38;5;241m=\u001b[39m \u001b[43mrollout\u001b[49m\u001b[43m(\u001b[49m\u001b[43msimulator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrollout_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrollout_dataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmetadata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mnoise\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     90\u001b[0m \u001b[38;5;66;03m#print(avg_velocity(rollout_out))\u001b[39;00m\n\u001b[1;32m     91\u001b[0m rollout_out \u001b[38;5;241m=\u001b[39m rollout_out\u001b[38;5;241m.\u001b[39mpermute(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m2\u001b[39m)\n",
      "Cell \u001b[0;32mIn[3], line 20\u001b[0m, in \u001b[0;36mrollout\u001b[0;34m(model, data, metadata, noise_std)\u001b[0m\n\u001b[1;32m     18\u001b[0m graph \u001b[38;5;241m=\u001b[39m preprocess(particle_type, traj[:, \u001b[38;5;241m-\u001b[39mwindow_size:], \u001b[38;5;28;01mNone\u001b[39;00m, metadata, \u001b[38;5;241m0.0\u001b[39m, radius\u001b[38;5;241m=\u001b[39mparams[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mconnectivity_radius\u001b[39m\u001b[38;5;124m'\u001b[39m], graph_type\u001b[38;5;241m=\u001b[39mparams[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mgraph_type\u001b[39m\u001b[38;5;124m'\u001b[39m])\n\u001b[1;32m     19\u001b[0m graph \u001b[38;5;241m=\u001b[39m graph\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m---> 20\u001b[0m acceleration \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgraph\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mcpu()\n\u001b[1;32m     21\u001b[0m acceleration \u001b[38;5;241m=\u001b[39m acceleration \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39msqrt(torch\u001b[38;5;241m.\u001b[39mtensor(metadata[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124macc_std\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;241m+\u001b[39m noise_std \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2\u001b[39m) \u001b[38;5;241m+\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor(metadata[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124macc_mean\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m     23\u001b[0m recent_position \u001b[38;5;241m=\u001b[39m traj[:, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\n",
      "File \u001b[0;32m~/mambaforge/envs/tim/lib/python3.8/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/mambaforge/envs/tim/lib/python3.8/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/Documents/Neural Operator/models/giorom3d_T.py:180\u001b[0m, in \u001b[0;36mPhysicsEngine.forward\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m    175\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, data):\n\u001b[1;32m    176\u001b[0m     \u001b[38;5;66;03m# pre-processing\u001b[39;00m\n\u001b[1;32m    177\u001b[0m     \n\u001b[1;32m    178\u001b[0m     \u001b[38;5;66;03m# node feature: combine categorial feature data.x and contiguous feature data.pos.\u001b[39;00m\n\u001b[1;32m    179\u001b[0m     node_feature \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat((\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membed_type(data\u001b[38;5;241m.\u001b[39mx), data\u001b[38;5;241m.\u001b[39mpos), dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m--> 180\u001b[0m     node_feature \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnode_in\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnode_feature\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    181\u001b[0m     edge_feature \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39medge_in(data\u001b[38;5;241m.\u001b[39medge_attr)\n\u001b[1;32m    183\u001b[0m     \u001b[38;5;66;03m# stack of GNN layers\u001b[39;00m\n",
      "File \u001b[0;32m~/mambaforge/envs/tim/lib/python3.8/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/mambaforge/envs/tim/lib/python3.8/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/Documents/Neural Operator/models/layers.py:33\u001b[0m, in \u001b[0;36mMLP.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     31\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[1;32m     32\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m layer \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[0;32m---> 33\u001b[0m         x \u001b[38;5;241m=\u001b[39m \u001b[43mlayer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     34\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m x\n",
      "File \u001b[0;32m~/mambaforge/envs/tim/lib/python3.8/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/mambaforge/envs/tim/lib/python3.8/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/mambaforge/envs/tim/lib/python3.8/site-packages/torch/nn/modules/linear.py:114\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 114\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (1791x30 and 37x128)"
     ]
    }
   ],
   "source": [
    "from collections import OrderedDict\n",
    "from dgl.geometry import farthest_point_sampler\n",
    "from scipy.spatial import Delaunay\n",
    "class RolloutDataset(pyg.data.Dataset):\n",
    "    def __init__(self, data_path, split, window_length=7, sampling=False, sampling_strategy='random', graph_type='radius', mesh_size=170, radius=None):\n",
    "        super().__init__()\n",
    "        \n",
    "        # load data from the disk\n",
    "        self.data_path = data_path\n",
    "        with open(os.path.join('/home/csuser/Documents/new_dataset/Sand/', \"metadata.json\")) as f:\n",
    "            self.metadata = json.load(f)\n",
    "        #with open(os.path.join(data_path, f\"{split}_offset.json\")) as f:\n",
    "        #    self.offset = json.load(f)\n",
    "        #self.offset = {int(k): v for k, v in self.offset.items()}\n",
    "        self.window_length = window_length\n",
    "        self.sampling = sampling\n",
    "        self.sampling_strategy = sampling_strategy\n",
    "        self.graph_type = graph_type\n",
    "        self.mesh_size=  mesh_size\n",
    "        self.radius = radius\n",
    "        #self.particle_type = np.memmap(os.path.join(data_path, f\"{split}_particle_type.dat\"), dtype=np.int64, mode=\"r\")\n",
    "        #self.position = np.memmap(os.path.join(data_path, f\"{split}_position.dat\"), dtype=np.float32, mode=\"r\")\n",
    "        dataset = torch.load(data_path)\n",
    "        self.particle_type = dataset['particle_type']\n",
    "        self.position = dataset['position']\n",
    "        self.n_particles_per_example = dataset['n_particles_per_example']\n",
    "        self.outputs = dataset['output']\n",
    "\n",
    "        if(self.sampling == True or self.sampling==False):\n",
    "            #mesh_size =  np.random.randint(int(0.30*360), int(0.45*360))\n",
    "            mesh_size = 120\n",
    "            while(mesh_size %10 !=0):\n",
    "                mesh_size += 1\n",
    "            \n",
    "            #points = list(range(0, 360, 4))\n",
    "            points = sorted(random.sample(range(0, 360), mesh_size))\n",
    "        self.points = points\n",
    "        #for traj in self.offset.values():\n",
    "        #    self.dim = traj[\"position\"][\"shape\"][2]\n",
    "        #    break\n",
    "        self.dim = self.position[0].shape[2]\n",
    "    def len(self):\n",
    "        return len(self.position)\n",
    "    \n",
    "    def get(self, idx):\n",
    "        #traj = self.offset[idx]\n",
    "        #size = traj[\"position\"][\"shape\"][1]\n",
    "        #time_step = traj[\"position\"][\"shape\"][0]\n",
    "        #particle_type = self.particle_type[traj[\"particle_type\"][\"offset\"]: traj[\"particle_type\"][\"offset\"] + size].copy()\n",
    "        #particle_type = torch.from_numpy(particle_type)\n",
    "        #position = self.position[traj[\"position\"][\"offset\"]: traj[\"position\"][\"offset\"] + time_step * size * self.dim].copy()\n",
    "        #position.resize(traj[\"position\"][\"shape\"])\n",
    "        #position = torch.from_numpy(position)\n",
    "\n",
    "        particle_type = torch.from_numpy(self.particle_type[idx])\n",
    "        position_seq = torch.from_numpy(self.position[idx])\n",
    "        position_seq = torch.permute(position_seq, dims=(1,0,2))\n",
    "        \n",
    "        target_position = torch.from_numpy(self.outputs[idx])\n",
    "        if(self.sampling):\n",
    "            if(self.sampling_strategy == 'random'):\n",
    "                self.points = sorted(random.sample(range(0, particle_type.shape[0]), self.mesh_size))\n",
    "                particle_type = particle_type[self.points]\n",
    "                position_seq = position_seq.permute(1,0,2)\n",
    "                position_seq = position_seq[self.points]\n",
    "                position_seq = position_seq.permute(1,0,2)\n",
    "                target_position = target_position[self.points]\n",
    "            elif(self.sampling_strategy == 'fps'):\n",
    "                init_pos = position_seq[0].unsqueeze(0)\n",
    "                point_idx = farthest_point_sampler(init_pos, self.mesh_size)[0]\n",
    "                particle_type = particle_type[point_idx]\n",
    "                position_seq = position_seq.permute(1, 0, 2)\n",
    "                position_seq = position_seq[point_idx]\n",
    "                position_seq = position_seq.permute(1, 0, 2)\n",
    "                target_position = target_position[point_idx]\n",
    "                print(particle_type.shape)\n",
    "                print(position_seq.shape)\n",
    "        data = {\"particle_type\": particle_type, \"position\": position_seq}\n",
    "        return data\n",
    "\n",
    "rollout_dataset = RolloutDataset('/home/csuser/Documents/new_dataset/Sand/rollout.pt', \"train\", sampling=params['sampling'], sampling_strategy=params['sampling_strategy'], \n",
    "                                 graph_type=params['graph_type'],radius=params['connectivity_radius'], mesh_size=params['mesh_size'])\n",
    "#print(len(rollout_dataset))\n",
    "simulator.eval()\n",
    "sim_id = 0\n",
    "rollout_data = rollout_dataset[sim_id]\n",
    "temp = rollout_data['position'][0]\n",
    "\n",
    "rollout_out = rollout(simulator, rollout_data, rollout_dataset.metadata, params[\"noise\"])\n",
    "#print(avg_velocity(rollout_out))\n",
    "rollout_out = rollout_out.permute(1, 0, 2)\n",
    "loss = (rollout_out - rollout_data[\"position\"]) ** 2\n",
    "loss = loss.sum(dim=-1).mean()\n",
    "print(\"Rollout Loss: \", loss)\n",
    "print(rollout_out[1:,:].shape)\n",
    "print(rollout_out.shape)\n",
    "\n",
    "torch.save(rollout_out, f'GIOROM_Sand_sim_{sim_id}.pt')\n",
    "torch.save(rollout_data, f'GIOROM_Sand_sim_{sim_id}_gt.pt')\n",
    "#print(avg_spacing(rollout_out))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from dataloader import OneStepDataset\n",
    "# from utils import visualize_graph\n",
    "\n",
    "# sample_dataset = OneStepDataset('/home/csuser/Documents/new_dataset/WaterDropSmall/', \"train.pt\", noise_std=params['noise'], return_pos=True, sampling=params['sampling'], sampling_strategy=params['sampling_strategy'], \n",
    "#                                  graph_type=params['graph_type'],radius=params['connectivity_radius'])\n",
    "# visualize_graph(sample_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import animation\n",
    "from IPython.display import HTML"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "TYPE_TO_COLOR = {\n",
    "    3: \"black\",\n",
    "    0: \"green\",\n",
    "    7: \"magenta\",\n",
    "    6: \"gold\",\n",
    "    5: \"#14b1f5\",\n",
    "}\n",
    "\n",
    "\n",
    "def visualize_prepare(ax, particle_type, position, metadata):\n",
    "    bounds = metadata[\"bounds\"]\n",
    "    ax.set_xlim(bounds[0][0], bounds[0][1])\n",
    "    ax.set_ylim(bounds[1][0], bounds[1][1])\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    ax.set_aspect(1.0)\n",
    "    points = {type_: ax.plot([], [], \"o\", ms=2, color=color)[0] for type_, color in TYPE_TO_COLOR.items()}\n",
    "    return ax, position, points\n",
    "\n",
    "\n",
    "def visualize_pair(particle_type, position_pred, position_gt, metadata):\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n",
    "    plot_info = [\n",
    "        visualize_prepare(axes[0], particle_type, position_gt, metadata),\n",
    "        visualize_prepare(axes[1], particle_type, position_pred, metadata),\n",
    "    ]\n",
    "    axes[0].set_title(\"Ground truth\")\n",
    "    axes[1].set_title(\"Prediction\")\n",
    "\n",
    "    plt.close()\n",
    "    def update(step_i):\n",
    "        outputs = []\n",
    "        for _, position, points in plot_info:\n",
    "            for type_, line in points.items():\n",
    "                mask = particle_type == type_\n",
    "                line.set_data(position[step_i, mask, 0], position[step_i, mask, 1])\n",
    "            outputs.append(line)\n",
    "        return outputs\n",
    "\n",
    "    return animation.FuncAnimation(fig, update, frames=np.arange(0, position_gt.size(0)), interval=10, blit=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "anim = visualize_pair(rollout_data[\"particle_type\"], rollout_out, rollout_data[\"position\"], rollout_dataset.metadata)\n",
    "writer = animation.writers['ffmpeg'](fps=30)\n",
    "anim.save('disc_4.mp4',writer=writer,dpi=200)\n",
    "HTML(anim.to_html5_video())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nclaw",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
