{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualizations for MEGA\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Behaviour & Achieved Goal Distribution\n",
    "Color code the behaviour/achieved goal over the epoch / training iterations.\n",
    "\n",
    "Similar to Sibling Rivalry paper Figure 3: https://papers.nips.cc/paper/9225-keeping-your-distance-solving-sparse-reward-tasks-using-self-balancing-shaped-rewards.pdf\n",
    "\n",
    "When running the `train_mega.py` training experiments, make sure to use the command line option `--save-embeddings` to ensure that we are logging the behaviour and achieved goals over the course of training for this visualization. \n",
    "\n",
    "Do this for:\n",
    "1. PointMaze\n",
    "2. AntMaze"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import os, glob\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PointMaze"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright (c) 2019, salesforce.com, inc.\n",
    "# All rights reserved.\n",
    "# SPDX-License-Identifier: MIT\n",
    "# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/MIT\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "class CircleMaze:\n",
    "    def __init__(self):\n",
    "        self.ring_r = 0.15\n",
    "        self.stop_t = 0.05\n",
    "        self.s_angle = 30\n",
    "\n",
    "        self.mean_s0 = (\n",
    "            float(np.cos(np.pi * self.s_angle / 180)),\n",
    "            float(np.sin(np.pi * self.s_angle / 180))\n",
    "        )\n",
    "        self.mean_g = (\n",
    "            float(np.cos(np.pi * (360-self.s_angle) / 180)),\n",
    "            float(np.sin(np.pi * (360-self.s_angle) / 180))\n",
    "        )\n",
    "\n",
    "    def plot(self, ax=None):\n",
    "        if ax is None:\n",
    "            _, ax = plt.subplots(1, 1, figsize=(5, 4), dpi=200)\n",
    "        if ax is None:\n",
    "            _, ax = plt.subplots(1, 1, figsize=(5, 4), dpi=200)\n",
    "        rads = np.linspace(self.stop_t * 2 * np.pi, (1 - self.stop_t) * 2 * np.pi)\n",
    "        xs_i = (1 - self.ring_r) * np.cos(rads)\n",
    "        ys_i = (1 - self.ring_r) * np.sin(rads)\n",
    "        xs_o = (1 + self.ring_r) * np.cos(rads)\n",
    "        ys_o = (1 + self.ring_r) * np.sin(rads)\n",
    "        ax.plot(xs_i, ys_i, 'k', linewidth=3)\n",
    "        ax.plot(xs_o, ys_o, 'k', linewidth=3)\n",
    "        ax.plot([xs_i[0], xs_o[0]], [ys_i[0], ys_o[0]], 'k', linewidth=3)\n",
    "        ax.plot([xs_i[-1], xs_o[-1]], [ys_i[-1], ys_o[-1]], 'k', linewidth=3)\n",
    "        lim = 1.1 + self.ring_r\n",
    "        ax.set_xlim([-lim, lim])\n",
    "        ax.set_ylim([-lim, lim])\n",
    "\n",
    "    def sample_start(self):\n",
    "        STD = 0.1\n",
    "        return self.move(self.mean_s0, (STD * np.random.randn(), STD * np.random.randn()))\n",
    "\n",
    "    def sample_goal(self):\n",
    "        STD = 0.1\n",
    "        return self.move(self.mean_g, (STD * np.random.randn(), STD * np.random.randn()))\n",
    "\n",
    "    @staticmethod\n",
    "    def xy_to_rt(xy):\n",
    "        x = xy[0]\n",
    "        y = xy[1]\n",
    "        r = np.sqrt(x ** 2 + y ** 2)\n",
    "        t = np.arctan2(y, x) % (2 * np.pi)\n",
    "        return r, t\n",
    "\n",
    "    def move(self, coords, action):\n",
    "        xp, yp = coords\n",
    "        rp, tp = self.xy_to_rt(coords)\n",
    "\n",
    "        xy = (coords[0] + action[0], coords[1] + action[1])\n",
    "\n",
    "        r, t = self.xy_to_rt(xy)\n",
    "        t = np.clip(t % (2 * np.pi), (0.001 + self.stop_t) * (2 * np.pi), (1 - (0.001 + self.stop_t)) * (2 * np.pi))\n",
    "        x = np.cos(t) * r\n",
    "        y = np.sin(t) * r\n",
    "\n",
    "        if coords is not None:\n",
    "\n",
    "            if xp > 0:\n",
    "                if (y < 0) and (yp > 0):\n",
    "                    t = self.stop_t * 2 * np.pi\n",
    "                elif (y > 0) and (yp < 0):\n",
    "                    t = (1 - self.stop_t) * 2 * np.pi\n",
    "            x = np.cos(t) * r\n",
    "            y = np.sin(t) * r\n",
    "\n",
    "        n = 8\n",
    "        xyi = np.array([xp, yp]).astype(np.float32)\n",
    "        dxy = (np.array([x, y]).astype(np.float32) - xyi) / n\n",
    "        new_r = float(rp)\n",
    "        new_t = float(tp)\n",
    "\n",
    "        count = 0\n",
    "\n",
    "        def r_ok(r_):\n",
    "            return (1 - self.ring_r) <= r_ <= (1 + self.ring_r)\n",
    "\n",
    "        def t_ok(t_):\n",
    "            return (self.stop_t * (2 * np.pi)) <= (t_ % (2 * np.pi)) <= ((1 - self.stop_t) * (2 * np.pi))\n",
    "\n",
    "        while r_ok(new_r) and t_ok(new_t) and count < n:\n",
    "            xyi += dxy\n",
    "            new_r, new_t = self.xy_to_rt(xyi)\n",
    "            count += 1\n",
    "\n",
    "        r = np.clip(new_r, 1 - self.ring_r + 0.01, 1 + self.ring_r - 0.01)\n",
    "        t = np.clip(new_t % (2 * np.pi), (0.001 + self.stop_t) * (2 * np.pi), (1 - (0.001 + self.stop_t)) * (2 * np.pi))\n",
    "        x = np.cos(t) * r\n",
    "        y = np.sin(t) * r\n",
    "\n",
    "        return float(x), float(y)\n",
    "\n",
    "\n",
    "class Maze:\n",
    "    def __init__(self, *segment_dicts, goal_squares=None, start_squares=None):\n",
    "        self._segments = {'origin': {'loc': (0.0, 0.0), 'connect': set()}}\n",
    "        self._locs = set()\n",
    "        self._locs.add(self._segments['origin']['loc'])\n",
    "        self._walls = set()\n",
    "        for direction in ['up', 'down', 'left', 'right']:\n",
    "            self._walls.add(self._wall_line(self._segments['origin']['loc'], direction))\n",
    "        self._last_segment = 'origin'\n",
    "        self.goal_squares = None\n",
    "\n",
    "        if goal_squares is None:\n",
    "            self._goal_squares = None\n",
    "        elif isinstance(goal_squares, str):\n",
    "            self._goal_squares = [goal_squares.lower()]\n",
    "        elif isinstance(goal_squares, (tuple, list)):\n",
    "            self._goal_squares = [gs.lower() for gs in goal_squares]\n",
    "        else:\n",
    "            raise TypeError\n",
    "\n",
    "        if start_squares is None:\n",
    "            self.start_squares = ['origin']\n",
    "        elif isinstance(start_squares, str):\n",
    "            self.start_squares = [start_squares.lower()]\n",
    "        elif isinstance(start_squares, (tuple, list)):\n",
    "            self.start_squares = [ss.lower() for ss in start_squares]\n",
    "        else:\n",
    "            raise TypeError\n",
    "\n",
    "        for segment_dict in segment_dicts:\n",
    "            self._add_segment(**segment_dict)\n",
    "        self._finalize()\n",
    "\n",
    "    @staticmethod\n",
    "    def _wall_line(coord, direction):\n",
    "        x, y = coord\n",
    "        if direction == 'up':\n",
    "            w = [(x - 0.5, x + 0.5), (y + 0.5, y + 0.5)]\n",
    "        elif direction == 'right':\n",
    "            w = [(x + 0.5, x + 0.5), (y + 0.5, y - 0.5)]\n",
    "        elif direction == 'down':\n",
    "            w = [(x - 0.5, x + 0.5), (y - 0.5, y - 0.5)]\n",
    "        elif direction == 'left':\n",
    "            w = [(x - 0.5, x - 0.5), (y - 0.5, y + 0.5)]\n",
    "        else:\n",
    "            raise ValueError\n",
    "        w = tuple([tuple(sorted(line)) for line in w])\n",
    "        return w\n",
    "\n",
    "    def _add_segment(self, name, anchor, direction, connect=None, times=1):\n",
    "        name = str(name).lower()\n",
    "        original_name = str(name).lower()\n",
    "        if times > 1:\n",
    "            assert connect is None\n",
    "            last_name = str(anchor).lower()\n",
    "            for time in range(times):\n",
    "                this_name = original_name + str(time)\n",
    "                self._add_segment(name=this_name.lower(), anchor=last_name, direction=direction)\n",
    "                last_name = str(this_name)\n",
    "            return\n",
    "\n",
    "        anchor = str(anchor).lower()\n",
    "        assert anchor in self._segments\n",
    "\n",
    "        direction = str(direction).lower()\n",
    "\n",
    "        final_connect = set()\n",
    "\n",
    "        if connect is not None:\n",
    "            if isinstance(connect, str):\n",
    "                connect = str(connect).lower()\n",
    "                assert connect in ['up', 'down', 'left', 'right']\n",
    "                final_connect.add(connect)\n",
    "            elif isinstance(connect, (tuple, list)):\n",
    "                for connect_direction in connect:\n",
    "                    connect_direction = str(connect_direction).lower()\n",
    "                    assert connect_direction in ['up', 'down', 'left', 'right']\n",
    "                    final_connect.add(connect_direction)\n",
    "\n",
    "        sx, sy = self._segments[anchor]['loc']\n",
    "        dx, dy = 0.0, 0.0\n",
    "        if direction == 'left':\n",
    "            dx -= 1\n",
    "            final_connect.add('right')\n",
    "        elif direction == 'right':\n",
    "            dx += 1\n",
    "            final_connect.add('left')\n",
    "        elif direction == 'up':\n",
    "            dy += 1\n",
    "            final_connect.add('down')\n",
    "        elif direction == 'down':\n",
    "            dy -= 1\n",
    "            final_connect.add('up')\n",
    "        else:\n",
    "            raise ValueError\n",
    "\n",
    "        new_loc = (sx + dx, sy + dy)\n",
    "        assert new_loc not in self._locs\n",
    "\n",
    "        self._segments[name] = {'loc': new_loc, 'connect': final_connect}\n",
    "        for direction in ['up', 'down', 'left', 'right']:\n",
    "            self._walls.add(self._wall_line(new_loc, direction))\n",
    "        self._locs.add(new_loc)\n",
    "\n",
    "        self._last_segment = name\n",
    "\n",
    "    def _finalize(self):\n",
    "        for segment in self._segments.values():\n",
    "            for c_dir in list(segment['connect']):\n",
    "                wall = self._wall_line(segment['loc'], c_dir)\n",
    "                if wall in self._walls:\n",
    "                    self._walls.remove(wall)\n",
    "\n",
    "        if self._goal_squares is None:\n",
    "            self.goal_squares = [self._last_segment]\n",
    "        else:\n",
    "            self.goal_squares = []\n",
    "            for gs in self._goal_squares:\n",
    "                assert gs in self._segments\n",
    "                self.goal_squares.append(gs)\n",
    "\n",
    "    def plot(self, ax=None):\n",
    "        if ax is None:\n",
    "            _, ax = plt.subplots(1, 1, figsize=(5, 4), dpi=200)\n",
    "        for x, y in self._walls:\n",
    "            ax.plot(x, y, 'k-')\n",
    "\n",
    "    def sample_start(self):\n",
    "        min_wall_dist = 0.05\n",
    "\n",
    "        s_square = self.start_squares[np.random.randint(low=0, high=len(self.start_squares))]\n",
    "        s_square_loc = self._segments[s_square]['loc']\n",
    "\n",
    "        while True:\n",
    "            shift = np.random.uniform(low=-0.5, high=0.5, size=(2,))\n",
    "            loc = s_square_loc + shift\n",
    "            dist_checker = np.array([min_wall_dist, min_wall_dist]) * np.sign(shift)\n",
    "            stopped_loc = self.move(loc, dist_checker)\n",
    "            if float(np.sum(np.abs((loc + dist_checker) - stopped_loc))) == 0.0:\n",
    "                break\n",
    "        return loc[0], loc[1]\n",
    "\n",
    "    def sample_goal(self, min_wall_dist=None):\n",
    "        if min_wall_dist is None:\n",
    "            min_wall_dist = 0.1\n",
    "        else:\n",
    "            min_wall_dist = min(0.4, max(0.01, min_wall_dist))\n",
    "\n",
    "        g_square = self.goal_squares[np.random.randint(low=0, high=len(self.goal_squares))]\n",
    "        g_square_loc = self._segments[g_square]['loc']\n",
    "        while True:\n",
    "            shift = np.random.uniform(low=-0.5, high=0.5, size=(2,))\n",
    "            loc = g_square_loc + shift\n",
    "            dist_checker = np.array([min_wall_dist, min_wall_dist]) * np.sign(shift)\n",
    "            stopped_loc = self.move(loc, dist_checker)\n",
    "            if float(np.sum(np.abs((loc + dist_checker) - stopped_loc))) == 0.0:\n",
    "                break\n",
    "        return loc[0], loc[1]\n",
    "\n",
    "    def move(self, coord_start, coord_delta, depth=None):\n",
    "        if depth is None:\n",
    "            depth = 0\n",
    "        cx, cy = coord_start\n",
    "        loc_x0 = np.round(cx)\n",
    "        loc_y0 = np.round(cy)\n",
    "        #assert (float(loc_x0), float(loc_y0)) in self._locs\n",
    "        dx, dy = coord_delta\n",
    "        loc_x1 = np.round(cx + dx)\n",
    "        loc_y1 = np.round(cy + dy)\n",
    "        d_loc_x = int(np.abs(loc_x1 - loc_x0))\n",
    "        d_loc_y = int(np.abs(loc_y1 - loc_y0))\n",
    "        xs_crossed = [loc_x0 + (np.sign(dx) * (i + 0.5)) for i in range(d_loc_x)]\n",
    "        ys_crossed = [loc_y0 + (np.sign(dy) * (i + 0.5)) for i in range(d_loc_y)]\n",
    "\n",
    "        rds = []\n",
    "\n",
    "        for x in xs_crossed:\n",
    "            r = (x - cx) / dx\n",
    "            loc_x = np.round(cx + (0.999 * r * dx))\n",
    "            loc_y = np.round(cy + (0.999 * r * dy))\n",
    "            direction = 'right' if dx > 0 else 'left'\n",
    "            crossed_line = self._wall_line((loc_x, loc_y), direction)\n",
    "            if crossed_line in self._walls:\n",
    "                rds.append([r, direction])\n",
    "\n",
    "        for y in ys_crossed:\n",
    "            r = (y - cy) / dy\n",
    "            loc_x = np.round(cx + (0.999 * r * dx))\n",
    "            loc_y = np.round(cy + (0.999 * r * dy))\n",
    "            direction = 'up' if dy > 0 else 'down'\n",
    "            crossed_line = self._wall_line((loc_x, loc_y), direction)\n",
    "            if crossed_line in self._walls:\n",
    "                rds.append([r, direction])\n",
    "\n",
    "        # The wall will only stop the agent in the direction perpendicular to the wall\n",
    "        if rds:\n",
    "            rds = sorted(rds)\n",
    "            r, direction = rds[0]\n",
    "            if depth < 3:\n",
    "                new_dx = r * dx\n",
    "                new_dy = r * dy\n",
    "                repulsion = float(np.abs(np.random.rand() * 0.01))\n",
    "                if direction in ['right', 'left']:\n",
    "                    new_dx -= np.sign(dx) * repulsion\n",
    "                    partial_coords = cx + new_dx, cy + new_dy\n",
    "                    remaining_delta = (0.0, (1 - r) * dy)\n",
    "                else:\n",
    "                    new_dy -= np.sign(dy) * repulsion\n",
    "                    partial_coords = cx + new_dx, cy + new_dy\n",
    "                    remaining_delta = ((1 - r) * dx, 0.0)\n",
    "                return self.move(partial_coords, remaining_delta, depth+1)\n",
    "        else:\n",
    "            r = 1.0\n",
    "\n",
    "        dx *= r\n",
    "        dy *= r\n",
    "        return cx + dx, cy + dy\n",
    "\n",
    "\n",
    "def make_crazy_maze(size, seed=None):\n",
    "    np.random.seed(seed)\n",
    "\n",
    "    deltas = [\n",
    "        [(-1, 0), 'right'],\n",
    "        [(1, 0), 'left'],\n",
    "        [(0, -1), 'up'],\n",
    "        [(0, 1), 'down'],\n",
    "    ]\n",
    "\n",
    "    empty_locs = []\n",
    "    for x in range(size):\n",
    "        for y in range(size):\n",
    "            empty_locs.append((x, y))\n",
    "\n",
    "    locs = [empty_locs.pop(0)]\n",
    "    dirs = [None]\n",
    "    anchors = [None]\n",
    "\n",
    "    while len(empty_locs) > 0:\n",
    "        still_empty = []\n",
    "        np.random.shuffle(empty_locs)\n",
    "        for empty_x, empty_y in empty_locs:\n",
    "            found_anchor = False\n",
    "            np.random.shuffle(deltas)\n",
    "            for (dx, dy), direction in deltas:\n",
    "                c = (empty_x + dx, empty_y + dy)\n",
    "                if c in locs:\n",
    "                    found_anchor = True\n",
    "                    locs.append((empty_x, empty_y))\n",
    "                    dirs.append(direction)\n",
    "                    anchors.append(c)\n",
    "                    break\n",
    "            if not found_anchor:\n",
    "                still_empty.append((empty_x, empty_y))\n",
    "        empty_locs = still_empty[:]\n",
    "\n",
    "    locs = [str(x) + ',' + str(y) for x, y in locs[1:]]\n",
    "    dirs = dirs[1:]\n",
    "    anchors = [str(x) + ',' + str(y) for x, y in anchors[1:]]\n",
    "    anchors = ['origin' if a == '0,0' else a for a in anchors]\n",
    "\n",
    "    segments = []\n",
    "    for loc, d, anchor in zip(locs, dirs, anchors):\n",
    "        segments.append(dict(name=loc, anchor=anchor, direction=d))\n",
    "\n",
    "    np.random.seed()\n",
    "    return Maze(*segments, goal_squares='{s},{s}'.format(s=size - 1))\n",
    "\n",
    "\n",
    "def make_experiment_maze(h, half_w, sz0):\n",
    "    if h < 2:\n",
    "        h = 2\n",
    "    if half_w < 3:\n",
    "        half_w = 3\n",
    "    w = 1 + (2*half_w)\n",
    "    # Create the starting row\n",
    "    segments = [{'anchor': 'origin', 'direction': 'right', 'name': '0,1'}]\n",
    "    for w_ in range(1, w-1):\n",
    "        segments.append({'anchor': '0,{}'.format(w_), 'direction': 'right', 'name': '0,{}'.format(w_+1)})\n",
    "\n",
    "    # Add each row to create H\n",
    "    for h_ in range(1, h):\n",
    "        segments.append({'anchor': '{},{}'.format(h_-1, w-1), 'direction': 'up', 'name': '{},{}'.format(h_, w-1)})\n",
    "\n",
    "        c = None if h_ == sz0 else 'down'\n",
    "        for w_ in range(w-2, -1, -1):\n",
    "            segments.append(\n",
    "                {'anchor': '{},{}'.format(h_, w_+1), 'direction': 'left', 'connect': c, 'name': '{},{}'.format(h_, w_)}\n",
    "            )\n",
    "\n",
    "    return Maze(*segments, goal_squares=['{},{}'.format(h-1, half_w+d) for d in [0]])\n",
    "\n",
    "\n",
    "def make_hallway_maze(corridor_length):\n",
    "    corridor_length = int(corridor_length)\n",
    "    assert corridor_length >= 1\n",
    "\n",
    "    segments = []\n",
    "    last = 'origin'\n",
    "    for x in range(1, corridor_length+1):\n",
    "        next_name = '0,{}'.format(x)\n",
    "        segments.append({'anchor': last, 'direction': 'right', 'name': next_name})\n",
    "        last = str(next_name)\n",
    "\n",
    "    return Maze(*segments, goal_squares=last)\n",
    "\n",
    "\n",
    "def make_u_maze(corridor_length):\n",
    "    corridor_length = int(corridor_length)\n",
    "    assert corridor_length >= 1\n",
    "\n",
    "    segments = []\n",
    "    last = 'origin'\n",
    "    for x in range(1, corridor_length + 1):\n",
    "        next_name = '0,{}'.format(x)\n",
    "        segments.append({'anchor': last, 'direction': 'right', 'name': next_name})\n",
    "        last = str(next_name)\n",
    "\n",
    "    assert last == '0,{}'.format(corridor_length)\n",
    "\n",
    "    up_size = 2\n",
    "\n",
    "    for x in range(1, up_size+1):\n",
    "        next_name = '{},{}'.format(x, corridor_length)\n",
    "        segments.append({'anchor': last, 'direction': 'up', 'name': next_name})\n",
    "        last = str(next_name)\n",
    "\n",
    "    assert last == '{},{}'.format(up_size, corridor_length)\n",
    "\n",
    "    for x in range(1, corridor_length + 1):\n",
    "        next_name = '{},{}'.format(up_size, corridor_length - x)\n",
    "        segments.append({'anchor': last, 'direction': 'left', 'name': next_name})\n",
    "        last = str(next_name)\n",
    "\n",
    "    assert last == '{},0'.format(up_size)\n",
    "\n",
    "    return Maze(*segments, goal_squares=last)\n",
    "\n",
    "\n",
    "\n",
    "mazes_dict = dict()\n",
    "\n",
    "mazes_dict['circle'] = {'maze': CircleMaze(), 'action_range': 0.25}\n",
    "\n",
    "segments_a = [\n",
    "    dict(name='A', anchor='origin', direction='down', times=4),\n",
    "    dict(name='B', anchor='A3', direction='right', times=4),\n",
    "    dict(name='C', anchor='B3', direction='up', times=4),\n",
    "    dict(name='D', anchor='A1', direction='right', times=2),\n",
    "    dict(name='E', anchor='D1', direction='up', times=2),\n",
    "]\n",
    "mazes_dict['square_a'] = {'maze': Maze(*segments_a, goal_squares=['c2', 'c3']), 'action_range': 0.95}\n",
    "\n",
    "segments_b = [\n",
    "    dict(name='A', anchor='origin', direction='down', times=4),\n",
    "    dict(name='B', anchor='A3', direction='right', times=4),\n",
    "    dict(name='C', anchor='B3', direction='up', times=4),\n",
    "    dict(name='D', anchor='B1', direction='up', times=4),\n",
    "]\n",
    "mazes_dict['square_b'] = {'maze': Maze(*segments_b, goal_squares=['c2', 'c3']), 'action_range': 0.95}\n",
    "\n",
    "segments_c = [\n",
    "    dict(name='A', anchor='origin', direction='down', times=4),\n",
    "    dict(name='B', anchor='A3', direction='right', times=2),\n",
    "    dict(name='C', anchor='B1', direction='up', times=4),\n",
    "    dict(name='D', anchor='C3', direction='right', times=2),\n",
    "    dict(name='E', anchor='D1', direction='down', times=4)\n",
    "]\n",
    "mazes_dict['square_c'] = {'maze': Maze(*segments_c, goal_squares=['e2', 'e3']), 'action_range': 0.95}\n",
    "\n",
    "segments_d = [\n",
    "    dict(name='TL', anchor='origin', direction='left', times=3),\n",
    "    dict(name='TLD', anchor='TL2', direction='down', times=3),\n",
    "    dict(name='TLR', anchor='TLD2', direction='right', times=2),\n",
    "    dict(name='TLU', anchor='TLR1', direction='up'),\n",
    "    dict(name='TR', anchor='origin', direction='right', times=3),\n",
    "    dict(name='TRD', anchor='TR2', direction='down', times=3),\n",
    "    dict(name='TRL', anchor='TRD2', direction='left', times=2),\n",
    "    dict(name='TRU', anchor='TRL1', direction='up'),\n",
    "    dict(name='TD', anchor='origin', direction='down', times=3),\n",
    "]\n",
    "mazes_dict['square_d'] = {'maze': Maze(*segments_d, goal_squares=['tlu', 'tlr1', 'tru', 'trl1']), 'action_range': 0.95}\n",
    "\n",
    "segments_crazy = [\n",
    "    {'anchor': 'origin', 'direction': 'right', 'name': '1,0'},\n",
    "     {'anchor': 'origin', 'direction': 'up', 'name': '0,1'},\n",
    "     {'anchor': '1,0', 'direction': 'right', 'name': '2,0'},\n",
    "     {'anchor': '0,1', 'direction': 'up', 'name': '0,2'},\n",
    "     {'anchor': '0,2', 'direction': 'right', 'name': '1,2'},\n",
    "     {'anchor': '2,0', 'direction': 'up', 'name': '2,1'},\n",
    "     {'anchor': '1,2', 'direction': 'right', 'name': '2,2'},\n",
    "     {'anchor': '0,2', 'direction': 'up', 'name': '0,3'},\n",
    "     {'anchor': '2,1', 'direction': 'right', 'name': '3,1'},\n",
    "     {'anchor': '1,2', 'direction': 'down', 'name': '1,1'},\n",
    "     {'anchor': '3,1', 'direction': 'down', 'name': '3,0'},\n",
    "     {'anchor': '1,2', 'direction': 'up', 'name': '1,3'},\n",
    "     {'anchor': '3,1', 'direction': 'right', 'name': '4,1'},\n",
    "     {'anchor': '1,3', 'direction': 'up', 'name': '1,4'},\n",
    "     {'anchor': '4,1', 'direction': 'right', 'name': '5,1'},\n",
    "     {'anchor': '4,1', 'direction': 'up', 'name': '4,2'},\n",
    "     {'anchor': '5,1', 'direction': 'down', 'name': '5,0'},\n",
    "     {'anchor': '3,0', 'direction': 'right', 'name': '4,0'},\n",
    "     {'anchor': '1,4', 'direction': 'right', 'name': '2,4'},\n",
    "     {'anchor': '4,2', 'direction': 'right', 'name': '5,2'},\n",
    "     {'anchor': '2,4', 'direction': 'right', 'name': '3,4'},\n",
    "     {'anchor': '3,4', 'direction': 'up', 'name': '3,5'},\n",
    "     {'anchor': '1,4', 'direction': 'left', 'name': '0,4'},\n",
    "     {'anchor': '1,4', 'direction': 'up', 'name': '1,5'},\n",
    "     {'anchor': '2,2', 'direction': 'up', 'name': '2,3'},\n",
    "     {'anchor': '3,1', 'direction': 'up', 'name': '3,2'},\n",
    "     {'anchor': '5,0', 'direction': 'right', 'name': '6,0'},\n",
    "     {'anchor': '3,2', 'direction': 'up', 'name': '3,3'},\n",
    "     {'anchor': '4,2', 'direction': 'up', 'name': '4,3'},\n",
    "     {'anchor': '6,0', 'direction': 'up', 'name': '6,1'},\n",
    "     {'anchor': '6,0', 'direction': 'right', 'name': '7,0'},\n",
    "     {'anchor': '6,1', 'direction': 'right', 'name': '7,1'},\n",
    "     {'anchor': '3,4', 'direction': 'right', 'name': '4,4'},\n",
    "     {'anchor': '1,5', 'direction': 'right', 'name': '2,5'},\n",
    "     {'anchor': '7,1', 'direction': 'up', 'name': '7,2'},\n",
    "     {'anchor': '1,5', 'direction': 'up', 'name': '1,6'},\n",
    "     {'anchor': '4,4', 'direction': 'right', 'name': '5,4'},\n",
    "     {'anchor': '5,4', 'direction': 'down', 'name': '5,3'},\n",
    "     {'anchor': '0,4', 'direction': 'up', 'name': '0,5'},\n",
    "     {'anchor': '7,2', 'direction': 'left', 'name': '6,2'},\n",
    "     {'anchor': '1,6', 'direction': 'left', 'name': '0,6'},\n",
    "     {'anchor': '7,0', 'direction': 'right', 'name': '8,0'},\n",
    "     {'anchor': '7,2', 'direction': 'right', 'name': '8,2'},\n",
    "     {'anchor': '2,5', 'direction': 'up', 'name': '2,6'},\n",
    "     {'anchor': '8,0', 'direction': 'up', 'name': '8,1'},\n",
    "     {'anchor': '3,5', 'direction': 'up', 'name': '3,6'},\n",
    "     {'anchor': '6,2', 'direction': 'up', 'name': '6,3'},\n",
    "     {'anchor': '6,3', 'direction': 'right', 'name': '7,3'},\n",
    "     {'anchor': '3,5', 'direction': 'right', 'name': '4,5'},\n",
    "     {'anchor': '7,3', 'direction': 'up', 'name': '7,4'},\n",
    "     {'anchor': '6,3', 'direction': 'up', 'name': '6,4'},\n",
    "     {'anchor': '6,4', 'direction': 'up', 'name': '6,5'},\n",
    "     {'anchor': '8,1', 'direction': 'right', 'name': '9,1'},\n",
    "     {'anchor': '8,2', 'direction': 'right', 'name': '9,2'},\n",
    "     {'anchor': '2,6', 'direction': 'up', 'name': '2,7'},\n",
    "     {'anchor': '8,2', 'direction': 'up', 'name': '8,3'},\n",
    "     {'anchor': '6,5', 'direction': 'left', 'name': '5,5'},\n",
    "     {'anchor': '5,5', 'direction': 'up', 'name': '5,6'},\n",
    "     {'anchor': '7,4', 'direction': 'right', 'name': '8,4'},\n",
    "     {'anchor': '8,4', 'direction': 'right', 'name': '9,4'},\n",
    "     {'anchor': '0,6', 'direction': 'up', 'name': '0,7'},\n",
    "     {'anchor': '2,7', 'direction': 'up', 'name': '2,8'},\n",
    "     {'anchor': '7,4', 'direction': 'up', 'name': '7,5'},\n",
    "     {'anchor': '9,4', 'direction': 'down', 'name': '9,3'},\n",
    "     {'anchor': '9,4', 'direction': 'up', 'name': '9,5'},\n",
    "     {'anchor': '2,7', 'direction': 'left', 'name': '1,7'},\n",
    "     {'anchor': '4,5', 'direction': 'up', 'name': '4,6'},\n",
    "     {'anchor': '9,1', 'direction': 'down', 'name': '9,0'},\n",
    "     {'anchor': '6,5', 'direction': 'up', 'name': '6,6'},\n",
    "     {'anchor': '3,6', 'direction': 'up', 'name': '3,7'},\n",
    "     {'anchor': '1,7', 'direction': 'up', 'name': '1,8'},\n",
    "     {'anchor': '3,7', 'direction': 'right', 'name': '4,7'},\n",
    "     {'anchor': '2,8', 'direction': 'up', 'name': '2,9'},\n",
    "     {'anchor': '2,9', 'direction': 'left', 'name': '1,9'},\n",
    "     {'anchor': '7,5', 'direction': 'up', 'name': '7,6'},\n",
    "     {'anchor': '1,8', 'direction': 'left', 'name': '0,8'},\n",
    "     {'anchor': '6,6', 'direction': 'up', 'name': '6,7'},\n",
    "     {'anchor': '0,8', 'direction': 'up', 'name': '0,9'},\n",
    "     {'anchor': '7,5', 'direction': 'right', 'name': '8,5'},\n",
    "     {'anchor': '6,7', 'direction': 'left', 'name': '5,7'},\n",
    "     {'anchor': '2,9', 'direction': 'right', 'name': '3,9'},\n",
    "     {'anchor': '3,9', 'direction': 'right', 'name': '4,9'},\n",
    "     {'anchor': '7,6', 'direction': 'right', 'name': '8,6'},\n",
    "     {'anchor': '3,7', 'direction': 'up', 'name': '3,8'},\n",
    "     {'anchor': '9,5', 'direction': 'up', 'name': '9,6'},\n",
    "     {'anchor': '7,6', 'direction': 'up', 'name': '7,7'},\n",
    "     {'anchor': '5,7', 'direction': 'up', 'name': '5,8'},\n",
    "     {'anchor': '3,8', 'direction': 'right', 'name': '4,8'},\n",
    "     {'anchor': '8,6', 'direction': 'up', 'name': '8,7'},\n",
    "     {'anchor': '5,8', 'direction': 'right', 'name': '6,8'},\n",
    "     {'anchor': '7,7', 'direction': 'up', 'name': '7,8'},\n",
    "     {'anchor': '4,9', 'direction': 'right', 'name': '5,9'},\n",
    "     {'anchor': '8,7', 'direction': 'right', 'name': '9,7'},\n",
    "     {'anchor': '7,8', 'direction': 'right', 'name': '8,8'},\n",
    "     {'anchor': '8,8', 'direction': 'up', 'name': '8,9'},\n",
    "     {'anchor': '5,9', 'direction': 'right', 'name': '6,9'},\n",
    "     {'anchor': '6,9', 'direction': 'right', 'name': '7,9'},\n",
    "     {'anchor': '8,9', 'direction': 'right', 'name': '9,9'},\n",
    "     {'anchor': '9,9', 'direction': 'down', 'name': '9,8'}\n",
    "]\n",
    "mazes_dict['square_large'] = {'maze': Maze(*segments_crazy, goal_squares='9,9'), 'action_range': 0.95}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mazes_dict['square_large']['maze'].plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200)\n",
    "mazes_dict['square_large']['maze'].plot(ax)\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### For OMEGA (i.e. with transition)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bgs_list = {}\n",
    "ags_list = {}\n",
    "rand_ags_list = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = \"pointmaze\"\n",
    "\n",
    "bgs_list[env] = {}\n",
    "ags_list[env] = {}\n",
    "rand_ags_list[env] = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method = \"OMEGA\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = \"/scratch/gobi1/spitis/shared/icml20/harris/proto_env-pointmaze_alg-DDPG_herrfaab_1_4_3_1_1_layer-(512, 512, 512)_seed111_tb-MIN_DENSITY_TRANSITION_ag_cu-minkde_eexpl0.1_first-True_dg_sc1.0_alpha--1.0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "timestep_list = [folder for folder in os.listdir(base_path) if folder.isdigit()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "timestep_dict = {}\n",
    "for time in timestep_list:\n",
    "    timestep_dict[int(time)] = time\n",
    "\n",
    "timestep_list = [int(t) for t in timestep_list]\n",
    "timestep_list.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "timestep_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False):\n",
    "    goals_list = []\n",
    "\n",
    "    # Skip every other and up to 120k \n",
    "    for timestep in timestep_list[::skip_every]:\n",
    "        if timestep < max_timestep and timestep > min_timestep:\n",
    "            path = os.path.join(base_path, timestep_dict[timestep], folder_name)\n",
    "            if os.path.exists(path):\n",
    "                filename = os.path.join(path, \"tensors.tsv\")\n",
    "                data = np.genfromtxt(fname=filename, delimiter=\"\\t\", skip_header=0, filling_values=-1)  # change filling_values as req'd to fill in missing values\n",
    "                # Filter out repeated datapoints\n",
    "                if filter_unique:\n",
    "                    data= np.unique(data, axis=0)\n",
    "                else:\n",
    "                    # Take last one in each episode\n",
    "                    data = data[(episode_length-1)::episode_length]\n",
    "                num_rows = data.shape[0]\n",
    "                train_step = np.ones((num_rows,1)) * timestep / episode_length / scale\n",
    "                data = np.append(data, train_step, axis=1)\n",
    "                goals_list.append(data)\n",
    "\n",
    "    goals_list = np.concatenate(goals_list, axis=0)\n",
    "    return goals_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_timestep = 120000\n",
    "min_timestep = 5000\n",
    "folder_name = \"last_bgs\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = True # For last_bgs\n",
    "skip_every = 2\n",
    "bgs_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the maze shape\n",
    "mazes_dict['square_large']['maze'].plot()\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(bgs_list[env][method][:,0], bgs_list[env][method][:,1], c=bgs_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Behaviour Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_bgs_omega_viz.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"last_ags\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 2\n",
    "ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mazes_dict['square_large']['maze'].plot()\n",
    "plt.scatter(ags_list[env][method][:,0], ags_list[env][method][:,1], c=ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Terminal Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_last_ags_omega_viz.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"rand_ags\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = True # For last_bgs\n",
    "skip_every = 2\n",
    "rand_ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mazes_dict['square_large']['maze'].plot()\n",
    "plt.scatter(rand_ags_list[env][method][:,0], rand_ags_list[env][method][:,1], c=rand_ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Random Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_rand_ags_omega_viz.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot all in one figure with subplots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_rows = 1\n",
    "num_cols = 3\n",
    "fig, axs = plt.subplots(nrows=num_rows,ncols=num_cols, figsize=(4*num_cols, 4*num_rows))\n",
    "\n",
    "data_list = [bgs_list[env][method], ags_list[env][method], rand_ags_list[env][method]]\n",
    "method_name = \"OMEGA\"\n",
    "title_list = [\"Behaviour Goals ({})\".format(method_name),\n",
    "              \"Terminal Achieved Goals ({})\".format(method_name),\n",
    "              \"Random Achieved Goals ({})\".format(method_name)]\n",
    "\n",
    "for c in range(num_cols):\n",
    "    mazes_dict['square_large']['maze'].plot(axs[c])\n",
    "    \n",
    "    im = axs[c].scatter(data_list[c][:,0], data_list[c][:,1], c=data_list[c][:,2], alpha=0.4)\n",
    "    axs[c].set_title(title_list[c])\n",
    "    axs[c].axis('off')\n",
    "\n",
    "fig.subplots_adjust(right=0.9)\n",
    "cbar_ax = fig.add_axes([0.95, 0.15, 0.02, 0.7])\n",
    "cbar = fig.colorbar(im, cax=cbar_ax)\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Try plotting both behaviour and achieved goal on the same figure\n",
    "\n",
    "num_rows = 1\n",
    "num_cols = 1\n",
    "fig, axs = plt.subplots(nrows=num_rows,ncols=num_cols, figsize=(4*num_cols, 4*num_rows), dpi=200)\n",
    "\n",
    "data_list = [bgs_list[env][method], ags_list[env][method], rand_ags_list[env][method]]\n",
    "method_name = \"OMEGA\"\n",
    "title_list = [\"Behaviour Goals ({})\".format(method_name),\n",
    "              \"Terminal Achieved Goals ({})\".format(method_name),\n",
    "              \"Random Achieved Goals ({})\".format(method_name)]\n",
    "\n",
    "mazes_dict['square_large']['maze'].plot(axs)\n",
    "    \n",
    "im = axs.scatter(data_list[0][:,0], data_list[0][:,1], c=data_list[0][:,2], alpha=0.4, marker=\"^\", label=\"Behavioural\")\n",
    "im = axs.scatter(data_list[1][:,0], data_list[1][:,1], c=data_list[1][:,2], alpha=0.4, marker=\".\", label=\"Terminal Achieved\")\n",
    "axs.axis('off')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Comment: Looks too busy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### For MEGA (i.e. without transition)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method = \"MEGA\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = \"/scratch/gobi1/spitis/shared/icml20/harris/proto_env-pointmaze_alg-DDPG_herrfaab_1_4_3_1_1_layer-(512, 512, 512)_seed111_tb-MIN_DENSITY_ag_cu-minkde_eexpl0.1_first-True_dg_sc1.0_alpha--1.0\"\n",
    "timestep_list = [folder for folder in os.listdir(base_path) if folder.isdigit()]\n",
    "timestep_dict = {}\n",
    "for time in timestep_list:\n",
    "    timestep_dict[int(time)] = time\n",
    "\n",
    "timestep_list = [int(t) for t in timestep_list]\n",
    "timestep_list.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"last_bgs\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = True # For last_bgs\n",
    "skip_every = 2\n",
    "bgs_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the maze shape\n",
    "mazes_dict['square_large']['maze'].plot()\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(bgs_list[env][method][:,0], bgs_list[env][method][:,1], c=bgs_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Behaviour Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_bgs_mega_viz.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"last_ags\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 2\n",
    "ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mazes_dict['square_large']['maze'].plot()\n",
    "plt.scatter(ags_list[env][method][:,0], ags_list[env][method][:,1], c=ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Terminal Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_last_ags_mega_viz.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"rand_ags\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = True # For last_bgs\n",
    "skip_every = 2\n",
    "rand_ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mazes_dict['square_large']['maze'].plot()\n",
    "plt.scatter(rand_ags_list[env][method][:,0], rand_ags_list[env][method][:,1], c=rand_ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Random Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_rand_ags_mega_viz.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Combine into one plot with shared colorbar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_rows = 1\n",
    "num_cols = 4\n",
    "fig, axs = plt.subplots(nrows=num_rows,ncols=num_cols, figsize=(4*num_cols, 4*num_rows), dpi=300)\n",
    "\n",
    "methods = [\"OMEGA\", \"MEGA\"]\n",
    "data_list = []\n",
    "title_list = []\n",
    "for method in methods:\n",
    "    data_list += [bgs_list[env][method], ags_list[env][method]]\n",
    "    title_list += [\"Behaviour Goals ({})\".format(method),\n",
    "              \"Terminal Achieved Goals ({})\".format(method)]\n",
    "\n",
    "for c in range(num_cols):\n",
    "    mazes_dict['square_large']['maze'].plot(axs[c])\n",
    "\n",
    "    im = axs[c].scatter(data_list[c][:,0], data_list[c][:,1], c=data_list[c][:,2], alpha=0.4)\n",
    "    axs[c].set_title(title_list[c], fontsize=14)\n",
    "    axs[c].axis('off')\n",
    "\n",
    "fig.subplots_adjust(right=0.95)\n",
    "cbar_ax = fig.add_axes([0.95, 0.15, 0.02, 0.7])\n",
    "cbar = fig.colorbar(im, cax=cbar_ax)\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.subplots_adjust(bottom=0.1, wspace=0.05)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SkewFit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method = \"SkewFit\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = \"/scratch/gobi1/spitis/shared/icml20/harris/proto_env-pointmaze_alg-DDPG_herrfaab_1_4_3_1_1_layer-(512, 512, 512)_seed111_tb-SKEWFIT_ag_cu-randkde_eexpl0.1_first-True_dg_sc1.0_alpha--1.0\"\n",
    "timestep_list = [folder for folder in os.listdir(base_path) if folder.isdigit()]\n",
    "timestep_dict = {}\n",
    "for time in timestep_list:\n",
    "    timestep_dict[int(time)] = time\n",
    "\n",
    "timestep_list = [int(t) for t in timestep_list]\n",
    "timestep_list.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"last_bgs\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = True # For last_bgs\n",
    "skip_every = 2\n",
    "bgs_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the maze shape\n",
    "mazes_dict['square_large']['maze'].plot()\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(bgs_list[env][method][:,0], bgs_list[env][method][:,1], c=bgs_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Behaviour Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_bgs_skewfit_viz.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"last_ags\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 2\n",
    "ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mazes_dict['square_large']['maze'].plot()\n",
    "plt.scatter(ags_list[env][method][:,0], ags_list[env][method][:,1], c=ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Terminal Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_last_ags_skewfit_viz.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"rand_ags\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = True # For last_bgs\n",
    "skip_every = 2\n",
    "rand_ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mazes_dict['square_large']['maze'].plot()\n",
    "plt.scatter(rand_ags_list[env][method][:,0], rand_ags_list[env][method][:,1], c=rand_ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Random Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_rand_ags_skewfit_viz.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MLE (RIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method = \"MLE\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = \"/scratch/gobi1/spitis/shared/icml20/harris/proto_env-pointmaze_alg-DDPG_herrfaab_1_4_3_1_1_layer-(512, 512, 512)_seed111_tb-DISCERN_ag_cu-randkde_eexpl0.1_first-True_dg_sc1.0_alpha-0.0\"\n",
    "timestep_list = [folder for folder in os.listdir(base_path) if folder.isdigit()]\n",
    "timestep_dict = {}\n",
    "for time in timestep_list:\n",
    "    timestep_dict[int(time)] = time\n",
    "\n",
    "timestep_list = [int(t) for t in timestep_list]\n",
    "timestep_list.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"last_bgs\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = True # For last_bgs\n",
    "skip_every = 2\n",
    "bgs_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the maze shape\n",
    "mazes_dict['square_large']['maze'].plot()\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(bgs_list[env][method][:,0], bgs_list[env][method][:,1], c=bgs_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Behaviour Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_bgs_mle_viz.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"last_ags\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 2\n",
    "ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mazes_dict['square_large']['maze'].plot()\n",
    "plt.scatter(ags_list[env][method][:,0], ags_list[env][method][:,1], c=ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Terminal Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_last_ags_mle_viz.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"rand_ags\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = True # For last_bgs\n",
    "skip_every = 2\n",
    "rand_ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mazes_dict['square_large']['maze'].plot()\n",
    "plt.scatter(rand_ags_list[env][method][:,0], rand_ags_list[env][method][:,1], c=rand_ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Random Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_rand_ags_mle_viz.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GoalGAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method = \"GoalGAN\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = \"/scratch/gobi1/spitis/shared/icml20/harris/proto_env-pointmaze_alg-DDPG_herrfaab_1_4_3_1_1_layer-(512, 512, 512)_seed111_tb-GOALGAN_ag_cu-goaldisc_eexpl0.1_first-True_dg_sc1.0_alpha--1.0\"\n",
    "timestep_list = [folder for folder in os.listdir(base_path) if folder.isdigit()]\n",
    "timestep_dict = {}\n",
    "for time in timestep_list:\n",
    "    timestep_dict[int(time)] = time\n",
    "\n",
    "timestep_list = [int(t) for t in timestep_list]\n",
    "timestep_list.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"last_bgs\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = True # For last_bgs\n",
    "skip_every = 2\n",
    "bgs_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the maze shape\n",
    "mazes_dict['square_large']['maze'].plot()\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(bgs_list[env][method][:,0], bgs_list[env][method][:,1], c=bgs_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Behaviour Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_bgs_goaldisc_viz.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"last_ags\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 2\n",
    "ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mazes_dict['square_large']['maze'].plot()\n",
    "plt.scatter(ags_list[env][method][:,0], ags_list[env][method][:,1], c=ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Terminal Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_last_ags_goaldisc_viz.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"rand_ags\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = True # For last_bgs\n",
    "skip_every = 2\n",
    "rand_ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mazes_dict['square_large']['maze'].plot()\n",
    "plt.scatter(rand_ags_list[env][method][:,0], rand_ags_list[env][method][:,1], c=rand_ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Random Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_rand_ags_goaldisc_viz.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Min Q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method = \"Min Q\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = \"/scratch/gobi1/spitis/shared/icml20/harris/proto_env-pointmaze_alg-DDPG_herrfaab_1_4_3_1_1_layer-(512, 512, 512)_seed111_tb-MIN_Q_ag_cu-minq_eexpl0.1_first-True_dg_sc1.0_alpha--1.0\"\n",
    "timestep_list = [folder for folder in os.listdir(base_path) if folder.isdigit()]\n",
    "timestep_dict = {}\n",
    "for time in timestep_list:\n",
    "    timestep_dict[int(time)] = time\n",
    "\n",
    "timestep_list = [int(t) for t in timestep_list]\n",
    "timestep_list.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"last_bgs\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = True # For last_bgs\n",
    "skip_every = 2\n",
    "bgs_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the maze shape\n",
    "mazes_dict['square_large']['maze'].plot()\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(bgs_list[env][method][:,0], bgs_list[env][method][:,1], c=bgs_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Behaviour Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_bgs_minq_viz.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"last_ags\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 2\n",
    "ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mazes_dict['square_large']['maze'].plot()\n",
    "plt.scatter(ags_list[env][method][:,0], ags_list[env][method][:,1], c=ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Terminal Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_last_ags_minq_viz.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"rand_ags\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = True # For last_bgs\n",
    "skip_every = 2\n",
    "rand_ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mazes_dict['square_large']['maze'].plot()\n",
    "plt.scatter(rand_ags_list[env][method][:,0], rand_ags_list[env][method][:,1], c=rand_ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Random Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_rand_ags_minq_viz.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## HER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method = \"HER\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = \"/scratch/gobi1/spitis/shared/icml20/harris/proto_env-pointmaze_alg-DDPG_herrfaab_1_4_3_1_1_layer-(512, 512, 512)_seed111_tb-HER_ag_cu-None_eexpl0.1_first-True_dg_sc1.0_alpha--1.0\"\n",
    "timestep_list = [folder for folder in os.listdir(base_path) if folder.isdigit()]\n",
    "timestep_dict = {}\n",
    "for time in timestep_list:\n",
    "    timestep_dict[int(time)] = time\n",
    "\n",
    "timestep_list = [int(t) for t in timestep_list]\n",
    "timestep_list.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"last_bgs\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = True # For last_bgs\n",
    "skip_every = 2\n",
    "bgs_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the maze shape\n",
    "mazes_dict['square_large']['maze'].plot()\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(bgs_list[env][method][:,0], bgs_list[env][method][:,1], c=bgs_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Behaviour Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_bgs_her_viz.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"last_ags\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 2\n",
    "ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mazes_dict['square_large']['maze'].plot()\n",
    "plt.scatter(ags_list[env][method][:,0], ags_list[env][method][:,1], c=ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Terminal Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_last_ags_her_viz.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"rand_ags\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = True # For last_bgs\n",
    "skip_every = 2\n",
    "rand_ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "mazes_dict['square_large']['maze'].plot()\n",
    "plt.scatter(rand_ags_list[env][method][:,0], rand_ags_list[env][method][:,1], c=rand_ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Random Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_rand_ags_her_viz.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MEGA (Entropy Gain KDE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method = \"MEGA(EG)\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = \"/scratch/gobi1/harris/icml20_2/proto_env-pointmaze_alg-DDPG_herrfaab_1_4_3_1_1_layer-(512, 512, 512)_seed111_tb-EG_CONDKDE2_ag_cu-entropygainscore_eexpl0.1_first-True_dg_sc1.0_alpha--1.0\"\n",
    "timestep_list = [folder for folder in os.listdir(base_path) if folder.isdigit()]\n",
    "timestep_dict = {}\n",
    "for time in timestep_list:\n",
    "    timestep_dict[int(time)] = time\n",
    "\n",
    "timestep_list = [int(t) for t in timestep_list]\n",
    "timestep_list.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"last_bgs\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = True # For last_bgs\n",
    "skip_every = 2\n",
    "bgs_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the maze shape\n",
    "mazes_dict['square_large']['maze'].plot()\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(bgs_list[env][method][:,0], bgs_list[env][method][:,1], c=bgs_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Behaviour Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_bgs_megeg_viz.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"last_ags\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 2\n",
    "ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mazes_dict['square_large']['maze'].plot()\n",
    "plt.scatter(ags_list[env][method][:,0], ags_list[env][method][:,1], c=ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Terminal Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_last_ags_megeg_viz.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"rand_ags\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = True # For last_bgs\n",
    "skip_every = 2\n",
    "rand_ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "mazes_dict['square_large']['maze'].plot()\n",
    "plt.scatter(rand_ags_list[env][method][:,0], rand_ags_list[env][method][:,1], c=rand_ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Random Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_rand_ags_megeg_viz.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## OMEGA (Entropy Gain KDE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method = \"OMEGA(EG)\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = \"/scratch/gobi1/harris/icml20_2/omega_eg_pointmaze_viz/proto_env-pointmaze_alg-DDPG_herrfaab_1_4_3_1_1_layer-(512, 512, 512)_seed111_tb-EG_CONDKDE2_TRANSITION_ag_cu-entropygainscore_eexpl0.1_first-True_dg_sc1.0_alpha--1.0_vremote\"\n",
    "timestep_list = [folder for folder in os.listdir(base_path) if folder.isdigit()]\n",
    "timestep_dict = {}\n",
    "for time in timestep_list:\n",
    "    timestep_dict[int(time)] = time\n",
    "\n",
    "timestep_list = [int(t) for t in timestep_list]\n",
    "timestep_list.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"last_bgs\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = True # For last_bgs\n",
    "skip_every = 2\n",
    "bgs_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the maze shape\n",
    "mazes_dict['square_large']['maze'].plot()\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(bgs_list[env][method][:,0], bgs_list[env][method][:,1], c=bgs_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Behaviour Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_bgs_omegeg_viz.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"last_ags\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 2\n",
    "ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mazes_dict['square_large']['maze'].plot()\n",
    "plt.scatter(ags_list[env][method][:,0], ags_list[env][method][:,1], c=ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Terminal Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_last_ags_omegeg_viz.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_name = \"rand_ags\"\n",
    "scale = 1000\n",
    "episode_length = 50\n",
    "do_filter_unique = True # For last_bgs\n",
    "skip_every = 2\n",
    "rand_ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "mazes_dict['square_large']['maze'].plot()\n",
    "plt.scatter(rand_ags_list[env][method][:,0], rand_ags_list[env][method][:,1], c=rand_ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Random Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_rand_ags_omegeg_viz.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Combined Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Behaviour and Achieved goals in one plot\n",
    "num_rows = 2\n",
    "num_cols = 7\n",
    "fig, axs = plt.subplots(nrows=num_rows,ncols=num_cols, figsize=(4*num_cols, 4*num_rows), dpi=300)\n",
    "\n",
    "methods = [\"HER\", \"MLE\", \"SkewFit\",\"GoalGAN\", \"Min Q\", \"MEGA\", \"OMEGA\"]\n",
    "method_name_dict = {\"HER\":\"HER\", \"OMEGA\":\"OMEGA\", \"MEGA\":\"MEGA\", \"SkewFit\":\"Diverse\", \"MLE\":\"Achieved\",\"GoalGAN\":\"GoalDisc\", \"Min Q\": \"Min Q\"}\n",
    "\n",
    "data_list = []\n",
    "title_list = []\n",
    "for r in range(num_rows):\n",
    "    data_list.append([])\n",
    "    title_list.append([])\n",
    "    for method in methods:\n",
    "        if r == 0:\n",
    "            data_list[r] += [bgs_list[env][method]]\n",
    "        elif r == 1:\n",
    "            data_list[r] += [ags_list[env][method]]\n",
    "        title_list[r] += [\n",
    "                  \"{} {}\".format(method_name_dict[method], \"(ours)\" if \"MEGA\" in method else \"\")]\n",
    "\n",
    "row_title = [r'Behavioural Goals', r'Final Achieved Goals']\n",
    "for r in range(num_rows):\n",
    "    for c in range(num_cols):\n",
    "        mazes_dict['square_large']['maze'].plot(axs[r,c])\n",
    "\n",
    "        im = axs[r,c].scatter(data_list[r][c][:,0], data_list[r][c][:,1], c=data_list[r][c][:,2], alpha=0.4)\n",
    "\n",
    "        if r == 0:\n",
    "            axs[r][c].set_title(title_list[r][c], fontsize=20, fontweight=\"bold\" if \"MEGA\" in methods[c] else \"normal\")\n",
    "        if c == 0:\n",
    "            axs[r][c].set_ylabel(row_title[r], fontsize=16, fontweight=\"bold\")\n",
    "            axs[r][c].axes.get_yaxis().set_ticks([])\n",
    "        else:\n",
    "            axs[r][c].get_yaxis().set_visible(False)\n",
    "\n",
    "        # Hacky way to hide the frames\n",
    "        axs[r][c].get_xaxis().set_visible(False)\n",
    "        axs[r][c].spines['bottom'].set_color('white')\n",
    "        axs[r][c].spines['top'].set_color('white') \n",
    "        axs[r][c].spines['right'].set_color('white')\n",
    "        axs[r][c].spines['left'].set_color('white')\n",
    "\n",
    "fig.subplots_adjust(right=0.95)\n",
    "cbar_ax = fig.add_axes([0.95, 0.15, 0.01, 0.7])\n",
    "cbar = fig.colorbar(im, cax=cbar_ax)\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=18, fontsize=18)\n",
    "cbar.ax.tick_params(labelsize=14)\n",
    "plt.subplots_adjust(bottom=0.1, wspace=0.04, hspace=0.04)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/pointmaze_last_ags_bgs_all_viz.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note**: The above plot went into the paper"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## AntMaze"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_antmaze_template(ax):\n",
    "    if ax is None:\n",
    "        fig = plt.figure(dpi=200)\n",
    "        ax = fig.add_subplot(111, aspect='equal')\n",
    "\n",
    "    # Frame\n",
    "    rect = matplotlib.patches.Rectangle((-4,-4), 24, 24, fill=False)\n",
    "    ax.add_patch(rect) \n",
    "    # Wall\n",
    "    wall = matplotlib.patches.Rectangle((-4,4), 16, 8, fill=True, facecolor='none', edgecolor='black', linewidth=1)\n",
    "    \n",
    "    # Cover one of the side\n",
    "    x = [-4, -4]\n",
    "    y = [4.1, 11.85]\n",
    "    ax.plot(x, y, color='white')\n",
    "    \n",
    "    ax.add_patch(wall)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = \"antmaze\"\n",
    "bgs_list = {}\n",
    "ags_list = {}\n",
    "rand_ags_list = {}\n",
    "\n",
    "bgs_list[env] = {}\n",
    "ags_list[env] = {}\n",
    "rand_ags_list[env] = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_timestep = 3000000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MEGA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method = \"MEGA\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = \"/scratch/gobi1/spitis/shared/icml20/spitis/proto_env-antmaze_alg-DDPG_herrfaab_1_4_3_1_1_layer-(512, 512, 512)_seed111_tb-MIN_DENSITY_ag_cu-minkde_eexpl0.1_first-True_dg_sc1.0_alpha--1.0\"\n",
    "timestep_list = [folder for folder in os.listdir(base_path) if folder.isdigit()]\n",
    "timestep_dict = {}\n",
    "for time in timestep_list:\n",
    "    timestep_dict[int(time)] = time\n",
    "\n",
    "timestep_list = [int(t) for t in timestep_list]\n",
    "timestep_list.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep = 5000\n",
    "folder_name = \"last_bgs\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "bgs_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(bgs_list[env][method][:,0], bgs_list[env][method][:,1], c=bgs_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Behaviour Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_bgs_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep = 5000\n",
    "folder_name = \"last_ags\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(ags_list[env][method][:,0], ags_list[env][method][:,1], c=ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Terminal Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_last_ags_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep = 5000\n",
    "folder_name = \"rand_ags\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "rand_ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(rand_ags_list[env][method][:,0], rand_ags_list[env][method][:,1], c=rand_ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Random Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_rand_ags_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## OMEGA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method = \"OMEGA\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = \"/scratch/gobi1/spitis/shared/icml20/spitis/proto_env-antmaze_alg-DDPG_herrfaab_1_4_3_1_1_layer-(512, 512, 512)_seed111_tb-MIN_DENSITY_TRANSITION_ag_cu-minkde_eexpl0.1_first-True_dg_sc1.0_alpha--1.0\"\n",
    "timestep_list = [folder for folder in os.listdir(base_path) if folder.isdigit()]\n",
    "timestep_dict = {}\n",
    "for time in timestep_list:\n",
    "    timestep_dict[int(time)] = time\n",
    "\n",
    "timestep_list = [int(t) for t in timestep_list]\n",
    "timestep_list.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep = 5000\n",
    "folder_name = \"last_bgs\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "bgs_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(bgs_list[env][method][:,0], bgs_list[env][method][:,1], c=bgs_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Behaviour Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_bgs_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep = 5000\n",
    "folder_name = \"last_ags\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(ags_list[env][method][:,0], ags_list[env][method][:,1], c=ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Terminal Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_last_ags_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep = 5000\n",
    "folder_name = \"rand_ags\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "rand_ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(rand_ags_list[env][method][:,0], rand_ags_list[env][method][:,1], c=rand_ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Random Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_rand_ags_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SkewFit / Diverse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method = \"Diverse\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = \"/scratch/gobi1/spitis/shared/icml20/spitis/proto_env-antmaze_alg-DDPG_herrfaab_1_4_3_1_1_layer-(512, 512, 512)_seed111_tb-SKEWFIT_ag_cu-randkde_eexpl0.1_first-True_dg_sc1.0_alpha--1.0\"\n",
    "timestep_list = [folder for folder in os.listdir(base_path) if folder.isdigit()]\n",
    "timestep_dict = {}\n",
    "for time in timestep_list:\n",
    "    timestep_dict[int(time)] = time\n",
    "\n",
    "timestep_list = [int(t) for t in timestep_list]\n",
    "timestep_list.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep = 5000\n",
    "folder_name = \"last_bgs\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "bgs_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(bgs_list[env][method][:,0], bgs_list[env][method][:,1], c=bgs_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Behaviour Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_bgs_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep  = 5000\n",
    "folder_name = \"last_ags\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(ags_list[env][method][:,0], ags_list[env][method][:,1], c=ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Terminal Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_last_ags_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep = 5000\n",
    "folder_name = \"rand_ags\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "rand_ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(rand_ags_list[env][method][:,0], rand_ags_list[env][method][:,1], c=rand_ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Random Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_rand_ags_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## HER (No Curiosity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method = \"HER\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_timestep"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = \"/scratch/gobi1/spitis/shared/icml20/harris/proto_env-antmaze_alg-DDPG_herrfaab_1_4_3_1_1_layer-(512, 512, 512)_seed111_tb-HER_ag_cu-None_eexpl0.1_first-True_dg_sc1.0_alpha--1.0\"\n",
    "timestep_list = [folder for folder in os.listdir(base_path) if folder.isdigit()]\n",
    "timestep_dict = {}\n",
    "for time in timestep_list:\n",
    "    timestep_dict[int(time)] = time\n",
    "\n",
    "timestep_list = [int(t) for t in timestep_list]\n",
    "timestep_list.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep = 5000\n",
    "folder_name = \"last_bgs\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "bgs_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(bgs_list[env][method][:,0], bgs_list[env][method][:,1], c=bgs_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Behaviour Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_bgs_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep = 5000\n",
    "folder_name = \"last_ags\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(ags_list[env][method][:,0], ags_list[env][method][:,1], c=ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Terminal Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_last_ags_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep = 5000\n",
    "folder_name = \"rand_ags\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "rand_ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(rand_ags_list[env][method][:,0], rand_ags_list[env][method][:,1], c=rand_ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Random Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_rand_ags_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MLE / Achieved"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method = \"Achieved\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = \"/scratch/gobi1/spitis/shared/icml20/spitis/proto_env-antmaze_alg-DDPG_herrfaab_1_4_3_1_1_layer-(512, 512, 512)_seed111_tb-DISCERN_ag_cu-randkde_eexpl0.1_first-True_dg_sc1.0_alpha-0.0\"\n",
    "timestep_list = [folder for folder in os.listdir(base_path) if folder.isdigit()]\n",
    "timestep_dict = {}\n",
    "for time in timestep_list:\n",
    "    timestep_dict[int(time)] = time\n",
    "\n",
    "timestep_list = [int(t) for t in timestep_list]\n",
    "timestep_list.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep = 5000\n",
    "folder_name = \"last_bgs\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "bgs_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(bgs_list[env][method][:,0], bgs_list[env][method][:,1], c=bgs_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Behaviour Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_bgs_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep = 5000\n",
    "folder_name = \"last_ags\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(ags_list[env][method][:,0], ags_list[env][method][:,1], c=ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Terminal Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_last_ags_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep = 5000\n",
    "folder_name = \"rand_ags\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "rand_ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(rand_ags_list[env][method][:,0], rand_ags_list[env][method][:,1], c=rand_ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Random Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_rand_ags_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GoalDisc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method = \"GoalDisc\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = \"/scratch/gobi1/spitis/shared/icml20/spitis/proto_env-antmaze_alg-DDPG_herrfaab_1_4_3_1_1_layer-(512, 512, 512)_seed111_tb-GOALDISC_ag_cu-goaldisc_eexpl0.1_first-True_dg_sc1.0_alpha--1.0\"\n",
    "timestep_list = [folder for folder in os.listdir(base_path) if folder.isdigit()]\n",
    "timestep_dict = {}\n",
    "for time in timestep_list:\n",
    "    timestep_dict[int(time)] = time\n",
    "\n",
    "timestep_list = [int(t) for t in timestep_list]\n",
    "timestep_list.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep = 5000\n",
    "folder_name = \"last_bgs\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "bgs_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(bgs_list[env][method][:,0], bgs_list[env][method][:,1], c=bgs_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Behaviour Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_bgs_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep = 5000\n",
    "folder_name = \"last_ags\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(ags_list[env][method][:,0], ags_list[env][method][:,1], c=ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Terminal Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_last_ags_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep = 5000\n",
    "folder_name = \"rand_ags\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "rand_ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(rand_ags_list[env][method][:,0], rand_ags_list[env][method][:,1], c=rand_ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Random Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_rand_ags_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Min-Q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method = \"Min Q\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = \"/scratch/gobi1/spitis/shared/icml20/spitis/proto_env-antmaze_alg-DDPG_herrfaab_1_4_3_1_1_layer-(512, 512, 512)_seed111_tb-MINQ_ag_cu-minq_eexpl0.1_first-True_dg_sc1.0_alpha--1.0\"\n",
    "timestep_list = [folder for folder in os.listdir(base_path) if folder.isdigit()]\n",
    "timestep_dict = {}\n",
    "for time in timestep_list:\n",
    "    timestep_dict[int(time)] = time\n",
    "\n",
    "timestep_list = [int(t) for t in timestep_list]\n",
    "timestep_list.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep = 5000\n",
    "folder_name = \"last_bgs\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "bgs_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(bgs_list[env][method][:,0], bgs_list[env][method][:,1], c=bgs_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Behaviour Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_bgs_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep = 5000\n",
    "folder_name = \"last_ags\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(ags_list[env][method][:,0], ags_list[env][method][:,1], c=ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Terminal Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_last_ags_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_timestep = 5000\n",
    "folder_name = \"rand_ags\"\n",
    "scale = 1000\n",
    "episode_length = 500\n",
    "do_filter_unique = False # For last_bgs\n",
    "skip_every = 10\n",
    "rand_ags_list[env][method] = extract_goals_data(base_path, timestep_dict, timestep_list, min_timestep, max_timestep, folder_name, episode_length, scale, skip_every, filter_unique=do_filter_unique)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the antmaze shape\n",
    "plot_antmaze_template(None)\n",
    "\n",
    "# Plot color coded points\n",
    "plt.scatter(rand_ags_list[env][method][:,0], rand_ags_list[env][method][:,1], c=rand_ags_list[env][method][:,2], alpha=0.4)\n",
    "plt.title(\"Random Achieved Goals Distribution ({})\".format(method))\n",
    "cbar = plt.colorbar()\n",
    "plt.axis('off')\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=15)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_rand_ags_{}_viz.pdf\".format(method.lower()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot all in one figure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Behaviour and Achieved goals in one plot\n",
    "# Max timestep = 3000000 (3M)\n",
    "methods = [\"HER\", \"Achieved\", \"Diverse\",\"GoalDisc\", \"Min Q\", \"MEGA\", \"OMEGA\"]\n",
    "method_name_dict = {\"HER\":\"HER\", \"OMEGA\":\"OMEGA\", \"MEGA\":\"MEGA\", \"Diverse\":\"Diverse\", \"Achieved\":\"Achieved\",\"GoalDisc\":\"GoalDisc\", \"Min Q\": \"Min Q\"}\n",
    "\n",
    "num_rows = 2\n",
    "num_cols = len(methods)\n",
    "fig, axs = plt.subplots(nrows=num_rows,ncols=num_cols, figsize=(4*num_cols, 4*num_rows), dpi=300)\n",
    "\n",
    "data_list = []\n",
    "title_list = []\n",
    "for r in range(num_rows):\n",
    "    data_list.append([])\n",
    "    title_list.append([])\n",
    "    for method in methods:\n",
    "        if r == 0:\n",
    "            data_list[r] += [bgs_list[env][method]]\n",
    "        elif r == 1:\n",
    "            data_list[r] += [ags_list[env][method]]\n",
    "        title_list[r] += [\n",
    "                  \"{} {}\".format(method_name_dict[method], \"(ours)\" if \"MEGA\" in method else \"\")]\n",
    "\n",
    "row_title = [r'Behavioural Goals', r'Final Achieved Goals']\n",
    "for r in range(num_rows):\n",
    "    for c in range(num_cols):\n",
    "        plot_antmaze_template(axs[r,c])\n",
    "\n",
    "        im = axs[r,c].scatter(data_list[r][c][:,0], data_list[r][c][:,1], c=data_list[r][c][:,2], alpha=0.4)\n",
    "\n",
    "        if r == 0:\n",
    "            axs[r][c].set_title(title_list[r][c], fontsize=20, fontweight=\"bold\" if \"MEGA\" in methods[c] else \"normal\")\n",
    "        if c == 0:\n",
    "            axs[r][c].set_ylabel(row_title[r], fontsize=16, fontweight=\"bold\")\n",
    "            axs[r][c].axes.get_yaxis().set_ticks([])\n",
    "        else:\n",
    "            axs[r][c].get_yaxis().set_visible(False)\n",
    "\n",
    "        # Hacky way to hide the frames\n",
    "        axs[r][c].get_xaxis().set_visible(False)\n",
    "        axs[r][c].spines['bottom'].set_color('white')\n",
    "        axs[r][c].spines['top'].set_color('white') \n",
    "        axs[r][c].spines['right'].set_color('white')\n",
    "        axs[r][c].spines['left'].set_color('white')\n",
    "\n",
    "fig.subplots_adjust(right=0.95)\n",
    "cbar_ax = fig.add_axes([0.95, 0.15, 0.01, 0.7])\n",
    "cbar = fig.colorbar(im, cax=cbar_ax)\n",
    "cbar.set_label('Episodes (Thousands)', rotation=270, labelpad=18, fontsize=18)\n",
    "cbar.ax.tick_params(labelsize=14)\n",
    "plt.subplots_adjust(bottom=0.1, wspace=0.04, hspace=0.04)\n",
    "plt.savefig(\"/scratch/gobi1/spitis/shared/icml20/plots/antmaze_last_ags_bgs_all_viz.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note**: The above plot went into the paper"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}