{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from copy import copy, deepcopy\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tug of War 2 lanes No FIFO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "maker_cost = {\n",
    "    'Marine T' : 50,\n",
    "    'Baneling T' : 75,\n",
    "    'Immortal T' : 200,\n",
    "    'Marine B' : 50,\n",
    "    'Baneling B' : 75,\n",
    "    'Immortal B' : 200,\n",
    "}\n",
    "def get_big_A_spend_all(miner, \n",
    "              all_A_vectors = None, vector = None, move = 0):\n",
    "    if all_A_vectors is None:\n",
    "        all_A_vectors = list()\n",
    "    if vector is None:\n",
    "        vector = (0,0,0,0,0,0)\n",
    "    if miner == 0:\n",
    "        all_A_vectors.append(vector)\n",
    "        return list(all_A_vectors)\n",
    "    \n",
    "    next_vector = copy(vector)\n",
    "#     if miner < 50:\n",
    "#         all_A_vectors.append(vector)\n",
    "#         return list(all_A_vectors)\n",
    "        \n",
    "#     next_vector = copy(vector)\n",
    "#     get_big_A_spend_all(miner - miner, all_A_vectors, next_vector)\n",
    "    if miner >= maker_cost['Marine T']:\n",
    "        if move <= 0:\n",
    "            next_vector = (vector[0] + 1, vector[1],vector[2],\n",
    "                            vector[3], vector[4], vector[5])\n",
    "            get_big_A_spend_all(miner - maker_cost['Marine T'], all_A_vectors, next_vector, 0)\n",
    "        if move <= 1:\n",
    "            next_vector = (vector[0], vector[1],vector[2],\n",
    "                            vector[3] + 1, vector[4], vector[5])\n",
    "            get_big_A_spend_all(miner - maker_cost['Marine B'], all_A_vectors, next_vector, 1)\n",
    "            \n",
    "        if miner >= maker_cost['Baneling T']:\n",
    "            if move <= 2:\n",
    "                next_vector = (vector[0], vector[1] + 1,vector[2],\n",
    "                                vector[3], vector[4], vector[5])\n",
    "                get_big_A_spend_all(miner - maker_cost['Baneling T'], all_A_vectors, next_vector, 2)\n",
    "            if move <= 3:\n",
    "                next_vector = (vector[0], vector[1],vector[2],\n",
    "                                vector[3], vector[4] + 1, vector[5])\n",
    "                get_big_A_spend_all(miner - maker_cost['Baneling B'], all_A_vectors, next_vector, 3)\n",
    "                \n",
    "            if miner >= maker_cost['Immortal T']:\n",
    "                if move <= 4:\n",
    "                    next_vector = (vector[0], vector[1],vector[2] + 1,\n",
    "                                    vector[3], vector[4], vector[5])\n",
    "                    get_big_A_spend_all(miner - maker_cost['Immortal T'], all_A_vectors, next_vector, 4)\n",
    "                if move <= 5:\n",
    "                    next_vector = (vector[0], vector[1],vector[2],\n",
    "                                    vector[3], vector[4], vector[5] + 1)\n",
    "                    get_big_A_spend_all(miner - maker_cost['Immortal B'], all_A_vectors, next_vector, 5)\n",
    "\n",
    "    return list(all_A_vectors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# action_space = get_big_A_spend_all(225)\n",
    "# print(len(action_space))\n",
    "# for action in action_space:\n",
    "#     print(action)\n",
    "\n",
    "all_actions = []\n",
    "mineral = 0\n",
    "maker_cost_np = np.zeros(len(maker_cost))\n",
    "for i, mc in enumerate(maker_cost.values()):\n",
    "    maker_cost_np[i] = mc\n",
    "for i in tqdm(range(1500 // 25 + 1)):\n",
    "    all_actions.extend(get_big_A_spend_all(mineral))\n",
    "    mineral += 25\n",
    "    if mineral > 1500:\n",
    "        break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_actions = np.array(all_actions)\n",
    "pylon_action = np.zeros((all_actions.shape[0], 1))\n",
    "print(all_actions.shape)\n",
    "print(pylon_action.shape)\n",
    "all_actions = np.hstack((all_actions, pylon_action))\n",
    "print(all_actions.shape)\n",
    "print(all_actions)\n",
    "# print(all_actions[all_actions[:,5] == 7])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key_mineral = 50\n",
    "action_dict = {0: 1, 25: 1}\n",
    "for i, a in tqdm(enumerate(np.array(all_actions[1:]))):\n",
    "#     if key_mineral < 500:\n",
    "#         print(np.sum(maker_cost_np * a))\n",
    "    if np.sum(maker_cost_np * a[:-1]) != key_mineral:\n",
    "        action_dict[key_mineral] = i + 1\n",
    "        key_mineral += 25\n",
    "action_dict[key_mineral] = i + 2\n",
    "print(action_dict)\n",
    "action_1500_dict = {}\n",
    "all_actions_torch = torch.Tensor(all_actions)\n",
    "action_1500_dict['actions'] = all_actions_torch\n",
    "action_1500_dict['mineral'] = action_dict\n",
    "\n",
    "torch.save(action_1500_dict, 'action_1500_dict_2L.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tug of War 2 lanes No FIFO (One lane one time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "maker_cost = {\n",
    "    'Marine' : 50,\n",
    "    'Baneling' : 75,\n",
    "    'Immortal' : 200,\n",
    "#     'Plyon' : 300\n",
    "}\n",
    "def get_big_A_spend_all_olot(miner, \n",
    "              all_A_vectors = None, vector = None, move = 0, pylons = 0):\n",
    "    if all_A_vectors is None:\n",
    "        all_A_vectors = list()\n",
    "    if vector is None:\n",
    "        vector = (0,0,0,0)\n",
    "    if miner == 0:\n",
    "        all_A_vectors.append(vector) \n",
    "        return list(all_A_vectors)\n",
    "    pylon_cost = 300 + pylons * 100\n",
    "    next_vector = copy(vector)\n",
    "#     if miner < 50:\n",
    "#         all_A_vectors.append(vector)\n",
    "#         return list(all_A_vectors)\n",
    "        \n",
    "#     next_vector = copy(vector)\n",
    "#     get_big_A_spend_all_olot(miner - miner, all_A_vectors, next_vector)\n",
    "    if miner >= maker_cost['Marine']:\n",
    "        if move <= 0:\n",
    "            next_vector = (vector[0] + 1, vector[1],vector[2], vector[3])\n",
    "            get_big_A_spend_all_olot(miner - maker_cost['Marine'], all_A_vectors, next_vector, 0,\n",
    "                                     pylons = pylons)\n",
    "            \n",
    "        if miner >= maker_cost['Baneling']:\n",
    "            if move <= 1:\n",
    "                next_vector = (vector[0], vector[1] + 1,vector[2], vector[3])\n",
    "                get_big_A_spend_all_olot(miner - maker_cost['Baneling'], all_A_vectors, next_vector, 1,\n",
    "                                         pylons = pylons)\n",
    "                \n",
    "            if miner >= maker_cost['Immortal']:\n",
    "                if move <= 2:\n",
    "                    next_vector = (vector[0], vector[1],vector[2] + 1, vector[3])\n",
    "                    get_big_A_spend_all_olot(miner - maker_cost['Immortal'], all_A_vectors, next_vector, 2,\n",
    "                                             pylons = pylons)\n",
    "                if miner >= pylon_cost and pylons < 3:\n",
    "                    if move <= 3:\n",
    "                        next_vector = (vector[0], vector[1], vector[2], vector[3] + 1)\n",
    "                        get_big_A_spend_all_olot(miner - pylon_cost, all_A_vectors, next_vector, 3, \n",
    "                                                 pylons = pylons + 1)\n",
    "\n",
    "    return list(all_A_vectors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/61 [00:00<?, ?it/s]\n",
      "  0%|          | 0/61 [00:00<?, ?it/s]\n",
      "  0%|          | 0/61 [00:00<?, ?it/s]\n",
      "  0%|          | 0/61 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1843\n",
      "1601\n",
      "1399\n",
      "1041\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# action_space = get_big_A_spend_all_olot(1500)\n",
    "# print(len(action_space))\n",
    "# for action in action_space:\n",
    "#     print(action)\n",
    "maker_cost_np = np.zeros(len(maker_cost))\n",
    "for i, mc in enumerate(maker_cost.values()):\n",
    "    maker_cost_np[i] = mc\n",
    "    \n",
    "all_actions = {}\n",
    "for i in range(4):\n",
    "    actions = []\n",
    "    mineral = 0\n",
    "    for _ in tqdm(range(1500 // 25 + 1)):\n",
    "        actions.extend(get_big_A_spend_all_olot(mineral, pylons = i))\n",
    "        mineral += 25\n",
    "        if mineral > 1500:\n",
    "            break\n",
    "    \n",
    "    print(len(actions))\n",
    "    all_actions[i] = np.array(actions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "[[ 2  8  4  0]\n",
      " [ 2  4  4  1]\n",
      " [ 2  4  2  2]\n",
      " [ 2  0  7  0]\n",
      " [ 2  0  1  3]\n",
      " [ 1 14  2  0]\n",
      " [ 1 10  2  1]\n",
      " [ 1 10  0  2]\n",
      " [ 1  6  5  0]\n",
      " [ 1  2  5  1]\n",
      " [ 1  2  3  2]\n",
      " [ 0 20  0  0]\n",
      " [ 0 16  0  1]\n",
      " [ 0 12  3  0]\n",
      " [ 0  8  3  1]\n",
      " [ 0  8  1  2]\n",
      " [ 0  4  6  0]\n",
      " [ 0  4  0  3]\n",
      " [ 0  0  6  1]\n",
      " [ 0  0  4  2]]\n",
      "1\n",
      "[[ 3  2  6  0]\n",
      " [ 3  2  4  1]\n",
      " [ 2 16  1  0]\n",
      " [ 2  8  4  0]\n",
      " [ 2  8  2  1]\n",
      " [ 2  4  1  2]\n",
      " [ 2  0  7  0]\n",
      " [ 2  0  5  1]\n",
      " [ 1 14  2  0]\n",
      " [ 1 14  0  1]\n",
      " [ 1  6  5  0]\n",
      " [ 1  6  3  1]\n",
      " [ 1  2  2  2]\n",
      " [ 0 20  0  0]\n",
      " [ 0 12  3  0]\n",
      " [ 0 12  1  1]\n",
      " [ 0  8  0  2]\n",
      " [ 0  4  6  0]\n",
      " [ 0  4  4  1]\n",
      " [ 0  0  3  2]]\n",
      "2\n",
      "[[ 4  4  5  0]\n",
      " [ 4  0  4  1]\n",
      " [ 3 18  0  0]\n",
      " [ 3 10  3  0]\n",
      " [ 3  6  2  1]\n",
      " [ 3  2  6  0]\n",
      " [ 2 16  1  0]\n",
      " [ 2 12  0  1]\n",
      " [ 2  8  4  0]\n",
      " [ 2  4  3  1]\n",
      " [ 2  0  7  0]\n",
      " [ 1 14  2  0]\n",
      " [ 1 10  1  1]\n",
      " [ 1  6  5  0]\n",
      " [ 1  2  4  1]\n",
      " [ 0 20  0  0]\n",
      " [ 0 12  3  0]\n",
      " [ 0  8  2  1]\n",
      " [ 0  4  6  0]\n",
      " [ 0  0  5  1]]\n",
      "3\n",
      "[[ 7 10  2  0]\n",
      " [ 7  2  5  0]\n",
      " [ 6 16  0  0]\n",
      " [ 6  8  3  0]\n",
      " [ 6  0  6  0]\n",
      " [ 5 14  1  0]\n",
      " [ 5  6  4  0]\n",
      " [ 4 12  2  0]\n",
      " [ 4  4  5  0]\n",
      " [ 3 18  0  0]\n",
      " [ 3 10  3  0]\n",
      " [ 3  2  6  0]\n",
      " [ 2 16  1  0]\n",
      " [ 2  8  4  0]\n",
      " [ 2  0  7  0]\n",
      " [ 1 14  2  0]\n",
      " [ 1  6  5  0]\n",
      " [ 0 20  0  0]\n",
      " [ 0 12  3  0]\n",
      " [ 0  4  6  0]]\n"
     ]
    }
   ],
   "source": [
    "for i in range(4):\n",
    "    print(i)\n",
    "    print(all_actions[i][-20:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1842it [00:00, 186751.46it/s]\n",
      "1600it [00:00, 145219.56it/s]\n",
      "1398it [00:00, 122636.88it/s]\n",
      "1040it [00:00, 116278.62it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{0: {0: 1, 25: 1, 50: 2, 75: 3, 100: 4, 125: 5, 150: 7, 175: 8, 200: 11, 225: 13, 250: 16, 275: 19, 300: 24, 325: 27, 350: 33, 375: 38, 400: 45, 425: 51, 450: 60, 475: 67, 500: 78, 525: 87, 550: 99, 575: 110, 600: 125, 625: 137, 650: 154, 675: 169, 700: 189, 725: 206, 750: 229, 775: 249, 800: 275, 825: 298, 850: 327, 875: 353, 900: 387, 925: 416, 950: 453, 975: 487, 1000: 528, 1025: 565, 1050: 611, 1075: 652, 1100: 703, 1125: 749, 1150: 804, 1175: 855, 1200: 917, 1225: 972, 1250: 1039, 1275: 1101, 1300: 1174, 1325: 1241, 1350: 1321, 1375: 1394, 1400: 1481, 1425: 1561, 1450: 1654, 1475: 1741, 1500: 1843}, 1: {0: 1, 25: 1, 50: 2, 75: 3, 100: 4, 125: 5, 150: 7, 175: 8, 200: 11, 225: 13, 250: 16, 275: 19, 300: 23, 325: 26, 350: 31, 375: 35, 400: 42, 425: 47, 450: 55, 475: 62, 500: 71, 525: 79, 550: 90, 575: 99, 600: 113, 625: 124, 650: 139, 675: 153, 700: 170, 725: 185, 750: 205, 775: 222, 800: 245, 825: 265, 850: 290, 875: 313, 900: 342, 925: 367, 950: 399, 975: 428, 1000: 464, 1025: 496, 1050: 536, 1075: 572, 1100: 616, 1125: 656, 1150: 704, 1175: 748, 1200: 802, 1225: 850, 1250: 908, 1275: 962, 1300: 1025, 1325: 1083, 1350: 1152, 1375: 1215, 1400: 1290, 1425: 1359, 1450: 1439, 1475: 1514, 1500: 1601}, 2: {0: 1, 25: 1, 50: 2, 75: 3, 100: 4, 125: 5, 150: 7, 175: 8, 200: 11, 225: 13, 250: 16, 275: 19, 300: 23, 325: 26, 350: 31, 375: 35, 400: 41, 425: 46, 450: 53, 475: 59, 500: 68, 525: 75, 550: 85, 575: 94, 600: 106, 625: 116, 650: 130, 675: 142, 700: 158, 725: 172, 750: 190, 775: 206, 800: 227, 825: 245, 850: 268, 875: 289, 900: 315, 925: 338, 950: 367, 975: 393, 1000: 425, 1025: 454, 1050: 489, 1075: 521, 1100: 560, 1125: 595, 1150: 637, 1175: 676, 1200: 722, 1225: 764, 1250: 814, 1275: 860, 1300: 914, 1325: 964, 1350: 1022, 1375: 1076, 1400: 1139, 1425: 1197, 1450: 1264, 1475: 1327, 1500: 1399}, 3: {0: 1, 25: 1, 50: 2, 75: 3, 100: 4, 125: 5, 150: 7, 175: 8, 200: 11, 225: 13, 250: 16, 275: 19, 300: 23, 325: 26, 350: 31, 375: 35, 400: 41, 425: 46, 450: 53, 475: 59, 500: 67, 525: 74, 550: 83, 575: 91, 600: 102, 625: 111, 650: 123, 675: 134, 700: 147, 725: 159, 750: 174, 775: 187, 800: 204, 825: 219, 850: 237, 875: 254, 900: 274, 925: 292, 950: 314, 975: 334, 1000: 358, 1025: 380, 1050: 406, 1075: 430, 1100: 458, 1125: 484, 1150: 514, 1175: 542, 1200: 575, 1225: 605, 1250: 640, 1275: 673, 1300: 710, 1325: 745, 1350: 785, 1375: 822, 1400: 865, 1425: 905, 1450: 950, 1475: 993, 1500: 1041}}\n",
      "{0: {0: 1, 25: 1, 50: 2, 75: 3, 100: 4, 125: 5, 150: 7, 175: 8, 200: 11, 225: 13, 250: 16, 275: 19, 300: 24, 325: 27, 350: 33, 375: 38, 400: 45, 425: 51, 450: 60, 475: 67, 500: 78, 525: 87, 550: 99, 575: 110, 600: 125, 625: 137, 650: 154, 675: 169, 700: 189, 725: 206, 750: 229, 775: 249, 800: 275, 825: 298, 850: 327, 875: 353, 900: 387, 925: 416, 950: 453, 975: 487, 1000: 528, 1025: 565, 1050: 611, 1075: 652, 1100: 703, 1125: 749, 1150: 804, 1175: 855, 1200: 917, 1225: 972, 1250: 1039, 1275: 1101, 1300: 1174, 1325: 1241, 1350: 1321, 1375: 1394, 1400: 1481, 1425: 1561, 1450: 1654, 1475: 1741, 1500: 1843}, 1: {0: 1, 25: 1, 50: 2, 75: 3, 100: 4, 125: 5, 150: 7, 175: 8, 200: 11, 225: 13, 250: 16, 275: 19, 300: 23, 325: 26, 350: 31, 375: 35, 400: 42, 425: 47, 450: 55, 475: 62, 500: 71, 525: 79, 550: 90, 575: 99, 600: 113, 625: 124, 650: 139, 675: 153, 700: 170, 725: 185, 750: 205, 775: 222, 800: 245, 825: 265, 850: 290, 875: 313, 900: 342, 925: 367, 950: 399, 975: 428, 1000: 464, 1025: 496, 1050: 536, 1075: 572, 1100: 616, 1125: 656, 1150: 704, 1175: 748, 1200: 802, 1225: 850, 1250: 908, 1275: 962, 1300: 1025, 1325: 1083, 1350: 1152, 1375: 1215, 1400: 1290, 1425: 1359, 1450: 1439, 1475: 1514, 1500: 1601}, 2: {0: 1, 25: 1, 50: 2, 75: 3, 100: 4, 125: 5, 150: 7, 175: 8, 200: 11, 225: 13, 250: 16, 275: 19, 300: 23, 325: 26, 350: 31, 375: 35, 400: 41, 425: 46, 450: 53, 475: 59, 500: 68, 525: 75, 550: 85, 575: 94, 600: 106, 625: 116, 650: 130, 675: 142, 700: 158, 725: 172, 750: 190, 775: 206, 800: 227, 825: 245, 850: 268, 875: 289, 900: 315, 925: 338, 950: 367, 975: 393, 1000: 425, 1025: 454, 1050: 489, 1075: 521, 1100: 560, 1125: 595, 1150: 637, 1175: 676, 1200: 722, 1225: 764, 1250: 814, 1275: 860, 1300: 914, 1325: 964, 1350: 1022, 1375: 1076, 1400: 1139, 1425: 1197, 1450: 1264, 1475: 1327, 1500: 1399}, 3: {0: 1, 25: 1, 50: 2, 75: 3, 100: 4, 125: 5, 150: 7, 175: 8, 200: 11, 225: 13, 250: 16, 275: 19, 300: 23, 325: 26, 350: 31, 375: 35, 400: 41, 425: 46, 450: 53, 475: 59, 500: 67, 525: 74, 550: 83, 575: 91, 600: 102, 625: 111, 650: 123, 675: 134, 700: 147, 725: 159, 750: 174, 775: 187, 800: 204, 825: 219, 850: 237, 875: 254, 900: 274, 925: 292, 950: 314, 975: 334, 1000: 358, 1025: 380, 1050: 406, 1075: 430, 1100: 458, 1125: 484, 1150: 514, 1175: 542, 1200: 575, 1225: 605, 1250: 640, 1275: 673, 1300: 710, 1325: 745, 1350: 785, 1375: 822, 1400: 865, 1425: 905, 1450: 950, 1475: 993, 1500: 1041}}\n",
      "{0: {0: 1, 25: 1, 50: 2, 75: 3, 100: 4, 125: 5, 150: 7, 175: 8, 200: 11, 225: 13, 250: 16, 275: 19, 300: 24, 325: 27, 350: 33, 375: 38, 400: 45, 425: 51, 450: 60, 475: 67, 500: 78, 525: 87, 550: 99, 575: 110, 600: 125, 625: 137, 650: 154, 675: 169, 700: 189, 725: 206, 750: 229, 775: 249, 800: 275, 825: 298, 850: 327, 875: 353, 900: 387, 925: 416, 950: 453, 975: 487, 1000: 528, 1025: 565, 1050: 611, 1075: 652, 1100: 703, 1125: 749, 1150: 804, 1175: 855, 1200: 917, 1225: 972, 1250: 1039, 1275: 1101, 1300: 1174, 1325: 1241, 1350: 1321, 1375: 1394, 1400: 1481, 1425: 1561, 1450: 1654, 1475: 1741, 1500: 1843}, 1: {0: 1, 25: 1, 50: 2, 75: 3, 100: 4, 125: 5, 150: 7, 175: 8, 200: 11, 225: 13, 250: 16, 275: 19, 300: 23, 325: 26, 350: 31, 375: 35, 400: 42, 425: 47, 450: 55, 475: 62, 500: 71, 525: 79, 550: 90, 575: 99, 600: 113, 625: 124, 650: 139, 675: 153, 700: 170, 725: 185, 750: 205, 775: 222, 800: 245, 825: 265, 850: 290, 875: 313, 900: 342, 925: 367, 950: 399, 975: 428, 1000: 464, 1025: 496, 1050: 536, 1075: 572, 1100: 616, 1125: 656, 1150: 704, 1175: 748, 1200: 802, 1225: 850, 1250: 908, 1275: 962, 1300: 1025, 1325: 1083, 1350: 1152, 1375: 1215, 1400: 1290, 1425: 1359, 1450: 1439, 1475: 1514, 1500: 1601}, 2: {0: 1, 25: 1, 50: 2, 75: 3, 100: 4, 125: 5, 150: 7, 175: 8, 200: 11, 225: 13, 250: 16, 275: 19, 300: 23, 325: 26, 350: 31, 375: 35, 400: 41, 425: 46, 450: 53, 475: 59, 500: 68, 525: 75, 550: 85, 575: 94, 600: 106, 625: 116, 650: 130, 675: 142, 700: 158, 725: 172, 750: 190, 775: 206, 800: 227, 825: 245, 850: 268, 875: 289, 900: 315, 925: 338, 950: 367, 975: 393, 1000: 425, 1025: 454, 1050: 489, 1075: 521, 1100: 560, 1125: 595, 1150: 637, 1175: 676, 1200: 722, 1225: 764, 1250: 814, 1275: 860, 1300: 914, 1325: 964, 1350: 1022, 1375: 1076, 1400: 1139, 1425: 1197, 1450: 1264, 1475: 1327, 1500: 1399}, 3: {0: 1, 25: 1, 50: 2, 75: 3, 100: 4, 125: 5, 150: 7, 175: 8, 200: 11, 225: 13, 250: 16, 275: 19, 300: 23, 325: 26, 350: 31, 375: 35, 400: 41, 425: 46, 450: 53, 475: 59, 500: 67, 525: 74, 550: 83, 575: 91, 600: 102, 625: 111, 650: 123, 675: 134, 700: 147, 725: 159, 750: 174, 775: 187, 800: 204, 825: 219, 850: 237, 875: 254, 900: 274, 925: 292, 950: 314, 975: 334, 1000: 358, 1025: 380, 1050: 406, 1075: 430, 1100: 458, 1125: 484, 1150: 514, 1175: 542, 1200: 575, 1225: 605, 1250: 640, 1275: 673, 1300: 710, 1325: 745, 1350: 785, 1375: 822, 1400: 865, 1425: 905, 1450: 950, 1475: 993, 1500: 1041}}\n",
      "{0: {0: 1, 25: 1, 50: 2, 75: 3, 100: 4, 125: 5, 150: 7, 175: 8, 200: 11, 225: 13, 250: 16, 275: 19, 300: 24, 325: 27, 350: 33, 375: 38, 400: 45, 425: 51, 450: 60, 475: 67, 500: 78, 525: 87, 550: 99, 575: 110, 600: 125, 625: 137, 650: 154, 675: 169, 700: 189, 725: 206, 750: 229, 775: 249, 800: 275, 825: 298, 850: 327, 875: 353, 900: 387, 925: 416, 950: 453, 975: 487, 1000: 528, 1025: 565, 1050: 611, 1075: 652, 1100: 703, 1125: 749, 1150: 804, 1175: 855, 1200: 917, 1225: 972, 1250: 1039, 1275: 1101, 1300: 1174, 1325: 1241, 1350: 1321, 1375: 1394, 1400: 1481, 1425: 1561, 1450: 1654, 1475: 1741, 1500: 1843}, 1: {0: 1, 25: 1, 50: 2, 75: 3, 100: 4, 125: 5, 150: 7, 175: 8, 200: 11, 225: 13, 250: 16, 275: 19, 300: 23, 325: 26, 350: 31, 375: 35, 400: 42, 425: 47, 450: 55, 475: 62, 500: 71, 525: 79, 550: 90, 575: 99, 600: 113, 625: 124, 650: 139, 675: 153, 700: 170, 725: 185, 750: 205, 775: 222, 800: 245, 825: 265, 850: 290, 875: 313, 900: 342, 925: 367, 950: 399, 975: 428, 1000: 464, 1025: 496, 1050: 536, 1075: 572, 1100: 616, 1125: 656, 1150: 704, 1175: 748, 1200: 802, 1225: 850, 1250: 908, 1275: 962, 1300: 1025, 1325: 1083, 1350: 1152, 1375: 1215, 1400: 1290, 1425: 1359, 1450: 1439, 1475: 1514, 1500: 1601}, 2: {0: 1, 25: 1, 50: 2, 75: 3, 100: 4, 125: 5, 150: 7, 175: 8, 200: 11, 225: 13, 250: 16, 275: 19, 300: 23, 325: 26, 350: 31, 375: 35, 400: 41, 425: 46, 450: 53, 475: 59, 500: 68, 525: 75, 550: 85, 575: 94, 600: 106, 625: 116, 650: 130, 675: 142, 700: 158, 725: 172, 750: 190, 775: 206, 800: 227, 825: 245, 850: 268, 875: 289, 900: 315, 925: 338, 950: 367, 975: 393, 1000: 425, 1025: 454, 1050: 489, 1075: 521, 1100: 560, 1125: 595, 1150: 637, 1175: 676, 1200: 722, 1225: 764, 1250: 814, 1275: 860, 1300: 914, 1325: 964, 1350: 1022, 1375: 1076, 1400: 1139, 1425: 1197, 1450: 1264, 1475: 1327, 1500: 1399}, 3: {0: 1, 25: 1, 50: 2, 75: 3, 100: 4, 125: 5, 150: 7, 175: 8, 200: 11, 225: 13, 250: 16, 275: 19, 300: 23, 325: 26, 350: 31, 375: 35, 400: 41, 425: 46, 450: 53, 475: 59, 500: 67, 525: 74, 550: 83, 575: 91, 600: 102, 625: 111, 650: 123, 675: 134, 700: 147, 725: 159, 750: 174, 775: 187, 800: 204, 825: 219, 850: 237, 875: 254, 900: 274, 925: 292, 950: 314, 975: 334, 1000: 358, 1025: 380, 1050: 406, 1075: 430, 1100: 458, 1125: 484, 1150: 514, 1175: 542, 1200: 575, 1225: 605, 1250: 640, 1275: 673, 1300: 710, 1325: 745, 1350: 785, 1375: 822, 1400: 865, 1425: 905, 1450: 950, 1475: 993, 1500: 1041}}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "pylon_cost = {\n",
    "    0:{\n",
    "        0: 0,\n",
    "        1: 300,\n",
    "        2: 700,\n",
    "        3: 1200\n",
    "    },\n",
    "    1:{\n",
    "        0: 0,\n",
    "        1: 400,\n",
    "        2: 900\n",
    "    },\n",
    "    2:{\n",
    "        0: 0,\n",
    "        1: 500\n",
    "    },\n",
    "    3: {\n",
    "        0: 0,\n",
    "    }\n",
    "}\n",
    "all_action_dict = {}\n",
    "for j in range(4):\n",
    "    action_pylon = all_actions[j]\n",
    "    all_action_dict[j] = {0: 1, 25: 1}\n",
    "    key_mineral = 50\n",
    "    for i, a in tqdm(enumerate(np.array(action_pylon[1:]))):\n",
    "        pylon_c = pylon_cost[j][a[-1]]\n",
    "#         print(pylon_c)\n",
    "#         print(a)\n",
    "#         if key_mineral < 500:\n",
    "#         print(np.sum(maker_cost_np * a[:-1]) + pylon_c)\n",
    "        if np.sum(maker_cost_np * a[:-1]) + pylon_c != key_mineral:\n",
    "            all_action_dict[j][key_mineral] = i + 1\n",
    "            key_mineral += 25\n",
    "            \n",
    "    all_action_dict[j][key_mineral] = i + 2\n",
    "for j in range(4):\n",
    "    print(all_action_dict)\n",
    "print(all_actions[3][])\n",
    "# action_1500_dict = {}\n",
    "# action_1500_dict['actions'] = all_actions\n",
    "# action_1500_dict['mineral'] = all_action_dict\n",
    "\n",
    "# torch.save(action_1500_dict, 'action_1500_dict_2L.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(action_1500_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = np.array([1,2,3])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate Marked experience."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "137200\n",
      "3400\n",
      "100000\n"
     ]
    }
   ],
   "source": [
    "rand_exp = torch.load(\"rand_v_rand.pt\")\n",
    "agents_exp = torch.load(\"all_experiences.pt\")\n",
    "unmark_exp = rand_exp[:96600] + agents_exp\n",
    "print(len(rand_exp))\n",
    "print(len(agents_exp))\n",
    "print(len(unmark_exp))\n",
    "torch.save(unmark_exp, \"all_experiences_100000.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# unmark_exp = torch.load(\"test_random_vs_random_2l.pt\")\n",
    "def player_1_win_condition(state_1_T_hp, state_1_B_hp, state_2_T_hp, state_2_B_hp):\n",
    "    if min(state_1_T_hp, state_1_B_hp) == min(state_2_T_hp, state_2_B_hp):\n",
    "        if state_1_T_hp + state_1_B_hp == state_2_T_hp + state_2_B_hp:\n",
    "            return 0\n",
    "        elif state_1_T_hp + state_1_B_hp > state_2_T_hp + state_2_B_hp:\n",
    "            return 1\n",
    "        else:\n",
    "            return -1\n",
    "    else:\n",
    "        if min(state_1_T_hp, state_1_B_hp) > min(state_2_T_hp, state_2_B_hp):\n",
    "            return 1\n",
    "        else:\n",
    "            return -1\n",
    "mark_exp = deepcopy(unmark_exp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "insert_num = 0\n",
    "for i, exp in enumerate(unmark_exp):\n",
    "#     print(exp)\n",
    "    if i != 0 and exp[0][-1] == 2:\n",
    "        s1_hp_t, s1_hp_b, s2_hp_t, s2_hp_b = unmark_exp[i - 1][1][27:31]\n",
    "        \n",
    "        win_lose = player_1_win_condition(s1_hp_t, s1_hp_b, s2_hp_t, s2_hp_b)\n",
    "        mark_exp.insert(i + insert_num, win_lose)\n",
    "#         skip = True\n",
    "        insert_num += 1\n",
    "#         print(s1_hp_t, s1_hp_b, s2_hp_t, s2_hp_b)\n",
    "#         print(win_lose)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for i, exp in enumerate(mark_exp):\n",
    "#     print(exp)\n",
    "#     if type(exp) == type(int(1)):\n",
    "#         print(list(mark_exp[i - 1][1][27:31]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(mark_exp, \"all_experiences_100000_mark.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(mark_exp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
