{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.optim import Adam\n",
    "from tensorboardX import SummaryWriter\n",
    "import tqdm\n",
    "\n",
    "use_cuda = torch.cuda.is_available()\n",
    "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n",
    "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n",
    "IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor\n",
    "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n",
    "Tensor = FloatTensor\n",
    "\n",
    "def weights_initialize(module):\n",
    "    if type(module) == nn.Linear:\n",
    "        nn.init.xavier_uniform_(module.weight, gain=nn.init.calculate_gain('relu'))\n",
    "        module.bias.data.fill_(0.01)\n",
    "        \n",
    "class _TransModel(nn.Module):\n",
    "    \"\"\" Model for DQN \"\"\"\n",
    "\n",
    "    def __init__(self, input_len, output_len):\n",
    "        super(_TransModel, self).__init__()\n",
    "        \n",
    "        self.fc1 = nn.Sequential(\n",
    "            torch.nn.Linear(input_len, 512),\n",
    "            torch.nn.BatchNorm1d(512),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        self.fc1.apply(weights_initialize)\n",
    "        \n",
    "        self.fc2 = nn.Sequential(\n",
    "            torch.nn.Linear(512, 128),\n",
    "            torch.nn.BatchNorm1d(128),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        self.fc2.apply(weights_initialize)\n",
    "        \n",
    "        self.output_layer = nn.Sequential(\n",
    "            torch.nn.Linear(128, output_len)\n",
    "        )\n",
    "        self.output_layer.apply(weights_initialize)\n",
    "        \n",
    "    def forward(self, input):\n",
    "        x = self.fc1(input)\n",
    "        x = self.fc2(x)\n",
    "        \n",
    "        return self.output_layer(x)\n",
    "\n",
    "    \n",
    "class TransModel():\n",
    "    def __init__(self, input_len, ouput_len, learning_rate = 0.0001):\n",
    "        self.model = _TransModel(input_len, ouput_len)\n",
    "        \n",
    "        if use_cuda:\n",
    "            print(\"Using GPU\")\n",
    "            self.model = self.model.cuda()\n",
    "        else:\n",
    "            print(\"Using CPU\")\n",
    "        self.steps = 0\n",
    "        self.model = nn.DataParallel(self.model)\n",
    "        self.optimizer = Adam(self.model.parameters(), lr = learning_rate)\n",
    "        self.loss_fn = nn.MSELoss(reduction='mean')\n",
    "        \n",
    "        self.summary = SummaryWriter(log_dir = 'trans_summary/')\n",
    "        self.steps = 0\n",
    "        \n",
    "    def predict(self, input, steps, learning):\n",
    "        output = self.model(input).squeeze(1)\n",
    "        #reward, next_state = output[0], output[1:]\n",
    "\n",
    "        return output\n",
    "\n",
    "    def predict_batch(self, input):\n",
    "        output = self.model(input)\n",
    "        #reward, next_state = output[:, 0], output[:, 1:]\n",
    "        return output\n",
    "\n",
    "    def fit(self, state, target_state):\n",
    "        loss = self.loss_fn(state, target_state)\n",
    "\n",
    "        self.optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        self.optimizer.step()\n",
    "        self.steps += 1\n",
    "        self.summary.add_scalar(tag=\"loss/train_Loss\",\n",
    "                                scalar_value=float(loss),\n",
    "                                global_step=self.steps)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = torch.load('random_v_random.pt')\n",
    "np.set_printoptions(suppress=True)\n",
    "\n",
    "l = len(data)\n",
    "\n",
    "index = [4, 9]\n",
    "\n",
    "for i in range(0, len(data)):\n",
    "    data[i][1] = [ data[i][0][4] / 2000, data[i][0][9] / 2000 ]\n",
    "    data[i][0] = np.delete(data[i][0], index)\n",
    "\n",
    "print(data[0][0], data[0][1])\n",
    "\n",
    "np.random.shuffle(data)\n",
    "\n",
    "train_data = np.array(data[: int(np.floor(l * 0.5))])\n",
    "test_data = np.array(data[int(np.floor(l * 0.5)) : ])\n",
    "print(train_data.shape, test_data.shape)\n",
    "\n",
    "batch_size = 64\n",
    "summary_test = SummaryWriter(log_dir = 'test_summary/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trans_model = TransModel(len(data[0][0]), len(data[0][1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluation(model, data, epoch):\n",
    "    state_action = torch.from_numpy(np.stack(data[:, 0])).type(FloatTensor)\n",
    "    next_state_reward = torch.from_numpy(np.stack(data[:, 1])).type(FloatTensor)\n",
    "    \n",
    "    model.model.eval()\n",
    "    criterion = nn.MSELoss(reduction='mean')\n",
    "    outputs = model.predict_batch(state_action)\n",
    "    mse = criterion(outputs, next_state_reward)\n",
    "    mse_p1 = criterion(outputs[:, 0], next_state_reward[:, 0])\n",
    "    mse_p2 = criterion(outputs[:, 1], next_state_reward[:, 1])\n",
    "\n",
    "    accuracy = torch.sum( torch.sum( torch.eq( outputs, next_state_reward ) )).item()\n",
    "    accuracy = accuracy / (2 * outputs.size()[0])\n",
    "\n",
    "    model.model.train()\n",
    "    \n",
    "    summary_test.add_scalar(tag=\"Total HP MSE\",\n",
    "                            scalar_value=float(mse.item()),\n",
    "                            global_step=epoch)\n",
    "\n",
    "    summary_test.add_scalar(tag=\"Player 1 Nexus HP MSE\",\n",
    "                        scalar_value=float(mse_p1.item()),\n",
    "                        global_step=epoch)\n",
    "\n",
    "    summary_test.add_scalar(tag=\"Player 2 Nexus HP MSE\",\n",
    "                    scalar_value=float(mse_p2.item()),\n",
    "                    global_step=epoch)\n",
    "\n",
    "    summary_test.add_scalar(tag=\"Accuracy (Correct / Total)\",\n",
    "                            scalar_value=float(accuracy),\n",
    "                            global_step=epoch)\n",
    "    \n",
    "    f = open(\"test_loss.txt\", \"a+\")\n",
    "    f.write(\"loss:\" + str(mse.item()) + \", \")\n",
    "    f.write(\"acc:\" + str(accuracy) + \"\\n\")\n",
    "    if epoch % 1000 == 0:\n",
    "        f.write(\"output:\" + str(outputs[0:2]) + \"\\n\")\n",
    "        f.write(\"ground true:\" + str(next_state_reward[0:2]) + \"\\n\")\n",
    "    f.close()\n",
    "    return mse.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state_action = torch.from_numpy(np.stack(train_data[:, 0])).type(FloatTensor)\n",
    "next_state_reward = torch.from_numpy(np.stack(train_data[:, 1])).type(FloatTensor)\n",
    "print(state_action.size(), next_state_reward.size())\n",
    "\n",
    "for epoch in tqdm.tqdm(range(10000)):\n",
    "    s = np.arange(state_action.shape[0])\n",
    "    np.random.shuffle(s)\n",
    "    train_x = state_action[s]\n",
    "    train_y = next_state_reward[s]\n",
    "    for i in range(state_action.shape[0] // batch_size + 1):\n",
    "        if (i + 1) * batch_size <= state_action.shape[0]:\n",
    "            start = i * batch_size\n",
    "            end = (i + 1) * batch_size\n",
    "        else:\n",
    "            start = i * batch_size\n",
    "            end = state_action.shape[0]\n",
    "        #print(start, end)\n",
    "        inputs, ground_true = train_x[start : end, :], train_y[start : end, :]\n",
    "        outputs = trans_model.predict_batch(inputs)\n",
    "        trans_model.fit(outputs, ground_true)\n",
    "#     print(epoch)\n",
    "    evaluation(trans_model, test_data, epoch)\n",
    "    #break"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
