{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torchvision\n",
    "from torchvision.datasets import MNIST\n",
    "from torchvision.transforms import ToTensor\n",
    "import torch.nn as nn\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    "from torch.optim import Adam\n",
    "from tensorboardX import SummaryWriter\n",
    "import uuid\n",
    "import tqdm\n",
    "import os\n",
    "import collections\n",
    "\n",
    "use_cuda = torch.cuda.is_available()\n",
    "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n",
    "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n",
    "IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor\n",
    "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n",
    "Tensor = FloatTensor\n",
    "unique_id = str(uuid.uuid4())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_data(dataset, val_pct):\n",
    "    # Determine size of validation set\n",
    "    n_val = int(val_pct*dataset)\n",
    "    # Create random permutation of 0 to n-1\n",
    "    idxs = np.random.permutation(dataset)\n",
    "    # Pick first n_val indices for validation set\n",
    "    return idxs[n_val:], idxs[:n_val]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.         0.         0.         0.03333333 1.         0.\n",
      " 0.         0.         0.03333333 1.         0.15       0.\n",
      " 0.         0.         0.         0.         0.        ] [1.0, 1.0]\n",
      "94050 10450\n",
      "[ 60248  98937  84965  13546  58398 100756  61149  77214  54019  48604]\n"
     ]
    }
   ],
   "source": [
    "data = torch.load('100000_random_v_random.pt')\n",
    "np.set_printoptions(suppress=True)\n",
    "l = len(data)\n",
    "\n",
    "for i in range(0, len(data)):\n",
    "    data[i][1] = [data[i][1][4] / 2000 , data[i][1][9] / 2000 ]\n",
    "    \n",
    "    data[i][0][0:4] = np.true_divide( data[i][0][0:4], 30) # Normalize P1 buildings\n",
    "    data[i][0][5:9] = np.true_divide( data[i][0][5:9], 30) # Normalize P2 buildings\n",
    "    data[i][0][4] = data[i][0][4] / 2000 # Normalize P1 Nexus HP\n",
    "    data[i][0][9] = data[i][0][9] / 2000 # Normalize P2 Nexus HP\n",
    "    data[i][0][10] = data[i][0][10] / 1500 # Normalize P1 Minerals\n",
    "    data[i][0][11:] = np.true_divide(data[i][0][11:], 60) # Normalize both Player's units on the field\n",
    "\n",
    "print(data[0][0], data[0][1])\n",
    "\n",
    "np_data = np.array(data)\n",
    "\n",
    "my_x = np_data[:, 0]\n",
    "my_y = np_data[:, 1]\n",
    "tensor_x = torch.stack([torch.Tensor(i) for i in my_x])\n",
    "tensor_y = torch.stack([torch.Tensor(i) for i in my_y])\n",
    "\n",
    "tensor_dataset = torch.utils.data.TensorDataset(tensor_x, tensor_y)\n",
    "\n",
    "train_indices, val_indices = split_data(len(data), val_pct=0.1)\n",
    "\n",
    "print(len(train_indices), len(val_indices))\n",
    "print(val_indices[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "summary_test = SummaryWriter(log_dir = 'nexus-HP-transition-model-report/')\n",
    "\n",
    "\n",
    "train_sampler = SubsetRandomSampler(train_indices)\n",
    "train_dl = DataLoader(tensor_dataset, batch_size, sampler=train_sampler)\n",
    "\n",
    "val_sampler = SubsetRandomSampler(val_indices)\n",
    "valid_dl = DataLoader(tensor_dataset, batch_size, sampler=val_sampler)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def weights_initialize(module):\n",
    "    if type(module) == nn.Linear:\n",
    "        nn.init.xavier_uniform_(module.weight, gain=nn.init.calculate_gain('relu'))\n",
    "        module.bias.data.fill_(0.01)\n",
    "        \n",
    "class TransModel(nn.Module):\n",
    "    \"\"\" Model for DQN \"\"\"\n",
    "\n",
    "    def __init__(self, input_len, output_len):\n",
    "        # super(_TransModel, self).__init__()\n",
    "        super().__init__()\n",
    "        \n",
    "        self.model = nn.Sequential(collections.OrderedDict([\n",
    "                        ('fc1', nn.Linear(input_len, 512)),\n",
    "                        ('rl1', nn.ReLU()),\n",
    "                        ('bn1', nn.BatchNorm1d(512)),\n",
    "                        ('fc2', nn.Linear(512, 128)),\n",
    "                        ('rl2', nn.ReLU()),\n",
    "                        ('bn2', nn.BatchNorm1d(128)),\n",
    "                        ('fc3', nn.Linear(128, output_len))\n",
    "                    ]))#.to(device, non_blocking=True)\n",
    "            \n",
    "        self.model.apply(weights_initialize)\n",
    "\n",
    "    def forward(self, xb):\n",
    "        return self.model(xb)\n",
    "\n",
    "# class TransModel():\n",
    "# def __init__(self, input_len, ouput_len, learning_rate = 0.0001):\n",
    "#     self.model = _TransModel(input_len, ouput_len)\n",
    "\n",
    "#         if use_cuda:\n",
    "#             print(\"Using GPU\")\n",
    "#             self.model = self.model.cuda()\n",
    "#         else:\n",
    "#             print(\"Using CPU\")\n",
    "\n",
    "#     self.steps = 0\n",
    "#     # self.model = nn.DataParallel(self.model)\n",
    "\n",
    "\n",
    "# def predict(self, input, steps, learning):\n",
    "#     output = self.model(input).squeeze(1)\n",
    "#     #reward, next_state = output[0], output[1:]\n",
    "#     return output\n",
    "\n",
    "# def predict_batch(self, input):\n",
    "#     output = self.model(input)\n",
    "#     #reward, next_state = output[:, 0], output[:, 1:]\n",
    "#     return output\n",
    "\n",
    "# def fit(self, state, target_state):\n",
    "#     loss = self.loss_fn(state, target_state)\n",
    "\n",
    "#     self.optimizer.zero_grad()\n",
    "#     loss.backward()\n",
    "#     self.optimizer.step()\n",
    "#     self.steps += 1\n",
    "#     return loss\n",
    "\n",
    "def save(model):\n",
    "    cwd = os.getcwd()\n",
    "    path = cwd + '/models'\n",
    "    if not os.path.exists(path):\n",
    "        os.makedirs(path, exist_ok=True)\n",
    "    file_path = path + '/NEXUS_' + unique_id + '.pt'\n",
    "    torch.save(model.state_dict(), file_path)\n",
    "\n",
    "#################################\n",
    "def loss_batch(model, loss_func, xb, yb, opt=None, metric=None):\n",
    "    # Generate predictions\n",
    "    preds = model(xb)\n",
    "    # Calculate loss\n",
    "    loss = loss_func(preds, yb)\n",
    "                     \n",
    "    if opt is not None:\n",
    "        # Compute gradients\n",
    "        loss.backward()\n",
    "        # Update parameters             \n",
    "        opt.step()\n",
    "        # Reset gradients\n",
    "        opt.zero_grad()\n",
    "    \n",
    "    metric_result = None\n",
    "    if metric is not None:\n",
    "        # Compute the metric\n",
    "        metric_result = metric(preds, yb)\n",
    "    \n",
    "    return loss.item(), len(xb), metric_result\n",
    "\n",
    "def evaluate(model, loss_fn, valid_dl, metric=None):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        # Pass each batch through the model\n",
    "        results = [loss_batch(model, loss_fn, xb, yb, metric=metric) for xb,yb in valid_dl]\n",
    "        # Separate losses, counts and metrics\n",
    "        losses, nums, metrics = zip(*results)\n",
    "        # Total size of the dataset\n",
    "        total = np.sum(nums)\n",
    "        # Avg. loss across batches \n",
    "        avg_loss = np.sum(np.multiply(losses, nums)) / total\n",
    "        avg_metric = None\n",
    "        if metric is not None:\n",
    "            # Avg. of metric across batches\n",
    "            avg_metric = np.sum(np.multiply(metrics, nums)) / total\n",
    "    model.train()\n",
    "    return avg_loss, total, avg_metric\n",
    "\n",
    "def fit(epochs, lr, model, loss_fn, train_dl, valid_dl, metric=None, opt_fn=None):\n",
    "    losses, metrics = [], []\n",
    "    \n",
    "    # Instantiate the optimizer\n",
    "    if opt_fn is None:\n",
    "        print('opt_fn is None')\n",
    "        opt_fn = torch.optim.SGD\n",
    "        opt = torch.optim.SGD(model.parameters(), lr=lr)\n",
    "    \n",
    "    for epoch in tqdm.tqdm(range(epochs)):\n",
    "        # Training\n",
    "        for xb, yb in train_dl:\n",
    "            loss,_,_ = loss_batch(model, loss_fn, xb, yb, opt_fn)\n",
    "\n",
    "        # Evaluation\n",
    "        val_loss, total, val_metric = evaluate(model, loss_fn, valid_dl, metric)\n",
    "        \n",
    "        # Record the loss & metric\n",
    "        losses.append(val_loss)\n",
    "        metrics.append(val_metric)\n",
    "        \n",
    "        # Print progress\n",
    "        if metric is None:\n",
    "            print('Epoch [{}/{}], Loss: {:.4f}, Accuracy: {:.4f}'\n",
    "                  .format(epoch+1, epochs, val_loss, val_metric))\n",
    "        else:\n",
    "            print('Epoch [{}/{}], Loss: {:.4f}, {}: {:.4f}'\n",
    "                  .format(epoch+1, epochs, val_loss, metric.__name__, val_metric))\n",
    "        \n",
    "        if epoch % 1000 == 0 and epoch != 0:\n",
    "            save(model)\n",
    "        \n",
    "    return losses, metrics\n",
    "\n",
    "def accuracy(outputs, ground_true):\n",
    "    preds = torch.sum( torch.eq( outputs, ground_true), dim=1 )\n",
    "    return torch.sum(preds).item() / len(preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.015275486650717703\n"
     ]
    }
   ],
   "source": [
    "test_data = [data[i] for i in val_indices]\n",
    "val_set = np.array(test_data)\n",
    "\n",
    "baseline = np.stack(val_set[:, 0])\n",
    "idx = [4, 9]\n",
    "baseline_hp = baseline[:, idx]\n",
    "\n",
    "bl_next_state_reward = np.stack(val_set[:, 1])\n",
    "\n",
    "mse_baseline = ((baseline_hp - bl_next_state_reward)**2).mean(axis=None)\n",
    "print(mse_baseline)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "trans_model = TransModel(len(data[0][0]), len(data[0][1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_default_device():\n",
    "    \"\"\"Pick GPU if available, else CPU\"\"\"\n",
    "    if torch.cuda.is_available():\n",
    "        return torch.device('cuda')\n",
    "    else:\n",
    "        return torch.device('cpu')\n",
    "\n",
    "def to_device(data, device):\n",
    "    \"\"\"Move tensor(s) to chosen device\"\"\"\n",
    "    if isinstance(data, (list,tuple)):\n",
    "        return [to_device(x, device) for x in data]\n",
    "    return data.to(device, non_blocking=True)\n",
    "\n",
    "class DeviceDataLoader():\n",
    "    \"\"\"Wrap a dataloader to move data to a device\"\"\"\n",
    "    def __init__(self, dl, device):\n",
    "        self.dl = dl\n",
    "        self.device = device\n",
    "        \n",
    "    def __iter__(self):\n",
    "        \"\"\"Yield a batch of data after moving it to device\"\"\"\n",
    "        for b in self.dl: \n",
    "            yield to_device(b, self.device)\n",
    "\n",
    "    def __len__(self):\n",
    "        \"\"\"Number of batches\"\"\"\n",
    "        return len(self.dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = get_default_device()\n",
    "\n",
    "train_dl = DeviceDataLoader(train_dl, device)\n",
    "valid_dl = DeviceDataLoader(valid_dl, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluation(model, data, epoch):\n",
    "    state_action = torch.from_numpy(np.stack(data[:, 0])).type(FloatTensor)\n",
    "    next_state_reward = torch.from_numpy(np.stack(data[:, 1])).type(FloatTensor)\n",
    "    \n",
    "    model.model.eval()\n",
    "    criterion = nn.MSELoss(reduction='mean')\n",
    "    outputs = model.predict_batch(state_action)\n",
    "    mse = criterion(outputs, next_state_reward)\n",
    "    mse_p1 = criterion(outputs[:, 0], next_state_reward[:, 0])\n",
    "    mse_p2 = criterion(outputs[:, 1], next_state_reward[:, 1])\n",
    "\n",
    "    accuracy = torch.sum( torch.sum( torch.eq( outputs, next_state_reward ) )).item()\n",
    "    accuracy = accuracy / (2 * outputs.size()[0])\n",
    "\n",
    "    model.model.train()\n",
    "    \n",
    "    summary_test.add_scalar(\"MSE\", float(mse.item()), epoch)\n",
    "    summary_test.add_scalars(\"MSE\",{'Player 1 Nexus HP MSE': float(mse_p1.item())}, epoch)\n",
    "    summary_test.add_scalars(\"MSE\",{'Player 2 Nexus HP  MSE': float(mse_p2.item())}, epoch)\n",
    "    summary_test.add_scalars(\"MSE\",{'Baseline Nexus HP MSE': float(mse_baseline)}, epoch)\n",
    "\n",
    "    #summary_test.add_scalar(tag=\"Accuracy (Correct / Total)\",\n",
    "    #                        scalar_value=float(accuracy),\n",
    "    #                        global_step=epoch)\n",
    "    \n",
    "    f = open(\"nexus-HP-transition-model-report/test_loss.txt\", \"a+\")\n",
    "    f.write(\"loss:\" + str(mse.item()) + \", \")\n",
    "    f.write(\"acc:\" + str(accuracy) + \"\\n\")\n",
    "    if epoch % 1000 == 0:\n",
    "        f.write(\"output:\" + str(outputs[0:2]) + \"\\n\")\n",
    "        f.write(\"ground true:\" + str(next_state_reward[0:2]) + \"\\n\")\n",
    "    f.close()\n",
    "    return mse.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/10000 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using GPU\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 1/10000 [00:03<10:05:55,  3.64s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/10000], Loss: 0.0520, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 2/10000 [00:06<9:43:03,  3.50s/it] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [2/10000], Loss: 0.0218, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 3/10000 [00:10<9:27:46,  3.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [3/10000], Loss: 0.0133, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 4/10000 [00:13<9:16:27,  3.34s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [4/10000], Loss: 0.0105, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 5/10000 [00:16<9:08:36,  3.29s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [5/10000], Loss: 0.0083, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 6/10000 [00:19<8:58:19,  3.23s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [6/10000], Loss: 0.0070, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 7/10000 [00:22<8:50:58,  3.19s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [7/10000], Loss: 0.0056, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 8/10000 [00:25<8:44:21,  3.15s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [8/10000], Loss: 0.0048, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 9/10000 [00:28<8:40:00,  3.12s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [9/10000], Loss: 0.0042, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 10/10000 [00:31<8:38:04,  3.11s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [10/10000], Loss: 0.0037, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 11/10000 [00:34<8:39:45,  3.12s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [11/10000], Loss: 0.0034, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 12/10000 [00:38<8:38:44,  3.12s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [12/10000], Loss: 0.0031, accuracy: 0.0001\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 13/10000 [00:41<8:39:07,  3.12s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [13/10000], Loss: 0.0026, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 14/10000 [00:44<8:38:09,  3.11s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [14/10000], Loss: 0.0026, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 15/10000 [00:47<8:38:59,  3.12s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [15/10000], Loss: 0.0023, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 16/10000 [00:50<8:39:22,  3.12s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [16/10000], Loss: 0.0022, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 17/10000 [00:53<8:38:20,  3.12s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [17/10000], Loss: 0.0021, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 18/10000 [00:56<8:37:53,  3.11s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [18/10000], Loss: 0.0022, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 19/10000 [00:59<8:35:44,  3.10s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [19/10000], Loss: 0.0021, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 20/10000 [01:02<8:34:21,  3.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [20/10000], Loss: 0.0018, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 21/10000 [01:05<8:33:12,  3.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [21/10000], Loss: 0.0018, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 22/10000 [01:09<8:33:28,  3.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [22/10000], Loss: 0.0018, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 23/10000 [01:12<8:33:43,  3.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [23/10000], Loss: 0.0016, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 24/10000 [01:15<8:33:54,  3.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [24/10000], Loss: 0.0018, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 25/10000 [01:18<8:32:44,  3.08s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [25/10000], Loss: 0.0016, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 26/10000 [01:21<8:41:59,  3.14s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [26/10000], Loss: 0.0016, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 27/10000 [01:24<8:39:58,  3.13s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [27/10000], Loss: 0.0015, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 28/10000 [01:27<8:45:31,  3.16s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [28/10000], Loss: 0.0014, accuracy: 0.0001\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 29/10000 [01:30<8:40:48,  3.13s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [29/10000], Loss: 0.0014, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 30/10000 [01:34<8:37:28,  3.11s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [30/10000], Loss: 0.0016, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 31/10000 [01:37<8:45:45,  3.16s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [31/10000], Loss: 0.0014, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 32/10000 [01:40<8:41:23,  3.14s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [32/10000], Loss: 0.0014, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 33/10000 [01:43<8:38:17,  3.12s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [33/10000], Loss: 0.0013, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 34/10000 [01:46<8:38:24,  3.12s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [34/10000], Loss: 0.0013, accuracy: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 35/10000 [01:49<8:35:52,  3.11s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [35/10000], Loss: 0.0013, accuracy: 0.0001\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 36/10000 [01:52<8:34:17,  3.10s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [36/10000], Loss: 0.0013, accuracy: 0.0000\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-11-d8612286e775>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[0mloss_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMSELoss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreduction\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'mean'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0mlosses2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetrics2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.0001\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrans_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_dl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalid_dl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maccuracy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     19\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m \u001b[0;31m# for epoch in tqdm.tqdm(range(10000)):\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-5-49981c9ededd>\u001b[0m in \u001b[0;36mfit\u001b[0;34m(epochs, lr, model, loss_fn, train_dl, valid_dl, metric, opt_fn)\u001b[0m\n\u001b[1;32m    119\u001b[0m         \u001b[0;31m# Training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    120\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtrain_dl\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 121\u001b[0;31m             \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_fn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    122\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    123\u001b[0m         \u001b[0;31m# Evaluation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-5-49981c9ededd>\u001b[0m in \u001b[0;36mloss_batch\u001b[0;34m(model, loss_func, xb, yb, opt, metric)\u001b[0m\n\u001b[1;32m     76\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mopt\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     77\u001b[0m         \u001b[0;31m# Compute gradients\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 78\u001b[0;31m         \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     79\u001b[0m         \u001b[0;31m# Update parameters\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     80\u001b[0m         \u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/scratch/eecs-share/newmanev/miniconda/lib/python3.6/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m    105\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    106\u001b[0m         \"\"\"\n\u001b[0;32m--> 107\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    108\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    109\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/scratch/eecs-share/newmanev/miniconda/lib/python3.6/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m     91\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[1;32m     92\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[1;32m     94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# state_action = torch.from_numpy(np.stack(train_data[:, 0])).type(FloatTensor)\n",
    "# next_state_reward = torch.from_numpy(np.stack(train_data[:, 1])).type(FloatTensor)\n",
    "# print(state_action.size(), next_state_reward.size())\n",
    "\n",
    "to_device(trans_model, device)\n",
    "\n",
    "if use_cuda:\n",
    "    print(\"Using GPU\")\n",
    "    trans_model = trans_model.cuda()\n",
    "else:\n",
    "    print(\"Using CPU\")\n",
    "\n",
    "# trans_model = nn.DataParallel(trans_model, device_ids=[0,1,2])\n",
    "\n",
    "optimizer = Adam(trans_model.parameters(), lr = 0.0001)\n",
    "loss_fn = nn.MSELoss(reduction='mean')\n",
    "\n",
    "losses2, metrics2 = fit(10000, 0.0001, trans_model, loss_fn, train_dl, valid_dl, accuracy, optimizer)\n",
    "\n",
    "# for epoch in tqdm.tqdm(range(10000)):   \n",
    "#     loss = 0\n",
    "#     # s = np.arange(state_action.shape[0])\n",
    "#     # np.random.shuffle(s)\n",
    "#     # train_x = state_action[s]\n",
    "#     # train_y = next_state_reward[s]\n",
    "# #     for i in range(state_action.shape[0] // batch_size + 1):\n",
    "# #         if (i + 1) * batch_size <= state_action.shape[0]:\n",
    "# #             start = i * batch_size\n",
    "# #             end = (i + 1) * batch_size\n",
    "# #         else:\n",
    "# #             start = i * batch_size\n",
    "# #             end = state_action.shape[0]\n",
    "# #         #print(start, end)\n",
    "# #         inputs, ground_true = train_x[start : end, :], train_y[start : end, :]\n",
    "#     for inputs, ground_true in train_dl:\n",
    "#         print(ground_true)\n",
    "#         outputs = trans_model.predict_batch(inputs)\n",
    "#         loss += trans_model.fit(outputs, ground_true)\n",
    "# #     print(epoch)\n",
    "#     summary_test.add_scalars(\"MSE\",{'Train MSE': float(loss / (state_action.shape[0] // batch_size + 1) )}, epoch)\n",
    "#     evaluation(trans_model, test_data, epoch)\n",
    "#     #break\n",
    "#     if epoch % 1000 == 0 and epoch != 0:\n",
    "#         trans_model.save()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
