{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import datetime\n",
    "import os\n",
    "import random\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "import yaml\n",
    "\n",
    "import torch\n",
    "from pytorch3d import transforms\n",
    "\n",
    "from edf.utils import preprocess\n",
    "from edf.agent import PlaceAgent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_dir = 'demo/mug_task_rim.gzip'\n",
    "train_config_dir = 'config/train_config/train_place.yaml'\n",
    "agent_config_dir = 'config/agent_config/place_agent.yaml'\n",
    "tp_pickle_dir = \"reproducible_pickles/place/\"\n",
    "visualize = True\n",
    "show_plot = True\n",
    "save_plot = False\n",
    "save_checkpoint = False\n",
    "checkpoint_path = f'checkpoint/train_place/{datetime.datetime.now().strftime(\"%b_%d_%Y__%H_%M_%S\")}/'\n",
    "plot_path = 'logs/train/place_agent/plot/'\n",
    "save_tp = False\n",
    "deterministic = True\n",
    "\n",
    "if save_plot is False:\n",
    "    plot_path = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from edf.train_place import train_place"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_place(sample_dir=sample_dir, train_config_dir=train_config_dir, agent_config_dir=agent_config_dir, \n",
    "            visualize=visualize, show_plot=show_plot, save_plot=save_plot, \n",
    "            save_checkpoint=save_checkpoint, checkpoint_path=checkpoint_path, plot_path=plot_path, save_tp=save_tp, deterministic=deterministic)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.12 ('edf')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "c0953e14757fdc95393fc3797619880897b68d726561afa3ebe46daeb55f7087"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
