{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2c7bb001",
   "metadata": {},
   "outputs": [],
   "source": [
    "from model import *\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b1e927bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "abalone = pd.read_csv('synthetic/abalone/real.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "08d70787",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = pd.read_csv('synthetic/abalone/test.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "8bde934c",
   "metadata": {},
   "outputs": [],
   "source": [
    "ttvae = TTVAE(epochs=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "c65b5280",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['sex', 'length', 'diameter', 'height', 'wholeweight', 'shuckedweight',\n",
       "       'visceraweight', 'shellweight', 'rings'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "abalone.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "0d884122",
   "metadata": {},
   "outputs": [],
   "source": [
    "discrete_columns = abalone.columns[abalone.dtypes=='object']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "aeb89cf5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['sex'], dtype='object')"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "discrete_columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "cb695da1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/binhducvu/anaconda3/lib/python3.9/site-packages/torch/nn/modules/transformer.py:379: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n",
      "  warnings.warn(\n",
      "/home/binhducvu/anaconda3/lib/python3.9/site-packages/torch/optim/lr_scheduler.py:62: UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "84 32 128 8 1028 0.1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1/300: 100%|██████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  6.98it/s, Loss=233]\n",
      "Epoch 2/300: 100%|█████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.71it/s, Loss=92.7]\n",
      "Epoch 3/300: 100%|█████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.74it/s, Loss=59.7]\n",
      "Epoch 4/300: 100%|█████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.97it/s, Loss=42.1]\n",
      "Epoch 5/300: 100%|█████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.00it/s, Loss=42.8]\n",
      "Epoch 6/300: 100%|███████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.77it/s, Loss=43]\n",
      "Epoch 7/300: 100%|█████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.13it/s, Loss=39.8]\n",
      "Epoch 8/300: 100%|█████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.19it/s, Loss=36.1]\n",
      "Epoch 9/300: 100%|█████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.18it/s, Loss=37.9]\n",
      "Epoch 10/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.38it/s, Loss=34.8]\n",
      "Epoch 11/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.79it/s, Loss=33.9]\n",
      "Epoch 12/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.61it/s, Loss=29.6]\n",
      "Epoch 13/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.74it/s, Loss=28.7]\n",
      "Epoch 14/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.38it/s, Loss=26.2]\n",
      "Epoch 15/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.01it/s, Loss=24.3]\n",
      "Epoch 16/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.68it/s, Loss=22.5]\n",
      "Epoch 17/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.65it/s, Loss=21.6]\n",
      "Epoch 18/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.66it/s, Loss=21.4]\n",
      "Epoch 19/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.27it/s, Loss=20.4]\n",
      "Epoch 20/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  7.95it/s, Loss=18.5]\n",
      "Epoch 21/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.79it/s, Loss=17.1]\n",
      "Epoch 22/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.79it/s, Loss=16.4]\n",
      "Epoch 23/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.78it/s, Loss=15.2]\n",
      "Epoch 24/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  7.67it/s, Loss=14.1]\n",
      "Epoch 25/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.05it/s, Loss=12.7]\n",
      "Epoch 26/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.71it/s, Loss=10.8]\n",
      "Epoch 27/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.69it/s, Loss=7.47]\n",
      "Epoch 28/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.44it/s, Loss=5.16]\n",
      "Epoch 29/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.21it/s, Loss=3.42]\n",
      "Epoch 30/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.65it/s, Loss=1.16]\n",
      "Epoch 31/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.84it/s, Loss=-.19]\n",
      "Epoch 32/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.48it/s, Loss=-2.05]\n",
      "Epoch 33/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.59it/s, Loss=-4.2]\n",
      "Epoch 34/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.63it/s, Loss=-5.41]\n",
      "Epoch 35/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.75it/s, Loss=-7.78]\n",
      "Epoch 36/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.95it/s, Loss=-9.09]\n",
      "Epoch 37/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.17it/s, Loss=-10.5]\n",
      "Epoch 38/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.03it/s, Loss=-12.4]\n",
      "Epoch 39/300: 100%|█████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.93it/s, Loss=-13]\n",
      "Epoch 40/300: 100%|█████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.33it/s, Loss=-14]\n",
      "Epoch 41/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.84it/s, Loss=-15.6]\n",
      "Epoch 42/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.10it/s, Loss=-16.4]\n",
      "Epoch 43/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.93it/s, Loss=-17.7]\n",
      "Epoch 44/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.25it/s, Loss=-18.5]\n",
      "Epoch 45/300: 100%|█████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.83it/s, Loss=-20]\n",
      "Epoch 46/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.65it/s, Loss=-21.2]\n",
      "Epoch 47/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.12it/s, Loss=-21.7]\n",
      "Epoch 48/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.98it/s, Loss=-22.6]\n",
      "Epoch 49/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.06it/s, Loss=-23.3]\n",
      "Epoch 50/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.01it/s, Loss=-23.4]\n",
      "Epoch 51/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.12it/s, Loss=-23.8]\n",
      "Epoch 52/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.50it/s, Loss=-24.6]\n",
      "Epoch 53/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.00it/s, Loss=-25.1]\n",
      "Epoch 54/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  7.65it/s, Loss=-25.5]\n",
      "Epoch 55/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.82it/s, Loss=-25.2]\n",
      "Epoch 56/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.08it/s, Loss=-26.2]\n",
      "Epoch 57/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.31it/s, Loss=-26.6]\n",
      "Epoch 58/300: 100%|█████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.61it/s, Loss=-27]\n",
      "Epoch 59/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.87it/s, Loss=-27.4]\n",
      "Epoch 60/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.37it/s, Loss=-27.6]\n",
      "Epoch 61/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.67it/s, Loss=-28.2]\n",
      "Epoch 62/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.34it/s, Loss=-28.4]\n",
      "Epoch 63/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  7.81it/s, Loss=-28.7]\n",
      "Epoch 64/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.80it/s, Loss=-28.3]\n",
      "Epoch 65/300: 100%|█████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.71it/s, Loss=-29]\n",
      "Epoch 66/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.81it/s, Loss=-29.2]\n",
      "Epoch 67/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.54it/s, Loss=-29.5]\n",
      "Epoch 68/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.54it/s, Loss=-29.8]\n",
      "Epoch 69/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.89it/s, Loss=-30.5]\n",
      "Epoch 70/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.63it/s, Loss=-30.4]\n",
      "Epoch 71/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.85it/s, Loss=-30.8]\n",
      "Epoch 72/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.28it/s, Loss=-31.3]\n",
      "Epoch 73/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.78it/s, Loss=-31.6]\n",
      "Epoch 74/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.70it/s, Loss=-31.7]\n",
      "Epoch 75/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.20it/s, Loss=-31.9]\n",
      "Epoch 76/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.07it/s, Loss=-31.8]\n",
      "Epoch 77/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.50it/s, Loss=-32.3]\n",
      "Epoch 78/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.56it/s, Loss=-32.7]\n",
      "Epoch 79/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.54it/s, Loss=-32.9]\n",
      "Epoch 80/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.00it/s, Loss=-32.2]\n",
      "Epoch 81/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.21it/s, Loss=-32.9]\n",
      "Epoch 82/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.87it/s, Loss=-32.2]\n",
      "Epoch 83/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.09it/s, Loss=-32.9]\n",
      "Epoch 84/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.82it/s, Loss=-33.9]\n",
      "Epoch 85/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.95it/s, Loss=-34.1]\n",
      "Epoch 86/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.04it/s, Loss=-34.8]\n",
      "Epoch 87/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.51it/s, Loss=-34.7]\n",
      "Epoch 88/300: 100%|█████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.63it/s, Loss=-35]\n",
      "Epoch 89/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.42it/s, Loss=-35.8]\n",
      "Epoch 90/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.68it/s, Loss=-36.1]\n",
      "Epoch 91/300: 100%|█████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.56it/s, Loss=-36]\n",
      "Epoch 92/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.28it/s, Loss=-35.4]\n",
      "Epoch 93/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.66it/s, Loss=-34.8]\n",
      "Epoch 94/300: 100%|█████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.43it/s, Loss=-36]\n",
      "Epoch 95/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.37it/s, Loss=-35.9]\n",
      "Epoch 96/300: 100%|█████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.09it/s, Loss=-36]\n",
      "Epoch 97/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.25it/s, Loss=-37.3]\n",
      "Epoch 98/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.51it/s, Loss=-37.8]\n",
      "Epoch 99/300: 100%|███████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.08it/s, Loss=-37.6]\n",
      "Epoch 100/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.23it/s, Loss=-38]\n",
      "Epoch 101/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  7.77it/s, Loss=-37.9]\n",
      "Epoch 102/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.97it/s, Loss=-38.9]\n",
      "Epoch 103/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.02it/s, Loss=-39.9]\n",
      "Epoch 104/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.37it/s, Loss=-37.8]\n",
      "Epoch 105/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.99it/s, Loss=-39.4]\n",
      "Epoch 106/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.11it/s, Loss=-38.1]\n",
      "Epoch 107/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.13it/s, Loss=-39.3]\n",
      "Epoch 108/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.01it/s, Loss=-39.8]\n",
      "Epoch 109/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.29it/s, Loss=-40.1]\n",
      "Epoch 110/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.94it/s, Loss=-40.1]\n",
      "Epoch 111/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.35it/s, Loss=-38.8]\n",
      "Epoch 112/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.07it/s, Loss=-42]\n",
      "Epoch 113/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.19it/s, Loss=-40.4]\n",
      "Epoch 114/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.14it/s, Loss=-42.8]\n",
      "Epoch 115/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.24it/s, Loss=-35.8]\n",
      "Epoch 116/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.02it/s, Loss=-36.7]\n",
      "Epoch 117/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.96it/s, Loss=-39.3]\n",
      "Epoch 118/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.94it/s, Loss=-40]\n",
      "Epoch 119/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.09it/s, Loss=-42.7]\n",
      "Epoch 120/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.10it/s, Loss=-42.7]\n",
      "Epoch 121/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.66it/s, Loss=-44.9]\n",
      "Epoch 122/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.00it/s, Loss=-44.7]\n",
      "Epoch 123/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.95it/s, Loss=-45.6]\n",
      "Epoch 124/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.05it/s, Loss=-43.7]\n",
      "Epoch 125/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.86it/s, Loss=-44.8]\n",
      "Epoch 126/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.91it/s, Loss=-43.1]\n",
      "Epoch 127/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.98it/s, Loss=-44]\n",
      "Epoch 128/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.11it/s, Loss=-44.8]\n",
      "Epoch 129/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.90it/s, Loss=-46.2]\n",
      "Epoch 130/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.09it/s, Loss=-45.9]\n",
      "Epoch 131/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.03it/s, Loss=-45.6]\n",
      "Epoch 132/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.12it/s, Loss=-43.7]\n",
      "Epoch 133/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.17it/s, Loss=-44.3]\n",
      "Epoch 134/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.97it/s, Loss=-45]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 135/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.29it/s, Loss=-45.4]\n",
      "Epoch 136/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.04it/s, Loss=-47.1]\n",
      "Epoch 137/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.35it/s, Loss=-44.7]\n",
      "Epoch 138/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.67it/s, Loss=-49.8]\n",
      "Epoch 139/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.62it/s, Loss=-49]\n",
      "Epoch 140/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.68it/s, Loss=-41.2]\n",
      "Epoch 141/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.85it/s, Loss=-42.9]\n",
      "Epoch 142/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.94it/s, Loss=-46.1]\n",
      "Epoch 143/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.04it/s, Loss=-46.8]\n",
      "Epoch 144/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.69it/s, Loss=-45.6]\n",
      "Epoch 145/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.10it/s, Loss=-47.9]\n",
      "Epoch 146/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.79it/s, Loss=-47.6]\n",
      "Epoch 147/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.87it/s, Loss=-48.2]\n",
      "Epoch 148/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.58it/s, Loss=-48.5]\n",
      "Epoch 149/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.84it/s, Loss=-50.2]\n",
      "Epoch 150/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.29it/s, Loss=-50.6]\n",
      "Epoch 151/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.19it/s, Loss=-46.9]\n",
      "Epoch 152/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.79it/s, Loss=-48.8]\n",
      "Epoch 153/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.61it/s, Loss=-49.5]\n",
      "Epoch 154/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.40it/s, Loss=-47.9]\n",
      "Epoch 155/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.30it/s, Loss=-50.2]\n",
      "Epoch 156/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.36it/s, Loss=-49.6]\n",
      "Epoch 157/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.57it/s, Loss=-50.9]\n",
      "Epoch 158/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.33it/s, Loss=-49.6]\n",
      "Epoch 159/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.63it/s, Loss=-50.5]\n",
      "Epoch 160/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.59it/s, Loss=-51.3]\n",
      "Epoch 161/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.86it/s, Loss=-50.6]\n",
      "Epoch 162/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.68it/s, Loss=-49.2]\n",
      "Epoch 163/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.86it/s, Loss=-53.1]\n",
      "Epoch 164/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.37it/s, Loss=-50.7]\n",
      "Epoch 165/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.68it/s, Loss=-51.7]\n",
      "Epoch 166/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.52it/s, Loss=-50.2]\n",
      "Epoch 167/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.74it/s, Loss=-52.3]\n",
      "Epoch 168/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.61it/s, Loss=-51.7]\n",
      "Epoch 169/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.65it/s, Loss=-49.8]\n",
      "Epoch 170/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.74it/s, Loss=-53.5]\n",
      "Epoch 171/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.77it/s, Loss=-53]\n",
      "Epoch 172/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.40it/s, Loss=-52.6]\n",
      "Epoch 173/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.15it/s, Loss=-54.1]\n",
      "Epoch 174/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.82it/s, Loss=-52.1]\n",
      "Epoch 175/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.83it/s, Loss=-47]\n",
      "Epoch 176/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.54it/s, Loss=-50.2]\n",
      "Epoch 177/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.90it/s, Loss=-51.4]\n",
      "Epoch 178/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.65it/s, Loss=-54]\n",
      "Epoch 179/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.17it/s, Loss=-55.2]\n",
      "Epoch 180/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.02it/s, Loss=-50.6]\n",
      "Epoch 181/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.02it/s, Loss=-54.2]\n",
      "Epoch 182/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.79it/s, Loss=-54.2]\n",
      "Epoch 183/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.49it/s, Loss=-51]\n",
      "Epoch 184/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.63it/s, Loss=-50.9]\n",
      "Epoch 185/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.60it/s, Loss=-52.7]\n",
      "Epoch 186/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.63it/s, Loss=-52.4]\n",
      "Epoch 187/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.57it/s, Loss=-53.1]\n",
      "Epoch 188/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.71it/s, Loss=-52.4]\n",
      "Epoch 189/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s, Loss=-55.6]\n",
      "Epoch 190/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.62it/s, Loss=-55.4]\n",
      "Epoch 191/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.83it/s, Loss=-55.2]\n",
      "Epoch 192/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.73it/s, Loss=-54.6]\n",
      "Epoch 193/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.88it/s, Loss=-52.6]\n",
      "Epoch 194/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.69it/s, Loss=-51.5]\n",
      "Epoch 195/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.71it/s, Loss=-55.2]\n",
      "Epoch 196/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.85it/s, Loss=-55.4]\n",
      "Epoch 197/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.45it/s, Loss=-56]\n",
      "Epoch 198/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.36it/s, Loss=-51.8]\n",
      "Epoch 199/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.91it/s, Loss=-53.1]\n",
      "Epoch 200/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.23it/s, Loss=-53.6]\n",
      "Epoch 201/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.76it/s, Loss=-56.7]\n",
      "Epoch 202/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.04it/s, Loss=-56.7]\n",
      "Epoch 203/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.39it/s, Loss=-55.9]\n",
      "Epoch 204/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.45it/s, Loss=-58.4]\n",
      "Epoch 205/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.81it/s, Loss=-57.7]\n",
      "Epoch 206/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.45it/s, Loss=-57.6]\n",
      "Epoch 207/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.87it/s, Loss=-59]\n",
      "Epoch 208/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.97it/s, Loss=-56.1]\n",
      "Epoch 209/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.53it/s, Loss=-51]\n",
      "Epoch 210/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.85it/s, Loss=-56.1]\n",
      "Epoch 211/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.28it/s, Loss=-56.4]\n",
      "Epoch 212/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.01it/s, Loss=-55.9]\n",
      "Epoch 213/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.78it/s, Loss=-58]\n",
      "Epoch 214/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.65it/s, Loss=-58]\n",
      "Epoch 215/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.77it/s, Loss=-54.6]\n",
      "Epoch 216/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.77it/s, Loss=-56.2]\n",
      "Epoch 217/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.78it/s, Loss=-57.2]\n",
      "Epoch 218/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.85it/s, Loss=-52.6]\n",
      "Epoch 219/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.94it/s, Loss=-56]\n",
      "Epoch 220/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  7.10it/s, Loss=-56.9]\n",
      "Epoch 221/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.44it/s, Loss=-59.3]\n",
      "Epoch 222/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.52it/s, Loss=-58.5]\n",
      "Epoch 223/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.87it/s, Loss=-56]\n",
      "Epoch 224/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.90it/s, Loss=-56.5]\n",
      "Epoch 225/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.81it/s, Loss=-58.9]\n",
      "Epoch 226/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.89it/s, Loss=-59.5]\n",
      "Epoch 227/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.05it/s, Loss=-57.5]\n",
      "Epoch 228/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.14it/s, Loss=-59.9]\n",
      "Epoch 229/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.21it/s, Loss=-55.6]\n",
      "Epoch 230/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.50it/s, Loss=-57.5]\n",
      "Epoch 231/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.00it/s, Loss=-58.6]\n",
      "Epoch 232/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.87it/s, Loss=-57.2]\n",
      "Epoch 233/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.99it/s, Loss=-57.3]\n",
      "Epoch 234/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.96it/s, Loss=-57.1]\n",
      "Epoch 235/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.93it/s, Loss=-58.5]\n",
      "Epoch 236/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.80it/s, Loss=-56.8]\n",
      "Epoch 237/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.74it/s, Loss=-58.1]\n",
      "Epoch 238/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.80it/s, Loss=-60.5]\n",
      "Epoch 239/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.71it/s, Loss=-60.2]\n",
      "Epoch 240/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.99it/s, Loss=-56.3]\n",
      "Epoch 241/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.82it/s, Loss=-53.7]\n",
      "Epoch 242/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.67it/s, Loss=-57.4]\n",
      "Epoch 243/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.79it/s, Loss=-55.4]\n",
      "Epoch 244/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.94it/s, Loss=-56.5]\n",
      "Epoch 245/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.44it/s, Loss=-58]\n",
      "Epoch 246/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.78it/s, Loss=-61.6]\n",
      "Epoch 247/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.79it/s, Loss=-58.3]\n",
      "Epoch 248/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.77it/s, Loss=-60.2]\n",
      "Epoch 249/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.08it/s, Loss=-62.1]\n",
      "Epoch 250/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.88it/s, Loss=-62.5]\n",
      "Epoch 251/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.84it/s, Loss=-58.1]\n",
      "Epoch 252/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.19it/s, Loss=-61.1]\n",
      "Epoch 253/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.93it/s, Loss=-59.3]\n",
      "Epoch 254/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.95it/s, Loss=-59.2]\n",
      "Epoch 255/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.55it/s, Loss=-58.9]\n",
      "Epoch 256/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.99it/s, Loss=-58.9]\n",
      "Epoch 257/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.95it/s, Loss=-59.2]\n",
      "Epoch 258/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.71it/s, Loss=-59.9]\n",
      "Epoch 259/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.42it/s, Loss=-60.2]\n",
      "Epoch 260/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.64it/s, Loss=-62]\n",
      "Epoch 261/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.04it/s, Loss=-57]\n",
      "Epoch 262/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.24it/s, Loss=-56.4]\n",
      "Epoch 263/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.87it/s, Loss=-61.4]\n",
      "Epoch 264/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.81it/s, Loss=-60.3]\n",
      "Epoch 265/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.80it/s, Loss=-57.5]\n",
      "Epoch 266/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.88it/s, Loss=-59.3]\n",
      "Epoch 267/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.98it/s, Loss=-58.2]\n",
      "Epoch 268/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.18it/s, Loss=-58.7]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 269/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.72it/s, Loss=-60.9]\n",
      "Epoch 270/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.09it/s, Loss=-62.3]\n",
      "Epoch 271/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.12it/s, Loss=-59.2]\n",
      "Epoch 272/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.81it/s, Loss=-62.2]\n",
      "Epoch 273/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.53it/s, Loss=-66.2]\n",
      "Epoch 274/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.47it/s, Loss=-60.4]\n",
      "Epoch 275/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.51it/s, Loss=-61.2]\n",
      "Epoch 276/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.24it/s, Loss=-60.7]\n",
      "Epoch 277/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.14it/s, Loss=-56.1]\n",
      "Epoch 278/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.81it/s, Loss=-57.2]\n",
      "Epoch 279/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.67it/s, Loss=-61.6]\n",
      "Epoch 280/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.06it/s, Loss=-59.2]\n",
      "Epoch 281/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.98it/s, Loss=-60.8]\n",
      "Epoch 282/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.72it/s, Loss=-50.6]\n",
      "Epoch 283/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.52it/s, Loss=-61.9]\n",
      "Epoch 284/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.25it/s, Loss=-61.8]\n",
      "Epoch 285/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.84it/s, Loss=-58.8]\n",
      "Epoch 286/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.89it/s, Loss=-61]\n",
      "Epoch 287/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.10it/s, Loss=-61.3]\n",
      "Epoch 288/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.69it/s, Loss=-63.3]\n",
      "Epoch 289/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.93it/s, Loss=-58.9]\n",
      "Epoch 290/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.05it/s, Loss=-61.4]\n",
      "Epoch 291/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.90it/s, Loss=-63.7]\n",
      "Epoch 292/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.02it/s, Loss=-60.3]\n",
      "Epoch 293/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.03it/s, Loss=-61.9]\n",
      "Epoch 294/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.11it/s, Loss=-59]\n",
      "Epoch 295/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.91it/s, Loss=-62.4]\n",
      "Epoch 296/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.87it/s, Loss=-62.9]\n",
      "Epoch 297/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.00it/s, Loss=-62.3]\n",
      "Epoch 298/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.76it/s, Loss=-62.5]\n",
      "Epoch 299/300: 100%|██████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.30it/s, Loss=-60.1]\n",
      "Epoch 300/300: 100%|████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.11it/s, Loss=-64]\n"
     ]
    }
   ],
   "source": [
    "ttvae.fit(abalone, discrete_columns, 'ttvae/ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "2bf435c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "transformed = ttvae.transformer.transform(test)\n",
    "transformed = torch.from_numpy(transformed.astype('float32')).to('cuda:0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "460c0bba",
   "metadata": {},
   "outputs": [],
   "source": [
    "mu, std, logvar, enc_embed = ttvae.encoder(transformed)\n",
    "synthetic_embeddings=mu\n",
    "noise = torch.Tensor(synthetic_embeddings).to(ttvae._device)\n",
    "ttvae.decoder.eval()\n",
    "with torch.no_grad():\n",
    "    fake, sigmas = ttvae.decoder(noise,enc_embed)\n",
    "    fake = torch.tanh(fake).cpu().detach().numpy()\n",
    "final = ttvae.transformer.inverse_transform(fake)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "5c231396",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sex</th>\n",
       "      <th>length</th>\n",
       "      <th>diameter</th>\n",
       "      <th>height</th>\n",
       "      <th>wholeweight</th>\n",
       "      <th>shuckedweight</th>\n",
       "      <th>visceraweight</th>\n",
       "      <th>shellweight</th>\n",
       "      <th>rings</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>F</td>\n",
       "      <td>0.582236</td>\n",
       "      <td>0.458723</td>\n",
       "      <td>0.164730</td>\n",
       "      <td>0.979701</td>\n",
       "      <td>0.345323</td>\n",
       "      <td>0.252814</td>\n",
       "      <td>0.307930</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>I</td>\n",
       "      <td>0.407512</td>\n",
       "      <td>0.304142</td>\n",
       "      <td>0.092103</td>\n",
       "      <td>0.291460</td>\n",
       "      <td>0.140331</td>\n",
       "      <td>0.057360</td>\n",
       "      <td>0.071587</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>M</td>\n",
       "      <td>0.281496</td>\n",
       "      <td>0.219529</td>\n",
       "      <td>0.074950</td>\n",
       "      <td>0.110056</td>\n",
       "      <td>0.039673</td>\n",
       "      <td>0.019856</td>\n",
       "      <td>0.030733</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>M</td>\n",
       "      <td>0.613830</td>\n",
       "      <td>0.462636</td>\n",
       "      <td>0.170174</td>\n",
       "      <td>1.058509</td>\n",
       "      <td>0.477705</td>\n",
       "      <td>0.274402</td>\n",
       "      <td>0.267761</td>\n",
       "      <td>10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>F</td>\n",
       "      <td>0.470045</td>\n",
       "      <td>0.355667</td>\n",
       "      <td>0.115330</td>\n",
       "      <td>0.490489</td>\n",
       "      <td>0.192078</td>\n",
       "      <td>0.130016</td>\n",
       "      <td>0.151151</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>831</th>\n",
       "      <td>F</td>\n",
       "      <td>0.391565</td>\n",
       "      <td>0.298649</td>\n",
       "      <td>0.096326</td>\n",
       "      <td>0.232699</td>\n",
       "      <td>0.077161</td>\n",
       "      <td>0.051461</td>\n",
       "      <td>0.076708</td>\n",
       "      <td>10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>832</th>\n",
       "      <td>F</td>\n",
       "      <td>0.478929</td>\n",
       "      <td>0.326496</td>\n",
       "      <td>0.137793</td>\n",
       "      <td>0.577165</td>\n",
       "      <td>0.212173</td>\n",
       "      <td>0.112431</td>\n",
       "      <td>0.207373</td>\n",
       "      <td>15</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>833</th>\n",
       "      <td>I</td>\n",
       "      <td>0.555228</td>\n",
       "      <td>0.426127</td>\n",
       "      <td>0.144914</td>\n",
       "      <td>0.804024</td>\n",
       "      <td>0.345306</td>\n",
       "      <td>0.179452</td>\n",
       "      <td>0.221822</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>834</th>\n",
       "      <td>M</td>\n",
       "      <td>0.465924</td>\n",
       "      <td>0.363853</td>\n",
       "      <td>0.080361</td>\n",
       "      <td>0.488604</td>\n",
       "      <td>0.188113</td>\n",
       "      <td>0.125006</td>\n",
       "      <td>0.151681</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>835</th>\n",
       "      <td>I</td>\n",
       "      <td>0.313014</td>\n",
       "      <td>0.235208</td>\n",
       "      <td>0.071889</td>\n",
       "      <td>0.134805</td>\n",
       "      <td>0.053159</td>\n",
       "      <td>0.030641</td>\n",
       "      <td>0.034579</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>836 rows × 9 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    sex    length  diameter    height  wholeweight  shuckedweight  \\\n",
       "0     F  0.582236  0.458723  0.164730     0.979701       0.345323   \n",
       "1     I  0.407512  0.304142  0.092103     0.291460       0.140331   \n",
       "2     M  0.281496  0.219529  0.074950     0.110056       0.039673   \n",
       "3     M  0.613830  0.462636  0.170174     1.058509       0.477705   \n",
       "4     F  0.470045  0.355667  0.115330     0.490489       0.192078   \n",
       "..   ..       ...       ...       ...          ...            ...   \n",
       "831   F  0.391565  0.298649  0.096326     0.232699       0.077161   \n",
       "832   F  0.478929  0.326496  0.137793     0.577165       0.212173   \n",
       "833   I  0.555228  0.426127  0.144914     0.804024       0.345306   \n",
       "834   M  0.465924  0.363853  0.080361     0.488604       0.188113   \n",
       "835   I  0.313014  0.235208  0.071889     0.134805       0.053159   \n",
       "\n",
       "     visceraweight  shellweight  rings  \n",
       "0         0.252814     0.307930     12  \n",
       "1         0.057360     0.071587      8  \n",
       "2         0.019856     0.030733      5  \n",
       "3         0.274402     0.267761     10  \n",
       "4         0.130016     0.151151      8  \n",
       "..             ...          ...    ...  \n",
       "831       0.051461     0.076708     10  \n",
       "832       0.112431     0.207373     15  \n",
       "833       0.179452     0.221822      9  \n",
       "834       0.125006     0.151681     11  \n",
       "835       0.030641     0.034579      6  \n",
       "\n",
       "[836 rows x 9 columns]"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "34455112",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sex</th>\n",
       "      <th>length</th>\n",
       "      <th>diameter</th>\n",
       "      <th>height</th>\n",
       "      <th>wholeweight</th>\n",
       "      <th>shuckedweight</th>\n",
       "      <th>visceraweight</th>\n",
       "      <th>shellweight</th>\n",
       "      <th>rings</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>F</td>\n",
       "      <td>0.585</td>\n",
       "      <td>0.455</td>\n",
       "      <td>0.165</td>\n",
       "      <td>0.9980</td>\n",
       "      <td>0.3450</td>\n",
       "      <td>0.2495</td>\n",
       "      <td>0.315</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>I</td>\n",
       "      <td>0.410</td>\n",
       "      <td>0.300</td>\n",
       "      <td>0.090</td>\n",
       "      <td>0.2800</td>\n",
       "      <td>0.1410</td>\n",
       "      <td>0.0575</td>\n",
       "      <td>0.075</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>M</td>\n",
       "      <td>0.285</td>\n",
       "      <td>0.215</td>\n",
       "      <td>0.075</td>\n",
       "      <td>0.1060</td>\n",
       "      <td>0.0415</td>\n",
       "      <td>0.0230</td>\n",
       "      <td>0.035</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>M</td>\n",
       "      <td>0.615</td>\n",
       "      <td>0.460</td>\n",
       "      <td>0.170</td>\n",
       "      <td>1.0565</td>\n",
       "      <td>0.4815</td>\n",
       "      <td>0.2720</td>\n",
       "      <td>0.270</td>\n",
       "      <td>10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>F</td>\n",
       "      <td>0.470</td>\n",
       "      <td>0.350</td>\n",
       "      <td>0.115</td>\n",
       "      <td>0.4870</td>\n",
       "      <td>0.1955</td>\n",
       "      <td>0.1270</td>\n",
       "      <td>0.155</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>831</th>\n",
       "      <td>F</td>\n",
       "      <td>0.395</td>\n",
       "      <td>0.295</td>\n",
       "      <td>0.095</td>\n",
       "      <td>0.2245</td>\n",
       "      <td>0.0780</td>\n",
       "      <td>0.0540</td>\n",
       "      <td>0.080</td>\n",
       "      <td>10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>832</th>\n",
       "      <td>F</td>\n",
       "      <td>0.465</td>\n",
       "      <td>0.390</td>\n",
       "      <td>0.140</td>\n",
       "      <td>0.5555</td>\n",
       "      <td>0.2130</td>\n",
       "      <td>0.1075</td>\n",
       "      <td>0.215</td>\n",
       "      <td>15</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>833</th>\n",
       "      <td>I</td>\n",
       "      <td>0.555</td>\n",
       "      <td>0.425</td>\n",
       "      <td>0.145</td>\n",
       "      <td>0.7905</td>\n",
       "      <td>0.3485</td>\n",
       "      <td>0.1765</td>\n",
       "      <td>0.225</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>834</th>\n",
       "      <td>M</td>\n",
       "      <td>0.465</td>\n",
       "      <td>0.360</td>\n",
       "      <td>0.080</td>\n",
       "      <td>0.4880</td>\n",
       "      <td>0.1910</td>\n",
       "      <td>0.1250</td>\n",
       "      <td>0.155</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>835</th>\n",
       "      <td>I</td>\n",
       "      <td>0.310</td>\n",
       "      <td>0.230</td>\n",
       "      <td>0.070</td>\n",
       "      <td>0.1245</td>\n",
       "      <td>0.0505</td>\n",
       "      <td>0.0265</td>\n",
       "      <td>0.038</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>836 rows × 9 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    sex  length  diameter  height  wholeweight  shuckedweight  visceraweight  \\\n",
       "0     F   0.585     0.455   0.165       0.9980         0.3450         0.2495   \n",
       "1     I   0.410     0.300   0.090       0.2800         0.1410         0.0575   \n",
       "2     M   0.285     0.215   0.075       0.1060         0.0415         0.0230   \n",
       "3     M   0.615     0.460   0.170       1.0565         0.4815         0.2720   \n",
       "4     F   0.470     0.350   0.115       0.4870         0.1955         0.1270   \n",
       "..   ..     ...       ...     ...          ...            ...            ...   \n",
       "831   F   0.395     0.295   0.095       0.2245         0.0780         0.0540   \n",
       "832   F   0.465     0.390   0.140       0.5555         0.2130         0.1075   \n",
       "833   I   0.555     0.425   0.145       0.7905         0.3485         0.1765   \n",
       "834   M   0.465     0.360   0.080       0.4880         0.1910         0.1250   \n",
       "835   I   0.310     0.230   0.070       0.1245         0.0505         0.0265   \n",
       "\n",
       "     shellweight  rings  \n",
       "0          0.315     12  \n",
       "1          0.075      8  \n",
       "2          0.035      5  \n",
       "3          0.270     10  \n",
       "4          0.155      8  \n",
       "..           ...    ...  \n",
       "831        0.080     10  \n",
       "832        0.215     15  \n",
       "833        0.225      9  \n",
       "834        0.155     11  \n",
       "835        0.038      6  \n",
       "\n",
       "[836 rows x 9 columns]"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "802d2f4f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
