{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import datetime\n",
    "import random\n",
    "import numpy as np\n",
    "import yaml\n",
    "import gzip\n",
    "import pickle\n",
    "\n",
    "import torch\n",
    "from pytorch3d import transforms\n",
    "\n",
    "from edf.utils import preprocess\n",
    "from edf.agent import PickAgent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_dir = 'demo/mug_task_rim.gzip'\n",
    "train_config_dir = 'config/train_config/train_pick.yaml'\n",
    "agent_config_dir = 'config/agent_config/pick_agent.yaml'\n",
    "tp_pickle_dir = \"reproducible_pickles/pick/\"\n",
    "visualize = True\n",
    "show_plot = True\n",
    "save_plot = False\n",
    "save_checkpoint = False\n",
    "checkpoint_path = f'checkpoint/train_pick/{datetime.datetime.now().strftime(\"%b_%d_%Y__%H_%M_%S\")}/'\n",
    "plot_path = 'logs/train/pick_agent/plot/'\n",
    "save_tp = False\n",
    "deterministic = True\n",
    "\n",
    "if save_plot is False:\n",
    "    plot_path = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from edf.train_pick import train_pick"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_pick(sample_dir=sample_dir, train_config_dir=train_config_dir, agent_config_dir=agent_config_dir, \n",
    "           visualize=visualize, show_plot=show_plot, save_plot=save_plot, \n",
    "           save_checkpoint=save_checkpoint, checkpoint_path=checkpoint_path, plot_path=plot_path, save_tp=save_tp, deterministic=deterministic)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.12 ('edf')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "c0953e14757fdc95393fc3797619880897b68d726561afa3ebe46daeb55f7087"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
