{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    "from torch.optim import Adam\n",
    "from tensorboardX import SummaryWriter\n",
    "import tqdm\n",
    "import os\n",
    "import uuid\n",
    "import collections\n",
    "\n",
    "use_cuda = torch.cuda.is_available()\n",
    "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n",
    "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n",
    "IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor\n",
    "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n",
    "Tensor = FloatTensor\n",
    "\n",
    "unique_id = str(uuid.uuid4())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def weights_initialize(module):\n",
    "    if type(module) == nn.Linear:\n",
    "        nn.init.xavier_uniform_(module.weight, gain=nn.init.calculate_gain('relu'))\n",
    "        module.bias.data.fill_(0.01)\n",
    "        \n",
    "class _TransModel(nn.Module):\n",
    "    \"\"\" Model for DQN \"\"\"\n",
    "\n",
    "    def __init__(self, input_len, output_len):\n",
    "        super(_TransModel, self).__init__()\n",
    "        \n",
    "        self.model = nn.Sequential(collections.OrderedDict([\n",
    "                ('fc1', nn.Linear(input_len, 512)),\n",
    "                ('rl1', nn.ReLU()),\n",
    "                ('bn1', nn.BatchNorm1d(512)),\n",
    "                ('fc2', nn.Linear(512, 128)),\n",
    "                ('rl2', nn.ReLU()),\n",
    "                ('bn2', nn.BatchNorm1d(128)),\n",
    "                ('fc3', nn.Linear(128, output_len))\n",
    "            ]))#.to(device, non_blocking=True)\n",
    "            \n",
    "        self.model.apply(weights_initialize)\n",
    "        \n",
    "    def forward(self, input):\n",
    "        return self.model(input)\n",
    "\n",
    "class TransModel():\n",
    "    def __init__(self, input_len, ouput_len, learning_rate = 0.0001):\n",
    "        self.model = _TransModel(input_len, ouput_len)\n",
    "        \n",
    "        self.steps = 0\n",
    "        self.model = nn.DataParallel(self.model, device_ids=[5])\n",
    "        self.optimizer = Adam(self.model.parameters(), lr = learning_rate)\n",
    "        self.loss_fn = nn.MSELoss(reduction='mean')\n",
    "        \n",
    "    def predict(self, input, steps, learning):\n",
    "        \n",
    "        output = self.model(input).squeeze(1)\n",
    "        #reward, next_state = output[0], output[1:]\n",
    "\n",
    "        return output\n",
    "\n",
    "    def predict_batch(self, input):\n",
    "        output = self.model(input)\n",
    "        #reward, next_state = output[:, 0], output[:, 1:]\n",
    "        return output\n",
    "\n",
    "    def fit(self, state, target_state):\n",
    "        loss = self.loss_fn(state, target_state)\n",
    "\n",
    "        self.optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        self.optimizer.step()\n",
    "        self.steps += 1\n",
    "        return loss\n",
    "    \n",
    "    def save(self):\n",
    "        cwd = os.getcwd()\n",
    "        path = cwd + '/models'\n",
    "        if not os.path.exists(path):\n",
    "            os.makedirs(path, exist_ok=True)\n",
    "        file_path = path + '/NEXUS_' + unique_id + '.pt'\n",
    "        torch.save(self.model.state_dict(), file_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_1 = torch.load('100000_random_v_random.pt')\n",
    "\n",
    "data_2 = torch.load('60000_sadq_v_random.pt')\n",
    "print(len(data_1[0][0]))\n",
    "print(len(data_2[0][0]))\n",
    "print(len(data_1[0][1]))\n",
    "print(len(data_2[0][1]))\n",
    "\n",
    "data_1_games = []\n",
    "data_2_games = []\n",
    "\n",
    "# for i in range(len(data_1)):\n",
    "#     if (((data[i][0][a_b_m]) + (data[i][0][a_b_v]) + (data[i][0][a_b_c]) + (data[i][0][a_b_p]) + \n",
    "#          (data[i][0][e_b_m]) + (data[i][0][e_b_v]) + (data[i][0][e_b_c]) + (data[i][0][e_b_p])) < \n",
    "#         ((data[i-1][0][a_b_m]) + (data[i-1][0][a_b_v]) + (data[i-1][0][a_b_c]) + (data[i-1][0][a_b_p]) + \n",
    "#          (data[i-1][0][e_b_m]) + (data[i-1][0][e_b_v]) + (data[i-1][0][e_b_c]) + (data[i-1][0][e_b_p]))):\n",
    "#         data_1_games.append(i)\n",
    "\n",
    "# data_2_games.append(len(data_1))\n",
    "# for i in range(len(data_1)):\n",
    "#     if (((data[i][0][a_b_m]) + (data[i][0][a_b_v]) + (data[i][0][a_b_c]) + (data[i][0][a_b_p]) + \n",
    "#          (data[i][0][e_b_m]) + (data[i][0][e_b_v]) + (data[i][0][e_b_c]) + (data[i][0][e_b_p])) < \n",
    "#         ((data[i-1][0][a_b_m]) + (data[i-1][0][a_b_v]) + (data[i-1][0][a_b_c]) + (data[i-1][0][a_b_p]) + \n",
    "#          (data[i-1][0][e_b_m]) + (data[i-1][0][e_b_v]) + (data[i-1][0][e_b_c]) + (data[i-1][0][e_b_p]))):\n",
    "#         data_2_games.append(len(data_1) + i)\n",
    "\n",
    "data = data_1 + data_2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_data(dataset, val_pct):\n",
    "    # Determine size of validation set\n",
    "    n_val = int(val_pct*dataset)\n",
    "    # Create random permutation of 0 to n-1\n",
    "    idxs = np.random.permutation(dataset)\n",
    "    # Pick first n_val indices for validation set\n",
    "    return idxs[n_val:], idxs[:n_val]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.set_printoptions(suppress=True)\n",
    "l = len(data)\n",
    "\n",
    "for i in range(0, len(data)):\n",
    "    data[i][1] = np.true_divide(data[i][1][12:], 60)    # Ground truth split into only next-state's units\n",
    "    \n",
    "    data[i][0][0:4] = np.true_divide( data[i][0][0:4], 30) # Normalize P1 buildings\n",
    "    data[i][0][5:9] = np.true_divide( data[i][0][5:9], 30) # Normalize P2 buildings\n",
    "    data[i][0][4] = data[i][0][4] / 2000 # Normalize P1 Nexus HP\n",
    "    data[i][0][9] = data[i][0][9] / 2000 # Normalize P2 Nexus HP\n",
    "    data[i][0][10] = data[i][0][10] / 1500 # Normalize P1 Minerals\n",
    "    data[i][0][11:] = np.true_divide(data[i][0][11:], 60) # Normalize both Player's units on the field\n",
    "    \n",
    "print(data[0][0], data[0][1][0:3], data[0][1][3:6])\n",
    "print(data[1][0], data[1][1])\n",
    "print(data[2][0], data[2][1])\n",
    "\n",
    "np_data = np.array(data)\n",
    "\n",
    "my_x = np_data[:, 0]\n",
    "my_y = np_data[:, 1]\n",
    "tensor_x = torch.stack([torch.Tensor(i) for i in my_x])\n",
    "tensor_y = torch.stack([torch.Tensor(i) for i in my_y])\n",
    "\n",
    "tensor_dataset = torch.utils.data.TensorDataset(tensor_x, tensor_y)\n",
    "\n",
    "train_indices, val_indices = split_data(len(data), val_pct=0.1)\n",
    "\n",
    "print(len(train_indices), len(val_indices))\n",
    "print(val_indices[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "summary_test = SummaryWriter(log_dir = 'nexus-HP-transition-model-report/')\n",
    "\n",
    "\n",
    "train_sampler = SubsetRandomSampler(train_indices)\n",
    "train_dl = DataLoader(tensor_dataset, batch_size, sampler=train_sampler)\n",
    "\n",
    "val_sampler = SubsetRandomSampler(val_indices)\n",
    "valid_dl = DataLoader(tensor_dataset, batch_size, sampler=val_sampler)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data = [data[i] for i in val_indices]\n",
    "val_set = np.array(test_data)\n",
    "\n",
    "baseline = np.stack(val_set[:, 0])\n",
    "baseline_hp = baseline[:, 12:]\n",
    "\n",
    "bl_next_state_reward = np.stack(val_set[:, 1])\n",
    "\n",
    "mse_baseline = ((baseline_hp - bl_next_state_reward)**2).mean(axis=None)\n",
    "print(mse_baseline)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_default_device():\n",
    "    \"\"\"Pick GPU if available, else CPU\"\"\"\n",
    "    if torch.cuda.is_available():\n",
    "        return torch.device('cuda')\n",
    "    else:\n",
    "        return torch.device('cpu')\n",
    "\n",
    "def to_device(data, device):\n",
    "    \"\"\"Move tensor(s) to chosen device\"\"\"\n",
    "    if isinstance(data, (list,tuple)):\n",
    "        return [to_device(x, device) for x in data]\n",
    "    return data.to(device, non_blocking=True)\n",
    "\n",
    "class DeviceDataLoader():\n",
    "    \"\"\"Wrap a dataloader to move data to a device\"\"\"\n",
    "    def __init__(self, dl, device):\n",
    "        self.dl = dl\n",
    "        self.device = device\n",
    "        \n",
    "    def __iter__(self):\n",
    "        \"\"\"Yield a batch of data after moving it to device\"\"\"\n",
    "        for b in self.dl: \n",
    "            yield to_device(b, self.device)\n",
    "\n",
    "    def __len__(self):\n",
    "        \"\"\"Number of batches\"\"\"\n",
    "        return len(self.dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = get_default_device()\n",
    "print(device)\n",
    "\n",
    "train_dl = DeviceDataLoader(train_dl, device)\n",
    "valid_dl = DeviceDataLoader(valid_dl, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trans_model = TransModel(len(data[0][0]), len(data[0][1]))\n",
    "\n",
    "to_device(trans_model.model, device)\n",
    "\n",
    "if use_cuda:\n",
    "    print(\"Using GPU\")\n",
    "    trans_model.model = trans_model.model.cuda()\n",
    "else:\n",
    "    print(\"Using CPU\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluation(model, state_action, next_state_reward):\n",
    "    \n",
    "    model.model.eval()\n",
    "    criterion = nn.MSELoss(reduction='mean')\n",
    "    outputs = model.predict_batch(state_action)\n",
    "    mse = criterion(outputs, next_state_reward)\n",
    "    mse_p1 = criterion(outputs[:, 0], next_state_reward[:, 0])\n",
    "    mse_p2 = criterion(outputs[:, 1], next_state_reward[:, 1])\n",
    "    \n",
    "    accuracy = torch.sum( torch.sum( torch.eq( outputs, next_state_reward ) )).item()\n",
    "    accuracy = accuracy / (2 * outputs.size()[0])\n",
    "\n",
    "    model.model.train()\n",
    "    \n",
    "    return mse.item(), mse_p1.item(), mse_p2.item(), accuracy, len(xb)\n",
    "\n",
    "def log_inputs(mse, mse_p1, mse_p2, accuracy, mse_baseline, output, ground_truth, epoch):\n",
    "    summary_test.add_scalar(\"MSE\", float(mse), epoch)\n",
    "    summary_test.add_scalars(\"MSE\",{'Player 1 Nexus HP MSE': float(mse_p1)}, epoch)\n",
    "    summary_test.add_scalars(\"MSE\",{'Player 2 Nexus HP  MSE': float(mse_p2)}, epoch)\n",
    "    summary_test.add_scalars(\"MSE\",{'Baseline Nexus HP MSE': float(mse_baseline)}, epoch)\n",
    "\n",
    "    #summary_test.add_scalar(tag=\"Accuracy (Correct / Total)\",\n",
    "    #                        scalar_value=float(accuracy),\n",
    "    #                        global_step=epoch)\n",
    "    \n",
    "    f = open(\"nexus-HP-transition-model-report/test_loss.txt\", \"a+\")\n",
    "    f.write(\"loss:\" + str(mse) + \", \")\n",
    "    f.write(\"acc:\" + str(accuracy) + \"\\n\")\n",
    "    if epoch % 1000 == 0:\n",
    "        f.write(\"output:\" + str(outputs[0:2]) + \"\\n\")\n",
    "        f.write(\"ground true:\" + str(ground_truth[0:2]) + \"\\n\")\n",
    "    f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for xb, yb in train_dl:\n",
    "    print(xb.size(), yb.size())\n",
    "    break\n",
    "    \n",
    "for epoch in tqdm.tqdm(range(10000)):\n",
    "    loss = 0\n",
    "    \n",
    "    for xb, yb in train_dl:\n",
    "        outputs = trans_model.predict_batch(xb)\n",
    "        loss += trans_model.fit(outputs, yb)\n",
    "\n",
    "    summary_test.add_scalars(\"MSE\",{'Train MSE': float(loss / (len(train_indices) // batch_size + 1) )}, epoch)\n",
    "    results = [evaluation(trans_model, xb, yb) for xb, yb in valid_dl]\n",
    "    \n",
    "    mse, mse_p1, mse_p2, accuracy, nums = zip(*results)\n",
    "    total = np.sum(nums)\n",
    "    avg_loss = np.sum(np.multiply(mse, nums)) / total\n",
    "    avg_p1_loss = np.sum(np.multiply(mse_p1, nums)) / total\n",
    "    avg_p2_loss = np.sum(np.multiply(mse_p2, nums)) / total\n",
    "    avg_metric = np.sum(np.multiply(accuracy, nums)) / total\n",
    "\n",
    "    log_inputs(avg_loss, avg_p1_loss, avg_p2_loss, avg_metric, mse_baseline, \n",
    "                outputs, yb, epoch)\n",
    "    \n",
    "    if epoch % 1000 == 0 and epoch != 0:\n",
    "        print(epoch)\n",
    "        trans_model.save()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
