{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/path/to/dir/.local/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import torch\n",
    "import torchvision\n",
    "import scoremodel\n",
    "\n",
    "# import sys\n",
    "\n",
    "# sys.path.append('/path/to/dir/mpd-public')\n",
    "\n",
    "from experiment_launcher import single_experiment_yaml, run_experiment\n",
    "from mpd import trainer\n",
    "from mpd.models import UNET_DIM_MULTS, TemporalUnet\n",
    "from mpd.trainer import get_dataset, get_model, get_loss, get_summary\n",
    "from mpd.trainer.trainer import get_num_epochs\n",
    "from torch_robotics.torch_utils.seed import fix_random_seed\n",
    "from torch_robotics.torch_utils.torch_utils import get_torch_device, dict_to_device\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# epsilon of step size\n",
    "eps = 1.5e-5\n",
    "\n",
    "# sigma min and max of Langevin dynamic\n",
    "sigma_min = 0.005\n",
    "sigma_max = 10\n",
    "\n",
    "# Langevin step size and Annealed size\n",
    "n_steps = 10\n",
    "annealed_step = 100\n",
    "\n",
    "device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = scoremodel.Model(device, n_steps, sigma_min, sigma_max)\n",
    "optim = torch.optim.Adam(model.parameters(), lr = 0.005)\n",
    "dynamic = scoremodel.AnnealedLangevinDynamic(sigma_min, sigma_max, n_steps, annealed_step, model, device, eps=eps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "---------------Loading data\n",
      "Precomputing the SDF grid and gradients took: 0.466 sec\n",
      "TrajectoryDataset\n",
      "n_trajs: 10000\n",
      "trajectory_dim: (64, 4)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "transform = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.Resize((dynamic.img_size, dynamic.img_size)),\n",
    "    torchvision.transforms.ToTensor()\n",
    "])\n",
    "\n",
    "\n",
    "# device = get_torch_device(device=device)\n",
    "tensor_args = {'device': device, 'dtype': torch.float32}\n",
    "\n",
    "dataset_subdir = 'EnvSimple2D-RobotPointMass'\n",
    "results_dir = 'logs'\n",
    "\n",
    "batch_size = 32\n",
    "\n",
    "train_subset, train_dataloader, val_subset, val_dataloader = get_dataset(\n",
    "        dataset_class='TrajectoryDataset',\n",
    "        include_velocity=True,\n",
    "        dataset_subdir=dataset_subdir,\n",
    "        batch_size=batch_size,\n",
    "        results_dir=results_dir,\n",
    "        save_indices=True,\n",
    "        tensor_args=tensor_args\n",
    "    )\n",
    "\n",
    "dataset = train_subset.dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_iteration = 3000\n",
    "current_iteration = 0\n",
    "display_iteration = 150\n",
    "sampling_number = 8\n",
    "device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')\n",
    "only_final = True\n",
    "\n",
    "run_name = \"trained\"\n",
    "os.makedirs('/path/to/dir/mpd-public/mpd/models/trained/', exist_ok=True)\n",
    "best_val_loss = float('inf')  # Initialize the best validation loss\n",
    "    \n",
    "losses = scoremodel.AverageMeter('Loss', ':.4f')\n",
    "progress = scoremodel.ProgressMeter(total_iteration, [losses], prefix='Iteration ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration [ 829/3000]\tLoss 0.0773 (0.0892)"
     ]
    }
   ],
   "source": [
    "while current_iteration != total_iteration:\n",
    "    \n",
    "    ## Training Routine ##\n",
    "    \n",
    "    model.train()\n",
    "    \n",
    "    for data in train_dataloader:\n",
    "        \n",
    "        \n",
    "        # Convert to tensor\n",
    "        data_new = torch.zeros(data['traj_normalized'].shape[0], 4, 64)\n",
    "        data_new[:, :, :64] = data['traj_normalized'].permute(0, 2, 1)\n",
    "        \n",
    "        data = data_new.to(tensor_args['device']).reshape((data['traj_normalized'].shape[0], 4, 8, 8))\n",
    "                \n",
    "        loss = model.loss_fn(data)\n",
    "\n",
    "        optim.zero_grad()\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "\n",
    "        losses.update(loss.item())\n",
    "        \n",
    "    progress.display(current_iteration)\n",
    "    current_iteration += 1\n",
    "    \n",
    "    \n",
    "    ## Validation Routine ##\n",
    "    \n",
    "    model.eval()\n",
    "    \n",
    "    val_loss_accumulator = 0.0\n",
    "    val_steps = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for data in val_dataloader:\n",
    "            \n",
    "            data_new = torch.zeros(data['traj_normalized'].shape[0], 4, 64)\n",
    "            data_new[:, :, :64] = data['traj_normalized'].permute(0, 2, 1)\n",
    "            data = data_new.to(tensor_args['device']).reshape((data['traj_normalized'].shape[0], 4, 8, 8))\n",
    "            \n",
    "            val_loss = model.loss_fn(data)\n",
    "            val_loss_accumulator += val_loss.item()\n",
    "            val_steps += 1\n",
    "\n",
    "    # Compute average validation loss for the epoch\n",
    "    avg_validation_loss = val_loss_accumulator / val_steps\n",
    "    \n",
    "    # Checkpointing\n",
    "    if avg_validation_loss < best_val_loss:\n",
    "        best_val_loss = avg_validation_loss\n",
    "\n",
    "        # Save original model checkpoint\n",
    "        model_save_path = os.path.join(\"/path/to/dir/mpd-public/mpd/models/\", run_name, f\"ckpt.pt\")\n",
    "        torch.save(model.state_dict(), model_save_path)\n",
    "    \n",
    "        # Optionally save the optimizer state\n",
    "        optimizer_save_path = os.path.join(\"/path/to/dir/mpd-public/mpd/models/\", run_name, f\"optim.pt\")\n",
    "        torch.save(optim.state_dict(), optimizer_save_path)\n",
    "        \n",
    "        \n",
    "    ## Logging ##\n",
    "    \n",
    "    # if current_iteration % display_iteration == 0:\n",
    "        \n",
    "    #     # Save original model checkpoint\n",
    "    #     model_save_path = os.path.join(\"models\", run_name, f\"ckpt_{current_iteration}.pt\")\n",
    "    #     torch.save(model.state_dict(), model_save_path)\n",
    "        \n",
    "    #     dynamic = scoremodel.AnnealedLangevinDynamic(sigma_min, sigma_max, n_steps, annealed_step, model, device, eps=eps)\n",
    "    #     sample = dynamic.sampling(sampling_number, only_final)\n",
    "    #     losses.reset()\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (MPD)",
   "language": "python",
   "name": "mpd"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
