{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VzgXI8ST20DS"
      },
      "source": [
        "# Transformation Coding [MountainCar]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tdtcedbh2ot2"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qsn5NtvYldkJ"
      },
      "source": [
        "import os\n",
        "import io\n",
        "import cv2\n",
        "import glob\n",
        "import time\n",
        "import math\n",
        "import json\n",
        "import imageio\n",
        "import colorsys\n",
        "import multiprocessing as mp\n",
        "from os import path\n",
        "from datetime import datetime\n",
        "from functools import partial\n",
        "\n",
        "import numpy as np\n",
        "from PIL import Image\n",
        "from tqdm import tqdm\n",
        "\n",
        "import gym\n",
        "from gym import spaces\n",
        "from gym.utils import seeding\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torch.nn.functional as F\n",
        "from torch import autograd\n",
        "\n",
        "%matplotlib inline\n",
        "import matplotlib.pyplot as plt\n",
        "from mpl_toolkits.mplot3d import Axes3D\n",
        "import matplotlib\n",
        "\n",
        "matplotlib.interactive(False)\n",
        "# matplotlib.use('agg')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a12JiTew2q9q"
      },
      "source": [
        "## Prepare Experiment Folders"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "x_FH7sPhIQVT"
      },
      "source": [
        "timestamp = str(datetime.now()).replace(' ', '_')\n",
        "EXPERIMENT_DIR = 'mountaincar_%s' % timestamp\n",
        "os.makedirs(EXPERIMENT_DIR)\n",
        "%cd $EXPERIMENT_DIR\n",
        "!mkdir gifs plots videos saved_models figs_final figs_progress"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hoIkwOnZKsP9"
      },
      "source": [
        "## Logging"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eCx_rhuiKr9v"
      },
      "source": [
        "try:\n",
        "    LOG_STR.close()\n",
        "    LOG_FILE.close()\n",
        "except:\n",
        "    pass\n",
        "\n",
        "timestamp = str(datetime.now())\n",
        "LOG_STR = io.StringIO()\n",
        "LOG_FILE = open('log.txt', 'w')\n",
        "\n",
        "try:\n",
        "    old_print\n",
        "except NameError:\n",
        "    old_print = print\n",
        "\n",
        "def print(*args, **kwargs):\n",
        "    kwargs['flush'] = True\n",
        "    old_print(*args, **kwargs)\n",
        "    kwargs['file'] = LOG_STR\n",
        "    old_print(*args, **kwargs)\n",
        "    kwargs['file'] = LOG_FILE\n",
        "    old_print(*args, **kwargs)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EVOc_tr3K4CU"
      },
      "source": [
        "## Arguments"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BEjSsrvSKiwM"
      },
      "source": [
        "class Args():\n",
        "    def __init__(self):\n",
        "        self.num_channels = 24\n",
        "        self.mlp_hidden_dim = 128\n",
        "        self.learning_rate = 1e-3\n",
        "        self.weight_decay = 1e-7\n",
        "        self.batch_size = 64\n",
        "        self.batch_size_incr = 0\n",
        "        self.num_actions = 4\n",
        "        self.num_epochs = 50\n",
        "        self.steps_per_epoch = 100\n",
        "        self.img_w = 32\n",
        "        self.img_h = 32\n",
        "        self.barrier_type = 'log' # 'inv' or 'log' or 'hinge'\n",
        "        self.barrier_coef = 1\n",
        "        self.cosine_sim = False\n",
        "        self.conformal_map = False\n",
        "        self.code_size = 3\n",
        "        self.proj_size = 3\n",
        "        self.gravity = True\n",
        "\n",
        "\n",
        "args = Args()\n",
        "subcode_size = args.code_size\n",
        "print(json.dumps(vars(args), sort_keys=True, indent=4))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CrsbWB-eQd3q"
      },
      "source": [
        "## MountainCar Environment"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JMy_g8QzQdLU"
      },
      "source": [
        "class Continuous_MountainCarEnv(gym.Env):\n",
        "    metadata = {\n",
        "        'render.modes': ['human', 'rgb_array'],\n",
        "        'video.frames_per_second': 30\n",
        "    }\n",
        "\n",
        "    def __init__(self, goal_velocity=0):\n",
        "        self.min_action = -1.0\n",
        "        self.max_action = 1.0\n",
        "\n",
        "        self.min_position = -1.2\n",
        "        self.max_position = 0.6\n",
        "\n",
        "        self.max_speed = 0.07\n",
        "        # self.max_speed = 0.1\n",
        "\n",
        "        self.goal_position = 0.45 # was 0.5 in gym, 0.45 in Arnaud de Broissia's version\n",
        "        self.goal_velocity = goal_velocity\n",
        "\n",
        "        self.power = 0.0015\n",
        "        # self.power = 0.005\n",
        "\n",
        "        self.low_state = np.array(\n",
        "            [self.min_position, -self.max_speed], dtype=np.float32\n",
        "        )\n",
        "        self.high_state = np.array(\n",
        "            [self.max_position, self.max_speed], dtype=np.float32\n",
        "        )\n",
        "\n",
        "        self.viewer = None\n",
        "\n",
        "        self.action_space = spaces.Box(\n",
        "            low=self.min_action,\n",
        "            high=self.max_action,\n",
        "            shape=(1,),\n",
        "            dtype=np.float32\n",
        "        )\n",
        "        self.observation_space = spaces.Box(\n",
        "            low=self.low_state,\n",
        "            high=self.high_state,\n",
        "            dtype=np.float32\n",
        "        )\n",
        "\n",
        "        self.seed()\n",
        "        self.reset()\n",
        "\n",
        "    def seed(self, seed=None):\n",
        "        self.np_random, seed = seeding.np_random(seed)\n",
        "        return [seed]\n",
        "\n",
        "    def step(self, action):\n",
        "        position = self.state[0]\n",
        "        velocity = self.state[1]\n",
        "        force = min(max(action[0], self.min_action), self.max_action)\n",
        "\n",
        "        velocity += force * self.power - (0.0025 * math.cos(3 * position) if args.gravity else 0)\n",
        "        # if (velocity > self.max_speed): velocity = self.max_speed\n",
        "        # if (velocity < -self.max_speed): velocity = -self.max_speed\n",
        "        position += velocity\n",
        "        # if (position > self.max_position): position = self.max_position\n",
        "        # if (position < self.min_position): position = self.min_position\n",
        "        # if (position == self.min_position and velocity < 0): velocity = 0\n",
        "\n",
        "        # Convert a possible numpy bool to a Python bool.\n",
        "        done = bool(\n",
        "            position >= self.goal_position and velocity >= self.goal_velocity\n",
        "        )\n",
        "\n",
        "        reward = 0\n",
        "        if done:\n",
        "            reward = 100.0\n",
        "        reward -= math.pow(action[0], 2) * 0.1\n",
        "\n",
        "        self.state = np.array([position, velocity])\n",
        "        return self.state, reward, done, {}\n",
        "\n",
        "    def reset(self):\n",
        "        self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])\n",
        "        return np.array(self.state)\n",
        "\n",
        "    def _height(self, xs):\n",
        "        return np.sin(3 * xs)*.45+.55\n",
        "\n",
        "    def render(self, mode='human'):\n",
        "        screen_width = 600 // 2\n",
        "        screen_height = 400 // 2\n",
        "\n",
        "        world_width = self.max_position - self.min_position\n",
        "        scale = screen_width/world_width\n",
        "        carwidth = 40\n",
        "        carheight = 20\n",
        "\n",
        "        if self.viewer is None:\n",
        "            from gym.envs.classic_control import rendering\n",
        "            self.viewer = rendering.Viewer(screen_width, screen_height)\n",
        "            xs = np.linspace(self.min_position, self.max_position, 100)\n",
        "            ys = self._height(xs)\n",
        "            xys = list(zip((xs-self.min_position)*scale, ys*scale))\n",
        "\n",
        "            self.track = rendering.make_polyline(xys)\n",
        "            self.track.set_linewidth(4)\n",
        "            self.viewer.add_geom(self.track)\n",
        "\n",
        "            clearance = 10\n",
        "\n",
        "            l, r, t, b = -carwidth / 2, carwidth / 2, carheight, 0\n",
        "            car = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])\n",
        "            car.add_attr(rendering.Transform(translation=(0, clearance)))\n",
        "            self.cartrans = rendering.Transform()\n",
        "            car.add_attr(self.cartrans)\n",
        "            self.viewer.add_geom(car)\n",
        "            frontwheel = rendering.make_circle(carheight / 2.5)\n",
        "            frontwheel.set_color(.5, .5, .5)\n",
        "            frontwheel.add_attr(\n",
        "                rendering.Transform(translation=(carwidth / 4, clearance))\n",
        "            )\n",
        "            frontwheel.add_attr(self.cartrans)\n",
        "            self.viewer.add_geom(frontwheel)\n",
        "            backwheel = rendering.make_circle(carheight / 2.5)\n",
        "            backwheel.add_attr(\n",
        "                rendering.Transform(translation=(-carwidth / 4, clearance))\n",
        "            )\n",
        "            backwheel.add_attr(self.cartrans)\n",
        "            backwheel.set_color(.5, .5, .5)\n",
        "            self.viewer.add_geom(backwheel)\n",
        "            flagx = (self.goal_position-self.min_position)*scale\n",
        "            flagy1 = self._height(self.goal_position)*scale\n",
        "            flagy2 = flagy1 + 50\n",
        "            flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2))\n",
        "            self.viewer.add_geom(flagpole)\n",
        "            flag = rendering.FilledPolygon(\n",
        "                [(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)]\n",
        "            )\n",
        "            flag.set_color(.8, .8, 0)\n",
        "            self.viewer.add_geom(flag)\n",
        "\n",
        "        pos = self.state[0]\n",
        "        self.cartrans.set_translation(\n",
        "            (pos-self.min_position) * scale, self._height(pos) * scale\n",
        "        )\n",
        "        self.cartrans.set_rotation(math.cos(3 * pos))\n",
        "\n",
        "        obs_rgb = self.viewer.render(return_rgb_array=mode == 'rgb_array')\n",
        "        return obs_rgb\n",
        "\n",
        "    def close(self):\n",
        "        if self.viewer:\n",
        "            self.viewer.close()\n",
        "            self.viewer = None\n",
        "\n",
        "    def sample_random_transition(self, u_list):\n",
        "        num_actions = len(u_list) + 1\n",
        "        low_state = np.array([self.min_position + 0.1, -self.max_speed], dtype=np.float32)\n",
        "        high_state = np.array([self.max_position - 0.1, self.max_speed], dtype=np.float32)\n",
        "        init_state = np.random.uniform(low=low_state, high=high_state)\n",
        "        x_list = [None] * num_actions\n",
        "        for i, u in enumerate(u_list):\n",
        "            assert (self.min_action <= u).all() and (u <= self.max_action).all()\n",
        "            self.state = init_state\n",
        "            obs_list = []\n",
        "            obs_list.append(self.render(mode='rgb_array'))\n",
        "            self.state[0] += self.state[1]\n",
        "            obs_list.append(self.render(mode='rgb_array'))\n",
        "            self.step(u)\n",
        "            obs_list.append(self.render(mode='rgb_array'))\n",
        "            x_list[0] = np.stack(obs_list[:2], axis=0)\n",
        "            x_list[i + 1] = np.stack(obs_list[-2:], axis=0)\n",
        "        return np.stack(x_list, axis=0)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yQgD4MDKBzrF"
      },
      "source": [
        "## Transformation Sampling"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RqDYzs_VAj3l"
      },
      "source": [
        "def normalize(x):\n",
        "    return 2.0 * (x.astype(np.float32) / 255.0 - 0.5)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wUaCUhY2N4VM"
      },
      "source": [
        "def process_obs(x):\n",
        "    return np.array(Image.fromarray(x).resize((args.img_w, args.img_h), resample=Image.BICUBIC).convert('L'))\n",
        "\n",
        "\n",
        "process_obs = np.vectorize(process_obs, signature='(n,m,k)->(p,q)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Y29nrdIXKmJz"
      },
      "source": [
        "env = Continuous_MountainCarEnv()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UllDY1qrgPUa"
      },
      "source": [
        "def unstack(a, axis):\n",
        "    return [np.squeeze(x, axis) for x in np.split(a, a.shape[axis], axis=axis)]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FITe2RvLyGY2"
      },
      "source": [
        "def sample(batch_size, num_actions):\n",
        "    u_list = [np.random.uniform(low=env.action_space.low, high=env.action_space.high, size=(2, 1)) for _ in range(num_actions - 1)]\n",
        "    batch_list = []\n",
        "    for i in range(batch_size):\n",
        "        batch_list.append(env.sample_random_transition(u_list))\n",
        "    batch = np.stack(batch_list, axis=0)\n",
        "    return unstack(batch, axis=1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5dvE1VCgnTvC"
      },
      "source": [
        "## Test Transformation Sampling"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rzNRqbL-KcGh"
      },
      "source": [
        "x1, x2 = sample(args.batch_size, num_actions=2)\n",
        "print(x1.shape, x1.dtype)\n",
        "print(x2.shape, x2.dtype)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "e2tXK14umWhN"
      },
      "source": [
        "plt.style.use('default')\n",
        "\n",
        "fig = plt.figure(figsize=(12, 3))\n",
        "idx = 0\n",
        "\n",
        "plt.subplot(1, 4, 1)\n",
        "plt.imshow(x1[idx, 0], cmap='gray', vmin=0, vmax=255)\n",
        "\n",
        "plt.subplot(1, 4, 2)\n",
        "plt.imshow(x1[idx, 1], cmap='gray', vmin=0, vmax=255)\n",
        "\n",
        "plt.subplot(1, 4, 3)\n",
        "plt.imshow(x2[idx, 0], cmap='gray', vmin=0, vmax=255)\n",
        "\n",
        "plt.subplot(1, 4, 4)\n",
        "plt.imshow(x2[idx, 1], cmap='gray', vmin=0, vmax=255)\n",
        "\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sJidR4-HkNvT"
      },
      "source": [
        "## Visualization"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TjXBdXCnkN1B"
      },
      "source": [
        "MIN_POSITION = env.min_position + 0.1\n",
        "MAX_POSITION = env.max_position - 0.1\n",
        "MAX_SPEED = env.max_speed\n",
        "\n",
        "N_POS = 160\n",
        "N_SPEED = 140\n",
        "\n",
        "viz_dataset = np.empty((N_POS, N_SPEED, 2, args.img_w, args.img_h))\n",
        "\n",
        "for i, pos in enumerate(tqdm(np.linspace(MIN_POSITION, MAX_POSITION, N_POS, endpoint=True))):\n",
        "    for j, speed in enumerate(np.linspace(-MAX_SPEED, MAX_SPEED, N_SPEED, endpoint=True)):\n",
        "        env.state = [pos, speed]\n",
        "        obs0 = env.render(mode='rgb_array')\n",
        "        env.state[0] += env.state[1]\n",
        "        obs1 = env.render(mode='rgb_array')\n",
        "        viz_dataset[i, j, 0] = normalize(process_obs(obs0))\n",
        "        viz_dataset[i, j, 1] = normalize(process_obs(obs1))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PndZvrLhgeAY"
      },
      "source": [
        "if subcode_size > args.proj_size:\n",
        "    t = np.random.randn(subcode_size, subcode_size)\n",
        "    PROJ_MAT = np.linalg.svd(t)[0][:, :args.proj_size]\n",
        "    PROJ_MAT /= np.sum(PROJ_MAT, axis=0, keepdims=True)\n",
        "else:\n",
        "    PROJ_MAT = np.eye(args.proj_size)\n",
        "\n",
        "with np.printoptions(linewidth=300, formatter={'float': lambda x: '%12f' % x}):\n",
        "    print('PROJ_MAT =\\n', PROJ_MAT)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MWErjBoDYDf7"
      },
      "source": [
        "def scatter_3d(ax, x, y, z, c, elev=30, azim=-60, **kwargs):\n",
        "    elev_rad = elev / 180 * np.pi\n",
        "    azim_rad = azim / 180 * np.pi\n",
        "    ax.view_init(elev, azim)\n",
        "    idx_sort = np.argsort(np.cos(azim_rad) * x + np.sin(azim_rad) * y + np.sin(elev_rad) * z, kind='heapsort')\n",
        "    ax.scatter(x[idx_sort], y[idx_sort], z[idx_sort], c=c[idx_sort], **kwargs)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DYxIDwtKkiQx"
      },
      "source": [
        "def visualize(ax, enc, data, s=1, **kwargs):\n",
        "    device = get_device(enc)\n",
        "    code = np.empty((N_POS, N_SPEED, args.proj_size))\n",
        "\n",
        "    with torch.no_grad():\n",
        "        for i in range(N_POS):\n",
        "            code[i] = enc(torch.tensor(data[i], dtype=torch.float32, device=device)).cpu().numpy() @ PROJ_MAT\n",
        "\n",
        "    colors = np.empty((N_POS, N_SPEED, 3))\n",
        "\n",
        "    for i in range(N_POS):\n",
        "        for j in range(N_SPEED):\n",
        "            hls = (i / N_POS, j / (N_SPEED - 1) * 0.95, 1)\n",
        "            rgb = colorsys.hls_to_rgb(*hls)\n",
        "            colors[i, j] = rgb\n",
        "\n",
        "    if args.proj_size == 3:\n",
        "        # make the panes transparent\n",
        "        ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
        "        ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
        "        ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
        "        # make the grid lines transparent\n",
        "        ax.xaxis._axinfo['grid']['color'] = (1,1,1,0)\n",
        "        ax.yaxis._axinfo['grid']['color'] = (1,1,1,0)\n",
        "        ax.zaxis._axinfo['grid']['color'] = (1,1,1,0)\n",
        "\n",
        "    code = code.reshape(-1, args.proj_size)\n",
        "    colors = colors.reshape(-1, 3)\n",
        "\n",
        "    if args.proj_size == 3:\n",
        "        scatter_3d(ax, code[:, 0], code[:, 1], code[:, 2], c=colors, s=s, **kwargs)\n",
        "    elif args.proj_size == 2:\n",
        "        ax.scatter(code[:, 0], code[:, 1], c=colors, s=s, **kwargs)\n",
        "    else:\n",
        "        assert False"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JH9QHBQJO7Qz"
      },
      "source": [
        "## Encoder Network"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2PRN3i6bCPLE"
      },
      "source": [
        "class View(nn.Module):\n",
        "    def __init__(self, *shape):\n",
        "        super(View, self).__init__()\n",
        "        self.shape = shape\n",
        "\n",
        "    def forward(self, x):\n",
        "        return x.view(self.shape)\n",
        "\n",
        "    def __repr__(self):\n",
        "        return 'View(%s)' % (', '.join(['%d' % x for x in self.shape]))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vjQVUzYLfS79"
      },
      "source": [
        "IN_CHANNELS = 2\n",
        "hidden_dim = args.num_channels * (args.img_w // 8) * (args.img_h // 8)\n",
        "\n",
        "enc = nn.Sequential(\n",
        "    nn.Conv2d(IN_CHANNELS, args.num_channels, kernel_size=3, padding=1),\n",
        "    nn.ReLU(),\n",
        "    nn.MaxPool2d(kernel_size=2),\n",
        "\n",
        "    nn.Conv2d(args.num_channels, args.num_channels, kernel_size=3, padding=1),\n",
        "    nn.ReLU(),\n",
        "    nn.MaxPool2d(kernel_size=2),\n",
        "\n",
        "    nn.Conv2d(args.num_channels, args.num_channels, kernel_size=3, padding=1),\n",
        "    nn.ReLU(),\n",
        "    nn.MaxPool2d(kernel_size=2),\n",
        "\n",
        "    View(-1, hidden_dim),\n",
        "\n",
        "    nn.Linear(hidden_dim, args.mlp_hidden_dim),\n",
        "    nn.ReLU(),\n",
        "    nn.Linear(args.mlp_hidden_dim, args.code_size, bias=False)\n",
        ")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "R9mtExX8fTSj"
      },
      "source": [
        "def count_parameters(net):\n",
        "    return sum(p.numel() for p in net.parameters() if p.requires_grad)\n",
        "\n",
        "\n",
        "def get_device(net):\n",
        "    '''\n",
        "    Returns the `torch.device` on which the network resides.\n",
        "    This method only makes sense when all module parameters reside on the **same** device.\n",
        "    '''\n",
        "    return list(net.parameters())[0].device"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Nsl2lsi_fTWi"
      },
      "source": [
        "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "print('Device:', DEVICE)\n",
        "enc.to(DEVICE)\n",
        "print(enc)\n",
        "print('%d parameters' % count_parameters(enc))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UiRpYgiHiPOJ"
      },
      "source": [
        "## Optimizer"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Yh2gKlJhfTaU"
      },
      "source": [
        "opt = optim.Adam(enc.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zwvYhWuI6UlG"
      },
      "source": [
        "## Loss Function"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0IWI-lPRYE_o"
      },
      "source": [
        "def cdist_mean(x1, x2, p=2.0, *args, **kwargs):\n",
        "    dim = x1.shape[1]\n",
        "    return torch.cdist(x1, x2, p, *args, **kwargs) * (dim ** (-1 / p))\n",
        "\n",
        "\n",
        "def loss_fn(enc, x_list, barrier_type, cosine_sim=False, conformal_map=False, decompositions=1):\n",
        "    device = get_device(enc)\n",
        "    num_actions = len(x_list)\n",
        "    assert num_actions >= 2\n",
        "    x_list = [torch.tensor(x, dtype=torch.float32, device=device) for x in x_list]\n",
        "    z_list = [enc(x) for x in x_list]\n",
        "    code_size = z_list[0].shape[1]\n",
        "\n",
        "    #-- symmetry loss --#\n",
        "    subcode_size = code_size // decompositions\n",
        "    loss_equiv = 0\n",
        "    for k in range(decompositions):\n",
        "        h_list = [z[:, k * subcode_size: (k + 1) * subcode_size] for z in z_list]\n",
        "\n",
        "        if conformal_map:\n",
        "            h_list = [(h[:, None, :] - h[None, :, :]).view(-1, subcode_size) for h in h_list]\n",
        "\n",
        "        if cosine_sim or conformal_map:\n",
        "            D_list = [1.0 - F.cosine_similarity(h[:, None, :], h[None, :, :], dim=2) for h in h_list]\n",
        "        else:\n",
        "            D_list = [cdist_mean(h, h, p=2) for h in h_list]\n",
        "\n",
        "        L_equiv = torch.zeros(num_actions, num_actions)\n",
        "        for i in range(num_actions):\n",
        "            for j in range(i + 1, num_actions):\n",
        "                L_equiv[i, j] = torch.mean((D_list[i] - D_list[j]) ** 2)\n",
        "        cur_loss_equiv = torch.sum(L_equiv) / (num_actions * (num_actions - 1) / 2)\n",
        "        loss_equiv += cur_loss_equiv / decompositions\n",
        "\n",
        "    #-- barrier loss --#\n",
        "    z_all = torch.cat(z_list, dim=0)\n",
        "    D_all = cdist_mean(z_all, z_all, p=2)\n",
        "    mask = torch.eye(D_all.shape[0], dtype=torch.bool)\n",
        "    if barrier_type == 'log':\n",
        "        loss_barrier = torch.mean(-torch.log(D_all[~mask] + 1e-9))\n",
        "    elif barrier_type == 'inv':\n",
        "        loss_barrier = torch.mean(1.0 / (D_all[~mask] + 1e-9))\n",
        "    elif barrier_type == 'hinge':\n",
        "        loss_barrier = torch.mean(torch.maximum(torch.zeros(1, device=device), 5.0 - D_all[~mask]))\n",
        "    else:\n",
        "        assert False, 'Unknown `barrier_type`'\n",
        "\n",
        "    return loss_equiv, loss_barrier"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PKZOSNYX_f-x"
      },
      "source": [
        "## Train"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-prwPRYR_kFt"
      },
      "source": [
        "def check_weights_nan(parameters):\n",
        "    for p in parameters:\n",
        "        if torch.isnan(p).any():\n",
        "            return True\n",
        "    return False\n",
        "\n",
        "\n",
        "def get_weights_norm(parameters, norm_type=2.0):\n",
        "    with torch.no_grad():\n",
        "        return torch.norm(torch.stack([torch.norm(p, norm_type) for p in parameters]), norm_type).item()\n",
        "\n",
        "\n",
        "def get_grads_norm(parameters, norm_type=2.0):\n",
        "    with torch.no_grad():\n",
        "        return torch.norm(torch.stack([torch.norm(p.grad, norm_type) for p in parameters]), norm_type).item()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VJcYn9Vn_kK7"
      },
      "source": [
        "torch.save(enc.state_dict(), path.join('saved_models', 'model_epoch_%03d.tar' % 0))\n",
        "\n",
        "fig = plt.figure(figsize=(6, 6))\n",
        "ax = fig.add_subplot(111, projection='3d' if args.proj_size == 3 else None)\n",
        "enc.eval()\n",
        "visualize(ax, enc, viz_dataset)\n",
        "fig.savefig(path.join('figs_progress', 'codes_epoch_%03d.png' % 0), bbox_inches='tight', dpi=200)\n",
        "plt.show()\n",
        "plt.close()\n",
        "\n",
        "batch_size = args.batch_size\n",
        "\n",
        "avg_loss_list = []\n",
        "avg_loss_equiv_list = []\n",
        "avg_loss_barrier_list = []\n",
        "for t in range(args.num_epochs):\n",
        "    enc.train()\n",
        "    loss_list = []\n",
        "    loss_equiv_list = []\n",
        "    loss_barrier_list = []\n",
        "\n",
        "    time_start = time.time()\n",
        "    progress = tqdm(range(args.steps_per_epoch), desc='Loss: None | Loss Equiv: None | Loss Barrier: None | L2 Weights: %12g | L2 Grads: 0' % (\n",
        "        get_weights_norm(enc.parameters(), norm_type=2.0)\n",
        "    ), total=args.steps_per_epoch, position=0, leave=True)\n",
        "    for _ in progress:\n",
        "        u = np.random.uniform(low=env.action_space.low, high=env.action_space.high, size=(2, 1))\n",
        "        x_list = sample(batch_size, args.num_actions)\n",
        "        x_list = [normalize(process_obs(x)) for x in x_list]\n",
        "        loss_equiv, loss_barrier = loss_fn(enc, x_list, args.barrier_type, args.cosine_sim, args.conformal_map)\n",
        "        loss = loss_equiv + args.barrier_coef * loss_barrier\n",
        "        opt.zero_grad()\n",
        "        loss.backward()\n",
        "        opt.step()\n",
        "        # I multiply the losses by 1e4 to get nicer numbers\n",
        "        loss_list.append(1e4 * loss.item())\n",
        "        loss_equiv_list.append(1e4 * loss_equiv.item())\n",
        "        loss_barrier_list.append(1e4 * args.barrier_coef * loss_barrier.item())\n",
        "        progress.set_description('Loss: %12g | Loss Equiv: %12g | Loss barrier: %12g | L2 Weights: %12g | L2 Grads: %12g' % (\n",
        "            loss_list[-1],\n",
        "            loss_equiv_list[-1],\n",
        "            loss_barrier_list[-1],\n",
        "            get_weights_norm(enc.parameters(), norm_type=2.0),\n",
        "            get_grads_norm(enc.parameters(), norm_type=2.0)\n",
        "        ))\n",
        "    time_end = time.time()\n",
        "\n",
        "    torch.save(enc.state_dict(), path.join('saved_models', 'model_epoch_%03d.tar' % (t + 1)))\n",
        "\n",
        "    avg_loss = np.mean(loss_list)\n",
        "    avg_loss_equiv = np.mean(loss_equiv_list)\n",
        "    avg_loss_barrier = np.mean(loss_barrier_list)\n",
        "    avg_loss_list.append(avg_loss)\n",
        "    avg_loss_equiv_list.append(avg_loss_equiv)\n",
        "    avg_loss_barrier_list.append(avg_loss_barrier)\n",
        "    print('\\nEpoch %3d | Loss: %12g | Loss Equiv: %12g | Loss Barrier: %12g | Time: %6.1f sec' % (\n",
        "        t + 1, avg_loss, avg_loss_equiv, avg_loss_barrier, time_end - time_start))\n",
        "\n",
        "    if args.batch_size_incr != 0:\n",
        "        new_batch_size = batch_size + args.batch_size_incr\n",
        "        print('Batch size increased from %d to %d' % (batch_size, new_batch_size))\n",
        "        batch_size = new_batch_size\n",
        "\n",
        "    fig = plt.figure(figsize=(6, 6))\n",
        "    ax = fig.add_subplot(111, projection='3d' if args.proj_size == 3 else None)\n",
        "    enc.eval()\n",
        "    visualize(ax, enc, viz_dataset)\n",
        "    fig.savefig(path.join('figs_progress', 'codes_epoch_%03d.png' % (t + 1)), bbox_inches='tight', dpi=200)\n",
        "    plt.show()\n",
        "    plt.close()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RDumfR4N_kPg"
      },
      "source": [
        "save_path = path.join('saved_models', 'model_final.tar')\n",
        "print('saving model to %s' % save_path)\n",
        "torch.save(enc.state_dict(), save_path)\n",
        "# enc.load_state_dict(torch.load('/content/model_epoch_050.tar', map_location=DEVICE))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W6jvbi3p_kVo"
      },
      "source": [
        "## Error Curve"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "d_WcS1xg_kaf"
      },
      "source": [
        "fig = plt.figure(figsize=(5, 5))\n",
        "plt.plot(avg_loss_list)\n",
        "plt.xlabel('Epoch')\n",
        "plt.ylabel('Loss')\n",
        "fig.savefig(path.join('plots', 'error_curve.png'), bbox_inches='tight', dpi=300)\n",
        "plt.show()\n",
        "plt.close()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oPJFA0mLNr48"
      },
      "source": [
        "## Visualize Progress"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3QY7b5HnNZwO"
      },
      "source": [
        "image_list = []\n",
        "files = sorted(glob.glob('figs_progress/*.png'))\n",
        "for file in files:\n",
        "    image_list.append(cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB))\n",
        "imageio.mimsave(path.join('gifs', 'codes_progress.gif'), image_list, fps=10)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "C1pm-Nt5O5A-"
      },
      "source": [
        "! ffmpeg -v quiet -r 10 -i ./figs_progress/codes_epoch_%03d.png -y videos/codes_progress.webm"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2pxdU9EZ_kgN"
      },
      "source": [
        "## Visualize Final Codes"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "g-myzU_3N8B2"
      },
      "source": [
        "fig = plt.figure(figsize=(8, 8))\n",
        "ax = fig.add_subplot(111, projection='3d' if args.proj_size == 3 else None)\n",
        "enc.eval()\n",
        "visualize(ax, enc, viz_dataset, s=5)\n",
        "fig.savefig(path.join('plots', 'codes_final.png'), bbox_inches='tight', dpi=600)\n",
        "plt.show()\n",
        "plt.close()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sgHFyL32lXMK"
      },
      "source": [
        "! rm -rf figs_final\n",
        "! mkdir figs_final\n",
        "\n",
        "enc.eval()\n",
        "\n",
        "if args.proj_size == 3:\n",
        "    for i in tqdm(range(72)):\n",
        "        fig = plt.figure(figsize=(6, 6))\n",
        "        ax = fig.add_subplot(111, projection='3d')\n",
        "        # plt.axis('off')\n",
        "        visualize(ax, enc, viz_dataset, azim=i * 5)\n",
        "        fig.savefig(path.join('figs_final', 'final_codes_angle_%03d.png' % i), bbox_inches='tight', dpi=200)\n",
        "        plt.close(fig)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iemaX8DzmbmF"
      },
      "source": [
        "if args.proj_size == 3:\n",
        "    image_list = []\n",
        "    files = sorted(glob.glob('figs_final/*.png'))\n",
        "    for file in files:\n",
        "        image_list.append(cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB))\n",
        "    imageio.mimsave(path.join('gifs', 'codes_final.gif'), image_list, fps=10)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qtVbIAb5PIkC"
      },
      "source": [
        "! ffmpeg -v quiet -r 10 -i ./figs_final/final_codes_angle_%03d.png -y videos/final_codes.webm"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9lVoqE2oFyh5"
      },
      "source": [
        "## Zip Results"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Y79gBkAARNvm"
      },
      "source": [
        "!zip -r9v $(basename \"$PWD\").zip figs_progress/ figs_final/ saved_models/ log.txt plots/ gifs/ videos/"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}