{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch:   0%|                                                                                                                                                                         | 0/10 [00:00<?, ?it/s]\n",
      "Iteration:   0%|                                                                                                                                                                     | 0/15 [00:00<?, ?it/s]\u001b[A\n",
      "Iteration:   7%|██████████▍                                                                                                                                                  | 1/15 [00:00<00:08,  1.57it/s]\u001b[A\n",
      "Iteration:  13%|████████████████████▉                                                                                                                                        | 2/15 [00:01<00:08,  1.57it/s]\u001b[A\n",
      "Iteration:  20%|███████████████████████████████▍                                                                                                                             | 3/15 [00:01<00:07,  1.51it/s]\u001b[A\n",
      "Iteration:  27%|█████████████████████████████████████████▊                                                                                                                   | 4/15 [00:02<00:06,  1.58it/s]\u001b[A\n",
      "Iteration:  33%|████████████████████████████████████████████████████▎                                                                                                        | 5/15 [00:03<00:06,  1.64it/s]\u001b[A\n",
      "Iteration:  40%|██████████████████████████████████████████████████████████████▊                                                                                              | 6/15 [00:03<00:05,  1.58it/s]\u001b[A\n",
      "Iteration:  47%|█████████████████████████████████████████████████████████████████████████▎                                                                                   | 7/15 [00:04<00:04,  1.60it/s]\u001b[A\n",
      "Iteration:  53%|███████████████████████████████████████████████████████████████████████████████████▋                                                                         | 8/15 [00:04<00:04,  1.64it/s]\u001b[A\n",
      "Iteration:  60%|██████████████████████████████████████████████████████████████████████████████████████████████▏                                                              | 9/15 [00:05<00:03,  1.67it/s]\u001b[A\n",
      "Iteration:  67%|████████████████████████████████████████████████████████████████████████████████████████████████████████                                                    | 10/15 [00:05<00:02,  1.69it/s]\u001b[A\n",
      "Iteration:  73%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                         | 11/15 [00:06<00:02,  1.72it/s]\u001b[A\n",
      "Iteration:  80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                               | 12/15 [00:07<00:01,  1.73it/s]\u001b[A\n",
      "Iteration:  87%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                    | 13/15 [00:07<00:01,  1.74it/s]\u001b[A\n",
      "Iteration:  93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌          | 14/15 [00:08<00:00,  1.75it/s]\u001b[A\n",
      "Iteration: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:08<00:00,  1.71it/s]\u001b[A\n",
      "Epoch:  10%|████████████████                                                                                                                                                 | 1/10 [00:08<01:18,  8.76s/it]\n",
      "Iteration:   0%|                                                                                                                                                                     | 0/15 [00:00<?, ?it/s]\u001b[A\n",
      "Iteration:   7%|██████████▍                                                                                                                                                  | 1/15 [00:00<00:06,  2.01it/s]\u001b[A\n",
      "Iteration:  13%|████████████████████▉                                                                                                                                        | 2/15 [00:01<00:06,  1.96it/s]\u001b[A\n",
      "Iteration:  20%|███████████████████████████████▍                                                                                                                             | 3/15 [00:01<00:06,  1.87it/s]\u001b[A\n",
      "Iteration:  27%|█████████████████████████████████████████▊                                                                                                                   | 4/15 [00:02<00:05,  1.88it/s]\u001b[A\n",
      "Iteration:  33%|████████████████████████████████████████████████████▎                                                                                                        | 5/15 [00:02<00:05,  1.85it/s]\u001b[A\n",
      "Iteration:  40%|██████████████████████████████████████████████████████████████▊                                                                                              | 6/15 [00:03<00:04,  1.84it/s]\u001b[A\n",
      "Iteration:  47%|█████████████████████████████████████████████████████████████████████████▎                                                                                   | 7/15 [00:03<00:04,  1.84it/s]\u001b[A\n",
      "Iteration:  53%|███████████████████████████████████████████████████████████████████████████████████▋                                                                         | 8/15 [00:04<00:03,  1.83it/s]\u001b[A\n",
      "Iteration:  60%|██████████████████████████████████████████████████████████████████████████████████████████████▏                                                              | 9/15 [00:04<00:03,  1.81it/s]\u001b[A\n",
      "Iteration:  67%|████████████████████████████████████████████████████████████████████████████████████████████████████████                                                    | 10/15 [00:05<00:02,  1.79it/s]\u001b[A\n",
      "Iteration:  73%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                         | 11/15 [00:06<00:02,  1.78it/s]\u001b[A\n",
      "Iteration:  80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                               | 12/15 [00:06<00:01,  1.78it/s]\u001b[A\n",
      "Iteration:  87%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                    | 13/15 [00:07<00:01,  1.78it/s]\u001b[A\n",
      "Iteration:  93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌          | 14/15 [00:07<00:00,  1.80it/s]\u001b[A\n",
      "Iteration: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:08<00:00,  1.80it/s]\u001b[A\n",
      "Epoch:  20%|████████████████████████████████▏                                                                                                                                | 2/10 [00:17<01:08,  8.51s/it]\n",
      "Iteration:   0%|                                                                                                                                                                     | 0/15 [00:00<?, ?it/s]\u001b[A\n",
      "Iteration:   7%|██████████▍                                                                                                                                                  | 1/15 [00:00<00:08,  1.74it/s]\u001b[A\n",
      "Iteration:  13%|████████████████████▉                                                                                                                                        | 2/15 [00:01<00:07,  1.85it/s]\u001b[A\n",
      "Iteration:  20%|███████████████████████████████▍                                                                                                                             | 3/15 [00:01<00:06,  1.89it/s]\u001b[A\n",
      "Iteration:  27%|█████████████████████████████████████████▊                                                                                                                   | 4/15 [00:02<00:05,  1.87it/s]\u001b[A\n",
      "Iteration:  33%|████████████████████████████████████████████████████▎                                                                                                        | 5/15 [00:02<00:05,  1.88it/s]\u001b[A\n",
      "Iteration:  40%|██████████████████████████████████████████████████████████████▊                                                                                              | 6/15 [00:03<00:04,  1.86it/s]\u001b[A\n",
      "Iteration:  47%|█████████████████████████████████████████████████████████████████████████▎                                                                                   | 7/15 [00:03<00:04,  1.85it/s]\u001b[A\n",
      "Iteration:  53%|███████████████████████████████████████████████████████████████████████████████████▋                                                                         | 8/15 [00:04<00:03,  1.85it/s]\u001b[A\n",
      "Iteration:  60%|██████████████████████████████████████████████████████████████████████████████████████████████▏                                                              | 9/15 [00:04<00:03,  1.83it/s]\u001b[A\n",
      "Iteration:  67%|████████████████████████████████████████████████████████████████████████████████████████████████████████                                                    | 10/15 [00:05<00:02,  1.85it/s]\u001b[A\n",
      "Iteration:  73%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                         | 11/15 [00:05<00:02,  1.87it/s]\u001b[A\n",
      "Iteration:  80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                               | 12/15 [00:06<00:01,  1.86it/s]\u001b[A\n",
      "Iteration:  87%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                    | 13/15 [00:07<00:01,  1.85it/s]\u001b[A\n",
      "Iteration:  93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌          | 14/15 [00:07<00:00,  1.85it/s]\u001b[A\n",
      "Iteration: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:08<00:00,  1.85it/s]\u001b[A\n",
      "Epoch:  30%|████████████████████████████████████████████████▎                                                                                                                | 3/10 [00:25<00:58,  8.32s/it]\n",
      "Iteration:   0%|                                                                                                                                                                     | 0/15 [00:00<?, ?it/s]\u001b[A\n",
      "Iteration:   7%|██████████▍                                                                                                                                                  | 1/15 [00:00<00:08,  1.74it/s]\u001b[A\n",
      "Iteration:  13%|████████████████████▉                                                                                                                                        | 2/15 [00:01<00:07,  1.85it/s]\u001b[A\n",
      "Iteration:  20%|███████████████████████████████▍                                                                                                                             | 3/15 [00:01<00:06,  1.90it/s]\u001b[A\n",
      "Iteration:  27%|█████████████████████████████████████████▊                                                                                                                   | 4/15 [00:02<00:05,  1.91it/s]\u001b[A\n",
      "Iteration:  33%|████████████████████████████████████████████████████▎                                                                                                        | 5/15 [00:02<00:05,  1.90it/s]\u001b[A\n",
      "Iteration:  40%|██████████████████████████████████████████████████████████████▊                                                                                              | 6/15 [00:03<00:04,  1.86it/s]\u001b[A\n",
      "Iteration:  47%|█████████████████████████████████████████████████████████████████████████▎                                                                                   | 7/15 [00:03<00:04,  1.86it/s]\u001b[A\n",
      "Iteration:  53%|███████████████████████████████████████████████████████████████████████████████████▋                                                                         | 8/15 [00:04<00:03,  1.87it/s]\u001b[A\n",
      "Iteration:  60%|██████████████████████████████████████████████████████████████████████████████████████████████▏                                                              | 9/15 [00:04<00:03,  1.85it/s]\u001b[A\n",
      "Iteration:  67%|████████████████████████████████████████████████████████████████████████████████████████████████████████                                                    | 10/15 [00:05<00:02,  1.76it/s]\u001b[A\n",
      "Iteration:  73%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                         | 11/15 [00:06<00:02,  1.68it/s]\u001b[A\n",
      "Iteration:  80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                               | 12/15 [00:07<00:01,  1.66it/s]\u001b[A\n",
      "Iteration:  87%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                    | 13/15 [00:07<00:01,  1.65it/s]\u001b[A\n",
      "Iteration:  93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌          | 14/15 [00:08<00:00,  1.67it/s]\u001b[A\n",
      "Iteration: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:08<00:00,  1.71it/s]\u001b[A\n",
      "Epoch:  40%|████████████████████████████████████████████████████████████████▍                                                                                                | 4/10 [00:33<00:50,  8.50s/it]\n",
      "Iteration:   0%|                                                                                                                                                                     | 0/15 [00:00<?, ?it/s]\u001b[A\n",
      "Iteration:   7%|██████████▍                                                                                                                                                  | 1/15 [00:00<00:08,  1.70it/s]\u001b[A\n",
      "Iteration:  13%|████████████████████▉                                                                                                                                        | 2/15 [00:01<00:07,  1.68it/s]\u001b[A\n",
      "Iteration:  20%|███████████████████████████████▍                                                                                                                             | 3/15 [00:01<00:07,  1.59it/s]\u001b[A\n",
      "Iteration:  27%|█████████████████████████████████████████▊                                                                                                                   | 4/15 [00:02<00:07,  1.56it/s]\u001b[A\n",
      "Iteration:  33%|████████████████████████████████████████████████████▎                                                                                                        | 5/15 [00:03<00:06,  1.64it/s]\u001b[A\n",
      "Iteration:  40%|██████████████████████████████████████████████████████████████▊                                                                                              | 6/15 [00:03<00:05,  1.66it/s]\u001b[A\n",
      "Iteration:  47%|█████████████████████████████████████████████████████████████████████████▎                                                                                   | 7/15 [00:04<00:04,  1.68it/s]\u001b[A\n",
      "Iteration:  53%|███████████████████████████████████████████████████████████████████████████████████▋                                                                         | 8/15 [00:05<00:05,  1.38it/s]\u001b[A\n",
      "Iteration:  60%|██████████████████████████████████████████████████████████████████████████████████████████████▏                                                              | 9/15 [00:06<00:04,  1.42it/s]\u001b[A\n",
      "Iteration:  67%|████████████████████████████████████████████████████████████████████████████████████████████████████████                                                    | 10/15 [00:07<00:03,  1.40it/s]\u001b[A\n",
      "Iteration:  73%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                         | 11/15 [00:07<00:02,  1.44it/s]\u001b[A\n",
      "Iteration:  80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                               | 12/15 [00:08<00:02,  1.47it/s]\u001b[A\n",
      "Iteration:  87%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                    | 13/15 [00:08<00:01,  1.52it/s]\u001b[A\n",
      "Iteration:  93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌          | 14/15 [00:09<00:00,  1.54it/s]\u001b[A\n",
      "Iteration: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:09<00:00,  1.56it/s]\u001b[A\n",
      "Epoch:  50%|████████████████████████████████████████████████████████████████████████████████▌                                                                                | 5/10 [00:43<00:44,  8.92s/it]\n",
      "Iteration:   0%|                                                                                                                                                                     | 0/15 [00:00<?, ?it/s]\u001b[A\n",
      "Iteration:   7%|██████████▍                                                                                                                                                  | 1/15 [00:00<00:08,  1.73it/s]\u001b[A\n",
      "Iteration:  13%|████████████████████▉                                                                                                                                        | 2/15 [00:01<00:07,  1.77it/s]\u001b[A\n",
      "Iteration:  20%|███████████████████████████████▍                                                                                                                             | 3/15 [00:01<00:06,  1.74it/s]\u001b[A\n",
      "Iteration:  27%|█████████████████████████████████████████▊                                                                                                                   | 4/15 [00:02<00:06,  1.72it/s]\u001b[A\n",
      "Iteration:  33%|████████████████████████████████████████████████████▎                                                                                                        | 5/15 [00:03<00:06,  1.55it/s]\u001b[A\n",
      "Iteration:  40%|██████████████████████████████████████████████████████████████▊                                                                                              | 6/15 [00:03<00:05,  1.59it/s]\u001b[A\n",
      "Iteration:  47%|█████████████████████████████████████████████████████████████████████████▎                                                                                   | 7/15 [00:04<00:05,  1.56it/s]\u001b[A\n",
      "Iteration:  53%|███████████████████████████████████████████████████████████████████████████████████▋                                                                         | 8/15 [00:05<00:04,  1.58it/s]\u001b[A\n",
      "Iteration:  60%|██████████████████████████████████████████████████████████████████████████████████████████████▏                                                              | 9/15 [00:05<00:03,  1.57it/s]\u001b[A\n",
      "Iteration:  67%|████████████████████████████████████████████████████████████████████████████████████████████████████████                                                    | 10/15 [00:06<00:03,  1.59it/s]\u001b[A\n",
      "Iteration:  73%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                         | 11/15 [00:06<00:02,  1.61it/s]\u001b[A\n",
      "Iteration:  80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                               | 12/15 [00:07<00:01,  1.63it/s]\u001b[A\n",
      "Iteration:  87%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                    | 13/15 [00:07<00:01,  1.65it/s]\u001b[A\n",
      "Iteration:  93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌          | 14/15 [00:08<00:00,  1.67it/s]\u001b[A\n",
      "Iteration: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:08<00:00,  1.67it/s]\u001b[A\n",
      "Epoch:  60%|████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                | 6/10 [00:52<00:35,  8.94s/it]\n",
      "Iteration:   0%|                                                                                                                                                                     | 0/15 [00:00<?, ?it/s]\u001b[A\n",
      "Iteration:   7%|██████████▍                                                                                                                                                  | 1/15 [00:00<00:08,  1.63it/s]\u001b[A\n",
      "Iteration:  13%|████████████████████▉                                                                                                                                        | 2/15 [00:01<00:07,  1.72it/s]\u001b[A\n",
      "Iteration:  20%|███████████████████████████████▍                                                                                                                             | 3/15 [00:01<00:07,  1.64it/s]\u001b[A\n",
      "Iteration:  27%|█████████████████████████████████████████▊                                                                                                                   | 4/15 [00:02<00:06,  1.57it/s]\u001b[A\n",
      "Iteration:  33%|████████████████████████████████████████████████████▎                                                                                                        | 5/15 [00:03<00:06,  1.62it/s]\u001b[A\n",
      "Iteration:  40%|██████████████████████████████████████████████████████████████▊                                                                                              | 6/15 [00:03<00:05,  1.66it/s]\u001b[A\n",
      "Iteration:  47%|█████████████████████████████████████████████████████████████████████████▎                                                                                   | 7/15 [00:04<00:04,  1.67it/s]\u001b[A\n",
      "Iteration:  53%|███████████████████████████████████████████████████████████████████████████████████▋                                                                         | 8/15 [00:04<00:04,  1.69it/s]\u001b[A\n",
      "Iteration:  60%|██████████████████████████████████████████████████████████████████████████████████████████████▏                                                              | 9/15 [00:05<00:03,  1.71it/s]\u001b[A\n",
      "Iteration:  67%|████████████████████████████████████████████████████████████████████████████████████████████████████████                                                    | 10/15 [00:05<00:02,  1.72it/s]\u001b[A\n",
      "Iteration:  73%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                         | 11/15 [00:06<00:02,  1.75it/s]\u001b[A\n",
      "Iteration:  80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                               | 12/15 [00:06<00:01,  1.75it/s]\u001b[A\n",
      "Iteration:  87%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                    | 13/15 [00:07<00:01,  1.77it/s]\u001b[A\n",
      "Iteration:  93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌          | 14/15 [00:07<00:00,  1.77it/s]\u001b[A\n",
      "Iteration: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:08<00:00,  1.72it/s]\u001b[A\n",
      "Epoch:  70%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                | 7/10 [01:01<00:26,  8.88s/it]\n",
      "Iteration:   0%|                                                                                                                                                                     | 0/15 [00:00<?, ?it/s]\u001b[A\n",
      "Iteration:   7%|██████████▍                                                                                                                                                  | 1/15 [00:00<00:10,  1.34it/s]\u001b[A\n",
      "Iteration:  13%|████████████████████▉                                                                                                                                        | 2/15 [00:01<00:09,  1.36it/s]\u001b[A\n",
      "Iteration:  20%|███████████████████████████████▍                                                                                                                             | 3/15 [00:02<00:08,  1.46it/s]\u001b[A\n",
      "Iteration:  27%|█████████████████████████████████████████▊                                                                                                                   | 4/15 [00:02<00:06,  1.57it/s]\u001b[A\n",
      "Iteration:  33%|████████████████████████████████████████████████████▎                                                                                                        | 5/15 [00:03<00:06,  1.60it/s]\u001b[A\n",
      "Iteration:  40%|██████████████████████████████████████████████████████████████▊                                                                                              | 6/15 [00:03<00:05,  1.65it/s]\u001b[A\n",
      "Iteration:  47%|█████████████████████████████████████████████████████████████████████████▎                                                                                   | 7/15 [00:04<00:04,  1.67it/s]\u001b[A\n",
      "Iteration:  53%|███████████████████████████████████████████████████████████████████████████████████▋                                                                         | 8/15 [00:04<00:04,  1.65it/s]\u001b[A\n",
      "Iteration:  60%|██████████████████████████████████████████████████████████████████████████████████████████████▏                                                              | 9/15 [00:05<00:03,  1.67it/s]\u001b[A\n",
      "Iteration:  67%|████████████████████████████████████████████████████████████████████████████████████████████████████████                                                    | 10/15 [00:06<00:02,  1.67it/s]\u001b[A\n",
      "Iteration:  73%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                         | 11/15 [00:06<00:02,  1.68it/s]\u001b[A\n",
      "Iteration:  80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                               | 12/15 [00:07<00:01,  1.70it/s]\u001b[A\n",
      "Iteration:  87%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                    | 13/15 [00:07<00:01,  1.69it/s]\u001b[A\n",
      "Iteration:  93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌          | 14/15 [00:08<00:00,  1.71it/s]\u001b[A\n",
      "Iteration: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:08<00:00,  1.69it/s]\u001b[A\n",
      "Epoch:  80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                | 8/10 [01:10<00:17,  8.88s/it]\n",
      "Iteration:   0%|                                                                                                                                                                     | 0/15 [00:00<?, ?it/s]\u001b[A\n",
      "Iteration:   7%|██████████▍                                                                                                                                                  | 1/15 [00:00<00:07,  1.90it/s]\u001b[A\n",
      "Iteration:  13%|████████████████████▉                                                                                                                                        | 2/15 [00:01<00:07,  1.82it/s]\u001b[A\n",
      "Iteration:  20%|███████████████████████████████▍                                                                                                                             | 3/15 [00:01<00:06,  1.82it/s]\u001b[A\n",
      "Iteration:  27%|█████████████████████████████████████████▊                                                                                                                   | 4/15 [00:02<00:06,  1.77it/s]\u001b[A\n",
      "Iteration:  33%|████████████████████████████████████████████████████▎                                                                                                        | 5/15 [00:02<00:05,  1.77it/s]\u001b[A\n",
      "Iteration:  40%|██████████████████████████████████████████████████████████████▊                                                                                              | 6/15 [00:03<00:05,  1.77it/s]\u001b[A\n",
      "Iteration:  47%|█████████████████████████████████████████████████████████████████████████▎                                                                                   | 7/15 [00:03<00:04,  1.79it/s]\u001b[A\n",
      "Iteration:  53%|███████████████████████████████████████████████████████████████████████████████████▋                                                                         | 8/15 [00:04<00:03,  1.78it/s]\u001b[A\n",
      "Iteration:  60%|██████████████████████████████████████████████████████████████████████████████████████████████▏                                                              | 9/15 [00:05<00:03,  1.78it/s]\u001b[A\n",
      "Iteration:  67%|████████████████████████████████████████████████████████████████████████████████████████████████████████                                                    | 10/15 [00:05<00:02,  1.78it/s]\u001b[A\n",
      "Iteration:  73%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                         | 11/15 [00:06<00:02,  1.77it/s]\u001b[A\n",
      "Iteration:  80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                               | 12/15 [00:06<00:01,  1.77it/s]\u001b[A\n",
      "Iteration:  87%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                    | 13/15 [00:07<00:01,  1.79it/s]\u001b[A\n",
      "Iteration:  93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌          | 14/15 [00:07<00:00,  1.77it/s]\u001b[A\n",
      "Iteration: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:08<00:00,  1.77it/s]\u001b[A\n",
      "Epoch:  90%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                | 9/10 [01:18<00:08,  8.75s/it]\n",
      "Iteration:   0%|                                                                                                                                                                     | 0/15 [00:00<?, ?it/s]\u001b[A\n",
      "Iteration:   7%|██████████▍                                                                                                                                                  | 1/15 [00:00<00:08,  1.70it/s]\u001b[A\n",
      "Iteration:  13%|████████████████████▉                                                                                                                                        | 2/15 [00:01<00:07,  1.68it/s]\u001b[A\n",
      "Iteration:  20%|███████████████████████████████▍                                                                                                                             | 3/15 [00:01<00:07,  1.68it/s]\u001b[A\n",
      "Iteration:  27%|█████████████████████████████████████████▊                                                                                                                   | 4/15 [00:02<00:06,  1.72it/s]\u001b[A\n",
      "Iteration:  33%|████████████████████████████████████████████████████▎                                                                                                        | 5/15 [00:02<00:05,  1.70it/s]\u001b[A\n",
      "Iteration:  40%|██████████████████████████████████████████████████████████████▊                                                                                              | 6/15 [00:03<00:05,  1.72it/s]\u001b[A\n",
      "Iteration:  47%|█████████████████████████████████████████████████████████████████████████▎                                                                                   | 7/15 [00:04<00:04,  1.70it/s]\u001b[A\n",
      "Iteration:  53%|███████████████████████████████████████████████████████████████████████████████████▋                                                                         | 8/15 [00:04<00:04,  1.71it/s]\u001b[A\n",
      "Iteration:  60%|██████████████████████████████████████████████████████████████████████████████████████████████▏                                                              | 9/15 [00:05<00:03,  1.70it/s]\u001b[A\n",
      "Iteration:  67%|████████████████████████████████████████████████████████████████████████████████████████████████████████                                                    | 10/15 [00:05<00:02,  1.71it/s]\u001b[A\n",
      "Iteration:  73%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                         | 11/15 [00:06<00:02,  1.71it/s]\u001b[A\n",
      "Iteration:  80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                               | 12/15 [00:06<00:01,  1.72it/s]\u001b[A\n",
      "Iteration:  87%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                    | 13/15 [00:07<00:01,  1.72it/s]\u001b[A\n",
      "Iteration:  93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌          | 14/15 [00:08<00:00,  1.72it/s]\u001b[A\n",
      "Iteration: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:08<00:00,  1.71it/s]\u001b[A\n",
      "Epoch: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [01:27<00:00,  8.75s/it]\n"
     ]
    }
   ],
   "source": [
    "\"\"\"Finetune the sentence transformer model on a custom dataset.\"\"\"\n",
    "from sentence_transformers import InputExample\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from torch.utils.data import DataLoader\n",
    "from sentence_transformers import losses\n",
    "from datasets import load_dataset\n",
    "import torch\n",
    "import sys\n",
    "import json, os\n",
    "from pathlib import Path\n",
    "import glob\n",
    "import pandas as pd\n",
    "    \n",
    "# model = SentenceTransformer(\"embedding-data/distilroberta-base-sentence-transformer\")\n",
    "model = SentenceTransformer(\"sentence-transformers/all-MiniLM-L6-v2\")\n",
    "\n",
    "# to merge all the mini-datasets we created\n",
    "csv_file_path = \"combined_data.csv\"\n",
    "# csv_file_path = \"data.csv\"\n",
    "dataset = load_dataset(\"csv\", data_files=csv_file_path)\n",
    "\n",
    "train_examples = []\n",
    "train_data = dataset[\"train\"]\n",
    "n_examples = dataset[\"train\"].num_rows\n",
    "\n",
    "for i in range(n_examples): # can change number of examples for faster training\n",
    "    task = train_data[i][\"Task\"]\n",
    "    function = train_data[i][\"Function\"]\n",
    "    train_examples.append(InputExample(texts=[task, function]))\n",
    "\n",
    "train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=64)\n",
    "train_loss = losses.MultipleNegativesRankingLoss(model=model)\n",
    "num_epochs = 10\n",
    "warmup_steps = int(len(train_dataloader) * num_epochs * 0.1)  # 10% of train data\n",
    "\n",
    "\n",
    "model.fit(\n",
    "    train_objectives=[(train_dataloader, train_loss)],\n",
    "    epochs=num_epochs,\n",
    "    warmup_steps=warmup_steps,\n",
    "    output_path='finetuned_model'\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"../objects_actions.tsv\", sep=\"\\t\")\n",
    "\n",
    "\n",
    "def get_openable_objects():\n",
    "    return df[df[\"Openable\"] == \"yes\"][\"Object Type\"].tolist()\n",
    "\n",
    "\n",
    "def get_pickable_objects():\n",
    "    return df[df[\"Pickupable\"] == \"yes\"][\"Object Type\"].tolist()\n",
    "\n",
    "\n",
    "def get_toggleable_objects():\n",
    "    return df[df[\"On/Off\"] == \"yes\"][\"Object Type\"].tolist()\n",
    "\n",
    "\n",
    "def get_receptacle_objects():\n",
    "    return df[df[\"Receptacle\"] == \"yes\"][\"Object Type\"].tolist()\n",
    "\n",
    "\n",
    "def get_fillable_objects():\n",
    "    return df[df[\"Fillable\"] == \"yes\"][\"Object Type\"].tolist()\n",
    "\n",
    "\n",
    "def get_sliceable_objects():\n",
    "    return df[df[\"Sliceable\"] == \"yes\"][\"Object Type\"].tolist()\n",
    "\n",
    "\n",
    "def get_cookable_objects():\n",
    "    return df[df[\"Cookable\"] == \"yes\"][\"Object Type\"].tolist()\n",
    "\n",
    "\n",
    "def get_breakable_objects():\n",
    "    return df[df[\"Breakable\"] == \"yes\"][\"Object Type\"].tolist()\n",
    "\n",
    "\n",
    "def get_dirty_objects():\n",
    "    return df[df[\"Dirty\"] == \"yes\"][\"Object Type\"].tolist()\n",
    "\n",
    "\n",
    "def get_usable_objects():\n",
    "    return df[df[\"UsedUp\"] == \"yes\"][\"Object Type\"].tolist()\n",
    "\n",
    "\n",
    "all_actions = []\n",
    "open = [f\"OpenObject({obj})\" for obj in get_openable_objects()]\n",
    "close = [f\"CloseObject({obj})\" for obj in get_openable_objects()]\n",
    "pick = [f\"PickupObject({obj})\" for obj in get_pickable_objects()]\n",
    "put = [f\"PutObject({obj})\" for obj in get_receptacle_objects()]\n",
    "toggle_on = [f\"ToggleObjectOn({obj})\" for obj in get_toggleable_objects()]\n",
    "toggle_off = [f\"ToggleObjectOff({obj})\" for obj in get_toggleable_objects()]\n",
    "# fill = [f\"FillObject({obj})\" for obj in get_fillable_objects()]\n",
    "slice = [f\"SliceObject({obj})\" for obj in get_sliceable_objects()]\n",
    "clean = [f\"CleanObject({obj})\" for obj in get_dirty_objects()]\n",
    "# cook = [f\"CookObject({obj})\" for obj in get_cookable_objects()]\n",
    "navigate = [f\"NavigateTo({obj})\" for obj in df[\"Object Type\"].tolist()]\n",
    "rotate = [f\"Rotate({obj})\" for obj in [\"Left\", \"Right\"]]\n",
    "lookup = [f\"LookUp({obj})\" for obj in [30, 60, 90, 120, 150, 180]]\n",
    "lookdown = [f\"LookDown({obj})\" for obj in [30, 60, 90, 120, 150, 180]]\n",
    "move = [f\"Move({obj})\" for obj in [\"Ahead\", \"Back\", \"Left\", \"Right\"]]\n",
    "done = [\"Done\"]\n",
    "idle = [\"Idle\"]\n",
    "\n",
    "all_actions.extend(pick)\n",
    "all_actions.extend(put)\n",
    "all_actions.extend(open)\n",
    "all_actions.extend(close)\n",
    "all_actions.extend(toggle_on)\n",
    "all_actions.extend(toggle_off)\n",
    "all_actions.extend(slice)\n",
    "all_actions.extend(clean)\n",
    "all_actions.extend(navigate)\n",
    "all_actions.extend(rotate)\n",
    "all_actions.extend(lookup)\n",
    "all_actions.extend(lookdown)\n",
    "all_actions.extend(move)\n",
    "all_actions.extend(done)\n",
    "all_actions.extend(idle)\n",
    "\n",
    "embeddings = torch.FloatTensor(model.encode(all_actions))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_closest_feasible_action(action: str):\n",
    "    \"\"\"To convert actions like RotateLeft to Rotate(Left)\"\"\"\n",
    "    action_embedding = torch.FloatTensor(model.encode([action]))\n",
    "    scores = torch.cosine_similarity(embeddings, action_embedding)\n",
    "    max_score, max_idx = torch.max(scores, 0)\n",
    "    return all_actions[max_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rotate(Left)\n"
     ]
    }
   ],
   "source": [
    "action = [\"look up by angle 30\", \"switch object off faucet\"]\n",
    "# print(get_closest_feasible_action(\"look up by angle 90\"))\n",
    "print(get_closest_feasible_action(\"rotate left\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "urop1",
   "language": "python",
   "name": "urop1"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
