{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3b7ed041-a82b-4e7d-9e4f-a8efcae2a7ee",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# Introduction\n",
    "This notebook contains the 6 agents used in the paper \"Go-Explore with a guide: Speeding up search in sparse reward domains with goal-directed intrinsic rewards\"\n",
    "- Random\n",
    "- TD\n",
    "- Q-Learning\n",
    "- Go-Explore\n",
    "- Go-Explore-Count\n",
    "- Explore-Count\n",
    "\n",
    "It contains the 4 discrete state environments used\n",
    "- Unwalled Maze\n",
    "- Walled Maze\n",
    "- Towers of Hanoi\n",
    "- Nim\n",
    "\n",
    "It also contains the 2 continuous state environments used\n",
    "- Cart Pole\n",
    "- Mountain Car"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "000f59b7-0167-4400-a116-1e9fcb2544e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import copy\n",
    "from collections import defaultdict"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d61ceb7c-eac3-4cfe-a31e-356f0582d624",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# Agent Definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93ca1159-6106-4f25-bf57-1643e64a5781",
   "metadata": {},
   "outputs": [],
   "source": [
    "# this is the memory for the agents which need them\n",
    "memory = defaultdict(lambda: 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38ab6955-10f1-4752-acf9-d6436efedf0c",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Agent 1: Random Agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "545a9b0d-37ce-4375-a410-3936e18da3ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "def RandomAgent(env, **kwargs):\n",
    "    return env.sample()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6c7bbc6-773e-4cb0-b42f-0eb2cda714a6",
   "metadata": {},
   "source": [
    "## Agent 2: TD Agent"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b4fc117-756b-41c0-b5e2-89eead91104a",
   "metadata": {},
   "source": [
    "TD-error: $\\delta_t = r_{t} + \\gamma \\max_{a\\in A, a: s_t \\rightarrow s_{t+1}}V(s_{t+1}) - V(s_t)$\n",
    "\n",
    "Value update: $V(s_t) \\leftarrow V(s_t) + \\alpha\\delta_t$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67d91115-7ead-408d-8699-a09167c70a15",
   "metadata": {},
   "outputs": [],
   "source": [
    "def TDAgent(env, **kwargs):\n",
    "    eps = kwargs.get('eps', 1)\n",
    "    GAMMA = 0.99\n",
    "    ALPHA = 1\n",
    "    \n",
    "    statehistory = kwargs.get('statehistory', [])\n",
    "    repeatedstate = kwargs.get('repeatedstate', False)\n",
    "    \n",
    "    curstate = env.staterep()\n",
    "    if repeatedstate:\n",
    "        curstate += str(statehistory.count(env.staterep()))\n",
    "    \n",
    "    if env.reward == 1:\n",
    "        memory[curstate] = 1\n",
    "        return\n",
    "\n",
    "    validmoves = env.getvalidmoves()\n",
    "    bestmove = None\n",
    "    bestvalue = -1\n",
    "    \n",
    "    # choose best move\n",
    "    for move in validmoves:\n",
    "        newenv = copy.deepcopy(env)\n",
    "        newenv.step(move)\n",
    "        nextstate = newenv.staterep()\n",
    "        if repeatedstate:\n",
    "            nextstate += str(statehistory.count(newenv.staterep())+1)\n",
    "            \n",
    "        curvalue = memory[nextstate] \n",
    "        if curvalue > bestvalue:\n",
    "            bestvalue = curvalue\n",
    "            bestmove = move\n",
    "    \n",
    "    # if epsilon, then choose randomly\n",
    "    if eps > np.random.rand():\n",
    "        bestmove = env.sample()\n",
    "        \n",
    "    # update current state value with optimal one-step lookahead estimate\n",
    "    td_error = env.reward + GAMMA*bestvalue - memory[curstate]\n",
    "    memory[curstate] = memory[curstate] + ALPHA*td_error\n",
    "    \n",
    "    return bestmove"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54cf064e-898f-445d-8ed9-4e83cc98ffb7",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Agent 3: Q-Learning Agent"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d760d1c-eec6-4721-bc42-294588e6d1fc",
   "metadata": {},
   "source": [
    "TD-error: $\\delta_t = r_t + \\gamma\\max_{a\\in A}Q(s_{t+1},a) - Q(s_t, a_t)$\n",
    "\n",
    "Q-learning update: $Q(s_t, a_t) \\leftarrow Q(s_t, a_t) + \\alpha\\delta_t$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2357a12a-ad42-43b3-8909-e17f4159be69",
   "metadata": {},
   "outputs": [],
   "source": [
    "def QAgent(env, **kwargs):\n",
    "    eps = kwargs.get('eps', 1)\n",
    "    statehistory = kwargs.get('statehistory', [])\n",
    "    repeatedstate = kwargs.get('repeatedstate', False)\n",
    "    \n",
    "    GAMMA = 0.99\n",
    "    ALPHA = 1\n",
    "    \n",
    "    curstate = env.staterep()\n",
    "    if repeatedstate:\n",
    "        curstate += str(statehistory.count(env.staterep()))\n",
    "    \n",
    "    if env.reward == 1:\n",
    "        for move in env.getvalidmoves():\n",
    "            memory[(curstate, move)] = 1\n",
    "        return\n",
    "\n",
    "    validmoves = env.getvalidmoves()\n",
    "    bestmove = None\n",
    "    bestvalue = -1\n",
    "    \n",
    "    # if epsilon, then choose randomly\n",
    "    if eps > np.random.rand():\n",
    "        bestmove = env.sample()\n",
    "    # else choose best move\n",
    "    else:\n",
    "        for move in validmoves:\n",
    "            curvalue = memory[(curstate, move)] \n",
    "            if curvalue > bestvalue:\n",
    "                bestvalue = curvalue\n",
    "                bestmove = move\n",
    "        \n",
    "    # do a one-step in the next direction\n",
    "    newenv = copy.deepcopy(env)\n",
    "    newenv.step(bestmove)\n",
    "    nextstate = newenv.staterep()\n",
    "    if repeatedstate:\n",
    "        nextstate += str(statehistory.count(newenv.staterep())+1)\n",
    "    td_error = env.reward + GAMMA*np.max([memory[(nextstate, move)] for move in newenv.getvalidmoves()]) - memory[(curstate, bestmove)]\n",
    "        \n",
    "    # update the Q-function\n",
    "    memory[(curstate, bestmove)] = memory[(curstate, bestmove)] + ALPHA*td_error\n",
    "    \n",
    "    return bestmove"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "934f2d62-11b0-4150-8a76-14039e20c963",
   "metadata": {},
   "source": [
    "## Agent 4: Go-Explore"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59265644-3b53-4894-8651-7e34b9b5d1cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def reward_formula(reward = 0, intrinsic = 0, moves = 0, numselected = 0, numvisits = 0, eps = 1e-20):\n",
    "    return reward*1000 + intrinsic*10 + np.sqrt(moves) - 100*np.sqrt(numselected+numvisits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cbb6722-3ca7-43be-870a-8b201122320b",
   "metadata": {},
   "outputs": [],
   "source": [
    "''' Chooses the next best memory state to go to '''\n",
    "def ChooseState(env):\n",
    "    bestvalue = -1e20\n",
    "    bestkey = None\n",
    "    \n",
    "    # choose the best memory based on heuristics\n",
    "    for key in memory:\n",
    "        # do not choose final state as there is nothing left to explore\n",
    "        if memory[key]['reward'] == 1: \n",
    "            continue\n",
    "        \n",
    "        curmem = memory[key]\n",
    "        reward = curmem['reward']\n",
    "        intrinsic = curmem['intrinsic']\n",
    "        moves = curmem['moves']\n",
    "        numselected = curmem.get('numselected', 0)\n",
    "        numvisits = curmem['numvisits']\n",
    "        \n",
    "        value = reward_formula(reward = reward, intrinsic = intrinsic, moves = moves, numselected = numselected, numvisits = numvisits)\n",
    "        if value > bestvalue:\n",
    "            bestvalue = value\n",
    "            bestkey = key\n",
    "    \n",
    "    # generate the trajectory to get the environment state\n",
    "    actionhistory = []\n",
    "    statehistory = []\n",
    "    \n",
    "    for move in memory[bestkey]['actionhistory']:\n",
    "        statehistory.append(env.staterep())\n",
    "        env.step(move)\n",
    "        actionhistory.append(move)\n",
    "        \n",
    "    # increment the selection visit count\n",
    "    if 'numselected' in memory[bestkey]:\n",
    "        memory[bestkey]['numselected'] = memory[bestkey]['numselected'] + 1\n",
    "    \n",
    "    return (actionhistory, statehistory, copy.deepcopy(env))\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "897adc24-76f0-4c3f-8e4b-cd8b6cffd26e",
   "metadata": {},
   "outputs": [],
   "source": [
    "''' Chooses the best move based on memory and intrinsic rewards '''\n",
    "def GoExplore(env, **kwargs):\n",
    "    \n",
    "    intrinsic_fn = kwargs.get('intrinsic_fn', None)\n",
    "    replay = kwargs.get('replay', False)\n",
    "    getbestmove = kwargs.get('getbestmove', False)\n",
    "    statehistory = kwargs.get('statehistory', [])\n",
    "    actionhistory = kwargs.get('actionhistory', [])\n",
    "    repeatedstate = kwargs.get('repeatedstate', False)\n",
    "    \n",
    "    # if no intrinsic guiding value, then do without intrinsic motivation\n",
    "    if intrinsic_fn is not None:\n",
    "        intrinsic_value = intrinsic_fn(env)\n",
    "    else:\n",
    "        intrinsic_value = 0\n",
    "        \n",
    "    curmoves = env.numsteps\n",
    "    curreward = env.reward\n",
    "\n",
    "    curstate = env.staterep()\n",
    "    if repeatedstate:\n",
    "        curstate += str(statehistory.count(env.staterep()))\n",
    "        \n",
    "    # if this state is not present in memory (should only happen for start state), add it in\n",
    "    if curstate not in memory:\n",
    "        memory[curstate] = {'statehistory': statehistory+[], 'reward': curreward, 'intrinsic': intrinsic_value, 'moves': curmoves, 'numvisits': 0, 'numselected': 0, 'actionhistory': actionhistory+[]}\n",
    "        \n",
    "    curmemory = memory[curstate]\n",
    "    \n",
    "    # only increment memory if not doing replay\n",
    "    if replay:\n",
    "        curmemory['numvisits'] = 0\n",
    "        curmemory['numselected'] = 0\n",
    "    else:\n",
    "        curmemory['numvisits'] = curmemory['numvisits'] + 1\n",
    "    \n",
    "    # if completed, no need to continue to next move selection\n",
    "    if env.done:\n",
    "        # if there is positive reward, then make intrinsic become the extrinsic reward\n",
    "        if env.reward > 0:\n",
    "            curmemory['intrinsic'] = env.reward\n",
    "        return\n",
    "\n",
    "    # if not completed, continue to select next move\n",
    "    validmoves = env.getvalidmoves()\n",
    "    \n",
    "    # if no valid moves, no need to continue to next move selection\n",
    "    if validmoves == []:\n",
    "        return\n",
    "    \n",
    "    bestmove = None\n",
    "    bestvalue = -1e20\n",
    "    bestintrinsic = -1e20\n",
    "    \n",
    "    # choose best move\n",
    "    for move in validmoves:\n",
    "        newenv = copy.deepcopy(env)\n",
    "        newenv.step(move)\n",
    "        \n",
    "        nextmoves = newenv.numsteps\n",
    "        nextreward = newenv.reward\n",
    "        nextmemory = None\n",
    "        \n",
    "        nextstate = newenv.staterep()\n",
    "        if repeatedstate:\n",
    "            nextstate += str(statehistory.count(newenv.staterep())+1)\n",
    "        \n",
    "        if nextstate in memory:\n",
    "            nextmemory = memory[nextstate] \n",
    "            # update the nextmemory if agent has a better reward\n",
    "            if nextreward > nextmemory['reward']:\n",
    "                nextmemory['reward'] = nextreward\n",
    "                nextmemory['moves'] = curmoves + 1\n",
    "                nextmemory['numvisits'] = 0\n",
    "                nextmemory['numselected'] = 0\n",
    "                nextmemory['actionhistory'] = actionhistory+[move]\n",
    "                nextmemory['statehistory'] = statehistory+[newenv.staterep()]\n",
    "                \n",
    "            # update the nextmemory if agent has similar reward but fewer number of moves\n",
    "            elif nextreward == nextmemory['reward'] and nextmemory['moves'] > curmoves + 1:\n",
    "                nextmemory['moves'] = curmoves + 1\n",
    "                nextmemory['numvisits'] = 0\n",
    "                nextmemory['numselected'] = 0\n",
    "                nextmemory['actionhistory'] = actionhistory+[move]\n",
    "                nextmemory['statehistory'] = statehistory+[newenv.staterep()]\n",
    "                \n",
    "        # start a new memory if this is a new state\n",
    "        else:\n",
    "            # if no intrinsic guiding value, then do without intrinsic motivation\n",
    "            if intrinsic_fn is not None:\n",
    "                next_intrinsic_value = intrinsic_fn(newenv)\n",
    "            else:\n",
    "                next_intrinsic_value = 0\n",
    "            memory[nextstate] = {'statehistory': statehistory+[newenv.staterep()], 'reward': nextreward, 'intrinsic': next_intrinsic_value, 'moves': curmoves + 1, 'numvisits': 0, 'numselected': 0, 'actionhistory': actionhistory+[move]}\n",
    "            nextmemory = memory[nextstate]\n",
    "                \n",
    "        # best intrinsic is the highest intrinsic value of all 1-step connections\n",
    "        bestintrinsic = max(bestintrinsic, nextmemory['intrinsic'])\n",
    "        \n",
    "        # determine the next square to visit via a heuristic\n",
    "        reward = nextmemory['reward'] \n",
    "        moves = nextmemory['moves'] \n",
    "        numvisits = nextmemory['numvisits']\n",
    "        intrinsic = nextmemory['intrinsic']\n",
    "        \n",
    "        totalvalue = reward_formula(reward = reward, intrinsic = intrinsic, moves = moves, numselected = 0, numvisits = numvisits)\n",
    "      \n",
    "        if totalvalue > bestvalue or bestvalue is None:\n",
    "            bestvalue = totalvalue\n",
    "            bestmove = move\n",
    "            \n",
    "    # update the one-step lookahead for intrinsic value\n",
    "    curmemory['intrinsic'] = bestintrinsic*0.99\n",
    "    \n",
    "    if getbestmove:\n",
    "        return bestmove\n",
    "    else:\n",
    "        return np.random.choice(validmoves)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9771605f-ad3f-413b-8680-196b06bb2fae",
   "metadata": {},
   "outputs": [],
   "source": [
    "''' Chooses Best Move '''\n",
    "def GoExploreCount(env, **kwargs):\n",
    "    return GoExplore(env, getbestmove = True, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d04090b3-345e-471b-ad65-264d1f070865",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Agent 5: Count Agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88b2d20c-e678-40ad-8ede-930afe969646",
   "metadata": {},
   "outputs": [],
   "source": [
    "''' Chooses the best move based on memory and intrinsic rewards '''\n",
    "def CountAgent(env, **kwargs):\n",
    "    \n",
    "    intrinsic_fn = kwargs.get('intrinsic_fn', None)\n",
    "    replay = kwargs.get('replay', False)\n",
    "    statehistory = kwargs.get('statehistory', [])\n",
    "    actionhistory = kwargs.get('actionhistory', [])\n",
    "    repeatedstate = kwargs.get('repeatedstate', False)\n",
    "    \n",
    "    # if no intrinsic guiding value, then do without intrinsic motivation\n",
    "    if intrinsic_fn is not None:\n",
    "        intrinsic_value = intrinsic_fn(env)\n",
    "    else:\n",
    "        intrinsic_value = 0\n",
    "        \n",
    "    curmoves = env.numsteps\n",
    "    curreward = env.reward\n",
    "\n",
    "    curstate = env.staterep()\n",
    "    if repeatedstate:\n",
    "        curstate += str(statehistory.count(env.staterep()))\n",
    "\n",
    "    # if this state is not present in memory (should only happen for start state), add it in\n",
    "    if curstate not in memory:\n",
    "        memory[curstate] = {'statehistory': statehistory+[], 'reward': curreward, 'intrinsic': intrinsic_value, 'moves': curmoves, 'numvisits': 0, 'actionhistory': actionhistory+[]}\n",
    "        \n",
    "    curmemory = memory[curstate]\n",
    "    \n",
    "    # only increment memory if not doing replay\n",
    "    if replay:\n",
    "        curmemory['numvisits'] = 0\n",
    "    else:\n",
    "        curmemory['numvisits'] = curmemory['numvisits'] + 1\n",
    "\n",
    "    # if completed, no need to continue to next move selection\n",
    "    if env.done:\n",
    "        if env.reward > 0:\n",
    "            curmemory['intrinsic'] = env.reward\n",
    "        return\n",
    "\n",
    "    # if not completed, continue to select next move\n",
    "    validmoves = env.getvalidmoves()\n",
    "    \n",
    "    # if no valid moves, no need to continue to next move selection\n",
    "    if validmoves == []:\n",
    "        return\n",
    "    \n",
    "    bestmove = None\n",
    "    bestvalue = -1e20\n",
    "    bestintrinsic = -1e20\n",
    "    \n",
    "    # choose best move\n",
    "    for move in validmoves:\n",
    "        newenv = copy.deepcopy(env)\n",
    "        newenv.step(move)\n",
    "        \n",
    "        nextmoves = newenv.numsteps\n",
    "        nextreward = newenv.reward\n",
    "        nextmemory = None\n",
    "        \n",
    "        nextstate = newenv.staterep()\n",
    "        if repeatedstate:\n",
    "            nextstate += str(statehistory.count(newenv.staterep())+1)\n",
    "        \n",
    "        if nextstate in memory:\n",
    "            nextmemory = memory[nextstate] \n",
    "            # update the nextmemory if agent has a better reward\n",
    "            if nextreward > nextmemory['reward']:\n",
    "                nextmemory['reward'] = nextreward\n",
    "                nextmemory['moves'] = curmoves + 1\n",
    "                nextmemory['numvisits'] = 0\n",
    "                nextmemory['actionhistory'] = actionhistory+[move]\n",
    "                nextmemory['statehistory'] = statehistory+[newenv.staterep()]\n",
    "                \n",
    "            # update the nextmemory if agent has similar reward but fewer number of moves\n",
    "            elif nextreward == nextmemory['reward'] and nextmemory['moves'] > curmoves + 1:\n",
    "                nextmemory['moves'] = curmoves + 1\n",
    "                nextmemory['numvisits'] = 0\n",
    "                nextmemory['actionhistory'] = actionhistory+[move]\n",
    "                nextmemory['statehistory'] = statehistory+[newenv.staterep()]\n",
    "                \n",
    "        # start a new memory if this is a new state\n",
    "        else:\n",
    "            # if no intrinsic guiding value, then do without intrinsic motivation\n",
    "            if intrinsic_fn is not None:\n",
    "                next_intrinsic_value = intrinsic_fn(newenv)\n",
    "            else:\n",
    "                next_intrinsic_value = 0\n",
    "            memory[nextstate] = {'statehistory': statehistory+[newenv.staterep()], 'reward': nextreward, 'intrinsic': next_intrinsic_value, 'moves': curmoves + 1, 'numvisits': 0, 'actionhistory': actionhistory+[move]}\n",
    "            nextmemory = memory[nextstate]\n",
    "            \n",
    "        # best intrinsic is the highest intrinsic value of all 1-step connections\n",
    "        bestintrinsic = max(bestintrinsic, nextmemory['intrinsic'])\n",
    "        \n",
    "        # determine the next square to visit via a heuristic\n",
    "        reward = nextmemory['reward'] \n",
    "        moves = nextmemory['moves'] \n",
    "        numvisits = nextmemory['numvisits']\n",
    "        intrinsic = nextmemory['intrinsic']\n",
    "        \n",
    "        totalvalue = reward_formula(reward = reward, intrinsic = intrinsic, moves = moves, numselected = 0, numvisits = numvisits)\n",
    "        \n",
    "        if totalvalue > bestvalue:\n",
    "            bestvalue = totalvalue\n",
    "            bestmove = move\n",
    "            \n",
    "        # update the one-step lookahead for intrinsic value\n",
    "        curmemory['intrinsic'] = bestintrinsic * 0.99\n",
    "    \n",
    "    return bestmove"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00e72846-7e64-4095-87f1-9dc62c60e161",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# Helper Functions\n",
    "These functions help to perform hippocampal replay, and evaluation of agent on the environment.\n",
    "- MemoryReplay: Implements hippocampal replay\n",
    "- Game: Plays an environment for a single run (episode)\n",
    "- MultipleGame: Plays an environment for 100 runs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cef4c0eb-0b3f-4269-939e-4bf8098ae783",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Memory Replay"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "700ccfe8-a2e3-4269-a007-d5375b96359a",
   "metadata": {},
   "outputs": [],
   "source": [
    "''' Performs memory replay '''\n",
    "def MemoryReplay(env, bestactionhistory = [], agent = RandomAgent, maxsteps = 500, seed = None, **kwargs):\n",
    "    statehistory = []\n",
    "    actionhistory = []\n",
    "    statehistory.append(env.staterep())\n",
    "    historytuplelist = []\n",
    "    \n",
    "    # do forward replay\n",
    "    for move in bestactionhistory:\n",
    "        kwargs['statehistory'] = statehistory        \n",
    "        kwargs['actionhistory'] = actionhistory\n",
    "        kwargs['replay']=True\n",
    "        historytuplelist.append((copy.deepcopy(env), statehistory, actionhistory))\n",
    "        env.step(move)     \n",
    "        statehistory.append(env.staterep())\n",
    "        agent(copy.deepcopy(env), **kwargs)\n",
    "        actionhistory.append(move)\n",
    "    \n",
    "    # # do backward replay\n",
    "    backwardstates = []\n",
    "    for env, statehistory, actionhistory in historytuplelist[::-1]:\n",
    "        kwargs['statehistory'] = statehistory        \n",
    "        kwargs['actionhistory'] = actionhistory\n",
    "        kwargs['replay'] = True\n",
    "        agent(copy.deepcopy(env), **kwargs)\n",
    "        backwardstates.append(env.staterep())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d343a09-3981-4d88-8b96-f682b90fe176",
   "metadata": {
    "tags": []
   },
   "source": [
    "## A Single Game"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8be19f4d-94f4-4480-9b0f-53b716732409",
   "metadata": {},
   "outputs": [],
   "source": [
    "''' Plays 1 game '''\n",
    "def Game(env, agent = RandomAgent, actionhistory = [], statehistory = [], maxsteps = 500, seed = None, verbose = True, **kwargs):\n",
    "    \n",
    "    if seed is not None:\n",
    "        np.random.seed(seed)\n",
    "    else:\n",
    "        np.random.seed(0)\n",
    "    while not env.done and env.numsteps < maxsteps:\n",
    "        statehistory.append(env.staterep())\n",
    "        kwargs['statehistory'] = statehistory        \n",
    "        kwargs['actionhistory'] = actionhistory\n",
    "        move = agent(env, **kwargs)\n",
    "        env.step(move)\n",
    "        actionhistory.append(move)\n",
    "            \n",
    "    statehistory.append(env.staterep())\n",
    "    kwargs['statehistory'] = statehistory\n",
    "    kwargs['actionhistory'] = actionhistory\n",
    "    # to update final state for RL agents\n",
    "    agent(env, **kwargs)\n",
    "    \n",
    "    if verbose:\n",
    "        print(env.done, env.reward, env.numsteps)\n",
    "        # env.print()\n",
    "    return env.done, env.reward, env.numsteps"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c852e681-f332-48f4-ba3b-f2df2e8f5942",
   "metadata": {},
   "source": [
    "## Multiple Games"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f7ffe41-7222-4308-8070-1c4f5a0e6b5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "''' Plays multiple games '''\n",
    "def MultiGame(env, numtries = 100, hippocampal_replay = True, **kwargs):\n",
    "    solvedcount = 0\n",
    "    stephistory = []\n",
    "    bestmemory = 0\n",
    "    beststeps = 1000000\n",
    "    firstsolve = None\n",
    "    tries = numtries\n",
    "    solved = False\n",
    "    \n",
    "    for i in range(tries):\n",
    "        # choose a new state for GoExplore or GoExploreCount\n",
    "        if kwargs['agent'] in [GoExplore, GoExploreCount] and i > 0:\n",
    "            actionhistory, statehistory, nextenv = ChooseState(env = copy.deepcopy(env))\n",
    "            done, reward, steps = Game(seed = i, env = copy.deepcopy(nextenv), actionhistory = actionhistory+[], statehistory = statehistory+[], **kwargs)\n",
    "        else:\n",
    "            done, reward, steps = Game(seed = i, env = copy.deepcopy(env), actionhistory = [], statehistory = [], **kwargs)\n",
    "        if reward == 1:\n",
    "            solvedcount += 1\n",
    "            stephistory.append(steps)\n",
    "            \n",
    "            # if first solve, note how much memory is used\n",
    "            if solvedcount == 1:\n",
    "                bestmemory = len(memory)\n",
    "                firstsolve = i+1\n",
    "                \n",
    "            if hippocampal_replay:\n",
    "                # hippocampal replay only for goexplore or intrinsic agent\n",
    "                if kwargs['agent'] in [GoExplore, GoExploreCount, CountAgent]:\n",
    "                    actionhistory = None\n",
    "                    for key, value in memory.items():\n",
    "                        if memory[key]['reward'] == 1:\n",
    "                            actionhistory = memory[key]['actionhistory']\n",
    "\n",
    "                    # MemoryReplay to improve chance of optimal path being followed\n",
    "                    if actionhistory is not None:\n",
    "                        MemoryReplay(env = copy.deepcopy(env), bestactionhistory = actionhistory, **kwargs)\n",
    "    name = kwargs['agent'].__name__\n",
    "    if name == 'RandomAgent': \n",
    "        name = 'Random'\n",
    "        bestmemory = '-'\n",
    "    if name == 'QAgent': name = 'Q-Learning'\n",
    "    if name == 'TDAgent': name = 'TD-Learning'\n",
    "    if name == 'GoExplore': name = 'Go-Explore'\n",
    "    if name == 'GoExploreCount': name = 'Go-Explore-Count'\n",
    "    if name == 'CountAgent': name = 'Explore-Count'\n",
    "\n",
    "    if kwargs['agent'] == QAgent or kwargs['agent'] == TDAgent:\n",
    "        if kwargs.get('eps', 1) == 0:\n",
    "            name += ' (Test)'\n",
    "        else:\n",
    "            name += ' (Train)'\n",
    "            \n",
    "    if kwargs.get('intrinsic_fn', None) is not None:\n",
    "        name += ' GDIR'\n",
    "    # if solvedcount == 0:\n",
    "    #     print(f\"{name} & {solvedcount}/{tries} & - & - & - & - & - \\\\\\\\\")\n",
    "    # else:\n",
    "    #     print(f\"{name} & {solvedcount}/{tries} & {firstsolve} & {bestmemory} & {sum(stephistory)/len(stephistory):.1f} & {min(stephistory):.1f} & {max(stephistory):.1f} \\\\\\\\\")\n",
    "    if solvedcount == 0:\n",
    "        print(f'Agent: {name}, No solves at all, Best Memory: {bestmemory}, Total Memory: {len(memory)}')\n",
    "    else:\n",
    "        print(f'Agent: {name}, Solve rate: {solvedcount}/{tries} ({solvedcount/tries*100:.1f}%), First Solve: {firstsolve}, Best Memory: {bestmemory}, Steps: Avg {sum(stephistory)/len(stephistory):.1f}, Min {min(stephistory):.1f}, Max {max(stephistory):.1f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "daca7925-e680-4bab-94d0-7392c8c79abe",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# Discrete Environments 1 & 2 - Maze Environment (Unwalled, Walled)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bfc749b-97af-48a2-a6ff-c4bdd1bd2eb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MazeEnv:\n",
    "    def __init__(self, height=20, width=20, numbricks = 10, grid = None, doorpos = None, agentpos = None, randomseed = None):\n",
    "        self.height = height\n",
    "        self.width = width\n",
    "        self.doorpos = doorpos\n",
    "        self.agentpos = agentpos\n",
    "        self.numbricks = numbricks\n",
    "        self.randomseed = randomseed\n",
    "        self.numsteps = 0\n",
    "        self.done = False\n",
    "        self.reward = 0\n",
    "        if self.randomseed is not None:\n",
    "            np.random.seed(self.randomseed)\n",
    "        self.mapping = {0: '.', 1: 'X', 2: 'D', 3: '#'}\n",
    "        \n",
    "        # if grid not defined, do a random initialization of maze\n",
    "        if grid is None:\n",
    "            self.grid = np.zeros((self.height, self.width))\n",
    "            \n",
    "            # Step 1: get a door position that is valid\n",
    "            if doorpos is None:\n",
    "                self.doorpos = self.getvalidpos()\n",
    "            else:\n",
    "                self.doorpos = doorpos\n",
    "            self.grid[self.doorpos] = 2\n",
    "\n",
    "            # Step 2: get a start position that is valid\n",
    "            if agentpos is None:\n",
    "                self.agentpos = self.getvalidpos()\n",
    "            else:\n",
    "                self.agentpos = agentpos\n",
    "            self.grid[self.agentpos] = 1\n",
    "            \n",
    "            # Step 3: fill in the bricks\n",
    "            for i in range(self.numbricks):\n",
    "                self.grid[self.getvalidpos()] = 3\n",
    "                \n",
    "        # if grid predefined, get the parameters from there instead\n",
    "        else:\n",
    "            self.grid = grid\n",
    "            self.height, self.width = self.grid.shape\n",
    "            \n",
    "            lista, listb = np.where(self.grid == 2)\n",
    "            if len(lista) == 0 or len(listb) == 0:\n",
    "                self.doorpos = self.getvalidpos()\n",
    "            else:\n",
    "                self.doorpos = (lista[0], listb[0])\n",
    "            self.grid[self.doorpos] = 2\n",
    "            \n",
    "            lista, listb = np.where(self.grid == 1)\n",
    "            if len(lista) == 0 or len(listb) == 0:\n",
    "                self.agentpos = self.getvalidpos()\n",
    "            else:\n",
    "                self.agentpos = (lista[0], listb[0])\n",
    "            self.grid[self.agentpos] = 1\n",
    "            \n",
    "        # some variables to reset the environment\n",
    "        self.startgrid = self.grid.copy()\n",
    "        self.startagentpos = self.agentpos\n",
    "        self.startdoorpos = self.doorpos\n",
    "            \n",
    "    def reset(self):\n",
    "        self.grid = self.startgrid.copy()\n",
    "        self.agentpos = self.startagentpos\n",
    "        self.doorpos = self.startdoorpos\n",
    "        self.done = False\n",
    "        self.reward = 0\n",
    "        self.numsteps = 0\n",
    "        if self.randomseed is not None:\n",
    "            np.random.seed(self.randomseed)\n",
    "            \n",
    "    # gets state representation\n",
    "    def staterep(self):\n",
    "        return str(self.agentpos)\n",
    "            \n",
    "    # gets a valid position\n",
    "    def getvalidpos(self):\n",
    "        validpos = []\n",
    "        for i in range(self.height):\n",
    "            for j in range(self.width):\n",
    "                if self.grid[i,j] == 0:\n",
    "                    validpos.append((i,j))\n",
    "        return validpos[np.random.randint(len(validpos))]\n",
    "        \n",
    "    # checks if a position is valid that is not out of the grid and not occupied\n",
    "    def isvalid(self, pos, allowdoor = False):\n",
    "        if pos == None or len(pos)!=2:\n",
    "            return False\n",
    "        height, width = pos\n",
    "        if height < 0 or height >= self.height or width < 0 or width >= self.width:\n",
    "            return False\n",
    "        if allowdoor and self.grid[height,width] == 2:\n",
    "            return True\n",
    "        if self.grid[height,width] == 0:\n",
    "            return True\n",
    "        return False\n",
    "    \n",
    "    def step(self, move):\n",
    "        validmoves = self.getvalidmoves()\n",
    "        # randomly sample a move if not in validmoves\n",
    "        if move not in validmoves:\n",
    "            move = validmoves[np.random.randint(len(validmoves))]\n",
    "        self.numsteps += 1\n",
    "        self.grid[self.agentpos] = 0\n",
    "        self.agentpos = self.movedir(self.agentpos, move)\n",
    "        if self.agentpos == self.doorpos:\n",
    "            self.done = True\n",
    "            self.reward = 1\n",
    "        self.grid[self.agentpos] = 1\n",
    "    \n",
    "    def movedir(self, pos, d):\n",
    "        if pos == None or len(pos)!=2:\n",
    "            return False\n",
    "        height, width = pos\n",
    "        if d=='left':\n",
    "            return (height, width-1)\n",
    "        elif d=='right':\n",
    "            return (height, width+1)\n",
    "        elif d=='up':\n",
    "            return (height-1, width)\n",
    "        elif d=='down':\n",
    "            return (height+1, width)\n",
    "    \n",
    "    def getvalidmoves(self):\n",
    "        validmoves = []\n",
    "        for move in ['left', 'right', 'up', 'down']:\n",
    "            if self.isvalid(self.movedir(self.agentpos, move), allowdoor = True):\n",
    "                validmoves.append(move)\n",
    "        return validmoves\n",
    "    \n",
    "    def sample(self):\n",
    "        return np.random.choice(self.getvalidmoves())\n",
    "    \n",
    "    def print(self):\n",
    "        for i in range(self.height):\n",
    "            for j in range(self.width):\n",
    "                print(self.mapping[self.grid[i,j]], end = '')\n",
    "            print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4907ded4-bf6b-4a23-8b2e-95dddf9f4e7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Manhattan(env):\n",
    "    ''' Calculates the Manhattan distance between the agent and the door '''\n",
    "    pointA = env.agentpos\n",
    "    pointB = env.doorpos\n",
    "    return -(abs(pointA[0]-pointB[0])+abs(pointA[1]-pointB[1]))/(env.width+env.height-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94c2cd90-ca17-4b85-8b1d-4993e73bfd83",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Unwalled maze (10x10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55b99f46-6c75-401a-a4c4-4a05f15c7abe",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is how the maze looks like\n",
    "height, width = 10, 10\n",
    "env = MazeEnv(height = height, width = width, agentpos = (0, 0), doorpos = (height-1, width-1), randomseed = 1, numbricks = height*width//10)\n",
    "env.print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c79753b3-35ee-4131-9a66-74a570eb1bc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False, eps = 0)\n",
    "\n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False, intrinsic_fn = Manhattan)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3906a13-114b-4dfa-9330-298a1ecf5430",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Unwalled maze (20x20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fa0f838-6196-4dd3-a830-26ff49bbcc2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is how the maze looks like\n",
    "height, width = 20, 20\n",
    "env = MazeEnv(height = height, width = width, agentpos = (0, 0), doorpos = (height-1, width-1), randomseed = 1, numbricks = height*width//10)\n",
    "env.print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7129fb89-8201-4c0e-8e31-9ba26e7f3135",
   "metadata": {},
   "outputs": [],
   "source": [
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False, eps = 0)\n",
    "    \n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False, intrinsic_fn = Manhattan)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "422c28c5-f847-473d-98e2-f8d32902e78f",
   "metadata": {},
   "source": [
    "## Unwalled maze (100x100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c7d3007-7931-4ac0-9de6-9064d9c265ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is how the maze looks like\n",
    "height, width = 100, 100\n",
    "env = MazeEnv(height = height, width = width, agentpos = (0, 0), doorpos = (height-1, width-1), randomseed = 1, numbricks = height*width//10)\n",
    "env.print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67f416d9-bec1-444d-863a-b893ef4b96ae",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False, eps = 0)\n",
    "    \n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False, intrinsic_fn = Manhattan)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7c81719-3778-495c-8f5c-31adc521b217",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Extra: Without hippocampal replay"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1046f337-03da-4abd-b157-97862221a2f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, hippocampal_replay = False, verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, hippocampal_replay = False, verbose = False, eps = 0)\n",
    "    \n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, hippocampal_replay = False, verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, hippocampal_replay = False, verbose = False, intrinsic_fn = Manhattan)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3b91b2c-2960-4b18-8fcd-808902e97a50",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Walled maze (10x10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "644c7e52-d8f5-4309-a450-965af5e2f971",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create game environment\n",
    "size = 10\n",
    "grid = np.zeros((size,size))\n",
    "maxheight, maxwidth = grid.shape\n",
    "grid[:, maxwidth//2-1] = 3\n",
    "grid[maxheight//2,:] = 3\n",
    "grid[1:maxheight, maxwidth-2] = 3\n",
    "grid[maxheight//2, maxwidth//4] = 0\n",
    "grid[maxheight//2, 3*maxwidth//4-1] = 0\n",
    "grid[3*maxheight//4, maxwidth//2-1] = 0\n",
    "grid[maxheight//2, maxwidth-1] = 0\n",
    "grid[0, 0] = 1\n",
    "grid[maxheight-1, maxwidth-1] = 2\n",
    "env = MazeEnv(grid = grid)\n",
    "env.print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91f53743-83f1-46b5-bbea-b950e2738e52",
   "metadata": {},
   "outputs": [],
   "source": [
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False, eps = 0)\n",
    "    \n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False, intrinsic_fn = Manhattan)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "103911ef-dd19-4146-8457-9116c2c352e6",
   "metadata": {},
   "source": [
    "## Walled maze (20x20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b2f7559-5cc6-4556-8a8b-543cbe8e6cd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create game environment\n",
    "size = 20\n",
    "grid = np.zeros((size,size))\n",
    "maxheight, maxwidth = grid.shape\n",
    "grid[:, maxwidth//2-1] = 3\n",
    "grid[maxheight//2,:] = 3\n",
    "grid[1:maxheight, maxwidth-2] = 3\n",
    "grid[maxheight//2, maxwidth//4] = 0\n",
    "grid[maxheight//2, 3*maxwidth//4-1] = 0\n",
    "grid[3*maxheight//4, maxwidth//2-1] = 0\n",
    "grid[maxheight//2, maxwidth-1] = 0\n",
    "grid[0, 0] = 1\n",
    "grid[maxheight-1, maxwidth-1] = 2\n",
    "env = MazeEnv(grid = grid)\n",
    "env.print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d3fdd4b-b771-4934-9d74-88c8639a6763",
   "metadata": {},
   "outputs": [],
   "source": [
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False, eps = 0)\n",
    "    \n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False, intrinsic_fn = Manhattan)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "560e37e2-aaea-416b-8d9f-652da51aa602",
   "metadata": {},
   "source": [
    "## Walled maze (100x100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9db7318b-2126-4ea0-96b6-9ef2715adc49",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create game environment\n",
    "size = 100\n",
    "grid = np.zeros((size,size))\n",
    "maxheight, maxwidth = grid.shape\n",
    "grid[:, maxwidth//2-1] = 3\n",
    "grid[maxheight//2,:] = 3\n",
    "grid[1:maxheight, maxwidth-2] = 3\n",
    "grid[maxheight//2, maxwidth//4] = 0\n",
    "grid[maxheight//2, 3*maxwidth//4-1] = 0\n",
    "grid[3*maxheight//4, maxwidth//2-1] = 0\n",
    "grid[maxheight//2, maxwidth-1] = 0\n",
    "grid[0, 0] = 1\n",
    "grid[maxheight-1, maxwidth-1] = 2\n",
    "env = MazeEnv(grid = grid)\n",
    "env.print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97c2936b-e682-4479-bbff-8657b3f2725d",
   "metadata": {},
   "outputs": [],
   "source": [
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False, eps = 0)\n",
    "    \n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False, intrinsic_fn = Manhattan)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8a45987-56db-40f1-bfb4-5f2e1aaa2005",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Extra: Without hippocampal replay"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee761593-20e0-42ee-887c-f655e5f8a139",
   "metadata": {},
   "outputs": [],
   "source": [
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, hippocampal_replay = False, verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, hippocampal_replay = False, verbose = False, eps = 0)\n",
    "    \n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, hippocampal_replay = False, verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, hippocampal_replay = False, verbose = False, intrinsic_fn = Manhattan)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c369ecf-3c18-4f5f-b903-41a5ac270ce3",
   "metadata": {},
   "source": [
    "## Extra: Random Intrinsic Rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9018819-2b85-4195-b886-3c74cddfc6c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Intrinsic Agent with random intrinsic rewards\n",
    "agent = CountAgent\n",
    "numtrials = 10\n",
    "\n",
    "for reward in [1, 5, 10, 20, 50]:\n",
    "    for j in range(numtrials):\n",
    "        np.random.seed(j)\n",
    "        # print(agent.__name__, 'with random reward of', reward)\n",
    "        memory = defaultdict(lambda: 0)\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.height*env.width, verbose = False, intrinsic_fn = lambda x: (-np.random.rand(numtrials)*reward)[j])\n",
    "        \n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c28feff-2c04-4a1f-bdba-0605f1c5374b",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# Discrete Environment 3 - Tower of Hanoi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edd65c51-59da-45b8-a1eb-0e009efca158",
   "metadata": {},
   "outputs": [],
   "source": [
    "class HanoiEnv:\n",
    "    def __init__(self, numplates = 3, grid = None):\n",
    "        self.numplates = numplates\n",
    "        self.numsteps = 0\n",
    "        self.done = False\n",
    "        self.reward = 0\n",
    "        # if grid not defined, then set to original position\n",
    "        if grid is None:\n",
    "            self.grid = [list(range(1, self.numplates+1)), [], []]\n",
    "                \n",
    "        # if grid predefined, get the parameters from there instead\n",
    "        else:\n",
    "            self.grid = grid\n",
    "            \n",
    "        # some variables to reset the environment\n",
    "        self.startgrid = self.grid.copy()\n",
    "            \n",
    "    def reset(self):\n",
    "        self.grid = self.startgrid.copy()\n",
    "        self.done = False\n",
    "        self.reward = 0\n",
    "        self.numsteps = 0\n",
    "        \n",
    "    # returns state representation\n",
    "    def staterep(self):\n",
    "        return str(self.grid)\n",
    "            \n",
    "    # gets a valid move\n",
    "    def getvalidmoves(self):\n",
    "        validmoves = []\n",
    "        for num, (pole1, pole2) in enumerate([(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]):\n",
    "            # if nothing to transfer, skip move\n",
    "            if len(self.grid[pole1]) == 0: continue\n",
    "            # if pole to be transferred to is empty, accept move\n",
    "            if len(self.grid[pole2]) == 0: validmoves.append(num)\n",
    "            # if piece to be transferred is smaller than the topmost piece of new pole, accept it\n",
    "            elif self.grid[pole1][0] < self.grid[pole2][0]: validmoves.append(num)    \n",
    "            \n",
    "        return validmoves\n",
    "    \n",
    "    def step(self, move):\n",
    "        validmoves = self.getvalidmoves()\n",
    "        # randomly sample a move if not in validmoves\n",
    "        if move not in validmoves:\n",
    "            move = validmoves[np.random.randint(len(validmoves))]\n",
    "        self.numsteps += 1\n",
    "        \n",
    "        movechoices = [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]\n",
    "        # do the move\n",
    "        pole1, pole2 = movechoices[move]\n",
    "        self.grid[pole2] = [self.grid[pole1][0]] + self.grid[pole2]\n",
    "        self.grid[pole1] = self.grid[pole1][1:]\n",
    "        \n",
    "        # check for completion\n",
    "        if self.grid[2] == list(range(1, self.numplates+1)):\n",
    "            self.reward = 1\n",
    "            self.done = True\n",
    "    \n",
    "    def sample(self):\n",
    "        return np.random.choice(self.getvalidmoves())\n",
    "    \n",
    "    def print(self):\n",
    "        print(self.grid[0], self.grid[1], self.grid[2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8605947-ce36-4e72-bfda-0970ee58bd09",
   "metadata": {},
   "outputs": [],
   "source": [
    "''' Number of disks away from solution (neg) '''\n",
    "def Disk(env):\n",
    "    finalpole = env.grid[2][::-1]\n",
    "    totalsum = 0\n",
    "    for i in range(env.numplates):\n",
    "        if i+1 > len(finalpole): break\n",
    "        if env.numplates - i != finalpole[i]: break\n",
    "        # totalsum += env.numplates - i\n",
    "        totalsum += 1\n",
    "    # return (totalsum - np.sum(list(range(env.numplates+1))))/np.sum(list(range(env.numplates+1)))\n",
    "    return (totalsum - env.numplates)/env.numplates"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04684dda-9750-4e4c-8502-4d18e10f4c70",
   "metadata": {},
   "source": [
    "## 3 plates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1990d0af-afa5-487b-b036-3775952df9e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = HanoiEnv(numplates = 3)\n",
    "\n",
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 2**(env.numplates+2), verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 2**(env.numplates+2), verbose = False, eps = 0)\n",
    "    \n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 2**(env.numplates+2), verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 2**(env.numplates+2), verbose = False, intrinsic_fn = Disk)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ba22122-fe33-4040-945d-c5c9347e4d9f",
   "metadata": {},
   "source": [
    "## 7 plates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a91e8aa5-38d5-4fd4-a804-0510e68b832e",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = HanoiEnv(numplates = 7)\n",
    "\n",
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 2**(env.numplates+2), verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 2**(env.numplates+2), verbose = False, eps = 0)\n",
    "    \n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 2**(env.numplates+2), verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 2**(env.numplates+2), verbose = False, intrinsic_fn = Disk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e460adf-1481-49ca-9542-6a0725c9377b",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = HanoiEnv(numplates = 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ec13841-8bac-4b64-9108-81c2ab321a0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 200, env = copy.deepcopy(env), agent = agent, maxsteps = 2**(env.numplates+2), verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 200, env = copy.deepcopy(env), agent = agent, maxsteps = 2**(env.numplates+2), verbose = False, eps = 0)\n",
    "    \n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 200, env = copy.deepcopy(env), agent = agent, maxsteps = 2**(env.numplates+2), verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 200, env = copy.deepcopy(env), agent = agent, maxsteps = 2**(env.numplates+2), verbose = False, intrinsic_fn = Disk)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83566fc2-c924-4ee7-b57f-95d6d7ced644",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# Discrete Environment 4 - Game of Nim with Perfect Opponent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "913314a9-2cf3-4810-a5f9-80fa38caeb1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class NimEnv:\n",
    "    def __init__(self, nummatches = 21, nummoves = 3, deterministic = True):\n",
    "        self.nummatches = nummatches\n",
    "        self.nummoves = nummoves\n",
    "        self.numsteps = 0\n",
    "        self.reward = 0\n",
    "        self.deterministic = deterministic\n",
    "        self.afterstate = nummatches\n",
    "        self.done = False\n",
    "            \n",
    "        # some variables to reset the environment\n",
    "        self.startmatches = self.nummatches\n",
    "        self.afterstate = nummatches\n",
    "            \n",
    "    def reset(self):\n",
    "        self.nummatches = self.startmatches\n",
    "        self.reward = 0\n",
    "        self.done = False\n",
    "        self.numsteps = 0\n",
    "        \n",
    "    # returns state representation\n",
    "    def staterep(self):\n",
    "        return str([self.nummatches, self.nummoves])\n",
    "            \n",
    "    # gets a valid move\n",
    "    def getvalidmoves(self):\n",
    "        validmoves = list(range(1, min(self.nummatches, self.nummoves)+1))\n",
    "        # give a valid move just to help with the learning process\n",
    "        if validmoves == []: validmoves = [0]\n",
    "        return validmoves\n",
    "    \n",
    "    def step(self, move):\n",
    "        if self.done:\n",
    "            return\n",
    "        \n",
    "        validmoves = self.getvalidmoves()\n",
    "        # randomly sample a move if not in validmoves\n",
    "        if move not in validmoves:\n",
    "            move = validmoves[np.random.randint(len(validmoves))]\n",
    "        self.numsteps += 1\n",
    "        \n",
    "        # do your move\n",
    "        self.nummatches -= move\n",
    "        \n",
    "        # get the afterstate\n",
    "        self.afterstate = self.nummatches\n",
    "        \n",
    "        if self.nummatches == 0:\n",
    "            self.done = True\n",
    "            self.reward = 1\n",
    "            return\n",
    "        \n",
    "        # do the perfect player's move\n",
    "        # choose randomly if already at perfect number\n",
    "        if self.nummatches % (self.nummoves+1) == 0:\n",
    "            if self.deterministic:\n",
    "                self.nummatches -= 1\n",
    "            else:\n",
    "                self.nummatches -= self.sample()\n",
    "            \n",
    "        # else make it perfect number\n",
    "        else:\n",
    "            self.nummatches -= self.nummatches % (self.nummoves+1)\n",
    "            \n",
    "        if self.nummatches == 0:\n",
    "            # failure state just for accounting purposes\n",
    "            self.nummatches = -1\n",
    "            self.done = True\n",
    "            return\n",
    "    \n",
    "    def sample(self):\n",
    "        return np.random.choice(self.getvalidmoves())\n",
    "    \n",
    "    def print(self):\n",
    "        print(self.nummatches, self.nummoves)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2c7d0c8-c33f-48bd-946a-88c13bb3c04b",
   "metadata": {},
   "outputs": [],
   "source": [
    "''' Gives a distance based on number of matches away from the goal '''\n",
    "def CountMatches(env):\n",
    "    return -env.afterstate/env.startmatches"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b0a5c93-092c-4c79-aca2-1ed615fa9e33",
   "metadata": {},
   "source": [
    "## 11 matches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cad6528-50e0-4cd5-95b0-de54a0ce02d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = NimEnv(nummatches = 11, nummoves = 3, deterministic = True)\n",
    "\n",
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.nummatches, verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.nummatches, verbose = False, eps = 0)\n",
    "    \n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.nummatches, verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.nummatches, verbose = False, intrinsic_fn = CountMatches)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7205b46-aae5-4cce-b62b-edd9a055a87f",
   "metadata": {},
   "source": [
    "## 21 matches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dceba0b2-0178-40a2-8ec6-1757d23cb4f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = NimEnv(nummatches = 21, nummoves = 3, deterministic = True)\n",
    "\n",
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.nummatches, verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.nummatches, verbose = False, eps = 0)\n",
    "    \n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.nummatches, verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.nummatches, verbose = False, intrinsic_fn = CountMatches)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49fbcf92-f9d3-48e5-b674-23c167736ee4",
   "metadata": {},
   "source": [
    "## 1001 matches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa991737-5829-4e82-ad09-4b94dc8e0644",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = NimEnv(nummatches = 1001, nummoves = 3, deterministic = True)\n",
    "\n",
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.nummatches, verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.nummatches, verbose = False, eps = 0)\n",
    "    \n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.nummatches, verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = env.nummatches, verbose = False, intrinsic_fn = CountMatches)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9f803c4-b288-4c31-a080-1e0121c01284",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# Gym Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "839e357a-deb7-4b0b-9d0f-4aa1ae6cf416",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "GSnZsMBt_vkJ",
    "outputId": "442e6bfa-191a-4387-b0ab-a3777f01cbda"
   },
   "outputs": [],
   "source": [
    "import gym\n",
    "from gym import logger as gymlogger\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import glob\n",
    "import io\n",
    "import base64\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9c19c58-4166-4686-a982-8db1c39bc95e",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# Continuous Environment 1 - Cart Pole"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2a21868-073d-4ec4-8462-14ef042f134b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CartPoleEnv:\n",
    "    def __init__(self, env_name = 'CartPole-v0', goal_steps = 50, start_state = None, normalizer = np.array([0.01, 0.001, 0.01, 0.001]), cap = np.array([42])):\n",
    "        self.env_name = env_name\n",
    "        self.numsteps = 0\n",
    "        self.reward = 0\n",
    "        self.done = False\n",
    "        self.start_state = start_state\n",
    "        self.goal_steps = goal_steps\n",
    "        self.normalizer = normalizer\n",
    "        self.cap = cap\n",
    "        self.env = gym.make(env_name)\n",
    "        self.env.reset()\n",
    "        if self.start_state is not None:\n",
    "            self.env.unwrapped.state = self.start_state\n",
    "        \n",
    "            \n",
    "    def reset(self):\n",
    "        self.env.reset()\n",
    "        if self.start_state is not None:\n",
    "            self.env.unwrapped.state = self.start_state\n",
    "        self.reward = 0\n",
    "        self.done = False\n",
    "        self.numsteps = 0\n",
    "        \n",
    "    # returns state representation\n",
    "    def staterep(self):\n",
    "        if self.normalizer is not None:\n",
    "            return str(((self.env.state)//self.normalizer).clip(-self.cap, self.cap))\n",
    "        else:\n",
    "            return str(self.env.state)\n",
    "            \n",
    "    # gets a valid move\n",
    "    def getvalidmoves(self):\n",
    "        return list(range(self.env.action_space.n))\n",
    "    \n",
    "    def step(self, move):\n",
    "        if self.done:\n",
    "            return\n",
    "    \n",
    "        validmoves = self.getvalidmoves()\n",
    "        # randomly sample a move if not in validmoves\n",
    "        if move not in validmoves:\n",
    "            move = validmoves[np.random.randint(len(validmoves))]\n",
    "        self.numsteps += 1\n",
    "        \n",
    "        # do your move\n",
    "        state, reward, done, _ = self.env.step(move)\n",
    "        # print(self.numsteps, self.env.state, (np.array(self.env.state)//self.normalizer).clip(-self.cap, self.cap))\n",
    "        \n",
    "        # only if done at step 175 is considered success\n",
    "        if done:\n",
    "            self.done = True\n",
    "            if self.numsteps > self.goal_steps:\n",
    "                self.reward = 1\n",
    "            else:\n",
    "                self.reward = 0\n",
    "    \n",
    "    def sample(self):\n",
    "        return np.random.choice(self.getvalidmoves())\n",
    "    \n",
    "    def print(self):\n",
    "        print(self.env.state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07e22c7b-9957-4c51-a641-fd9bf4076421",
   "metadata": {},
   "outputs": [],
   "source": [
    "''' Align the cart position (state[0]) to the center, and the pole angle (state[2]) to the center '''\n",
    "def Cart(env):\n",
    "    return -0.5*np.abs(env.env.state[0])/2.4 - 0.5*np.abs(env.env.state[2])/0.2095"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f03919d4-7e0d-427e-9fd1-83b285c73ecc",
   "metadata": {},
   "source": [
    "## Cart Pole 50 - Without numbered repeated state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72e857c2-d34e-4a79-b56a-3a2719130e60",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = CartPoleEnv(start_state = np.array([0, 0, 0, 0]), goal_steps = 50)\n",
    "\n",
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, verbose = False, eps = 0)\n",
    "\n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "# for agent in [GoExplore]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, verbose = False, intrinsic_fn = Cart)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a53832de-0192-49d1-b393-90dbc0248161",
   "metadata": {},
   "source": [
    "## Cart Pole 50R - With numbered repeated state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "471739a3-3ff9-4d42-9d87-19f0fcd96338",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = CartPoleEnv(start_state = np.array([0, 0, 0, 0]), goal_steps = 50)\n",
    "\n",
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, repeatedstate = True, verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, repeatedstate = True, verbose = False, eps = 0)\n",
    "\n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "# for agent in [GoExplore]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, repeatedstate = True, verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, repeatedstate = True, verbose = False, intrinsic_fn = Cart)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "839dad9d-c63b-465d-9a42-519b036bf5e0",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Cart Pole 100 - without numbered repeated state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "370a7b7e-f714-4212-a19d-0b8fd7b69d0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = CartPoleEnv(start_state = np.array([0, 0, 0, 0]), goal_steps = 100)\n",
    "\n",
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, verbose = False, eps = 0)\n",
    "\n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "# for agent in [GoExplore]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, verbose = False, intrinsic_fn = Cart)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56ccdbb7-44e2-4967-b303-5c0316c5fd0f",
   "metadata": {},
   "source": [
    "## Cart Pole 100R - with numbered repeated state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4679ee5-c69c-49ee-96e0-b66a44282029",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = CartPoleEnv(start_state = np.array([0, 0, 0, 0]), goal_steps = 100)\n",
    "\n",
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, repeatedstate = True, verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, repeatedstate = True, verbose = False, eps = 0)\n",
    "\n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "# for agent in [GoExplore]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, repeatedstate = True, verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, repeatedstate = True, verbose = False, intrinsic_fn = Cart)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3bb96474-4be3-48dd-ae4f-65a630ae3ee7",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Cart Pole 175 - without numbered repeated state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09c13037-4f3c-4600-bb60-2d3ace427168",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = CartPoleEnv(start_state = np.array([0, 0, 0, 0]), goal_steps = 175)\n",
    "\n",
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, verbose = False, eps = 0)\n",
    "\n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "# for agent in [GoExplore]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, verbose = False, intrinsic_fn = Cart)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "705ed15a-9dba-45cd-a83a-d08b549dcad6",
   "metadata": {},
   "source": [
    "## Cart Pole 175R - with numbered repeated state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68b2fc42-9b39-44fc-9abe-52d9aa599d15",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = CartPoleEnv(start_state = np.array([0, 0, 0, 0]), goal_steps = 175)\n",
    "\n",
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, repeatedstate = True, verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, repeatedstate = True, verbose = False, eps = 0)\n",
    "\n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "# for agent in [CountAgent]:\n",
    "# for agent in [GoExplore]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, repeatedstate = True, verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, repeatedstate = True, verbose = False, intrinsic_fn = Cart)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "229e6874-786c-4413-946c-42cd5842c36a",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# Continuous Environment 2 - Mountain Car"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6de8ffe8-72b5-4bb1-8265-8a5edad02ed3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MountainCarEnv:\n",
    "    def __init__(self, env_name = 'MountainCar-v0', start_state = None, repeat_moves = 1, normalizer = np.array([0.01, 0.001]), cap = np.array([50])):\n",
    "        self.env_name = env_name\n",
    "        self.numsteps = 0\n",
    "        self.reward = 0\n",
    "        self.done = False\n",
    "        self.start_state = start_state\n",
    "        self.repeat_moves = repeat_moves\n",
    "        self.normalizer = normalizer\n",
    "        self.cap = cap\n",
    "        self.env = gym.make(env_name)\n",
    "        self.env.reset()\n",
    "        if self.start_state is not None:\n",
    "            self.env.unwrapped.state = self.start_state\n",
    "        \n",
    "            \n",
    "    def reset(self):\n",
    "        self.env.reset()\n",
    "        if self.start_state is not None:\n",
    "            self.env.unwrapped.state = self.start_state\n",
    "        self.reward = 0\n",
    "        self.done = False\n",
    "        self.numsteps = 0\n",
    "        \n",
    "    # returns state representation\n",
    "    def staterep(self):\n",
    "        if self.normalizer is not None:\n",
    "            return str(((self.env.state)//self.normalizer).clip(-self.cap, self.cap))\n",
    "        else:\n",
    "            return str(self.env.state)\n",
    "            \n",
    "    # gets a valid move\n",
    "    def getvalidmoves(self):\n",
    "        return list(range(self.env.action_space.n))\n",
    "    \n",
    "    def step(self, move):\n",
    "        if self.done:\n",
    "            return\n",
    "    \n",
    "        validmoves = self.getvalidmoves()\n",
    "        # randomly sample a move if not in validmoves\n",
    "        if move not in validmoves:\n",
    "            move = validmoves[np.random.randint(len(validmoves))]\n",
    "        self.numsteps += self.repeat_moves\n",
    "        \n",
    "        # do your move (repeat_moves number of times)\n",
    "        for i in range(self.repeat_moves):\n",
    "            state, reward, done, _ = self.env.step(move)\n",
    "        # print(self.numsteps, self.env.state, (np.array(self.env.state)//self.normalizer).clip(-self.cap, self.cap))\n",
    "        \n",
    "        # only if past x=0.5, then it is success\n",
    "        if done:\n",
    "            self.done = True\n",
    "            if self.env.state[0] >= 0.5:\n",
    "                self.reward = 1\n",
    "            else:\n",
    "                self.reward = 0\n",
    "    \n",
    "    def sample(self):\n",
    "        return np.random.choice(self.getvalidmoves())\n",
    "    \n",
    "    def print(self):\n",
    "        print(self.env.state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "345c20be-ad24-4953-966a-93be9bc17133",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Pos(env):\n",
    "    return env.env.state[0]-0.5"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3389dbdf-6d9b-47dc-9edf-c1c63bd42429",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Normal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "128091de-0925-44cb-bc60-bf94317a1e0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = MountainCarEnv(start_state = np.array([-0.5, 0]))\n",
    "\n",
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, verbose = False, eps = 0)\n",
    "\n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "# for agent in [CountAgent]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, verbose = False, intrinsic_fn = Pos)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6867989e-781a-4739-8030-fff41434df51",
   "metadata": {},
   "source": [
    "## Repeated Actions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f426d6a-0bf7-4c04-82ec-e5fd211f6f4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = MountainCarEnv(start_state = np.array([-0.5, 0]), repeat_moves = 10)\n",
    "\n",
    "for agent in [RandomAgent, TDAgent, QAgent]:\n",
    "    # print(agent.__name__, 'training')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, verbose = False)\n",
    "    \n",
    "    # for QAgent and TDAgent, we include the final number of solves with deterministic transition\n",
    "    if agent == QAgent or agent == TDAgent:\n",
    "        # print(agent.__name__, 'testing')\n",
    "        MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, verbose = False, eps = 0)\n",
    "\n",
    "for agent in [GoExplore, GoExploreCount, CountAgent]:\n",
    "# for agent in [CountAgent]:\n",
    "    # print(agent.__name__, 'without IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, verbose = False, intrinsic_fn = None)\n",
    "\n",
    "    # for Intrinsic Agent, we also do the run with intrinsic motivation\n",
    "    # print(agent.__name__, 'with IM')\n",
    "    memory = defaultdict(lambda: 0)\n",
    "    MultiGame(numtries = 100, env = copy.deepcopy(env), agent = agent, maxsteps = 200, verbose = False, intrinsic_fn = Pos)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
