{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.optim import Adam\n",
    "from tensorboardX import SummaryWriter\n",
    "import tqdm\n",
    "import os\n",
    "import uuid\n",
    "import random\n",
    "\n",
    "use_cuda = torch.cuda.is_available()\n",
    "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n",
    "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n",
    "IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor\n",
    "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n",
    "Tensor = FloatTensor\n",
    "\n",
    "unique_id = str(uuid.uuid4())\n",
    "\n",
    "def weights_initialize(module):\n",
    "    if type(module) == nn.Linear:\n",
    "        nn.init.xavier_uniform_(module.weight, gain=nn.init.calculate_gain('relu'))\n",
    "        module.bias.data.fill_(0.01)\n",
    "        \n",
    "class _TransModel(nn.Module):\n",
    "    \"\"\" Model for DQN \"\"\"\n",
    "\n",
    "    def __init__(self, input_len, output_len):\n",
    "        super(_TransModel, self).__init__()\n",
    "        \n",
    "        self.fc1 = nn.Sequential(\n",
    "            torch.nn.Linear(input_len, 512),\n",
    "            torch.nn.BatchNorm1d(512),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        self.fc1.apply(weights_initialize)\n",
    "        \n",
    "        self.fc2 = nn.Sequential(\n",
    "            torch.nn.Linear(512, 128),\n",
    "            torch.nn.BatchNorm1d(128),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        self.fc2.apply(weights_initialize)\n",
    "        \n",
    "        self.output_layer = nn.Sequential(\n",
    "            torch.nn.Linear(128, output_len)\n",
    "        )\n",
    "        self.output_layer.apply(weights_initialize)\n",
    "        \n",
    "    def forward(self, input):\n",
    "        x = self.fc1(input)\n",
    "        x = self.fc2(x)\n",
    "        \n",
    "        return self.output_layer(x)\n",
    "\n",
    "class TransModel():\n",
    "    def __init__(self, input_len, ouput_len, learning_rate = 0.0001):\n",
    "        self.model = _TransModel(input_len, ouput_len)\n",
    "        \n",
    "        if use_cuda:\n",
    "            print(\"Using GPU\")\n",
    "            self.model = self.model.cuda()\n",
    "        else:\n",
    "            print(\"Using CPU\")\n",
    "        self.steps = 0\n",
    "        # self.model = nn.DataParallel(self.model)\n",
    "        self.optimizer = Adam(self.model.parameters(), lr = learning_rate)\n",
    "        self.loss_fn = nn.MSELoss(reduction='mean')\n",
    "        \n",
    "        self.steps = 0\n",
    "        \n",
    "    def predict(self, input, steps, learning):\n",
    "        \n",
    "        output = self.model(input).squeeze(1)\n",
    "        #reward, next_state = output[0], output[1:]\n",
    "\n",
    "        return output\n",
    "\n",
    "    def predict_batch(self, input):\n",
    "        output = self.model(input)\n",
    "        #reward, next_state = output[:, 0], output[:, 1:]\n",
    "        return output\n",
    "\n",
    "    def fit(self, state, target_state):\n",
    "        loss = self.loss_fn(state, target_state)\n",
    "\n",
    "        self.optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        self.optimizer.step()\n",
    "        self.steps += 1\n",
    "        return loss\n",
    "    \n",
    "    def save(self):\n",
    "        cwd = os.getcwd()\n",
    "        path = cwd + '/models'\n",
    "        if not os.path.exists(path):\n",
    "            os.makedirs(path, exist_ok=True)\n",
    "        file_path = path + '/NEXUS_' + unique_id + '.pt'\n",
    "        torch.save(self.model.state_dict(), file_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "17\n",
      "17\n",
      "18\n",
      "18\n"
     ]
    }
   ],
   "source": [
    "a_b_m = 0\n",
    "a_b_v = 1\n",
    "a_b_c = 2\n",
    "a_b_p = 3\n",
    "a_nex = 4\n",
    "e_b_m = 5\n",
    "e_b_v = 6\n",
    "e_b_c = 7\n",
    "e_b_p = 8\n",
    "e_nex = 9\n",
    "a_mnrl = 10\n",
    "a_u_m = 11\n",
    "a_u_v = 12\n",
    "a_u_c = 13\n",
    "e_u_m = 14\n",
    "e_u_v = 15\n",
    "e_u_c = 16\n",
    "a_rwd = 17 # only in second column of data so data[i][1], NOT data[i][0]\n",
    "\n",
    "data_1 = torch.load('100000_random_v_random.pt')\n",
    "\n",
    "data_2 = torch.load('60000_sadq_v_random.pt')\n",
    "print(len(data_1[0][0]))\n",
    "print(len(data_2[0][0]))\n",
    "print(len(data_1[0][1]))\n",
    "print(len(data_2[0][1]))\n",
    "\n",
    "data_1_games = []\n",
    "data_2_games = []\n",
    "data_1_games.append(0)\n",
    "data_2_games.append(0)\n",
    "\n",
    "for i in range(len(data_1)):\n",
    "    if (((data[i][0][a_b_m]) + (data[i][0][a_b_v]) + (data[i][0][a_b_c]) + (data[i][0][a_b_p]) + \n",
    "         (data[i][0][e_b_m]) + (data[i][0][e_b_v]) + (data[i][0][e_b_c]) + (data[i][0][e_b_p])) < \n",
    "        ((data[i-1][0][a_b_m]) + (data[i-1][0][a_b_v]) + (data[i-1][0][a_b_c]) + (data[i-1][0][a_b_p]) + \n",
    "         (data[i-1][0][e_b_m]) + (data[i-1][0][e_b_v]) + (data[i-1][0][e_b_c]) + (data[i-1][0][e_b_p]))):\n",
    "        data_1_games.append(i)\n",
    "\n",
    "for i in range(len(data_2)):\n",
    "    if (((data[i][0][a_b_m]) + (data[i][0][a_b_v]) + (data[i][0][a_b_c]) + (data[i][0][a_b_p]) + \n",
    "         (data[i][0][e_b_m]) + (data[i][0][e_b_v]) + (data[i][0][e_b_c]) + (data[i][0][e_b_p])) < \n",
    "        ((data[i-1][0][a_b_m]) + (data[i-1][0][a_b_v]) + (data[i-1][0][a_b_c]) + (data[i-1][0][a_b_p]) + \n",
    "         (data[i-1][0][e_b_m]) + (data[i-1][0][e_b_v]) + (data[i-1][0][e_b_c]) + (data[i-1][0][e_b_p]))):\n",
    "        data_2_games.append(i)\n",
    "\n",
    "\n",
    "data = data_1 + data_2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.         0.         0.         0.03333333 1.         0.\n",
      " 0.         0.         0.03333333 1.         0.15       0.\n",
      " 0.         0.         0.         0.         0.        ] [1.0, 1.0]\n"
     ]
    }
   ],
   "source": [
    "np.set_printoptions(suppress=True)\n",
    "l = len(data)\n",
    "\n",
    "for i in range(0, len(data)):\n",
    "    data[i][1] = [data[i][1][4] / 2000 , data[i][1][9] / 2000 ]\n",
    "    \n",
    "    data[i][0][0:4] = np.true_divide( data[i][0][0:4], 30) # Normalize P1 buildings\n",
    "    data[i][0][5:9] = np.true_divide( data[i][0][5:9], 30) # Normalize P2 buildings\n",
    "    data[i][0][4] = data[i][0][4] / 2000 # Normalize P1 Nexus HP\n",
    "    data[i][0][9] = data[i][0][9] / 2000 # Normalize P2 Nexus HP\n",
    "    data[i][0][10] = data[i][0][10] / 1500 # Normalize P1 Minerals\n",
    "    data[i][0][11:] = np.true_divide(data[i][0][11:], 60) # Normalize both Player's units on the field\n",
    "\n",
    "\n",
    "print(data[0][0], data[0][1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "167900\n",
      "(151110, 2) (16790, 2)\n"
     ]
    }
   ],
   "source": [
    "import copy\n",
    "\n",
    "print(len(data))\n",
    "oneTenthData1 = len(data_1) // 10\n",
    "oneTenthData2 = len(data_2) // 10\n",
    "\n",
    "idx_1 = random.sample(range(len(data_1_games)), len(data_1_games))\n",
    "idx_2 = random.sample(range(len(data_2_games)), len(data_2_games))\n",
    "\n",
    "test_data_1 = []\n",
    "test_data_2 = []\n",
    "train_data = []\n",
    "\n",
    "for i in range(len(idx_1)):\n",
    "    where_game = idx_1[i]\n",
    "    grab_game = data_1_games[where_game]\n",
    "    \n",
    "    if where_game + 1 >= len(data_1_games):\n",
    "        end_grab_game = len(data_1)\n",
    "    else:\n",
    "        end_grab_game = data_1_games[where_game + 1]\n",
    "    \n",
    "    for k in range(end_grab_game - grab_game):\n",
    "        if len(test_data_1) < oneTenthData1:\n",
    "            test_data_1.append(copy.deepcopy(data[grab_game + k]))\n",
    "        else:\n",
    "            train_data.append(copy.deepcopy(data[grab_game + k]))\n",
    "\n",
    "for i in range(len(idx_2)):\n",
    "    where_game = idx_2[i]\n",
    "    grab_game = data_2_games[where_game]\n",
    "    \n",
    "    if where_game + 1 >= len(data_2_games):\n",
    "        end_grab_game = len(data_2)\n",
    "    else:\n",
    "        end_grab_game = data_2_games[where_game + 1]\n",
    "    \n",
    "    for k in range(end_grab_game - grab_game):\n",
    "        if len(test_data_2) < oneTenthData2:\n",
    "            test_data_2.append(copy.deepcopy(data[len(data_1) + grab_game + k]))\n",
    "        else:\n",
    "            train_data.append(copy.deepcopy(data[len(data_1) + grab_game + k]))\n",
    "\n",
    "test_data = test_data_1 + test_data_2\n",
    "\n",
    "train_data = np.array(train_data)\n",
    "test_data = np.array(test_data)\n",
    "            \n",
    "np.random.shuffle(train_data)\n",
    "np.random.shuffle(test_data)\n",
    "\n",
    "# train_data = np.array(data[: int(np.floor(l * 0.9))])\n",
    "# test_data = np.array(data[int(np.floor(l * 0.9)) : ])\n",
    "print(train_data.shape, test_data.shape)\n",
    "\n",
    "batch_size = 64\n",
    "summary_test = SummaryWriter(log_dir = 'nexus-HP-transition-model-report/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline = np.stack(test_data[:, 0])\n",
    "idx = [4, 9]\n",
    "baseline_hp = baseline[:, idx]\n",
    "\n",
    "bl_next_state_reward = np.stack(test_data[:, 1])\n",
    "\n",
    "mse_baseline = ((baseline_hp - bl_next_state_reward)**2).mean(axis=None)\n",
    "print(mse_baseline)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "check_model = TransModel(len(data[0][0]), len(data[0][1]))\n",
    "\n",
    "check_model.model.load_state_dict(torch.load('./models/NEXUS_7a4d0dff-6035-49c9-bf82-afa77d8cf1eb.pt'))\n",
    "\n",
    "evaluation(check_model, test_data, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trans_model = TransModel(len(data[0][0]), len(data[0][1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluation(model, data, epoch):\n",
    "    state_action = torch.from_numpy(np.stack(data[:, 0])).type(FloatTensor)\n",
    "    next_state_reward = torch.from_numpy(np.stack(data[:, 1])).type(FloatTensor)\n",
    "    \n",
    "    model.model.eval()\n",
    "    criterion = nn.MSELoss(reduction='mean')\n",
    "    outputs = model.predict_batch(state_action)\n",
    "#     unNormalizeInputs = next_state_reward[:] * 2000\n",
    "#     unNormalizeOutputs = outputs[:] * 2000\n",
    "    mse = criterion(outputs, next_state_reward)\n",
    "    mse_p1 = criterion(outputs[:, 0], next_state_reward[:, 0])\n",
    "    mse_p2 = criterion(outputs[:, 1], next_state_reward[:, 1])\n",
    "\n",
    "    accuracy = torch.sum( torch.sum( torch.eq( outputs, next_state_reward ) )).item()\n",
    "    accuracy = accuracy / (2 * outputs.size()[0])\n",
    "\n",
    "    model.model.train()\n",
    "    \n",
    "#     summary_test.add_scalar(\"MSE\", float(mse.item()), epoch)\n",
    "#     summary_test.add_scalars(\"MSE\",{'Player 1 Nexus HP MSE': float(mse_p1.item())}, epoch)\n",
    "#     summary_test.add_scalars(\"MSE\",{'Player 2 Nexus HP MSE': float(mse_p2.item())}, epoch)\n",
    "#     summary_test.add_scalars(\"MSE\",{'Baseline Nexus HP MSE': float(mse_baseline)}, epoch)\n",
    "\n",
    "    #summary_test.add_scalar(tag=\"Accuracy (Correct / Total)\",\n",
    "    #                        scalar_value=float(accuracy),\n",
    "    #                        global_step=epoch)\n",
    "    \n",
    "#     f = open(\"nexus-HP-transition-model-report/test_loss.txt\", \"a+\")\n",
    "#     f.write(\"loss:\" + str(mse.item()) + \", \")\n",
    "#     f.write(\"acc:\" + str(accuracy) + \"\\n\")\n",
    "#     if epoch % 1000 == 0:\n",
    "#         f.write(\"output:\" + str(outputs[0:2]) + \"\\n\")\n",
    "#         f.write(\"ground true:\" + str(next_state_reward[0:2]) + \"\\n\")\n",
    "#     f.close()\n",
    "    return mse.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state_action = torch.from_numpy(np.stack(train_data[:, 0])).type(FloatTensor)\n",
    "next_state_reward = torch.from_numpy(np.stack(train_data[:, 1])).type(FloatTensor)\n",
    "print(state_action.size(), next_state_reward.size())\n",
    "\n",
    "for epoch in tqdm.tqdm(range(10000)):\n",
    "    loss = 0\n",
    "    s = np.arange(state_action.shape[0])\n",
    "    np.random.shuffle(s)\n",
    "    train_x = state_action[s]\n",
    "    train_y = next_state_reward[s]\n",
    "    for i in range(state_action.shape[0] // batch_size + 1):\n",
    "        if (i + 1) * batch_size <= state_action.shape[0]:\n",
    "            start = i * batch_size\n",
    "            end = (i + 1) * batch_size\n",
    "        else:\n",
    "            start = i * batch_size\n",
    "            end = state_action.shape[0]\n",
    "        #print(start, end)\n",
    "        inputs, ground_true = train_x[start : end, :], train_y[start : end, :]\n",
    "        outputs = trans_model.predict_batch(inputs)\n",
    "        loss += trans_model.fit(outputs, ground_true)\n",
    "#     print(epoch)\n",
    "    summary_test.add_scalars(\"MSE\",{'Train MSE': float(loss / (state_action.shape[0] // batch_size + 1) )}, epoch)\n",
    "    evaluation(trans_model, test_data, epoch)\n",
    "    #break\n",
    "    if epoch % 1000 == 0 and epoch != 0:\n",
    "        print(epoch)\n",
    "        trans_model.save()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
