{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e1bc3506-f509-4b37-8179-15e7a944c9ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os,sys\n",
    "sys.path.append('/home/quickjkee/projects/Light-GCOT')\n",
    "\n",
    "import warnings\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import anndata as ad\n",
    "import scanpy as sc\n",
    "from sklearn.preprocessing import Normalizer\n",
    "import torch.optim.lr_scheduler as lr_scheduler\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import wandb\n",
    "import moscot.plotting as mtp\n",
    "import scipy\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "from moscot import datasets\n",
    "from moscot.problems.cross_modality import TranslationProblem\n",
    "from sklearn import preprocessing as pp\n",
    "\n",
    "from src.samplers.from_dataset import DatasetSampler\n",
    "from src.samplers.from_loader import PairedLoaderSampler\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from src.samplers.primary import StandardNormalSampler, SwissRollSampler\n",
    "from src.models.light_gcot import LightGCOT\n",
    "from sklearn.manifold import TSNE\n",
    "tsne = TSNE(n_components=2, random_state=50)\n",
    "\n",
    "#https://moscot.readthedocs.io/en/latest/notebooks/tutorials/600_tutorial_translation.html"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a9a6ad3",
   "metadata": {},
   "source": [
    "# Data preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2366fd9",
   "metadata": {},
   "source": [
    "### PS\n",
    "1) Source  \n",
    "$X \\in \\mathbb{R}^{N \\times d_{1}}, N - \\text{number of locations}, d_{1} - \\text{features dim}$ \\\n",
    "$x = (\\mu, \\sigma) - \\text{for a given location in June}$ \\\n",
    "$N = 1396, d_{1} = 188$ \n",
    "\n",
    "2) $Y \\in \\mathbb{R}^{N \\times M \\times d_{2}}, N - \\text{number of locations}, M - \\text{measurements for a given location in January by day}$ \\\n",
    "$M = [1, 31], d_{2} = 94$ "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "02965c47",
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################################\n",
    "#-------------- RAW DATA -----------------\n",
    "##########################################\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "root = '../tabred/kal/weather'\n",
    "\n",
    "data = np.load(f'{root}/X_num.npy')\n",
    "data = np.stack([d for d in data if sum(np.isnan(d)) == 0])\n",
    "data_csv = pd.read_csv(f'{root}/csv/X_num.csv')\n",
    "#train_data = data[train_idx]\n",
    "#test_data = data[test_idx]\n",
    "\n",
    "target = np.load(f'{root}/Y.npy')\n",
    "meta = np.load(f'{root}/X_meta.npy')\n",
    "meta = np.stack([meta[i] for i, d in enumerate(data) if sum(np.isnan(d)) == 0])\n",
    "meta_csv = pd.read_csv(f'{root}/csv/X_meta.csv')\n",
    "\n",
    "names = list(data_csv.columns)\n",
    "names.append('location')\n",
    "data_new = np.concatenate((data, meta[:, -2].reshape(-1, 1)), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "eb1dee44",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1653 1578\n"
     ]
    }
   ],
   "source": [
    "#########################################################\n",
    "#--------------- Month/location splitted ---------------- \n",
    "#########################################################\n",
    "scaler = StandardScaler()\n",
    "\n",
    "dict_location_src = {}\n",
    "for d in data_new:\n",
    "    if d[-2] == 1.0:\n",
    "        d_new = d[:-7]\n",
    "        try:\n",
    "            dict_location_src[d[-1]].append(d_new)\n",
    "        except KeyError:\n",
    "            dict_location_src[d[-1]] = []\n",
    "            dict_location_src[d[-1]].append(d_new)\n",
    "     \n",
    "\n",
    "dict_location_src_new = {}\n",
    "for key in dict_location_src.keys():\n",
    "    item = dict_location_src[key]\n",
    "    item = np.stack(item)\n",
    "    if item.shape[0] > 1:\n",
    "        item = (item - np.min(item, axis=0)) / (np.max(item, axis=0) - np.min(item, axis=0) + 1e-1)\n",
    "        dict_location_src_new[key] = item\n",
    "dict_location_src = dict_location_src_new\n",
    "# ------------------------------------------------------------\n",
    "\n",
    "dict_location_trg = {}\n",
    "for d in data_new:\n",
    "    if d[-2] == 6.0:\n",
    "        d_new = d[:-7]\n",
    "        try:\n",
    "            dict_location_trg[d[-1]].append(d_new)\n",
    "        except KeyError:\n",
    "            dict_location_trg[d[-1]] = []\n",
    "            dict_location_trg[d[-1]].append(d_new)\n",
    "    \n",
    "dict_location_trg_new = {}\n",
    "for key in dict_location_trg.keys():\n",
    "    item = dict_location_trg[key]\n",
    "    item = np.stack(item)\n",
    "    if item.shape[0] > 1:\n",
    "        item = (item - np.min(item, axis=0)) / (np.max(item, axis=0) - np.min(item, axis=0) + 1e-1)\n",
    "        dict_location_trg_new[key] = item\n",
    "dict_location_trg = dict_location_trg_new\n",
    "\n",
    "print(len(dict_location_trg), len(dict_location_src))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c5c56fb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#########################################################\n",
    "#--------------------- X, Y paired ----------------------\n",
    "#########################################################\n",
    "\n",
    "chosen_locs = list(dict_location_trg.keys())[:200]\n",
    "X_pair_orig, Y_pair_orig = [], []\n",
    "for key in dict_location_src.keys():\n",
    "    if key not in chosen_locs:\n",
    "        continue\n",
    "    item_src = dict_location_src[key] \n",
    "    x = np.concatenate([np.mean(item_src, axis=0), np.std(item_src, axis=0)]) # mean, std\n",
    "    X_pair_orig.append(x)\n",
    "    item_trg = dict_location_trg[key]\n",
    "    Y_pair_orig.append(item_trg) # sample\n",
    "X_pair_orig = np.stack(X_pair_orig)\n",
    "\n",
    "\n",
    "#########################################################\n",
    "#----------------------- X, Y ---------------------------\n",
    "#########################################################\n",
    "\n",
    "# N x 1 x 2D - src\n",
    "# N x M x D - trg\n",
    " \n",
    "# sampling: \n",
    "# b x 1 x 2D,\n",
    "# b x M x D -> sample -> b x 1 x D\n",
    "\n",
    "X_orig = []\n",
    "for key in dict_location_src.keys():\n",
    "    if key in chosen_locs:\n",
    "        continue\n",
    "    item_src = dict_location_src[key] \n",
    "    x = np.concatenate([np.mean(item_src, axis=0), np.std(item_src, axis=0)]) # mean, std\n",
    "    X_orig.append(x)\n",
    "X_orig = np.stack(X_orig)\n",
    "\n",
    "Y_orig = []\n",
    "for key in dict_location_trg.keys():\n",
    "    if key in chosen_locs:\n",
    "        continue\n",
    "    item_trg = dict_location_trg[key]\n",
    "    Y_orig.append(item_trg) # sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b44ab264",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1386, 188) 1453 (192, 188) 192\n"
     ]
    }
   ],
   "source": [
    "print(X_orig.shape, len(Y_orig), X_pair_orig.shape, len(Y_pair_orig))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11b6332b",
   "metadata": {},
   "source": [
    "# Running"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f0a1831a-9e8f-480c-abb1-da3f6dc56ff9",
   "metadata": {},
   "outputs": [],
   "source": [
    "source_data = X_orig\n",
    "target_data = Y_orig[0]\n",
    "X_DIM = source_data.shape[1]\n",
    "Y_DIM = target_data.shape[1]\n",
    "#X_DIM = data_set[\"features\"].shape[1]\n",
    "#Y_DIM = data_set[\"features\"].shape[1]\n",
    "assert X_DIM > 1\n",
    "assert Y_DIM > 1\n",
    "\n",
    "OUTPUT_SEED = 50\n",
    "\n",
    "N_POTENTIALS = 10\n",
    "M_POTENTIALS = 1 #10\n",
    "EPSILON = 1\n",
    "A_DIAGONAL_INIT = 0.5\n",
    "L_PAIRED_SAMPLES = len(X_pair_orig)\n",
    "M_X_UNPAIRED_SAMPLES = 0\n",
    "N_Y_UNPAIRED_SAMPLES = 0\n",
    "\n",
    "BATCH_SIZE = 128\n",
    "SAMPLING_BATCH_SIZE = 128\n",
    "\n",
    "D_LR = 3e-4  # 1e-3 for eps 0.1, 0.01 and 3e-4 for eps 0.002\n",
    "D_GRADIENT_MAX_NORM = float(\"inf\")\n",
    "\n",
    "NUM_LABELED = 10\n",
    "TRAIN_SUBSET_SIZE = 2\n",
    "\n",
    "PLOT_EVERY = 1000\n",
    "MAX_STEPS = 20000\n",
    "CONTINUE = -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "dc8ded78-a9b3-4053-972e-e0fbed555f64",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP_COST = \"MLP_deep_deep\"\n",
    "EXP_COST_INCLUDED = True\n",
    "EXP_META_INFO = \"\"\n",
    "EXP_NAME = (\n",
    "    f\"Light-GCOT_Batch_Effect_\"\n",
    "    + f\"EPSILON_{EPSILON}_\"\n",
    "    + f\"N_{N_POTENTIALS}_\"\n",
    "    + f\"M_{M_POTENTIALS}_\"\n",
    "    + f\"with_{EXP_COST}_\"\n",
    "    + f\"cost_included_{EXP_COST_INCLUDED}_\"\n",
    "    + f\"N_PAIRED_{NUM_LABELED}_\"\n",
    "    + f\"M_UNPAIRED_{len(source_data)}_\"\n",
    "    + EXP_META_INFO\n",
    ")\n",
    "OUTPUT_PATH = \"../checkpoints/{}\".format(EXP_NAME)\n",
    "\n",
    "config = dict(\n",
    "    X_DIM=X_DIM,\n",
    "    Y_DIM=Y_DIM,\n",
    "    D_LR=D_LR,\n",
    "    BATCH_SIZE=BATCH_SIZE,\n",
    "    EPSILON=EPSILON,\n",
    "    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,\n",
    "    N_POTENTIALS=N_POTENTIALS,\n",
    "    M_POTENTIALS=M_POTENTIALS,\n",
    "    A_DIAGONAL_INIT=A_DIAGONAL_INIT,\n",
    "    N_PAIRED_SAMPLES=NUM_LABELED,\n",
    "    M_UNPAIRED_SAMPLES=len(source_data),\n",
    ")\n",
    "\n",
    "if not os.path.exists(OUTPUT_PATH):\n",
    "    os.makedirs(OUTPUT_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cb4506b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#pytorch_total_params = sum(p.numel() for p in D.parameters())\n",
    "#pytorch_total_params"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55d32db3-1805-4e51-b492-e1eaaba959e6",
   "metadata": {},
   "source": [
    "## Ablation Study"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "86eea788",
   "metadata": {},
   "outputs": [],
   "source": [
    "def paired_sampler(X_pair, Y_pair, b_size):\n",
    "    idxs = np.random.randint(low=0, high=len(X_pair)-1, size=b_size)\n",
    "    x_pair_batch = torch.tensor(X_pair[idxs]).to('cuda')\n",
    "    y_pair_batch = np.stack([Y_pair[idx][random.randint(0, len(Y_pair[idx])-1)] for idx in idxs])\n",
    "    y_pair_batch = torch.tensor(y_pair_batch).to('cuda')\n",
    "    return x_pair_batch.to(torch.float32), y_pair_batch.to(torch.float32)\n",
    "\n",
    "def unpaired_sampler(X, Y, b_size):\n",
    "    # UNPAIRED SAMPLER\n",
    "    idxs = np.random.randint(low=0, high=len(X)-1, size=b_size)\n",
    "    idxs_y = np.array([len(X) - idx - 1 for idx in idxs])\n",
    "\n",
    "    x_batch = torch.tensor(X[idxs]).to('cuda')\n",
    "    y_batch = np.stack([Y[idx][random.randint(0, len(Y[idx])-1)] for idx in idxs_y])\n",
    "    y_batch = torch.tensor(y_batch).to('cuda')\n",
    "    return x_batch.to(torch.float32), y_batch.to(torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7d300dd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models.models import MyCDiscriminator, MyCGenerator\n",
    "from src.samplers.from_dataset import DatasetSampler\n",
    "from src.samplers.primary import StandardNormalSampler, SwissRollSampler\n",
    "from src.utils.discrete_ot import OTPlanSampler\n",
    "from src.utils.paired import generate_paired_data, get_GT_points, get_paired_sampler\n",
    "import torch.nn.functional as F\n",
    "from src.models.models import ConditionalRealNVP\n",
    "from scipy import linalg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "0f272ef9",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda'\n",
    "T = ConditionalRealNVP(\n",
    "        features=Y_DIM,\n",
    "        context_features=X_DIM,\n",
    "        hidden_context_features=512,\n",
    "        hidden_features=128,\n",
    "        num_blocks_per_layer=4,\n",
    "        num_layers=5,\n",
    "        use_volume_preserving=False,\n",
    "    ).to(device)\n",
    "\n",
    "T_opt_paired = torch.optim.Adam(T.parameters(), lr=3e-4, weight_decay=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "2131aaab",
   "metadata": {},
   "outputs": [],
   "source": [
    "history = {\n",
    "        \"D_loss\": [],\n",
    "        \"G_loss\": [],\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f5391bb1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                         | 5/10000 [00:00<12:52, 12.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: 124.69519805908203\n",
      "Loss: 122.06782531738281\n",
      "Loss: 119.5482406616211\n",
      "Loss: 117.70692443847656\n",
      "Loss: 116.25128173828125\n",
      "Loss: 114.65636444091797\n",
      "Loss: 113.2208251953125\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                        | 13/10000 [00:00<07:11, 23.15it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: 111.58411407470703\n",
      "Loss: 109.9658432006836\n",
      "Loss: 107.77027893066406\n",
      "Loss: 104.80323791503906\n",
      "Loss: 101.0558090209961\n",
      "Loss: 95.69197082519531\n",
      "Loss: 88.51145935058594\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                        | 21/10000 [00:00<06:03, 27.42it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: 79.2396240234375\n",
      "Loss: 67.57830810546875\n",
      "Loss: 57.212257385253906\n",
      "Loss: 45.635231018066406\n",
      "Loss: 31.225608825683594\n",
      "Loss: 19.936628341674805\n",
      "Loss: 14.301422119140625\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|                                        | 25/10000 [00:01<05:44, 28.99it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: 7.919950485229492\n",
      "Loss: 21.823177337646484\n",
      "Loss: 11.065366744995117\n",
      "Loss: 5.51217794418335\n",
      "Loss: 6.1929097175598145\n",
      "Loss: 3.7766590118408203\n",
      "Loss: 1.3375177383422852\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|▏                                       | 33/10000 [00:01<05:05, 32.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: 0.06937730312347412\n",
      "Loss: -1.4516799449920654\n",
      "Loss: -2.1824748516082764\n",
      "Loss: -7.064180374145508\n",
      "Loss: -7.603277683258057\n",
      "Loss: -9.743797302246094\n",
      "Loss: -9.31951904296875\n",
      "Loss: -10.066746711730957\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|▏                                       | 41/10000 [00:01<04:40, 35.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -12.675296783447266\n",
      "Loss: -10.656030654907227\n",
      "Loss: -14.089376449584961\n",
      "Loss: -11.116109848022461\n",
      "Loss: -15.0326509475708\n",
      "Loss: -16.286359786987305\n",
      "Loss: -17.640466690063477\n",
      "Loss: -15.259725570678711\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|▏                                       | 49/10000 [00:01<04:44, 34.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -16.49637222290039\n",
      "Loss: -19.654804229736328\n",
      "Loss: -21.230602264404297\n",
      "Loss: -23.49593734741211\n",
      "Loss: -18.691104888916016\n",
      "Loss: -14.987218856811523\n",
      "Loss: -20.39842987060547\n",
      "Loss: -13.932546615600586\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|▏                                       | 58/10000 [00:01<04:21, 38.03it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -22.476587295532227\n",
      "Loss: -20.25934600830078\n",
      "Loss: -20.330440521240234\n",
      "Loss: -22.38896369934082\n",
      "Loss: -28.75426483154297\n",
      "Loss: -24.87776756286621\n",
      "Loss: -23.20615577697754\n",
      "Loss: -27.150428771972656\n",
      "Loss: -26.66167640686035\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|▎                                       | 66/10000 [00:02<04:30, 36.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -27.445621490478516\n",
      "Loss: -29.908367156982422\n",
      "Loss: -30.042869567871094\n",
      "Loss: -27.938640594482422\n",
      "Loss: -30.935089111328125\n",
      "Loss: -31.776988983154297\n",
      "Loss: -30.937129974365234\n",
      "Loss: -28.911514282226562\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|▎                                       | 74/10000 [00:02<04:29, 36.82it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -30.84530258178711\n",
      "Loss: -31.855754852294922\n",
      "Loss: -32.44939041137695\n",
      "Loss: -25.29859733581543\n",
      "Loss: -21.797876358032227\n",
      "Loss: -28.482566833496094\n",
      "Loss: -21.94959259033203\n",
      "Loss: -26.867347717285156\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|▎                                       | 82/10000 [00:02<04:37, 35.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -31.916807174682617\n",
      "Loss: -24.967618942260742\n",
      "Loss: -25.89249038696289\n",
      "Loss: -19.50304412841797\n",
      "Loss: -27.76620864868164\n",
      "Loss: -20.063743591308594\n",
      "Loss: -25.910938262939453\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|▎                                       | 90/10000 [00:02<04:54, 33.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -29.7688045501709\n",
      "Loss: -25.59218406677246\n",
      "Loss: -35.55389404296875\n",
      "Loss: -30.50889778137207\n",
      "Loss: -29.313236236572266\n",
      "Loss: -32.26536178588867\n",
      "Loss: -31.83681869506836\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|▍                                       | 98/10000 [00:03<04:49, 34.21it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -31.024293899536133\n",
      "Loss: -32.832847595214844\n",
      "Loss: -32.856239318847656\n",
      "Loss: -34.754730224609375\n",
      "Loss: -36.772178649902344\n",
      "Loss: -31.82395362854004\n",
      "Loss: -34.00893020629883\n",
      "Loss: -33.75880813598633\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|▍                                      | 106/10000 [00:03<04:45, 34.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -36.08045959472656\n",
      "Loss: -35.57659912109375\n",
      "Loss: -31.533248901367188\n",
      "Loss: -37.95287322998047\n",
      "Loss: -40.47740936279297\n",
      "Loss: -39.97392654418945\n",
      "Loss: -34.053199768066406\n",
      "Loss: 11.926435470581055\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|▍                                      | 114/10000 [00:03<04:44, 34.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -6.664104461669922\n",
      "Loss: 10.103384017944336\n",
      "Loss: 17.63614845275879\n",
      "Loss: 25.101455688476562\n",
      "Loss: 30.13666534423828\n",
      "Loss: 30.692346572875977\n",
      "Loss: 29.733718872070312\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  1%|▍                                      | 118/10000 [00:03<04:49, 34.17it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: 25.88336181640625\n",
      "Loss: 18.970417022705078\n",
      "Loss: 9.190962791442871\n",
      "Loss: -0.05779796838760376\n",
      "Loss: 0.9450564384460449\n",
      "Loss: -6.989668369293213\n",
      "Loss: -17.105724334716797\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|▍                                      | 126/10000 [00:03<04:53, 33.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -21.354934692382812\n",
      "Loss: -23.201385498046875\n",
      "Loss: -24.01919174194336\n",
      "Loss: -26.011899948120117\n",
      "Loss: -28.41409683227539\n",
      "Loss: -27.420259475708008\n",
      "Loss: -35.79963684082031\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|▌                                      | 134/10000 [00:04<04:41, 35.00it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -33.608924865722656\n",
      "Loss: -34.34648895263672\n",
      "Loss: -32.228214263916016\n",
      "Loss: -32.78790283203125\n",
      "Loss: -32.563880920410156\n",
      "Loss: -32.05015563964844\n",
      "Loss: -35.037208557128906\n",
      "Loss: -33.10234069824219\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|▌                                      | 142/10000 [00:04<04:28, 36.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -36.555885314941406\n",
      "Loss: -36.22914505004883\n",
      "Loss: -37.401222229003906\n",
      "Loss: -34.72804260253906\n",
      "Loss: -37.5246696472168\n",
      "Loss: -40.48721694946289\n",
      "Loss: -38.92938232421875\n",
      "Loss: -41.0953369140625\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▌                                      | 150/10000 [00:04<04:25, 37.15it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -38.337406158447266\n",
      "Loss: -41.44116973876953\n",
      "Loss: -38.76331329345703\n",
      "Loss: -41.211463928222656\n",
      "Loss: -38.3660888671875\n",
      "Loss: -41.757171630859375\n",
      "Loss: -44.56021499633789\n",
      "Loss: -43.49707794189453\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▌                                      | 158/10000 [00:04<04:35, 35.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -40.40340042114258\n",
      "Loss: -36.898590087890625\n",
      "Loss: -27.91596794128418\n",
      "Loss: -35.83551788330078\n",
      "Loss: -32.14555740356445\n",
      "Loss: -33.926849365234375\n",
      "Loss: -37.50649642944336\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▋                                      | 166/10000 [00:05<04:25, 36.97it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -37.094512939453125\n",
      "Loss: -39.729652404785156\n",
      "Loss: -38.202392578125\n",
      "Loss: -34.0936164855957\n",
      "Loss: -37.8663330078125\n",
      "Loss: -37.36061096191406\n",
      "Loss: -36.434513092041016\n",
      "Loss: -40.06709671020508\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▋                                      | 175/10000 [00:05<04:15, 38.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -45.84842300415039\n",
      "Loss: -43.73977279663086\n",
      "Loss: -42.70974349975586\n",
      "Loss: -33.25492477416992\n",
      "Loss: -35.9406852722168\n",
      "Loss: -35.5322151184082\n",
      "Loss: -34.84415817260742\n",
      "Loss: -38.971641540527344\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▋                                      | 183/10000 [00:05<04:20, 37.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -39.82047653198242\n",
      "Loss: -42.57708740234375\n",
      "Loss: -40.97068405151367\n",
      "Loss: -41.43019104003906\n",
      "Loss: -42.62173080444336\n",
      "Loss: -43.01407241821289\n",
      "Loss: -41.52372741699219\n",
      "Loss: -42.064422607421875\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  2%|▋                                      | 187/10000 [00:05<04:40, 35.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -44.90473556518555\n",
      "Loss: -45.65032196044922\n",
      "Loss: -44.404579162597656\n",
      "Loss: -45.81272888183594\n",
      "Loss: -49.848876953125\n",
      "Loss: -44.92730712890625\n",
      "Loss: -46.14917755126953\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▊                                      | 195/10000 [00:05<04:46, 34.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -47.05402374267578\n",
      "Loss: -46.46839141845703\n",
      "Loss: -44.8110237121582\n",
      "Loss: -48.650543212890625\n",
      "Loss: -46.80640411376953\n",
      "Loss: -47.33873748779297\n",
      "Loss: -45.03166198730469\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▊                                      | 203/10000 [00:06<04:29, 36.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -48.8537712097168\n",
      "Loss: -46.44843292236328\n",
      "Loss: -48.991119384765625\n",
      "Loss: -50.67597961425781\n",
      "Loss: -48.991859436035156\n",
      "Loss: -49.81085205078125\n",
      "Loss: -50.252769470214844\n",
      "Loss: -50.611000061035156\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▊                                      | 212/10000 [00:06<04:17, 38.05it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -51.136695861816406\n",
      "Loss: -53.329593658447266\n",
      "Loss: -50.94403839111328\n",
      "Loss: -51.06364059448242\n",
      "Loss: -49.05675506591797\n",
      "Loss: -48.385704040527344\n",
      "Loss: -50.44016647338867\n",
      "Loss: -51.033992767333984\n",
      "Loss: -49.9166259765625\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▊                                      | 222/10000 [00:06<04:08, 39.41it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -46.725730895996094\n",
      "Loss: -48.531883239746094\n",
      "Loss: -44.333457946777344\n",
      "Loss: -50.871482849121094\n",
      "Loss: -47.94940948486328\n",
      "Loss: -53.078269958496094\n",
      "Loss: -52.165855407714844\n",
      "Loss: -50.82698059082031\n",
      "Loss: -48.054046630859375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▉                                      | 230/10000 [00:06<04:10, 38.96it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -29.179685592651367\n",
      "Loss: -33.7745361328125\n",
      "Loss: -31.419382095336914\n",
      "Loss: -31.753494262695312\n",
      "Loss: -30.87979507446289\n",
      "Loss: -35.220149993896484\n",
      "Loss: -40.561126708984375\n",
      "Loss: -46.8463020324707\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▉                                      | 238/10000 [00:06<04:21, 37.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -40.147544860839844\n",
      "Loss: -46.04754638671875\n",
      "Loss: -45.41069793701172\n",
      "Loss: -47.25298309326172\n",
      "Loss: -44.47052001953125\n",
      "Loss: -44.35559844970703\n",
      "Loss: -45.57342529296875\n",
      "Loss: -47.1388053894043\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▉                                      | 246/10000 [00:07<04:16, 38.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -48.19435119628906\n",
      "Loss: -50.17945861816406\n",
      "Loss: -48.766868591308594\n",
      "Loss: -49.7215690612793\n",
      "Loss: -50.369903564453125\n",
      "Loss: -51.401580810546875\n",
      "Loss: -49.097023010253906\n",
      "Loss: -52.39006423950195\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|▉                                      | 255/10000 [00:07<04:11, 38.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -51.655696868896484\n",
      "Loss: -50.18073654174805\n",
      "Loss: -53.302921295166016\n",
      "Loss: -51.8776741027832\n",
      "Loss: -51.704322814941406\n",
      "Loss: -56.180267333984375\n",
      "Loss: -54.06134033203125\n",
      "Loss: -52.77210998535156\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|█                                      | 263/10000 [00:07<04:21, 37.28it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -54.56452560424805\n",
      "Loss: -52.40976333618164\n",
      "Loss: -52.38718795776367\n",
      "Loss: -42.018951416015625\n",
      "Loss: -50.64902877807617\n",
      "Loss: -46.64581298828125\n",
      "Loss: -50.63819122314453\n",
      "Loss: -51.85246658325195\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|█                                      | 271/10000 [00:07<04:19, 37.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -52.364715576171875\n",
      "Loss: -53.647727966308594\n",
      "Loss: -51.0311279296875\n",
      "Loss: -52.34389114379883\n",
      "Loss: -49.03104782104492\n",
      "Loss: -48.76624298095703\n",
      "Loss: -51.18376541137695\n",
      "Loss: -49.737586975097656\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|█                                      | 279/10000 [00:08<04:22, 36.99it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -50.068878173828125\n",
      "Loss: -55.586219787597656\n",
      "Loss: -54.521759033203125\n",
      "Loss: -54.02113723754883\n",
      "Loss: -50.285404205322266\n",
      "Loss: -57.50719451904297\n",
      "Loss: -54.953792572021484\n",
      "Loss: -53.53101348876953\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|█                                      | 287/10000 [00:08<04:20, 37.28it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -55.483253479003906\n",
      "Loss: -55.804283142089844\n",
      "Loss: -58.08420181274414\n",
      "Loss: -54.36107635498047\n",
      "Loss: -57.53184509277344\n",
      "Loss: -53.72121810913086\n",
      "Loss: -56.88426971435547\n",
      "Loss: -57.159942626953125\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|█▏                                     | 295/10000 [00:08<04:12, 38.40it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -59.37815475463867\n",
      "Loss: -54.229515075683594\n",
      "Loss: -56.275020599365234\n",
      "Loss: -57.117645263671875\n",
      "Loss: -53.33919906616211\n",
      "Loss: -55.1595458984375\n",
      "Loss: -58.35413360595703\n",
      "Loss: -57.40397644042969\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|█▏                                     | 304/10000 [00:08<04:07, 39.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -56.19053268432617\n",
      "Loss: -55.95640563964844\n",
      "Loss: -60.160377502441406\n",
      "Loss: -57.58992004394531\n",
      "Loss: -61.68645095825195\n",
      "Loss: -57.310699462890625\n",
      "Loss: -56.873844146728516\n",
      "Loss: -56.52107238769531\n",
      "Loss: -57.480987548828125\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|█▏                                     | 312/10000 [00:08<04:10, 38.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -57.09376525878906\n",
      "Loss: -57.29015350341797\n",
      "Loss: -59.48230743408203\n",
      "Loss: -60.27489471435547\n",
      "Loss: -56.483924865722656\n",
      "Loss: -52.79831314086914\n",
      "Loss: -38.19000244140625\n",
      "Loss: -39.06718444824219\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|█▏                                     | 320/10000 [00:09<04:07, 39.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -35.820804595947266\n",
      "Loss: -36.85158157348633\n",
      "Loss: -36.88003158569336\n",
      "Loss: -43.46327590942383\n",
      "Loss: -43.52655792236328\n",
      "Loss: -47.106998443603516\n",
      "Loss: -54.05339813232422\n",
      "Loss: -54.871910095214844\n",
      "Loss: -56.89655303955078\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|█▎                                     | 328/10000 [00:09<04:07, 39.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -47.437889099121094\n",
      "Loss: -52.63081359863281\n",
      "Loss: -58.0450439453125\n",
      "Loss: -53.92291259765625\n",
      "Loss: -54.878150939941406\n",
      "Loss: -56.85052490234375\n",
      "Loss: -59.31090545654297\n",
      "Loss: -57.7916259765625\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|█▎                                     | 336/10000 [00:09<04:06, 39.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -55.51416778564453\n",
      "Loss: -59.66011047363281\n",
      "Loss: -59.22492218017578\n",
      "Loss: -58.31077194213867\n",
      "Loss: -59.222965240478516\n",
      "Loss: -59.52375030517578\n",
      "Loss: -63.942623138427734\n",
      "Loss: -56.54988098144531\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|█▎                                     | 344/10000 [00:09<04:10, 38.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -59.11913299560547\n",
      "Loss: -57.60503005981445\n",
      "Loss: -59.05704879760742\n",
      "Loss: -58.28623580932617\n",
      "Loss: -58.36595916748047\n",
      "Loss: -61.18556213378906\n",
      "Loss: -57.58726501464844\n",
      "Loss: -59.29127883911133\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|█▎                                     | 352/10000 [00:09<04:16, 37.58it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -60.80419921875\n",
      "Loss: -57.956146240234375\n",
      "Loss: -55.863739013671875\n",
      "Loss: -48.56462478637695\n",
      "Loss: -55.35740280151367\n",
      "Loss: -51.999488830566406\n",
      "Loss: -58.287391662597656\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|█▍                                     | 360/10000 [00:10<04:28, 35.88it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -56.012779235839844\n",
      "Loss: -61.30914306640625\n",
      "Loss: -59.209495544433594\n",
      "Loss: -57.227455139160156\n",
      "Loss: -59.67976379394531\n",
      "Loss: -55.999568939208984\n",
      "Loss: -54.59564971923828\n",
      "Loss: -59.568763732910156\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|█▍                                     | 368/10000 [00:10<04:18, 37.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -55.99684143066406\n",
      "Loss: -56.7393798828125\n",
      "Loss: -57.11064910888672\n",
      "Loss: -59.54229736328125\n",
      "Loss: -60.998329162597656\n",
      "Loss: -61.44230651855469\n",
      "Loss: -60.965145111083984\n",
      "Loss: -60.190120697021484\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|█▍                                     | 376/10000 [00:10<04:11, 38.29it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -61.27328109741211\n",
      "Loss: -63.57487869262695\n",
      "Loss: -63.17862319946289\n",
      "Loss: -61.62968826293945\n",
      "Loss: -60.50149154663086\n",
      "Loss: -60.637229919433594\n",
      "Loss: -61.58281707763672\n",
      "Loss: -58.08446502685547\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|█▍                                     | 384/10000 [00:10<04:15, 37.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -62.93783950805664\n",
      "Loss: -59.45417404174805\n",
      "Loss: -60.54492950439453\n",
      "Loss: -62.421363830566406\n",
      "Loss: -60.24625778198242\n",
      "Loss: -62.25651550292969\n",
      "Loss: -63.75128936767578\n",
      "Loss: -60.5103759765625\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|█▌                                     | 392/10000 [00:11<04:26, 36.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -60.541664123535156\n",
      "Loss: -59.23419189453125\n",
      "Loss: -55.840049743652344\n",
      "Loss: -54.87910461425781\n",
      "Loss: -62.00048828125\n",
      "Loss: -58.97480773925781\n",
      "Loss: -61.41978454589844\n",
      "Loss: -60.86193084716797\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|█▌                                     | 400/10000 [00:11<04:26, 36.04it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -59.221595764160156\n",
      "Loss: -61.847869873046875\n",
      "Loss: -58.0927848815918\n",
      "Loss: -60.8890266418457\n",
      "Loss: -60.68858337402344\n",
      "Loss: -52.67642593383789\n",
      "Loss: -59.86717224121094\n",
      "Loss: -63.35324478149414\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  4%|█▌                                     | 404/10000 [00:11<04:31, 35.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -60.235591888427734\n",
      "Loss: -62.753379821777344\n",
      "Loss: -60.1770133972168\n",
      "Loss: -63.06066131591797\n",
      "Loss: -61.22602844238281\n",
      "Loss: -61.974884033203125\n",
      "Loss: -61.44085693359375\n",
      "Loss: -62.15715026855469\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|█▌                                     | 412/10000 [00:11<04:37, 34.60it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -64.35387420654297\n",
      "Loss: -58.20386505126953\n",
      "Loss: -48.965309143066406\n",
      "Loss: -46.887062072753906\n",
      "Loss: -50.10513687133789\n",
      "Loss: -52.386573791503906\n",
      "Loss: -50.20967483520508\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|█▋                                     | 420/10000 [00:11<04:37, 34.58it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -50.37812805175781\n",
      "Loss: -58.208717346191406\n",
      "Loss: -58.654415130615234\n",
      "Loss: -56.301544189453125\n",
      "Loss: -61.038795471191406\n",
      "Loss: -61.850791931152344\n",
      "Loss: -62.48258590698242\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|█▋                                     | 428/10000 [00:12<04:34, 34.89it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -60.58407211303711\n",
      "Loss: -60.819217681884766\n",
      "Loss: -61.41960906982422\n",
      "Loss: -65.43515014648438\n",
      "Loss: -62.144813537597656\n",
      "Loss: -65.07119750976562\n",
      "Loss: -63.166343688964844\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|█▋                                     | 436/10000 [00:12<04:41, 34.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -64.14207458496094\n",
      "Loss: -61.65345764160156\n",
      "Loss: -64.14176177978516\n",
      "Loss: -64.62527465820312\n",
      "Loss: -67.84147644042969\n",
      "Loss: -61.76140594482422\n",
      "Loss: -64.64866638183594\n",
      "Loss: -62.325775146484375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|█▋                                     | 445/10000 [00:12<04:17, 37.14it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -64.51130676269531\n",
      "Loss: -63.296180725097656\n",
      "Loss: -64.00372314453125\n",
      "Loss: -68.06318664550781\n",
      "Loss: -63.42375183105469\n",
      "Loss: -65.16548919677734\n",
      "Loss: -49.317466735839844\n",
      "Loss: -38.11354064941406\n",
      "Loss: -48.84376525878906\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|█▊                                     | 454/10000 [00:12<04:07, 38.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -42.586822509765625\n",
      "Loss: -38.971397399902344\n",
      "Loss: -40.22137451171875\n",
      "Loss: -45.63146209716797\n",
      "Loss: -55.42884826660156\n",
      "Loss: -56.47003173828125\n",
      "Loss: -56.596466064453125\n",
      "Loss: -60.02581024169922\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|█▊                                     | 462/10000 [00:13<04:20, 36.55it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -57.14925765991211\n",
      "Loss: -57.75768280029297\n",
      "Loss: -57.25505828857422\n",
      "Loss: -59.78697967529297\n",
      "Loss: -60.161415100097656\n",
      "Loss: -62.7203369140625\n",
      "Loss: -59.49909210205078\n",
      "Loss: -61.81392288208008\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|█▊                                     | 470/10000 [00:13<04:18, 36.82it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -62.33844757080078\n",
      "Loss: -64.34060668945312\n",
      "Loss: -63.45174789428711\n",
      "Loss: -64.22872924804688\n",
      "Loss: -64.0615234375\n",
      "Loss: -63.20039749145508\n",
      "Loss: -65.75631713867188\n",
      "Loss: -62.53851318359375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  5%|█▊                                     | 474/10000 [00:13<04:31, 35.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -64.14844512939453\n",
      "Loss: -67.36192321777344\n",
      "Loss: -65.92691040039062\n",
      "Loss: -62.0549201965332\n",
      "Loss: -69.23869323730469\n",
      "Loss: -65.03593444824219\n",
      "Loss: -53.91619110107422\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|█▉                                     | 482/10000 [00:13<04:23, 36.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -42.939064025878906\n",
      "Loss: -55.60239028930664\n",
      "Loss: -48.342742919921875\n",
      "Loss: -47.40602493286133\n",
      "Loss: -50.94705581665039\n",
      "Loss: -58.52149200439453\n",
      "Loss: -60.02227020263672\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|█▉                                     | 490/10000 [00:13<04:47, 33.03it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -55.0830078125\n",
      "Loss: -61.278751373291016\n",
      "Loss: -64.55178833007812\n",
      "Loss: -57.60838317871094\n",
      "Loss: -60.877845764160156\n",
      "Loss: -61.936241149902344\n",
      "Loss: -61.7861442565918\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|█▉                                     | 499/10000 [00:14<04:25, 35.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -64.40766906738281\n",
      "Loss: -65.11514282226562\n",
      "Loss: -61.28396224975586\n",
      "Loss: -62.49808120727539\n",
      "Loss: -67.08512115478516\n",
      "Loss: -65.13383483886719\n",
      "Loss: -62.982295989990234\n",
      "Loss: -64.55462646484375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|█▉                                     | 507/10000 [00:14<04:12, 37.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -67.02059936523438\n",
      "Loss: -66.0541000366211\n",
      "Loss: -68.24034118652344\n",
      "Loss: -70.22966766357422\n",
      "Loss: -67.42308044433594\n",
      "Loss: -68.53370666503906\n",
      "Loss: -66.32374572753906\n",
      "Loss: -63.378822326660156\n",
      "Loss: -67.72154235839844\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|██                                     | 516/10000 [00:14<04:07, 38.40it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -68.936279296875\n",
      "Loss: -66.7448501586914\n",
      "Loss: -67.69541931152344\n",
      "Loss: -66.77662658691406\n",
      "Loss: -71.7842025756836\n",
      "Loss: -70.35150146484375\n",
      "Loss: -65.7406005859375\n",
      "Loss: -42.673492431640625\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|██                                     | 525/10000 [00:14<04:02, 39.03it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -30.5704345703125\n",
      "Loss: -28.280237197875977\n",
      "Loss: -22.186542510986328\n",
      "Loss: -17.139034271240234\n",
      "Loss: -15.255148887634277\n",
      "Loss: -18.17705535888672\n",
      "Loss: -26.709075927734375\n",
      "Loss: -38.068641662597656\n",
      "Loss: -47.61773681640625\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|██                                     | 533/10000 [00:14<04:12, 37.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -43.826141357421875\n",
      "Loss: -54.350379943847656\n",
      "Loss: -57.18891906738281\n",
      "Loss: -50.455726623535156\n",
      "Loss: -53.13236999511719\n",
      "Loss: -52.26474380493164\n",
      "Loss: -58.26313781738281\n",
      "Loss: -60.68341064453125\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|██                                     | 541/10000 [00:15<04:21, 36.14it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -59.4371452331543\n",
      "Loss: -61.60676574707031\n",
      "Loss: -60.75212097167969\n",
      "Loss: -63.45301055908203\n",
      "Loss: -62.1051139831543\n",
      "Loss: -61.145904541015625\n",
      "Loss: -58.313758850097656\n",
      "Loss: -63.239280700683594\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|██▏                                    | 549/10000 [00:15<04:25, 35.62it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -61.84173583984375\n",
      "Loss: -64.34811401367188\n",
      "Loss: -62.88929748535156\n",
      "Loss: -66.34454345703125\n",
      "Loss: -66.93181610107422\n",
      "Loss: -62.36741638183594\n",
      "Loss: -62.11685562133789\n",
      "Loss: -62.697731018066406\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  6%|██▏                                    | 553/10000 [00:15<04:34, 34.37it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -68.05841827392578\n",
      "Loss: -64.0075454711914\n",
      "Loss: -66.2413330078125\n",
      "Loss: -66.55397033691406\n",
      "Loss: -65.97813415527344\n",
      "Loss: -64.47344207763672\n",
      "Loss: -61.65254211425781\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|██▏                                    | 561/10000 [00:15<04:31, 34.80it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -60.58737564086914\n",
      "Loss: -65.61083984375\n",
      "Loss: -64.59416198730469\n",
      "Loss: -66.85804748535156\n",
      "Loss: -67.23124694824219\n",
      "Loss: -66.73263549804688\n",
      "Loss: -59.722164154052734\n",
      "Loss: -61.68318176269531\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|██▏                                    | 569/10000 [00:16<04:39, 33.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -67.13703918457031\n",
      "Loss: -64.47278594970703\n",
      "Loss: -65.36975860595703\n",
      "Loss: -66.03028869628906\n",
      "Loss: -67.63758850097656\n",
      "Loss: -65.99446105957031\n",
      "Loss: -65.89686584472656\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|██▎                                    | 577/10000 [00:16<04:32, 34.58it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -65.18382263183594\n",
      "Loss: -70.3978500366211\n",
      "Loss: -68.76115417480469\n",
      "Loss: -68.57545471191406\n",
      "Loss: -69.439453125\n",
      "Loss: -71.27009582519531\n",
      "Loss: -68.63192749023438\n",
      "Loss: -64.43358612060547\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|██▎                                    | 585/10000 [00:16<04:17, 36.54it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -52.28770446777344\n",
      "Loss: -58.649261474609375\n",
      "Loss: -57.50682067871094\n",
      "Loss: -61.22705841064453\n",
      "Loss: -56.40834045410156\n",
      "Loss: -61.951683044433594\n",
      "Loss: -68.28242492675781\n",
      "Loss: -63.003868103027344\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|██▎                                    | 593/10000 [00:16<04:21, 35.99it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -68.92697143554688\n",
      "Loss: -68.00302124023438\n",
      "Loss: -71.57920837402344\n",
      "Loss: -67.5340576171875\n",
      "Loss: -67.70071411132812\n",
      "Loss: -68.38967895507812\n",
      "Loss: -70.53252410888672\n",
      "Loss: -73.03893280029297\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|██▎                                    | 601/10000 [00:16<04:17, 36.53it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -69.21833801269531\n",
      "Loss: -72.29914855957031\n",
      "Loss: -71.96324920654297\n",
      "Loss: -72.71630096435547\n",
      "Loss: -71.38436126708984\n",
      "Loss: -70.0723876953125\n",
      "Loss: -73.03173828125\n",
      "Loss: -69.55973815917969\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|██▍                                    | 609/10000 [00:17<04:20, 36.12it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -71.13148498535156\n",
      "Loss: -72.93273162841797\n",
      "Loss: -70.32929992675781\n",
      "Loss: -72.98783874511719\n",
      "Loss: -72.3287582397461\n",
      "Loss: -59.273929595947266\n",
      "Loss: -46.50648880004883\n",
      "Loss: -57.04352951049805\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|██▍                                    | 617/10000 [00:17<04:13, 37.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -44.98094940185547\n",
      "Loss: -44.37731170654297\n",
      "Loss: -47.68260192871094\n",
      "Loss: -52.082611083984375\n",
      "Loss: -57.57320785522461\n",
      "Loss: -56.66057586669922\n",
      "Loss: -63.86106491088867\n",
      "Loss: -66.73661804199219\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|██▍                                    | 625/10000 [00:17<04:15, 36.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -63.85273742675781\n",
      "Loss: -64.94615936279297\n",
      "Loss: -66.39971160888672\n",
      "Loss: -65.9103012084961\n",
      "Loss: -65.96005249023438\n",
      "Loss: -64.95013427734375\n",
      "Loss: -67.50604248046875\n",
      "Loss: -69.98128509521484\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|██▍                                    | 633/10000 [00:17<04:14, 36.82it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -67.30101013183594\n",
      "Loss: -69.44662475585938\n",
      "Loss: -67.95005798339844\n",
      "Loss: -67.46905517578125\n",
      "Loss: -72.01776885986328\n",
      "Loss: -68.72196960449219\n",
      "Loss: -66.8889389038086\n",
      "Loss: -71.12184143066406\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|██▍                                    | 641/10000 [00:17<04:08, 37.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -70.92884826660156\n",
      "Loss: -70.40357208251953\n",
      "Loss: -73.33590698242188\n",
      "Loss: -74.78628540039062\n",
      "Loss: -71.72018432617188\n",
      "Loss: -66.66349792480469\n",
      "Loss: -68.29259490966797\n",
      "Loss: -72.47269439697266\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|██▌                                    | 649/10000 [00:18<04:04, 38.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -69.79513549804688\n",
      "Loss: -70.69731903076172\n",
      "Loss: -70.27366638183594\n",
      "Loss: -72.13507080078125\n",
      "Loss: -72.95592498779297\n",
      "Loss: -70.05115509033203\n",
      "Loss: -62.44244384765625\n",
      "Loss: -70.75888061523438\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  7%|██▌                                    | 658/10000 [00:18<03:58, 39.11it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -68.699462890625\n",
      "Loss: -69.48241424560547\n",
      "Loss: -68.8961410522461\n",
      "Loss: -72.66226196289062\n",
      "Loss: -73.26515197753906\n",
      "Loss: -70.43888854980469\n",
      "Loss: -72.20912170410156\n",
      "Loss: -73.26339721679688\n",
      "Loss: -70.09518432617188\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  7%|██▌                                    | 668/10000 [00:18<03:54, 39.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -62.43928527832031\n",
      "Loss: -66.0365982055664\n",
      "Loss: -67.63404846191406\n",
      "Loss: -68.11599731445312\n",
      "Loss: -68.85075378417969\n",
      "Loss: -72.1707534790039\n",
      "Loss: -68.78335571289062\n",
      "Loss: -72.7143783569336\n",
      "Loss: -71.79023742675781\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  7%|██▋                                    | 676/10000 [00:18<04:10, 37.21it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -73.43539428710938\n",
      "Loss: -70.72036743164062\n",
      "Loss: -75.89004516601562\n",
      "Loss: -72.6031494140625\n",
      "Loss: -75.12905883789062\n",
      "Loss: -76.59942626953125\n",
      "Loss: -72.71942138671875\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  7%|██▋                                    | 680/10000 [00:19<04:24, 35.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -76.39042663574219\n",
      "Loss: -76.14305877685547\n",
      "Loss: -76.14220428466797\n",
      "Loss: -71.80702209472656\n",
      "Loss: -69.74981689453125\n",
      "Loss: -58.557228088378906\n",
      "Loss: -69.33735656738281\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  7%|██▋                                    | 688/10000 [00:19<04:18, 35.97it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -60.24760818481445\n",
      "Loss: -63.70033264160156\n",
      "Loss: -67.13404083251953\n",
      "Loss: -67.52510070800781\n",
      "Loss: -70.83525085449219\n",
      "Loss: -70.26990509033203\n",
      "Loss: -71.34839630126953\n",
      "Loss: -68.29735565185547\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  7%|██▋                                    | 696/10000 [00:19<04:17, 36.15it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -69.11492919921875\n",
      "Loss: -72.03700256347656\n",
      "Loss: -72.34012603759766\n",
      "Loss: -70.92582702636719\n",
      "Loss: -68.48674011230469\n",
      "Loss: -72.77670288085938\n",
      "Loss: -72.38614654541016\n",
      "Loss: -69.95590209960938\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  7%|██▋                                    | 704/10000 [00:19<04:09, 37.19it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -78.89451599121094\n",
      "Loss: -73.18494415283203\n",
      "Loss: -72.73493957519531\n",
      "Loss: -75.41300964355469\n",
      "Loss: -71.76799011230469\n",
      "Loss: -75.29386138916016\n",
      "Loss: -76.30223083496094\n",
      "Loss: -77.16980743408203\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  7%|██▊                                    | 713/10000 [00:19<04:00, 38.61it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -74.19721221923828\n",
      "Loss: -75.46971893310547\n",
      "Loss: -75.88287353515625\n",
      "Loss: -75.10006713867188\n",
      "Loss: -77.03361511230469\n",
      "Loss: -75.13041687011719\n",
      "Loss: -76.34730529785156\n",
      "Loss: -73.55489349365234\n",
      "Loss: -78.99774932861328\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  7%|██▊                                    | 721/10000 [00:20<03:57, 39.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -72.75740051269531\n",
      "Loss: -73.21927642822266\n",
      "Loss: -70.11009979248047\n",
      "Loss: -68.9644546508789\n",
      "Loss: -70.91805267333984\n",
      "Loss: -68.38260650634766\n",
      "Loss: -73.9644775390625\n",
      "Loss: -73.23646545410156\n",
      "Loss: -73.83100891113281\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  7%|██▊                                    | 730/10000 [00:20<03:54, 39.47it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -71.53514862060547\n",
      "Loss: -73.04397583007812\n",
      "Loss: -79.38135528564453\n",
      "Loss: -75.61334228515625\n",
      "Loss: -76.73883056640625\n",
      "Loss: -75.22409057617188\n",
      "Loss: -75.54130554199219\n",
      "Loss: -76.18544006347656\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  7%|██▉                                    | 738/10000 [00:20<03:55, 39.38it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -77.77752685546875\n",
      "Loss: -76.44548034667969\n",
      "Loss: -76.41276550292969\n",
      "Loss: -74.23541259765625\n",
      "Loss: -73.64712524414062\n",
      "Loss: -76.6842041015625\n",
      "Loss: -73.03519439697266\n",
      "Loss: -67.40086364746094\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  7%|██▉                                    | 746/10000 [00:20<04:08, 37.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -61.81608963012695\n",
      "Loss: -72.80317687988281\n",
      "Loss: -66.81201171875\n",
      "Loss: -70.7114028930664\n",
      "Loss: -70.46891784667969\n",
      "Loss: -71.00820922851562\n",
      "Loss: -73.18831634521484\n",
      "Loss: -73.87739562988281\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|██▉                                    | 754/10000 [00:20<04:29, 34.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -74.50398254394531\n",
      "Loss: -71.35240173339844\n",
      "Loss: -77.05207824707031\n",
      "Loss: -76.70440673828125\n",
      "Loss: -73.24472045898438\n",
      "Loss: -79.22164916992188\n",
      "Loss: -76.13180541992188\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|██▉                                    | 762/10000 [00:21<04:20, 35.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -74.93063354492188\n",
      "Loss: -76.58541870117188\n",
      "Loss: -77.59477996826172\n",
      "Loss: -75.01832580566406\n",
      "Loss: -76.08973693847656\n",
      "Loss: -78.38362884521484\n",
      "Loss: -78.98590850830078\n",
      "Loss: -80.01945495605469\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|███                                    | 770/10000 [00:21<04:31, 34.05it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -76.48335266113281\n",
      "Loss: -78.95361328125\n",
      "Loss: -78.11186218261719\n",
      "Loss: -75.90978240966797\n",
      "Loss: -60.564483642578125\n",
      "Loss: -42.40751647949219\n",
      "Loss: -62.807151794433594\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|███                                    | 778/10000 [00:21<04:46, 32.15it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -50.79949188232422\n",
      "Loss: -45.184532165527344\n",
      "Loss: -48.27311706542969\n",
      "Loss: -53.68470764160156\n",
      "Loss: -60.165748596191406\n",
      "Loss: -65.34095764160156\n",
      "Loss: -66.13836669921875\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|███                                    | 786/10000 [00:21<04:31, 33.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -69.01897430419922\n",
      "Loss: -71.41688537597656\n",
      "Loss: -70.25364685058594\n",
      "Loss: -70.75645446777344\n",
      "Loss: -72.74649810791016\n",
      "Loss: -70.62032318115234\n",
      "Loss: -70.1606674194336\n",
      "Loss: -73.82234954833984\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|███                                    | 794/10000 [00:22<04:25, 34.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -73.71515655517578\n",
      "Loss: -73.3988037109375\n",
      "Loss: -74.51151275634766\n",
      "Loss: -72.03755187988281\n",
      "Loss: -73.86386108398438\n",
      "Loss: -75.03668212890625\n",
      "Loss: -75.05567932128906\n",
      "Loss: -74.53230285644531\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|███▏                                   | 802/10000 [00:22<04:15, 35.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -74.2371826171875\n",
      "Loss: -75.97653198242188\n",
      "Loss: -74.80064392089844\n",
      "Loss: -75.2420425415039\n",
      "Loss: -77.6231460571289\n",
      "Loss: -79.56057739257812\n",
      "Loss: -77.63394927978516\n",
      "Loss: -79.15635681152344\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|███▏                                   | 810/10000 [00:22<04:08, 37.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -75.1712875366211\n",
      "Loss: -72.31978607177734\n",
      "Loss: -66.67570495605469\n",
      "Loss: -68.12987518310547\n",
      "Loss: -70.56867218017578\n",
      "Loss: -69.6461181640625\n",
      "Loss: -71.26256561279297\n",
      "Loss: -73.28884887695312\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|███▏                                   | 818/10000 [00:22<04:02, 37.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -74.97464752197266\n",
      "Loss: -75.08708953857422\n",
      "Loss: -69.79027557373047\n",
      "Loss: -56.38963317871094\n",
      "Loss: -53.391571044921875\n",
      "Loss: -56.4820556640625\n",
      "Loss: -56.75360107421875\n",
      "Loss: -54.08219528198242\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  8%|███▏                                   | 822/10000 [00:22<04:20, 35.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -53.476768493652344\n",
      "Loss: -55.46284484863281\n",
      "Loss: -63.813148498535156\n",
      "Loss: -70.54255676269531\n",
      "Loss: -68.36724853515625\n",
      "Loss: -69.35916137695312\n",
      "Loss: -71.66442108154297\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|███▏                                   | 830/10000 [00:23<04:10, 36.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -72.02407836914062\n",
      "Loss: -69.75\n",
      "Loss: -71.50534057617188\n",
      "Loss: -70.89051055908203\n",
      "Loss: -73.16644287109375\n",
      "Loss: -72.97989654541016\n",
      "Loss: -75.11932373046875\n",
      "Loss: -73.94650268554688\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|███▎                                   | 838/10000 [00:23<04:03, 37.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -73.91642761230469\n",
      "Loss: -74.58847045898438\n",
      "Loss: -77.06082153320312\n",
      "Loss: -75.13755798339844\n",
      "Loss: -73.8177490234375\n",
      "Loss: -75.43342590332031\n",
      "Loss: -78.0821762084961\n",
      "Loss: -77.33827209472656\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|███▎                                   | 846/10000 [00:23<04:07, 36.93it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -76.02836608886719\n",
      "Loss: -75.5401611328125\n",
      "Loss: -76.88057708740234\n",
      "Loss: -78.35406494140625\n",
      "Loss: -73.86251831054688\n",
      "Loss: -79.69178771972656\n",
      "Loss: -76.29283142089844\n",
      "Loss: -79.90165710449219\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|███▎                                   | 854/10000 [00:23<04:13, 36.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -68.78070068359375\n",
      "Loss: -67.54544067382812\n",
      "Loss: -67.85508728027344\n",
      "Loss: -74.43399047851562\n",
      "Loss: -71.31153869628906\n",
      "Loss: -73.44099426269531\n",
      "Loss: -75.63755798339844\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|███▎                                   | 862/10000 [00:24<04:21, 34.93it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -79.754638671875\n",
      "Loss: -74.0905990600586\n",
      "Loss: -68.10285949707031\n",
      "Loss: -64.91162109375\n",
      "Loss: -73.44244384765625\n",
      "Loss: -68.65899658203125\n",
      "Loss: -71.64430236816406\n",
      "Loss: -70.94416809082031\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|███▍                                   | 870/10000 [00:24<04:15, 35.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -74.69062805175781\n",
      "Loss: -75.57678985595703\n",
      "Loss: -75.1390380859375\n",
      "Loss: -76.46541595458984\n",
      "Loss: -74.62989807128906\n",
      "Loss: -74.90769958496094\n",
      "Loss: -78.45759582519531\n",
      "Loss: -75.57489776611328\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|███▍                                   | 878/10000 [00:24<04:08, 36.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -75.06805419921875\n",
      "Loss: -77.70257568359375\n",
      "Loss: -75.36125183105469\n",
      "Loss: -77.80056762695312\n",
      "Loss: -76.05347442626953\n",
      "Loss: -75.98367309570312\n",
      "Loss: -77.12989807128906\n",
      "Loss: -80.2895736694336\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|███▍                                   | 886/10000 [00:24<04:05, 37.20it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -80.240478515625\n",
      "Loss: -79.24311065673828\n",
      "Loss: -75.29512786865234\n",
      "Loss: -74.20498657226562\n",
      "Loss: -73.2848892211914\n",
      "Loss: -79.40084838867188\n",
      "Loss: -74.73834991455078\n",
      "Loss: -77.145751953125\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|███▍                                   | 894/10000 [00:24<04:13, 35.88it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -76.80597686767578\n",
      "Loss: -78.80872344970703\n",
      "Loss: -80.93395233154297\n",
      "Loss: -76.2099609375\n",
      "Loss: -71.7877197265625\n",
      "Loss: -74.55436706542969\n",
      "Loss: -74.45014953613281\n",
      "Loss: -75.48405456542969\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|███▌                                   | 902/10000 [00:25<04:03, 37.43it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -78.1766357421875\n",
      "Loss: -73.87792205810547\n",
      "Loss: -77.05455780029297\n",
      "Loss: -79.26988983154297\n",
      "Loss: -78.8818359375\n",
      "Loss: -76.25900268554688\n",
      "Loss: -73.22822570800781\n",
      "Loss: -58.76342010498047\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|███▌                                   | 910/10000 [00:25<03:58, 38.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -69.49392700195312\n",
      "Loss: -57.01399230957031\n",
      "Loss: -60.98164749145508\n",
      "Loss: -63.997703552246094\n",
      "Loss: -65.0235366821289\n",
      "Loss: -65.21034240722656\n",
      "Loss: -73.03785705566406\n",
      "Loss: -74.17634582519531\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|███▌                                   | 918/10000 [00:25<04:13, 35.82it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -75.07319641113281\n",
      "Loss: -74.31932067871094\n",
      "Loss: -76.95812225341797\n",
      "Loss: -76.13311004638672\n",
      "Loss: -77.71168518066406\n",
      "Loss: -75.69058227539062\n",
      "Loss: -73.83268737792969\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|███▌                                   | 926/10000 [00:25<04:19, 34.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -74.85359191894531\n",
      "Loss: -79.66370391845703\n",
      "Loss: -78.62165832519531\n",
      "Loss: -77.7550277709961\n",
      "Loss: -79.91098022460938\n",
      "Loss: -81.03968811035156\n",
      "Loss: -82.84066772460938\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|███▋                                   | 934/10000 [00:26<04:19, 34.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -80.20346069335938\n",
      "Loss: -84.00814819335938\n",
      "Loss: -80.15432739257812\n",
      "Loss: -79.08332824707031\n",
      "Loss: -83.04876708984375\n",
      "Loss: -82.89419555664062\n",
      "Loss: -83.80545043945312\n",
      "Loss: -83.6302490234375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|███▋                                   | 942/10000 [00:26<04:20, 34.81it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -79.20059967041016\n",
      "Loss: -75.8846435546875\n",
      "Loss: -54.90692138671875\n",
      "Loss: -65.38922119140625\n",
      "Loss: -62.24332809448242\n",
      "Loss: -65.93745422363281\n",
      "Loss: -61.80715560913086\n",
      "Loss: -63.19647216796875\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|███▋                                   | 950/10000 [00:26<04:12, 35.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -68.71504211425781\n",
      "Loss: -77.88104248046875\n",
      "Loss: -74.71711730957031\n",
      "Loss: -74.54264831542969\n",
      "Loss: -73.7471923828125\n",
      "Loss: -74.49992370605469\n",
      "Loss: -75.4488754272461\n",
      "Loss: -78.20089721679688\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 10%|███▋                                   | 954/10000 [00:26<04:22, 34.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -80.35457611083984\n",
      "Loss: -78.89163208007812\n",
      "Loss: -75.76519775390625\n",
      "Loss: -79.17137145996094\n",
      "Loss: -79.66551208496094\n",
      "Loss: -76.86166381835938\n",
      "Loss: -78.75994873046875\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|███▊                                   | 962/10000 [00:26<04:22, 34.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -79.78610229492188\n",
      "Loss: -80.4828109741211\n",
      "Loss: -79.63223266601562\n",
      "Loss: -76.61149597167969\n",
      "Loss: -76.38606262207031\n",
      "Loss: -78.79125213623047\n",
      "Loss: -78.45536804199219\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|███▊                                   | 970/10000 [00:27<04:27, 33.80it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -80.6185073852539\n",
      "Loss: -82.10383605957031\n",
      "Loss: -81.11442565917969\n",
      "Loss: -77.32180786132812\n",
      "Loss: -49.82054138183594\n",
      "Loss: -48.480506896972656\n",
      "Loss: -58.51079559326172\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|███▊                                   | 978/10000 [00:27<04:19, 34.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -55.61158752441406\n",
      "Loss: -49.818084716796875\n",
      "Loss: -49.31468200683594\n",
      "Loss: -53.09385681152344\n",
      "Loss: -59.84300231933594\n",
      "Loss: -68.6146240234375\n",
      "Loss: -73.38258361816406\n",
      "Loss: -68.28265380859375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|███▊                                   | 986/10000 [00:27<04:08, 36.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -71.37488555908203\n",
      "Loss: -71.62783813476562\n",
      "Loss: -71.17249298095703\n",
      "Loss: -72.10951232910156\n",
      "Loss: -72.040283203125\n",
      "Loss: -70.21351623535156\n",
      "Loss: -75.4241714477539\n",
      "Loss: -74.47103881835938\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|███▉                                   | 994/10000 [00:27<04:17, 35.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -74.45916748046875\n",
      "Loss: -74.51290893554688\n",
      "Loss: -76.19532775878906\n",
      "Loss: -74.65852355957031\n",
      "Loss: -79.49040222167969\n",
      "Loss: -77.825439453125\n",
      "Loss: -76.09296417236328\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 10%|███▉                                   | 998/10000 [00:27<04:21, 34.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -79.16769409179688\n",
      "Loss: -80.70620727539062\n",
      "Loss: -79.54842376708984\n",
      "Loss: -80.96101379394531\n",
      "Loss: -80.84236907958984\n",
      "Loss: -79.9310302734375\n",
      "Loss: -77.99794006347656\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|███▊                                  | 1006/10000 [00:28<04:13, 35.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -75.54542541503906\n",
      "Loss: -75.99751281738281\n",
      "Loss: -74.38114929199219\n",
      "Loss: -74.88739013671875\n",
      "Loss: -79.19302368164062\n",
      "Loss: -75.80644226074219\n",
      "Loss: -78.94412231445312\n",
      "Loss: -79.31721496582031\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|███▊                                  | 1014/10000 [00:28<04:04, 36.80it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -85.5957260131836\n",
      "Loss: -79.88008117675781\n",
      "Loss: -67.89392852783203\n",
      "Loss: -43.31171798706055\n",
      "Loss: -58.21632385253906\n",
      "Loss: -41.34562301635742\n",
      "Loss: -36.40240478515625\n",
      "Loss: -34.44599151611328\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|███▉                                  | 1022/10000 [00:28<03:54, 38.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -35.634544372558594\n",
      "Loss: -40.86315155029297\n",
      "Loss: -46.693607330322266\n",
      "Loss: -56.250648498535156\n",
      "Loss: -57.8972282409668\n",
      "Loss: -66.69190216064453\n",
      "Loss: -69.80729675292969\n",
      "Loss: -70.36341857910156\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|███▉                                  | 1030/10000 [00:28<03:58, 37.62it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -70.75989532470703\n",
      "Loss: -67.27222442626953\n",
      "Loss: -75.53501892089844\n",
      "Loss: -72.63082885742188\n",
      "Loss: -73.57089233398438\n",
      "Loss: -74.19352722167969\n",
      "Loss: -72.33369445800781\n",
      "Loss: -77.06582641601562\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|███▉                                  | 1038/10000 [00:28<03:59, 37.47it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -75.53720092773438\n",
      "Loss: -73.1534423828125\n",
      "Loss: -76.74368286132812\n",
      "Loss: -77.47843170166016\n",
      "Loss: -78.18632507324219\n",
      "Loss: -76.79932403564453\n",
      "Loss: -76.78828430175781\n",
      "Loss: -79.78639221191406\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|███▉                                  | 1046/10000 [00:29<04:04, 36.66it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -78.8742446899414\n",
      "Loss: -76.1417007446289\n",
      "Loss: -80.28028869628906\n",
      "Loss: -77.40792846679688\n",
      "Loss: -79.91586303710938\n",
      "Loss: -81.21998596191406\n",
      "Loss: -78.03302764892578\n",
      "Loss: -79.16759490966797\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 11%|████                                  | 1054/10000 [00:29<04:11, 35.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -78.78585815429688\n",
      "Loss: -80.26437377929688\n",
      "Loss: -80.97645568847656\n",
      "Loss: -77.50224304199219\n",
      "Loss: -79.71051025390625\n",
      "Loss: -83.69795227050781\n",
      "Loss: -77.64727783203125\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 11%|████                                  | 1062/10000 [00:29<04:14, 35.05it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -67.9501953125\n",
      "Loss: -59.8438835144043\n",
      "Loss: -73.48170471191406\n",
      "Loss: -68.37553405761719\n",
      "Loss: -65.86637115478516\n",
      "Loss: -71.45533752441406\n",
      "Loss: -75.51790618896484\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 11%|████                                  | 1070/10000 [00:29<04:21, 34.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -72.86408233642578\n",
      "Loss: -80.90701293945312\n",
      "Loss: -76.55933380126953\n",
      "Loss: -77.6939697265625\n",
      "Loss: -71.2349853515625\n",
      "Loss: -56.961517333984375\n",
      "Loss: -67.451171875\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 11%|████                                  | 1078/10000 [00:30<04:17, 34.58it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -59.873619079589844\n",
      "Loss: -60.35808563232422\n",
      "Loss: -61.33613204956055\n",
      "Loss: -61.686309814453125\n",
      "Loss: -60.98080062866211\n",
      "Loss: -66.49037170410156\n",
      "Loss: -71.36062622070312\n",
      "Loss: -76.5035629272461\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 11%|████▏                                 | 1086/10000 [00:30<04:09, 35.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -72.96231079101562\n",
      "Loss: -76.60948944091797\n",
      "Loss: -77.50772094726562\n",
      "Loss: -74.73515319824219\n",
      "Loss: -76.00503540039062\n",
      "Loss: -79.16537475585938\n",
      "Loss: -77.69009399414062\n",
      "Loss: -74.20768737792969\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 11%|████▏                                 | 1094/10000 [00:30<04:05, 36.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -76.38548278808594\n",
      "Loss: -78.17169189453125\n",
      "Loss: -79.47438049316406\n",
      "Loss: -77.0032958984375\n",
      "Loss: -76.91217041015625\n",
      "Loss: -77.44544982910156\n",
      "Loss: -83.28887939453125\n",
      "Loss: -81.04512786865234\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 11%|████▏                                 | 1102/10000 [00:30<04:03, 36.55it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -78.78276824951172\n",
      "Loss: -79.5465087890625\n",
      "Loss: -79.01031494140625\n",
      "Loss: -80.37591552734375\n",
      "Loss: -83.17567443847656\n",
      "Loss: -79.80061340332031\n",
      "Loss: -77.65109252929688\n",
      "Loss: -83.6415786743164\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 11%|████▏                                 | 1106/10000 [00:30<04:06, 36.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -79.9715576171875\n",
      "Loss: -83.25579833984375\n",
      "Loss: -79.95858001708984\n",
      "Loss: -81.09168243408203\n",
      "Loss: -81.72978210449219\n",
      "Loss: -82.906982421875\n",
      "Loss: -82.72865295410156\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 11%|████▏                                 | 1114/10000 [00:31<04:06, 36.05it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -83.41268920898438\n",
      "Loss: -81.76266479492188\n",
      "Loss: -81.18820190429688\n",
      "Loss: -85.82003784179688\n",
      "Loss: -82.48170471191406\n",
      "Loss: -82.00003051757812\n",
      "Loss: -83.31217193603516\n",
      "Loss: -84.38363647460938\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 11%|████▎                                 | 1122/10000 [00:31<04:01, 36.81it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -85.54855346679688\n",
      "Loss: -82.83976745605469\n",
      "Loss: -83.74918365478516\n",
      "Loss: -73.9916000366211\n",
      "Loss: -62.875144958496094\n",
      "Loss: -76.69973754882812\n",
      "Loss: -69.06267547607422\n",
      "Loss: -73.34178161621094\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 11%|████▎                                 | 1130/10000 [00:31<04:08, 35.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -74.79644775390625\n",
      "Loss: -72.78413391113281\n",
      "Loss: -80.08353424072266\n",
      "Loss: -79.08753967285156\n",
      "Loss: -81.15765380859375\n",
      "Loss: -83.48272705078125\n",
      "Loss: -76.43962097167969\n",
      "Loss: -64.4107666015625\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 11%|████▎                                 | 1138/10000 [00:31<04:16, 34.60it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -65.28857421875\n",
      "Loss: -71.8663558959961\n",
      "Loss: -69.2546157836914\n",
      "Loss: -64.8938217163086\n",
      "Loss: -63.97599792480469\n",
      "Loss: -66.81625366210938\n",
      "Loss: -73.21147155761719\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 11%|████▎                                 | 1146/10000 [00:31<04:14, 34.84it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -75.85995483398438\n",
      "Loss: -81.12922668457031\n",
      "Loss: -78.81120300292969\n",
      "Loss: -77.23870849609375\n",
      "Loss: -77.10377502441406\n",
      "Loss: -78.74298095703125\n",
      "Loss: -76.3531723022461\n",
      "Loss: -79.618896484375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|████▍                                 | 1154/10000 [00:32<04:09, 35.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -82.0023193359375\n",
      "Loss: -80.08234405517578\n",
      "Loss: -81.05213165283203\n",
      "Loss: -80.3643798828125\n",
      "Loss: -78.06961059570312\n",
      "Loss: -85.66252136230469\n",
      "Loss: -82.80292510986328\n",
      "Loss: -83.37431335449219\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|████▍                                 | 1162/10000 [00:32<04:03, 36.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -82.88575744628906\n",
      "Loss: -82.48136901855469\n",
      "Loss: -82.9385986328125\n",
      "Loss: -82.46459197998047\n",
      "Loss: -85.07691955566406\n",
      "Loss: -85.18077850341797\n",
      "Loss: -86.08880615234375\n",
      "Loss: -84.00482177734375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|████▍                                 | 1170/10000 [00:32<03:54, 37.58it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -86.70838165283203\n",
      "Loss: -83.98645782470703\n",
      "Loss: -86.73613739013672\n",
      "Loss: -86.016845703125\n",
      "Loss: -82.10995483398438\n",
      "Loss: -67.37284851074219\n",
      "Loss: 2.051996946334839\n",
      "Loss: -56.868865966796875\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|████▍                                 | 1178/10000 [00:32<03:55, 37.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -28.336334228515625\n",
      "Loss: -14.31329345703125\n",
      "Loss: -7.653692722320557\n",
      "Loss: -2.613337278366089\n",
      "Loss: 0.0908835381269455\n",
      "Loss: 0.041964784264564514\n",
      "Loss: -1.2977688312530518\n",
      "Loss: -7.125141143798828\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|████▌                                 | 1186/10000 [00:33<03:59, 36.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -13.863617897033691\n",
      "Loss: -24.04717254638672\n",
      "Loss: -33.489112854003906\n",
      "Loss: -33.67868423461914\n",
      "Loss: -53.35852813720703\n",
      "Loss: -58.86960220336914\n",
      "Loss: -63.8387565612793\n",
      "Loss: -68.77592468261719\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|████▌                                 | 1194/10000 [00:33<04:03, 36.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -69.50344848632812\n",
      "Loss: -65.8796615600586\n",
      "Loss: -68.70930480957031\n",
      "Loss: -67.39637756347656\n",
      "Loss: -67.19499206542969\n",
      "Loss: -69.35810089111328\n",
      "Loss: -69.54350280761719\n",
      "Loss: -72.19251251220703\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|████▌                                 | 1202/10000 [00:33<04:01, 36.47it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -76.548095703125\n",
      "Loss: -78.95848083496094\n",
      "Loss: -71.8207778930664\n",
      "Loss: -78.70722198486328\n",
      "Loss: -75.93328857421875\n",
      "Loss: -75.85542297363281\n",
      "Loss: -77.01118469238281\n",
      "Loss: -77.68523406982422\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|████▌                                 | 1210/10000 [00:33<03:59, 36.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -79.87507629394531\n",
      "Loss: -78.46281433105469\n",
      "Loss: -77.749267578125\n",
      "Loss: -76.49736785888672\n",
      "Loss: -80.83595275878906\n",
      "Loss: -81.70938110351562\n",
      "Loss: -79.24775695800781\n",
      "Loss: -81.74079895019531\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|████▋                                 | 1218/10000 [00:33<03:57, 37.04it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -80.45895385742188\n",
      "Loss: -79.64148712158203\n",
      "Loss: -83.67881774902344\n",
      "Loss: -80.17298889160156\n",
      "Loss: -82.614990234375\n",
      "Loss: -83.03694152832031\n",
      "Loss: -81.56655883789062\n",
      "Loss: -82.5836181640625\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|████▋                                 | 1226/10000 [00:34<04:09, 35.14it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -85.43941497802734\n",
      "Loss: -86.05996704101562\n",
      "Loss: -83.36455535888672\n",
      "Loss: -82.26213073730469\n",
      "Loss: -83.47330474853516\n",
      "Loss: -83.52870178222656\n",
      "Loss: -83.70953369140625\n",
      "Loss: -84.30587768554688\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|████▋                                 | 1234/10000 [00:34<04:02, 36.17it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -84.31808471679688\n",
      "Loss: -85.1347427368164\n",
      "Loss: -82.1548080444336\n",
      "Loss: -80.89842224121094\n",
      "Loss: -77.16525268554688\n",
      "Loss: -80.68975830078125\n",
      "Loss: -79.87484741210938\n",
      "Loss: -81.6584701538086\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|████▋                                 | 1242/10000 [00:34<04:03, 36.04it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -80.68942260742188\n",
      "Loss: -83.6395034790039\n",
      "Loss: -86.33515930175781\n",
      "Loss: -79.047119140625\n",
      "Loss: -70.96463012695312\n",
      "Loss: -63.638916015625\n",
      "Loss: -79.52534484863281\n",
      "Loss: -68.5632095336914\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|████▊                                 | 1250/10000 [00:34<04:00, 36.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -69.96149444580078\n",
      "Loss: -75.66995239257812\n",
      "Loss: -74.2685775756836\n",
      "Loss: -77.80339050292969\n",
      "Loss: -78.09223175048828\n",
      "Loss: -79.67439270019531\n",
      "Loss: -81.30516052246094\n",
      "Loss: -80.2440185546875\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 13%|████▊                                 | 1258/10000 [00:35<03:58, 36.60it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -80.79421997070312\n",
      "Loss: -78.32830047607422\n",
      "Loss: -80.84564208984375\n",
      "Loss: -83.6548843383789\n",
      "Loss: -82.52793884277344\n",
      "Loss: -82.70858764648438\n",
      "Loss: -85.05381774902344\n",
      "Loss: -83.45986938476562\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 13%|████▊                                 | 1266/10000 [00:35<03:49, 38.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -85.21261596679688\n",
      "Loss: -83.97709655761719\n",
      "Loss: -83.96963500976562\n",
      "Loss: -83.85371398925781\n",
      "Loss: -78.71866607666016\n",
      "Loss: -78.22220611572266\n",
      "Loss: -84.13400268554688\n",
      "Loss: -77.84654235839844\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 13%|████▊                                 | 1274/10000 [00:35<03:51, 37.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -84.6995620727539\n",
      "Loss: -81.0907974243164\n",
      "Loss: -85.23231506347656\n",
      "Loss: -85.02082824707031\n",
      "Loss: -87.0888442993164\n",
      "Loss: -83.74578857421875\n",
      "Loss: -77.32744598388672\n",
      "Loss: -76.85639953613281\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 13%|████▊                                 | 1282/10000 [00:35<04:09, 34.89it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -85.26225280761719\n",
      "Loss: -81.38052368164062\n",
      "Loss: -82.44572448730469\n",
      "Loss: -82.22200012207031\n",
      "Loss: -82.31983947753906\n",
      "Loss: -83.00761413574219\n",
      "Loss: -83.7601318359375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 13%|████▉                                 | 1290/10000 [00:35<04:05, 35.42it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -81.40908813476562\n",
      "Loss: -86.16547393798828\n",
      "Loss: -83.39913940429688\n",
      "Loss: -82.68302917480469\n",
      "Loss: -83.96305847167969\n",
      "Loss: -86.49966430664062\n",
      "Loss: -87.14163970947266\n",
      "Loss: -83.36557006835938\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 13%|████▉                                 | 1298/10000 [00:36<04:00, 36.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -85.94032287597656\n",
      "Loss: -86.38922882080078\n",
      "Loss: -84.22639465332031\n",
      "Loss: -88.6214370727539\n",
      "Loss: -87.15628814697266\n",
      "Loss: -81.82579040527344\n",
      "Loss: -72.30623626708984\n",
      "Loss: -75.55692291259766\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 13%|████▉                                 | 1307/10000 [00:36<03:49, 37.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -78.63105773925781\n",
      "Loss: -77.12771606445312\n",
      "Loss: -77.1573257446289\n",
      "Loss: -79.03030395507812\n",
      "Loss: -81.06741333007812\n",
      "Loss: -79.57616424560547\n",
      "Loss: -84.38633728027344\n",
      "Loss: -81.8774185180664\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 13%|████▉                                 | 1315/10000 [00:36<03:46, 38.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -83.74872589111328\n",
      "Loss: -82.92301177978516\n",
      "Loss: -81.77223205566406\n",
      "Loss: -83.77072143554688\n",
      "Loss: -84.23634338378906\n",
      "Loss: -84.12503814697266\n",
      "Loss: -82.94261169433594\n",
      "Loss: -86.05204010009766\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 13%|█████                                 | 1323/10000 [00:36<03:46, 38.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -85.38432312011719\n",
      "Loss: -85.96963500976562\n",
      "Loss: -85.95545959472656\n",
      "Loss: -83.34861755371094\n",
      "Loss: -73.3668212890625\n",
      "Loss: -62.22825622558594\n",
      "Loss: -78.26748657226562\n",
      "Loss: -69.06266021728516\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 13%|█████                                 | 1327/10000 [00:36<04:04, 35.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -72.93087768554688\n",
      "Loss: -73.38993072509766\n",
      "Loss: -73.01423645019531\n",
      "Loss: -76.26786804199219\n",
      "Loss: -79.1380615234375\n",
      "Loss: -79.56768035888672\n",
      "Loss: -84.33465576171875\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 13%|█████                                 | 1336/10000 [00:37<03:58, 36.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -82.95126342773438\n",
      "Loss: -78.9032974243164\n",
      "Loss: -80.31855773925781\n",
      "Loss: -80.79544067382812\n",
      "Loss: -82.31619262695312\n",
      "Loss: -83.6880874633789\n",
      "Loss: -81.53575134277344\n",
      "Loss: -84.13573455810547\n",
      "Loss: -77.78009033203125\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 13%|█████                                 | 1344/10000 [00:37<03:50, 37.59it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -84.48534393310547\n",
      "Loss: -84.7930908203125\n",
      "Loss: -85.15294647216797\n",
      "Loss: -85.74296569824219\n",
      "Loss: -86.129638671875\n",
      "Loss: -80.34227752685547\n",
      "Loss: -76.85189819335938\n",
      "Loss: -85.15594482421875\n",
      "Loss: -83.56593322753906\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█████▏                                | 1353/10000 [00:37<03:43, 38.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -84.361083984375\n",
      "Loss: -82.36665344238281\n",
      "Loss: -84.30165100097656\n",
      "Loss: -83.774658203125\n",
      "Loss: -81.3344497680664\n",
      "Loss: -85.21678161621094\n",
      "Loss: -86.03007507324219\n",
      "Loss: -87.05686950683594\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█████▏                                | 1361/10000 [00:37<03:50, 37.55it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -86.28877258300781\n",
      "Loss: -84.90557861328125\n",
      "Loss: -84.92236328125\n",
      "Loss: -87.74296569824219\n",
      "Loss: -83.7347412109375\n",
      "Loss: -85.7396240234375\n",
      "Loss: -86.47833251953125\n",
      "Loss: -83.64063262939453\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█████▏                                | 1369/10000 [00:38<03:53, 37.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -89.6351318359375\n",
      "Loss: -89.70210266113281\n",
      "Loss: -88.73230743408203\n",
      "Loss: -86.034912109375\n",
      "Loss: -43.43536376953125\n",
      "Loss: 13.944866180419922\n",
      "Loss: -41.95280456542969\n",
      "Loss: -26.64554214477539\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█████▏                                | 1377/10000 [00:38<03:52, 37.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -12.648808479309082\n",
      "Loss: -3.3996691703796387\n",
      "Loss: 1.9538122415542603\n",
      "Loss: 5.06955099105835\n",
      "Loss: 5.552013397216797\n",
      "Loss: 3.0897655487060547\n",
      "Loss: -1.5514380931854248\n",
      "Loss: -8.74498176574707\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█████▎                                | 1386/10000 [00:38<03:37, 39.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -20.776447296142578\n",
      "Loss: -31.372596740722656\n",
      "Loss: -14.02901554107666\n",
      "Loss: -52.515235900878906\n",
      "Loss: -49.22056579589844\n",
      "Loss: -43.365745544433594\n",
      "Loss: -42.41321563720703\n",
      "Loss: -44.960205078125\n",
      "Loss: -49.36924743652344\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█████▎                                | 1396/10000 [00:38<03:26, 41.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -54.89386749267578\n",
      "Loss: -62.293853759765625\n",
      "Loss: -72.39620971679688\n",
      "Loss: -65.17316436767578\n",
      "Loss: -66.98077392578125\n",
      "Loss: -65.83057403564453\n",
      "Loss: -68.74755859375\n",
      "Loss: -69.99402618408203\n",
      "Loss: -70.53961944580078\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█████▎                                | 1406/10000 [00:38<03:34, 40.05it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -73.84001159667969\n",
      "Loss: -75.34136199951172\n",
      "Loss: -75.74321746826172\n",
      "Loss: -79.27082061767578\n",
      "Loss: -75.93240356445312\n",
      "Loss: -73.58802795410156\n",
      "Loss: -76.62506103515625\n",
      "Loss: -76.02999114990234\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 14%|█████▎                                | 1411/10000 [00:39<03:41, 38.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -77.84909057617188\n",
      "Loss: -78.3756103515625\n",
      "Loss: -79.79681396484375\n",
      "Loss: -78.62742614746094\n",
      "Loss: -81.81230163574219\n",
      "Loss: -82.30699157714844\n",
      "Loss: -81.35774230957031\n",
      "Loss: -81.84854125976562\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█████▍                                | 1419/10000 [00:39<03:42, 38.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -82.89506530761719\n",
      "Loss: -84.19407653808594\n",
      "Loss: -83.45377349853516\n",
      "Loss: -84.85956573486328\n",
      "Loss: -83.87193298339844\n",
      "Loss: -83.39620208740234\n",
      "Loss: -83.4683609008789\n",
      "Loss: -82.96534729003906\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█████▍                                | 1427/10000 [00:39<03:40, 38.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -80.13558197021484\n",
      "Loss: -81.96415710449219\n",
      "Loss: -85.49909210205078\n",
      "Loss: -85.56929016113281\n",
      "Loss: -82.66775512695312\n",
      "Loss: -87.01948547363281\n",
      "Loss: -84.2442626953125\n",
      "Loss: -82.90306091308594\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█████▍                                | 1435/10000 [00:39<03:40, 38.87it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -82.66923522949219\n",
      "Loss: -86.46519470214844\n",
      "Loss: -86.86862182617188\n",
      "Loss: -86.26364135742188\n",
      "Loss: -83.68923950195312\n",
      "Loss: -85.08685302734375\n",
      "Loss: -87.85678100585938\n",
      "Loss: -87.55497741699219\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█████▍                                | 1443/10000 [00:39<03:53, 36.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -86.00988006591797\n",
      "Loss: -80.56114196777344\n",
      "Loss: -87.07553100585938\n",
      "Loss: -84.78754425048828\n",
      "Loss: -84.56214904785156\n",
      "Loss: -84.79460144042969\n",
      "Loss: -86.57160949707031\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|█████▌                                | 1452/10000 [00:40<03:43, 38.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -86.97107696533203\n",
      "Loss: -89.7106704711914\n",
      "Loss: -71.4770278930664\n",
      "Loss: -31.182764053344727\n",
      "Loss: -72.26806640625\n",
      "Loss: -50.555259704589844\n",
      "Loss: -48.654327392578125\n",
      "Loss: -52.80107879638672\n",
      "Loss: -53.498809814453125\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|█████▌                                | 1461/10000 [00:40<03:37, 39.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -56.08515167236328\n",
      "Loss: -60.41407775878906\n",
      "Loss: -63.942527770996094\n",
      "Loss: -67.92032623291016\n",
      "Loss: -75.01107025146484\n",
      "Loss: -78.93421936035156\n",
      "Loss: -79.94261169433594\n",
      "Loss: -78.80097961425781\n",
      "Loss: -73.62420654296875\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|█████▌                                | 1470/10000 [00:40<03:35, 39.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -77.30109405517578\n",
      "Loss: -78.17958068847656\n",
      "Loss: -80.1714096069336\n",
      "Loss: -79.73381042480469\n",
      "Loss: -77.96892547607422\n",
      "Loss: -79.73851013183594\n",
      "Loss: -82.34332275390625\n",
      "Loss: -82.92999267578125\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|█████▌                                | 1478/10000 [00:40<03:36, 39.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -81.84207153320312\n",
      "Loss: -81.66410827636719\n",
      "Loss: -84.35857391357422\n",
      "Loss: -83.13800048828125\n",
      "Loss: -83.24126434326172\n",
      "Loss: -83.69172668457031\n",
      "Loss: -85.25985717773438\n",
      "Loss: -90.27723693847656\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|█████▋                                | 1486/10000 [00:41<03:50, 36.93it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -86.14718627929688\n",
      "Loss: -84.2545394897461\n",
      "Loss: -87.59970092773438\n",
      "Loss: -84.67465209960938\n",
      "Loss: -87.36407470703125\n",
      "Loss: -84.62138366699219\n",
      "Loss: -84.84127044677734\n",
      "Loss: -83.82505798339844\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|█████▋                                | 1494/10000 [00:41<03:44, 37.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -88.62439727783203\n",
      "Loss: -85.48804473876953\n",
      "Loss: -90.5030517578125\n",
      "Loss: -87.32833862304688\n",
      "Loss: -87.6478271484375\n",
      "Loss: -88.01636505126953\n",
      "Loss: -85.18915557861328\n",
      "Loss: -82.60295104980469\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|█████▋                                | 1502/10000 [00:41<03:40, 38.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -81.78147888183594\n",
      "Loss: -86.47179412841797\n",
      "Loss: -84.28194427490234\n",
      "Loss: -81.6976547241211\n",
      "Loss: -86.13107299804688\n",
      "Loss: -86.88732147216797\n",
      "Loss: -86.70195007324219\n",
      "Loss: -84.34265899658203\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|█████▋                                | 1510/10000 [00:41<03:45, 37.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -78.29644012451172\n",
      "Loss: -79.96231842041016\n",
      "Loss: -84.19254302978516\n",
      "Loss: -87.73475646972656\n",
      "Loss: -83.034912109375\n",
      "Loss: -86.23284912109375\n",
      "Loss: -86.96800994873047\n",
      "Loss: -85.84087371826172\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|█████▊                                | 1518/10000 [00:41<03:46, 37.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -86.53787231445312\n",
      "Loss: -80.19636535644531\n",
      "Loss: -74.87017822265625\n",
      "Loss: -82.63772583007812\n",
      "Loss: -81.97969055175781\n",
      "Loss: -85.33767700195312\n",
      "Loss: -83.9202651977539\n",
      "Loss: -84.9702377319336\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|█████▊                                | 1526/10000 [00:42<03:50, 36.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -87.36576843261719\n",
      "Loss: -89.35533142089844\n",
      "Loss: -86.22883605957031\n",
      "Loss: -86.51366424560547\n",
      "Loss: -81.52290344238281\n",
      "Loss: -77.63602447509766\n",
      "Loss: -83.10052490234375\n",
      "Loss: -84.33283996582031\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|█████▊                                | 1534/10000 [00:42<03:52, 36.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -85.626220703125\n",
      "Loss: -85.9603042602539\n",
      "Loss: -84.30616760253906\n",
      "Loss: -89.52891540527344\n",
      "Loss: -85.75891876220703\n",
      "Loss: -85.5211181640625\n",
      "Loss: -85.46417236328125\n",
      "Loss: -85.7566146850586\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|█████▊                                | 1542/10000 [00:42<03:51, 36.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -83.86293029785156\n",
      "Loss: -86.44239807128906\n",
      "Loss: -87.10818481445312\n",
      "Loss: -87.51463317871094\n",
      "Loss: -88.05128479003906\n",
      "Loss: -88.88992309570312\n",
      "Loss: -88.55857849121094\n",
      "Loss: -91.31962585449219\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█████▉                                | 1550/10000 [00:42<03:51, 36.58it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -86.39512634277344\n",
      "Loss: -89.0855941772461\n",
      "Loss: -87.63385772705078\n",
      "Loss: -87.4514389038086\n",
      "Loss: -90.09823608398438\n",
      "Loss: -87.86940002441406\n",
      "Loss: -90.31517028808594\n",
      "Loss: -88.33349609375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█████▉                                | 1558/10000 [00:43<03:51, 36.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -87.71109008789062\n",
      "Loss: -88.6083984375\n",
      "Loss: -86.4571533203125\n",
      "Loss: -88.57891845703125\n",
      "Loss: -90.3323974609375\n",
      "Loss: -90.69283294677734\n",
      "Loss: -88.58952331542969\n",
      "Loss: -82.42715454101562\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█████▉                                | 1566/10000 [00:43<03:42, 37.97it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -70.60374450683594\n",
      "Loss: -77.15463256835938\n",
      "Loss: -79.39547729492188\n",
      "Loss: -79.73355102539062\n",
      "Loss: -76.89683532714844\n",
      "Loss: -76.62979125976562\n",
      "Loss: -81.13322448730469\n",
      "Loss: -81.14700317382812\n",
      "Loss: -84.86555480957031\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█████▉                                | 1575/10000 [00:43<03:34, 39.21it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -90.33366394042969\n",
      "Loss: -82.60076141357422\n",
      "Loss: -83.2315673828125\n",
      "Loss: -89.79219055175781\n",
      "Loss: -84.65667724609375\n",
      "Loss: -87.01522064208984\n",
      "Loss: -84.86871337890625\n",
      "Loss: -89.08074951171875\n",
      "Loss: -85.65463256835938\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|██████                                | 1583/10000 [00:43<03:37, 38.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -84.40394592285156\n",
      "Loss: -89.09083557128906\n",
      "Loss: -87.06245422363281\n",
      "Loss: -89.1407241821289\n",
      "Loss: -91.34931182861328\n",
      "Loss: -85.4920425415039\n",
      "Loss: -87.05242156982422\n",
      "Loss: -86.07511138916016\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|██████                                | 1591/10000 [00:43<03:36, 38.80it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -84.71577453613281\n",
      "Loss: -83.9181900024414\n",
      "Loss: -88.69743347167969\n",
      "Loss: -85.77249908447266\n",
      "Loss: -88.59530639648438\n",
      "Loss: -88.30410766601562\n",
      "Loss: -89.09794616699219\n",
      "Loss: -89.16889953613281\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|██████                                | 1599/10000 [00:44<03:39, 38.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -83.16896057128906\n",
      "Loss: -84.19508361816406\n",
      "Loss: -84.90274047851562\n",
      "Loss: -90.82342529296875\n",
      "Loss: -87.54728698730469\n",
      "Loss: -88.40351104736328\n",
      "Loss: -88.07201385498047\n",
      "Loss: -90.68016052246094\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|██████                                | 1607/10000 [00:44<03:36, 38.80it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -89.38372039794922\n",
      "Loss: -88.60255432128906\n",
      "Loss: -89.92800903320312\n",
      "Loss: -90.73223114013672\n",
      "Loss: -87.49995422363281\n",
      "Loss: -88.60569763183594\n",
      "Loss: -89.72222137451172\n",
      "Loss: -91.6094741821289\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|██████▏                               | 1615/10000 [00:44<03:41, 37.89it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -89.06989288330078\n",
      "Loss: -90.08694458007812\n",
      "Loss: -91.23954772949219\n",
      "Loss: -91.26661682128906\n",
      "Loss: -94.65809631347656\n",
      "Loss: -93.80354309082031\n",
      "Loss: -89.75054931640625\n",
      "Loss: -78.40499877929688\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|██████▏                               | 1623/10000 [00:44<03:43, 37.54it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -21.099763870239258\n",
      "Loss: -35.419673919677734\n",
      "Loss: -33.98027801513672\n",
      "Loss: -32.34212875366211\n",
      "Loss: -26.38321304321289\n",
      "Loss: -20.740928649902344\n",
      "Loss: -17.07854461669922\n",
      "Loss: -14.821348190307617\n",
      "Loss: -14.3035888671875\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|██████▏                               | 1632/10000 [00:44<03:42, 37.58it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -15.50059700012207\n",
      "Loss: -18.01044464111328\n",
      "Loss: -22.657224655151367\n",
      "Loss: -29.208057403564453\n",
      "Loss: -37.300270080566406\n",
      "Loss: -46.16857147216797\n",
      "Loss: -55.952049255371094\n",
      "Loss: -55.73974609375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|██████▏                               | 1640/10000 [00:45<03:50, 36.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -67.65318298339844\n",
      "Loss: -74.82492065429688\n",
      "Loss: -74.28923034667969\n",
      "Loss: -73.8441162109375\n",
      "Loss: -70.99127960205078\n",
      "Loss: -74.75474548339844\n",
      "Loss: -80.03178405761719\n",
      "Loss: -76.364501953125\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|██████▎                               | 1648/10000 [00:45<03:47, 36.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -77.17018127441406\n",
      "Loss: -78.657470703125\n",
      "Loss: -80.971923828125\n",
      "Loss: -80.36495971679688\n",
      "Loss: -79.10617065429688\n",
      "Loss: -80.42922973632812\n",
      "Loss: -84.69522094726562\n",
      "Loss: -81.42879486083984\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|██████▎                               | 1656/10000 [00:45<03:48, 36.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -80.75340270996094\n",
      "Loss: -88.22480010986328\n",
      "Loss: -86.83574676513672\n",
      "Loss: -82.95632934570312\n",
      "Loss: -85.65538024902344\n",
      "Loss: -86.38339233398438\n",
      "Loss: -86.5004653930664\n",
      "Loss: -87.14044189453125\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|██████▎                               | 1664/10000 [00:45<03:52, 35.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -86.62442016601562\n",
      "Loss: -83.82733154296875\n",
      "Loss: -85.89868927001953\n",
      "Loss: -84.48990631103516\n",
      "Loss: -87.00884246826172\n",
      "Loss: -90.19630432128906\n",
      "Loss: -90.31159973144531\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|██████▎                               | 1672/10000 [00:46<04:12, 33.03it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -85.98914337158203\n",
      "Loss: -92.2147216796875\n",
      "Loss: -88.78492736816406\n",
      "Loss: -91.20777130126953\n",
      "Loss: -87.09661865234375\n",
      "Loss: -83.51437377929688\n",
      "Loss: -86.95201873779297\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|██████▍                               | 1680/10000 [00:46<03:54, 35.52it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -87.33267974853516\n",
      "Loss: -89.02983856201172\n",
      "Loss: -87.74545288085938\n",
      "Loss: -89.45578002929688\n",
      "Loss: -87.47743225097656\n",
      "Loss: -89.71202850341797\n",
      "Loss: -92.70161437988281\n",
      "Loss: -91.27462768554688\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|██████▍                               | 1688/10000 [00:46<03:45, 36.84it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -92.89473724365234\n",
      "Loss: -88.28340911865234\n",
      "Loss: -82.43682098388672\n",
      "Loss: -76.04679870605469\n",
      "Loss: -78.60777282714844\n",
      "Loss: -85.36338806152344\n",
      "Loss: -79.70880126953125\n",
      "Loss: -79.88008117675781\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|██████▍                               | 1696/10000 [00:46<03:45, 36.90it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -82.78519439697266\n",
      "Loss: -82.58906555175781\n",
      "Loss: -87.62631225585938\n",
      "Loss: -84.77243041992188\n",
      "Loss: -89.2523422241211\n",
      "Loss: -86.678466796875\n",
      "Loss: -88.93508911132812\n",
      "Loss: -59.66489028930664\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|██████▍                               | 1704/10000 [00:46<03:49, 36.21it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -18.942710876464844\n",
      "Loss: -69.05561828613281\n",
      "Loss: -53.346466064453125\n",
      "Loss: -43.95195388793945\n",
      "Loss: -42.00757598876953\n",
      "Loss: -39.54351043701172\n",
      "Loss: -38.42917251586914\n",
      "Loss: -39.20032501220703\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|██████▌                               | 1712/10000 [00:47<03:45, 36.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -40.23638916015625\n",
      "Loss: -42.88869857788086\n",
      "Loss: -45.448001861572266\n",
      "Loss: -49.05849838256836\n",
      "Loss: -54.58118438720703\n",
      "Loss: -58.40242385864258\n",
      "Loss: -64.2049560546875\n",
      "Loss: -71.77192687988281\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|██████▌                               | 1720/10000 [00:47<03:46, 36.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -74.14673614501953\n",
      "Loss: -78.84933471679688\n",
      "Loss: -78.17863464355469\n",
      "Loss: -79.54936218261719\n",
      "Loss: -80.82664489746094\n",
      "Loss: -78.62562561035156\n",
      "Loss: -80.1979751586914\n",
      "Loss: -82.43571472167969\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|██████▌                               | 1728/10000 [00:47<03:43, 36.93it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -82.37687683105469\n",
      "Loss: -84.2328872680664\n",
      "Loss: -84.43851470947266\n",
      "Loss: -81.46299743652344\n",
      "Loss: -83.96685028076172\n",
      "Loss: -83.51179504394531\n",
      "Loss: -84.89118957519531\n",
      "Loss: -84.10009765625\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|██████▌                               | 1736/10000 [00:47<03:51, 35.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -86.8026123046875\n",
      "Loss: -85.38505554199219\n",
      "Loss: -86.3182373046875\n",
      "Loss: -87.76995849609375\n",
      "Loss: -87.56266784667969\n",
      "Loss: -84.55231475830078\n",
      "Loss: -91.37191772460938\n",
      "Loss: -88.78353881835938\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|██████▋                               | 1744/10000 [00:48<03:46, 36.41it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -90.19215393066406\n",
      "Loss: -86.98263549804688\n",
      "Loss: -88.47843933105469\n",
      "Loss: -89.43040466308594\n",
      "Loss: -89.97931671142578\n",
      "Loss: -89.24394226074219\n",
      "Loss: -90.72189331054688\n",
      "Loss: -89.45478820800781\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|██████▋                               | 1752/10000 [00:48<03:47, 36.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -88.37763977050781\n",
      "Loss: -91.80438232421875\n",
      "Loss: -87.97737884521484\n",
      "Loss: -91.57029724121094\n",
      "Loss: -88.53071594238281\n",
      "Loss: -93.63980102539062\n",
      "Loss: -90.01457214355469\n",
      "Loss: -88.16911315917969\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|██████▋                               | 1760/10000 [00:48<03:44, 36.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -90.67145538330078\n",
      "Loss: -88.10658264160156\n",
      "Loss: -90.60194396972656\n",
      "Loss: -90.28111267089844\n",
      "Loss: -90.85511016845703\n",
      "Loss: -91.8502197265625\n",
      "Loss: -91.5741958618164\n",
      "Loss: -90.85985565185547\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|██████▋                               | 1768/10000 [00:48<03:43, 36.79it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -90.88203430175781\n",
      "Loss: -87.79205322265625\n",
      "Loss: -80.20521545410156\n",
      "Loss: -33.90766906738281\n",
      "Loss: -50.518367767333984\n",
      "Loss: -58.061275482177734\n",
      "Loss: -61.934349060058594\n",
      "Loss: -59.07412338256836\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|██████▋                               | 1776/10000 [00:48<03:42, 36.96it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -54.88518524169922\n",
      "Loss: -52.075347900390625\n",
      "Loss: -51.08714294433594\n",
      "Loss: -52.18968200683594\n",
      "Loss: -53.77696228027344\n",
      "Loss: -57.15494155883789\n",
      "Loss: -63.0096435546875\n",
      "Loss: -68.95208740234375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|██████▊                               | 1784/10000 [00:49<03:45, 36.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -74.73220825195312\n",
      "Loss: -81.28925323486328\n",
      "Loss: -81.20779418945312\n",
      "Loss: -84.82234191894531\n",
      "Loss: -84.15396118164062\n",
      "Loss: -81.32919311523438\n",
      "Loss: -81.7239990234375\n",
      "Loss: -81.53399658203125\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|██████▊                               | 1792/10000 [00:49<03:39, 37.38it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -83.75111389160156\n",
      "Loss: -87.55462646484375\n",
      "Loss: -84.71014404296875\n",
      "Loss: -85.5501708984375\n",
      "Loss: -85.31210327148438\n",
      "Loss: -86.14561462402344\n",
      "Loss: -88.22702026367188\n",
      "Loss: -88.33952331542969\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|██████▊                               | 1800/10000 [00:49<03:46, 36.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -91.30067443847656\n",
      "Loss: -88.86763000488281\n",
      "Loss: -88.02519226074219\n",
      "Loss: -86.396728515625\n",
      "Loss: -87.19004821777344\n",
      "Loss: -87.30758666992188\n",
      "Loss: -88.176513671875\n",
      "Loss: -88.5145034790039\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|██████▊                               | 1808/10000 [00:49<03:46, 36.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -92.13533782958984\n",
      "Loss: -91.20162963867188\n",
      "Loss: -87.24308776855469\n",
      "Loss: -94.21695709228516\n",
      "Loss: -89.31830596923828\n",
      "Loss: -97.40245056152344\n",
      "Loss: -88.81915283203125\n",
      "Loss: -88.16468811035156\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|██████▉                               | 1816/10000 [00:50<03:47, 35.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -88.44678497314453\n",
      "Loss: -87.4236831665039\n",
      "Loss: -91.68033599853516\n",
      "Loss: -91.679931640625\n",
      "Loss: -88.51832580566406\n",
      "Loss: -91.08942413330078\n",
      "Loss: -91.10977172851562\n",
      "Loss: -91.5861587524414\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|██████▉                               | 1824/10000 [00:50<03:49, 35.55it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -89.64099884033203\n",
      "Loss: -94.69155883789062\n",
      "Loss: -87.68407440185547\n",
      "Loss: -93.7701187133789\n",
      "Loss: -91.92176818847656\n",
      "Loss: -89.4552001953125\n",
      "Loss: -93.30573272705078\n",
      "Loss: -92.98072814941406\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|██████▉                               | 1833/10000 [00:50<03:35, 37.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -92.20770263671875\n",
      "Loss: -93.07246398925781\n",
      "Loss: -91.6808853149414\n",
      "Loss: -91.82010650634766\n",
      "Loss: -89.41220092773438\n",
      "Loss: -82.32723236083984\n",
      "Loss: -66.04374694824219\n",
      "Loss: -75.18525695800781\n",
      "Loss: -75.29197692871094\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|██████▉                               | 1841/10000 [00:50<03:32, 38.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -81.88941955566406\n",
      "Loss: -75.83706665039062\n",
      "Loss: -77.8154296875\n",
      "Loss: -84.51948547363281\n",
      "Loss: -79.90462493896484\n",
      "Loss: -85.79310607910156\n",
      "Loss: -89.40973663330078\n",
      "Loss: -89.87078094482422\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|███████                               | 1849/10000 [00:50<03:31, 38.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -89.86286926269531\n",
      "Loss: -89.7281723022461\n",
      "Loss: -87.11740112304688\n",
      "Loss: -90.28916931152344\n",
      "Loss: -88.75578308105469\n",
      "Loss: -89.57971954345703\n",
      "Loss: -90.45460510253906\n",
      "Loss: -90.25894165039062\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 19%|███████                               | 1857/10000 [00:51<03:41, 36.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -91.6387939453125\n",
      "Loss: -91.03133392333984\n",
      "Loss: -88.85359954833984\n",
      "Loss: -92.3786849975586\n",
      "Loss: -90.8037109375\n",
      "Loss: -90.65380096435547\n",
      "Loss: -93.17353057861328\n",
      "Loss: -90.71831512451172\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 19%|███████                               | 1861/10000 [00:51<03:40, 36.89it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -91.59635925292969\n",
      "Loss: -88.86705017089844\n",
      "Loss: -97.07769775390625\n",
      "Loss: -90.61392974853516\n",
      "Loss: -93.31513977050781\n",
      "Loss: -93.10447692871094\n",
      "Loss: -91.64370727539062\n",
      "Loss: -91.98387145996094\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 19%|███████                               | 1870/10000 [00:51<03:42, 36.57it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -92.13028717041016\n",
      "Loss: -91.55299377441406\n",
      "Loss: -95.44841003417969\n",
      "Loss: -95.83673858642578\n",
      "Loss: -93.49081420898438\n",
      "Loss: -90.53057861328125\n",
      "Loss: -58.83374786376953\n",
      "Loss: 21.777585983276367\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 19%|███████▏                              | 1879/10000 [00:51<03:35, 37.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -76.28551483154297\n",
      "Loss: -49.20487976074219\n",
      "Loss: -39.93357849121094\n",
      "Loss: -36.61321258544922\n",
      "Loss: -35.33394241333008\n",
      "Loss: -33.85273742675781\n",
      "Loss: -33.250099182128906\n",
      "Loss: -34.184608459472656\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 19%|███████▏                              | 1887/10000 [00:51<03:34, 37.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -35.937232971191406\n",
      "Loss: -37.550376892089844\n",
      "Loss: -41.81414794921875\n",
      "Loss: -46.39116287231445\n",
      "Loss: -52.86650085449219\n",
      "Loss: -60.3487548828125\n",
      "Loss: -67.20018768310547\n",
      "Loss: -71.43205261230469\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 19%|███████▏                              | 1895/10000 [00:52<03:35, 37.53it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -79.96378326416016\n",
      "Loss: -79.54265594482422\n",
      "Loss: -79.1293716430664\n",
      "Loss: -76.69197082519531\n",
      "Loss: -81.39863586425781\n",
      "Loss: -80.49058532714844\n",
      "Loss: -80.17809295654297\n",
      "Loss: -84.36764526367188\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 19%|███████▏                              | 1903/10000 [00:52<03:32, 38.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -84.6243896484375\n",
      "Loss: -82.27976989746094\n",
      "Loss: -84.72515106201172\n",
      "Loss: -84.70006561279297\n",
      "Loss: -86.32303619384766\n",
      "Loss: -87.00468444824219\n",
      "Loss: -84.08235168457031\n",
      "Loss: -88.24497985839844\n",
      "Loss: -88.90170288085938\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 19%|███████▎                              | 1913/10000 [00:52<03:14, 41.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -85.70360565185547\n",
      "Loss: -90.6800308227539\n",
      "Loss: -88.74101257324219\n",
      "Loss: -86.42756652832031\n",
      "Loss: -91.05107116699219\n",
      "Loss: -87.77230072021484\n",
      "Loss: -91.58457946777344\n",
      "Loss: -90.67813873291016\n",
      "Loss: -89.6866683959961\n",
      "Loss: -89.98847961425781\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 19%|███████▎                              | 1923/10000 [00:52<03:07, 43.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -89.25262451171875\n",
      "Loss: -91.02732849121094\n",
      "Loss: -92.89682006835938\n",
      "Loss: -94.18692016601562\n",
      "Loss: -93.29058837890625\n",
      "Loss: -86.59849548339844\n",
      "Loss: -89.8926773071289\n",
      "Loss: -89.41146087646484\n",
      "Loss: -88.026611328125\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 19%|███████▎                              | 1928/10000 [00:52<03:11, 42.20it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -90.23130798339844\n",
      "Loss: -89.50330352783203\n",
      "Loss: -87.84315490722656\n",
      "Loss: -91.99269104003906\n",
      "Loss: -92.47156524658203\n",
      "Loss: -94.18692016601562\n",
      "Loss: -94.15336608886719\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 19%|███████▎                              | 1937/10000 [00:53<03:30, 38.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -91.20065307617188\n",
      "Loss: -91.8631362915039\n",
      "Loss: -91.69688415527344\n",
      "Loss: -76.33302307128906\n",
      "Loss: -40.034385681152344\n",
      "Loss: -76.74136352539062\n",
      "Loss: -54.977264404296875\n",
      "Loss: -65.40486907958984\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 19%|███████▍                              | 1945/10000 [00:53<03:26, 39.00it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -68.92259216308594\n",
      "Loss: -64.27423095703125\n",
      "Loss: -61.63609313964844\n",
      "Loss: -63.330833435058594\n",
      "Loss: -69.27545928955078\n",
      "Loss: -73.68133544921875\n",
      "Loss: -76.10252380371094\n",
      "Loss: -77.71219635009766\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|███████▍                              | 1953/10000 [00:53<03:27, 38.82it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -81.30010986328125\n",
      "Loss: -87.5858154296875\n",
      "Loss: -83.89783477783203\n",
      "Loss: -89.44355773925781\n",
      "Loss: -84.35860443115234\n",
      "Loss: -86.18143463134766\n",
      "Loss: -86.08985900878906\n",
      "Loss: -83.74796295166016\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|███████▍                              | 1961/10000 [00:53<03:46, 35.42it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -87.3750991821289\n",
      "Loss: -88.30548095703125\n",
      "Loss: -87.32070922851562\n",
      "Loss: -86.64073181152344\n",
      "Loss: -89.47869873046875\n",
      "Loss: -90.9331283569336\n",
      "Loss: -91.7841796875\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|███████▍                              | 1969/10000 [00:54<03:45, 35.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -91.76004791259766\n",
      "Loss: -86.9654541015625\n",
      "Loss: -90.64131164550781\n",
      "Loss: -93.70000457763672\n",
      "Loss: -88.83773803710938\n",
      "Loss: -90.8707275390625\n",
      "Loss: -90.32972717285156\n",
      "Loss: -84.85794067382812\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|███████▌                              | 1977/10000 [00:54<03:37, 36.97it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -72.01190185546875\n",
      "Loss: -71.20819091796875\n",
      "Loss: -89.21771240234375\n",
      "Loss: -78.30621337890625\n",
      "Loss: -80.36463928222656\n",
      "Loss: -82.18777465820312\n",
      "Loss: -81.92655944824219\n",
      "Loss: -81.0585708618164\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|███████▌                              | 1985/10000 [00:54<03:32, 37.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -85.38850402832031\n",
      "Loss: -88.22537231445312\n",
      "Loss: -85.86373138427734\n",
      "Loss: -89.68213653564453\n",
      "Loss: -92.22107696533203\n",
      "Loss: -90.79457092285156\n",
      "Loss: -88.49164581298828\n",
      "Loss: -64.05403137207031\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|███████▌                              | 1994/10000 [00:54<03:24, 39.10it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -1.7378954887390137\n",
      "Loss: -50.771812438964844\n",
      "Loss: -41.070457458496094\n",
      "Loss: -57.181907653808594\n",
      "Loss: -58.1019172668457\n",
      "Loss: -52.87565612792969\n",
      "Loss: -47.131534576416016\n",
      "Loss: -43.894691467285156\n",
      "Loss: -42.07297897338867\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|███████▌                              | 2003/10000 [00:54<03:24, 39.13it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -41.482784271240234\n",
      "Loss: -43.93644714355469\n",
      "Loss: -46.065425872802734\n",
      "Loss: -50.582923889160156\n",
      "Loss: -54.85395812988281\n",
      "Loss: -60.378414154052734\n",
      "Loss: -66.92616271972656\n",
      "Loss: -71.91136169433594\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|███████▋                              | 2011/10000 [00:55<03:29, 38.14it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -73.44741821289062\n",
      "Loss: -77.95440673828125\n",
      "Loss: -82.62431335449219\n",
      "Loss: -78.71078491210938\n",
      "Loss: -83.59635925292969\n",
      "Loss: -76.37670135498047\n",
      "Loss: -82.86225891113281\n",
      "Loss: -84.71439361572266\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|███████▋                              | 2021/10000 [00:55<03:18, 40.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -86.07063293457031\n",
      "Loss: -83.08717346191406\n",
      "Loss: -87.87187194824219\n",
      "Loss: -82.66740417480469\n",
      "Loss: -84.5085678100586\n",
      "Loss: -83.62347412109375\n",
      "Loss: -86.5811538696289\n",
      "Loss: -86.55117797851562\n",
      "Loss: -87.14856719970703\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      " 20%|███████▋                              | 2026/10000 [00:55<03:16, 40.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -85.5025634765625\n",
      "Loss: -87.8707275390625\n",
      "Loss: -91.20401000976562\n",
      "Loss: -87.38163757324219\n",
      "Loss: -88.49962615966797\n",
      "Loss: -90.78417205810547\n",
      "Loss: -90.29499053955078\n",
      "Loss: -89.11942291259766\n",
      "Loss: -91.82395935058594\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|███████▋                              | 2032/10000 [00:55<03:38, 36.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: -89.83417510986328\n",
      "Loss: -90.13851928710938\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[13], line 23\u001b[0m\n\u001b[1;32m     21\u001b[0m T_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39mlog_prob\u001b[38;5;241m.\u001b[39mmean()\n\u001b[1;32m     22\u001b[0m T_loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[0;32m---> 23\u001b[0m \u001b[43mT_opt_paired\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     24\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLoss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mT_loss\u001b[38;5;241m.\u001b[39mitem()\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m     26\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m     27\u001b[0m \u001b[38;5;124;03mwith torch.no_grad():\u001b[39;00m\n\u001b[1;32m     28\u001b[0m \u001b[38;5;124;03m    if step % 1000 == 0:\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     50\u001b[0m \u001b[38;5;124;03m        print(np.mean(fids2))\u001b[39;00m\n\u001b[1;32m     51\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.9/site-packages/torch/optim/optimizer.py:484\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    479\u001b[0m         \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    480\u001b[0m             \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m    481\u001b[0m                 \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    482\u001b[0m             )\n\u001b[0;32m--> 484\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    485\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m    487\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.9/site-packages/torch/optim/optimizer.py:89\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.<locals>._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m     87\u001b[0m     torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefaults[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdifferentiable\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m     88\u001b[0m     torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n\u001b[0;32m---> 89\u001b[0m     ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     90\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m     91\u001b[0m     torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.9/site-packages/torch/optim/adam.py:226\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m    214\u001b[0m     beta1, beta2 \u001b[38;5;241m=\u001b[39m group[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbetas\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m    216\u001b[0m     has_complex \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_init_group(\n\u001b[1;32m    217\u001b[0m         group,\n\u001b[1;32m    218\u001b[0m         params_with_grad,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    223\u001b[0m         state_steps,\n\u001b[1;32m    224\u001b[0m     )\n\u001b[0;32m--> 226\u001b[0m     \u001b[43madam\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    227\u001b[0m \u001b[43m        \u001b[49m\u001b[43mparams_with_grad\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    228\u001b[0m \u001b[43m        \u001b[49m\u001b[43mgrads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    229\u001b[0m \u001b[43m        \u001b[49m\u001b[43mexp_avgs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    230\u001b[0m \u001b[43m        \u001b[49m\u001b[43mexp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    231\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmax_exp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    232\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstate_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    233\u001b[0m \u001b[43m        \u001b[49m\u001b[43mamsgrad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mamsgrad\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    234\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhas_complex\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhas_complex\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    235\u001b[0m \u001b[43m        \u001b[49m\u001b[43mbeta1\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta1\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    236\u001b[0m \u001b[43m        \u001b[49m\u001b[43mbeta2\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta2\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    237\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlr\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    238\u001b[0m \u001b[43m        \u001b[49m\u001b[43mweight_decay\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mweight_decay\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    239\u001b[0m \u001b[43m        \u001b[49m\u001b[43meps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43meps\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    240\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmaximize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmaximize\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    241\u001b[0m \u001b[43m        \u001b[49m\u001b[43mforeach\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mforeach\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    242\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcapturable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcapturable\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    243\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdifferentiable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdifferentiable\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    244\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfused\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mfused\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    245\u001b[0m \u001b[43m        \u001b[49m\u001b[43mgrad_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mgrad_scale\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    246\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfound_inf\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mfound_inf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    247\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    249\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.9/site-packages/torch/optim/optimizer.py:161\u001b[0m, in \u001b[0;36m_disable_dynamo_if_unsupported.<locals>.wrapper.<locals>.maybe_fallback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    159\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m disabled_func(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m    160\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 161\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.9/site-packages/torch/optim/adam.py:766\u001b[0m, in \u001b[0;36madam\u001b[0;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach, capturable, differentiable, fused, grad_scale, found_inf, has_complex, amsgrad, beta1, beta2, lr, weight_decay, eps, maximize)\u001b[0m\n\u001b[1;32m    763\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    764\u001b[0m     func \u001b[38;5;241m=\u001b[39m _single_tensor_adam\n\u001b[0;32m--> 766\u001b[0m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    767\u001b[0m \u001b[43m    \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    768\u001b[0m \u001b[43m    \u001b[49m\u001b[43mgrads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    769\u001b[0m \u001b[43m    \u001b[49m\u001b[43mexp_avgs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    770\u001b[0m \u001b[43m    \u001b[49m\u001b[43mexp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    771\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmax_exp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    772\u001b[0m \u001b[43m    \u001b[49m\u001b[43mstate_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    773\u001b[0m \u001b[43m    \u001b[49m\u001b[43mamsgrad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mamsgrad\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    774\u001b[0m \u001b[43m    \u001b[49m\u001b[43mhas_complex\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhas_complex\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    775\u001b[0m \u001b[43m    \u001b[49m\u001b[43mbeta1\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta1\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    776\u001b[0m \u001b[43m    \u001b[49m\u001b[43mbeta2\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta2\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    777\u001b[0m \u001b[43m    \u001b[49m\u001b[43mlr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlr\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    778\u001b[0m \u001b[43m    \u001b[49m\u001b[43mweight_decay\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mweight_decay\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    779\u001b[0m \u001b[43m    \u001b[49m\u001b[43meps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43meps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    780\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmaximize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmaximize\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    781\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcapturable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcapturable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    782\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdifferentiable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdifferentiable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    783\u001b[0m \u001b[43m    \u001b[49m\u001b[43mgrad_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgrad_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    784\u001b[0m \u001b[43m    \u001b[49m\u001b[43mfound_inf\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfound_inf\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    785\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.9/site-packages/torch/optim/adam.py:587\u001b[0m, in \u001b[0;36m_multi_tensor_adam\u001b[0;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, grad_scale, found_inf, amsgrad, has_complex, beta1, beta2, lr, weight_decay, eps, maximize, capturable, differentiable)\u001b[0m\n\u001b[1;32m    583\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    584\u001b[0m     bias_correction1 \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m    585\u001b[0m         \u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m beta1 \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m _get_value(step) \u001b[38;5;28;01mfor\u001b[39;00m step \u001b[38;5;129;01min\u001b[39;00m device_state_steps\n\u001b[1;32m    586\u001b[0m     ]\n\u001b[0;32m--> 587\u001b[0m     bias_correction2 \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m    588\u001b[0m         \u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m beta2 \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m _get_value(step) \u001b[38;5;28;01mfor\u001b[39;00m step \u001b[38;5;129;01min\u001b[39;00m device_state_steps\n\u001b[1;32m    589\u001b[0m     ]\n\u001b[1;32m    591\u001b[0m     step_size \u001b[38;5;241m=\u001b[39m _stack_if_compiling([(lr \u001b[38;5;241m/\u001b[39m bc) \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m bc \u001b[38;5;129;01min\u001b[39;00m bias_correction1])\n\u001b[1;32m    593\u001b[0m     bias_correction2_sqrt \u001b[38;5;241m=\u001b[39m [_dispatch_sqrt(bc) \u001b[38;5;28;01mfor\u001b[39;00m bc \u001b[38;5;129;01min\u001b[39;00m bias_correction2]  \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.9/site-packages/torch/optim/adam.py:588\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    583\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    584\u001b[0m     bias_correction1 \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m    585\u001b[0m         \u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m beta1 \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m _get_value(step) \u001b[38;5;28;01mfor\u001b[39;00m step \u001b[38;5;129;01min\u001b[39;00m device_state_steps\n\u001b[1;32m    586\u001b[0m     ]\n\u001b[1;32m    587\u001b[0m     bias_correction2 \u001b[38;5;241m=\u001b[39m [\n\u001b[0;32m--> 588\u001b[0m         \u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m beta2 \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m \u001b[43m_get_value\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m step \u001b[38;5;129;01min\u001b[39;00m device_state_steps\n\u001b[1;32m    589\u001b[0m     ]\n\u001b[1;32m    591\u001b[0m     step_size \u001b[38;5;241m=\u001b[39m _stack_if_compiling([(lr \u001b[38;5;241m/\u001b[39m bc) \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m bc \u001b[38;5;129;01min\u001b[39;00m bias_correction1])\n\u001b[1;32m    593\u001b[0m     bias_correction2_sqrt \u001b[38;5;241m=\u001b[39m [_dispatch_sqrt(bc) \u001b[38;5;28;01mfor\u001b[39;00m bc \u001b[38;5;129;01min\u001b[39;00m bias_correction2]  \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.9/site-packages/torch/optim/optimizer.py:104\u001b[0m, in \u001b[0;36m_get_value\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m    102\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m x\n\u001b[1;32m    103\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 104\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitem\u001b[49m() \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(x, torch\u001b[38;5;241m.\u001b[39mTensor) \u001b[38;5;28;01melse\u001b[39;00m x\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "MAX_STEPS = 10000\n",
    "stats = []\n",
    "D_loss = []\n",
    "fids = []\n",
    "fids2 = []\n",
    "device = 'cuda'\n",
    "\n",
    "# Splitting\n",
    "L_PAIRED_SAMPLES = 90\n",
    "L_UNPAIRED_SAMPLES = 500\n",
    "X, Y = X_orig[:L_UNPAIRED_SAMPLES], Y_orig[-L_UNPAIRED_SAMPLES:]\n",
    "X_pair, Y_pair = X_pair_orig[:L_PAIRED_SAMPLES], Y_pair_orig[:L_PAIRED_SAMPLES]\n",
    "X_pair_test, Y_pair_test = X_pair_orig[-100:], Y_pair_orig[-100:]\n",
    "\n",
    "\n",
    "for step in tqdm(range(CONTINUE + 1, MAX_STEPS)):\n",
    "        T_opt_paired.zero_grad()\n",
    "        X_paired, Y_paired = paired_sampler(X_pair, Y_pair, BATCH_SIZE)\n",
    "\n",
    "        log_prob = T.log_prob(inputs=Y_paired, context=X_paired)\n",
    "        T_loss = -log_prob.mean()\n",
    "        T_loss.backward()\n",
    "        T_opt_paired.step()\n",
    "        print(f\"Loss: {T_loss.item()}\")\n",
    "        \n",
    "        \"\"\"\n",
    "        with torch.no_grad():\n",
    "            if step % 1000 == 0:\n",
    "                for x, y in zip(X_pair_test, Y_pair_test):\n",
    "                    x = torch.tensor(x).unsqueeze(0).to(device)\n",
    "                    y = torch.tensor(y).to(device)\n",
    "                    samples = []\n",
    "                    sample = T.sample(len(y), x.detach().float()).squeeze()\n",
    "                    fid_samples = np.array(sample.cpu()) #np.array(torch.cat(samples, dim=0).cpu())\n",
    "                    fid_samples_2 = np.array(y.cpu())\n",
    "\n",
    "                    mu1 = np.mean(fid_samples, axis=0)\n",
    "                    sigma1 = np.cov(fid_samples, rowvar=False)\n",
    "                    mu2 = np.mean(fid_samples_2, axis=0)\n",
    "                    sigma2 = np.cov(fid_samples_2, rowvar=False)\n",
    "\n",
    "                    diff = mu1 - mu2\n",
    "                    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)\n",
    "                    tr_covmean = np.trace(covmean)\n",
    "                    fid = (diff.dot(diff) + np.trace(sigma1) +  np.trace(sigma2) - 2 * tr_covmean)\n",
    "                    fids.append(fid.real)\n",
    "                    fids2.append(fid.real / np.var(fid_samples_2))\n",
    "                    \n",
    "                print(np.mean(fids))\n",
    "                print(np.mean(fids2))\n",
    "        \"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b6e5e5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1.27, 1.33, 1.28"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "f87f12bf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1.2933333333333332, 0.02624669291337273)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean([1.27, 1.33, 1.28]), np.std([1.27, 1.33, 1.28])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82bd73f0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efad248f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
