{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "eeb3333f-d989-42dc-839c-3c53678dc2de",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "ename": "ImportError",
     "evalue": "cannot import name 'vmap'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mImportError\u001b[0m                               Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-1-92dfd219da15>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/path/to/dir/mpd-public'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mmpd\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     14\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_dataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_summary\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/mpd-public/mpd/trainer/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mtrain_loaders\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/mpd-public/mpd/trainer/train_loaders.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrandom_split\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mmpd\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmodels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlosses\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msummaries\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      7\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmodel_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpretrain_helper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch_robotics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtorch_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtorch_utils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mfreeze_torch_model_params\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/mpd-public/mpd/datasets/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mtrajectories\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/mpd-public/mpd/datasets/trajectories.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnormalization\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDatasetNormalizer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloading\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mload_params_from_yaml\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch_robotics\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0menvironments\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrobots\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     12\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch_robotics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvironments\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mEnvDense2DExtraObjects\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch_robotics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvironments\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menv_simple_2d_extra_objects\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mEnvSimple2DExtraObjects\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/path/to/dir/mpd-public/deps/torch_robotics/torch_robotics/environments/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0menv_simple_2d\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0menv_simple_2d_extra_objects\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0menv_circle_2d\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0menv_dense_2d\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0menv_dense_2d_extra_objects\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/path/to/dir/mpd-public/deps/torch_robotics/torch_robotics/environments/env_simple_2d.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch_robotics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvironments\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimitives\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mObjectField\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMultiSphereField\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMultiBoxField\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch_robotics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvironments\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcreate_grid_spheres\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch_robotics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrobots\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mRobotPointMass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      9\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch_robotics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtorch_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtorch_utils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDEFAULT_TENSOR_ARGS\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch_robotics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvisualizers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplanning_visualizer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcreate_fig_and_axes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/path/to/dir/mpd-public/deps/torch_robotics/torch_robotics/robots/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mrobot_panda\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mrobot_point_mass\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mrobot_planar2link\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/path/to/dir/mpd-public/deps/torch_robotics/torch_robotics/robots/robot_panda.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0myaml\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mvmap\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch_robotics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvironments\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimitives\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mMultiSphereField\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mImportError\u001b[0m: cannot import name 'vmap'"
     ]
    }
   ],
   "source": [
    "from scoremodel_temporal import *\n",
    "import torch\n",
    "import functools\n",
    "from torch.optim import Adam\n",
    "from torch.utils.data import DataLoader\n",
    "import torchvision\n",
    "import tqdm\n",
    "import random\n",
    "\n",
    "import sys\n",
    "sys.path.append('/path/to/dir/mpd-public')\n",
    "\n",
    "from mpd import trainer\n",
    "from mpd.trainer import get_dataset, get_model, get_loss, get_summary\n",
    "\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "\n",
    "score_model = torch.nn.DataParallel(ScoreNet(marginal_prob_std=marginal_prob_std_fn))\n",
    "score_model = score_model.to(device)\n",
    "# checkpoint_path = './ckpt_noise_cond_1.pth'\n",
    "# score_model.load_state_dict(torch.load(checkpoint_path, map_location=device))\n",
    "\n",
    "\n",
    "n_epochs   =  10000\n",
    "batch_size =  32\n",
    "lr=1e-4\n",
    "\n",
    "\n",
    "# from datasets.bouncing_balls import load_balls\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "\n",
    "transform = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.Resize((64, 64)),\n",
    "    torchvision.transforms.ToTensor()\n",
    "])\n",
    "\n",
    "\n",
    "# device = get_torch_device(device=device)\n",
    "tensor_args = {'device': device, 'dtype': torch.float32}\n",
    "\n",
    "dataset_subdir = 'EnvSimple2D-RobotPointMass'\n",
    "results_dir = 'logs'\n",
    "\n",
    "batch_size = 32\n",
    "\n",
    "train_subset, train_dataloader, val_subset, val_dataloader = get_dataset(\n",
    "        dataset_class='TrajectoryDataset',\n",
    "        include_velocity=True,\n",
    "        dataset_subdir=dataset_subdir,\n",
    "        batch_size=batch_size,\n",
    "        results_dir=results_dir,\n",
    "        save_indices=True,\n",
    "        tensor_args=tensor_args\n",
    "    )\n",
    "\n",
    "dataset = train_subset.dataset\n",
    "\n",
    "\n",
    "optimizer = Adam(score_model.parameters(), lr=lr)\n",
    "min_loss = 14.0\n",
    "\n",
    "print(\"start\")\n",
    "\n",
    "for epoch in (range(n_epochs)):\n",
    "  avg_loss = 0.\n",
    "  num_items = 0\n",
    "  for data, _ in data_loader:\n",
    "    \n",
    "    split = random.randrange(1, 5, 1)\n",
    "    \n",
    "    c = data[:, :split, :, :]\n",
    "    \n",
    "    c_batch = []\n",
    "    for batch in range(c.shape[0]):\n",
    "        c_batch.append(torch.stack([c_i.flatten() for _, c_i in enumerate(c[batch])]))   \n",
    "    c = torch.stack(c_batch)\n",
    "    \n",
    "    \n",
    "    data = data[:, split, :, :].unsqueeze(1)\n",
    "    print('1', c.shape, data.shape)\n",
    "    \n",
    "    \n",
    "    # print(data.shape, y.shape)\n",
    "    # raise NotImplementedError()\n",
    "    \n",
    "    data_reshape = []\n",
    "    for i in range(data.shape[1]):\n",
    "            data_reshape.append(data[:, i])\n",
    "    data = torch.cat(data_reshape, dim=0).unsqueeze(1)\n",
    "    x = data.to(device)   \n",
    "    c = c.to(device)\n",
    "\n",
    "\n",
    "    loss = loss_fn(score_model, x, c, marginal_prob_std_fn)\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()    \n",
    "    optimizer.step()\n",
    "    avg_loss += loss.item() * x.shape[0]\n",
    "    num_items += x.shape[0]\n",
    "    \n",
    "  # Print the averaged training loss so far.\n",
    "    \n",
    "  if (avg_loss / num_items) < min_loss:\n",
    "    min_loss = (avg_loss / num_items)\n",
    "    torch.save(score_model.state_dict(), '../autor/ckpt.pth')\n",
    "    print('Average Loss: {:5f}, Epoch: {:5f}'.format(avg_loss / num_items, epoch))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad3b4fdc-ddb7-49cf-af58-a89372a92b77",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"done\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "961d0dcd-4123-4456-9a6e-bef662d00dec",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (MPD)",
   "language": "python",
   "name": "mpd"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
