{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VzgXI8ST20DS"
      },
      "source": [
        "# Transformation Coding [Pendulum]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tdtcedbh2ot2"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qsn5NtvYldkJ"
      },
      "source": [
        "import os\n",
        "import io\n",
        "import cv2\n",
        "import glob\n",
        "import time\n",
        "import json\n",
        "import imageio\n",
        "import colorsys\n",
        "import multiprocessing as mp\n",
        "from os import path\n",
        "from datetime import datetime\n",
        "from functools import partial\n",
        "\n",
        "import numpy as np\n",
        "from PIL import Image\n",
        "from tqdm import tqdm\n",
        "\n",
        "import gym\n",
        "from gym import spaces\n",
        "from gym.utils import seeding\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torch.nn.functional as F\n",
        "from torch import autograd\n",
        "\n",
        "%matplotlib inline\n",
        "import matplotlib.pyplot as plt\n",
        "from mpl_toolkits.mplot3d import Axes3D\n",
        "import matplotlib\n",
        "\n",
        "matplotlib.interactive(False)\n",
        "# matplotlib.use('agg')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a12JiTew2q9q"
      },
      "source": [
        "## Prepare Experiment Folders"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "x_FH7sPhIQVT"
      },
      "source": [
        "timestamp = str(datetime.now()).replace(' ', '_')\n",
        "EXPERIMENT_DIR = 'pendulum_%s' % timestamp\n",
        "os.makedirs(EXPERIMENT_DIR)\n",
        "%cd $EXPERIMENT_DIR\n",
        "!mkdir gifs plots videos saved_models figs_final figs_progress"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hoIkwOnZKsP9"
      },
      "source": [
        "## Logging"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eCx_rhuiKr9v"
      },
      "source": [
        "try:\n",
        "    LOG_STR.close()\n",
        "    LOG_FILE.close()\n",
        "except:\n",
        "    pass\n",
        "\n",
        "timestamp = str(datetime.now())\n",
        "LOG_STR = io.StringIO()\n",
        "LOG_FILE = open('log.txt', 'a')\n",
        "\n",
        "try:\n",
        "    old_print\n",
        "except NameError:\n",
        "    old_print = print\n",
        "\n",
        "def print(*args, **kwargs):\n",
        "    kwargs['flush'] = True\n",
        "    old_print(*args, **kwargs)\n",
        "    kwargs['file'] = LOG_STR\n",
        "    old_print(*args, **kwargs)\n",
        "    kwargs['file'] = LOG_FILE\n",
        "    old_print(*args, **kwargs)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EVOc_tr3K4CU"
      },
      "source": [
        "## Arguments"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BEjSsrvSKiwM"
      },
      "source": [
        "class Args():\n",
        "    def __init__(self):\n",
        "        self.num_channels = 24\n",
        "        self.mlp_hidden_dim = 128\n",
        "        self.learning_rate = 1e-3\n",
        "        self.weight_decay = 1e-7\n",
        "        self.batch_size = 64\n",
        "        self.batch_size_incr = 0\n",
        "        self.num_actions = 4\n",
        "        self.num_epochs = 50\n",
        "        self.steps_per_epoch = 100\n",
        "        self.img_w = 32\n",
        "        self.img_h = 32\n",
        "        self.barrier_type = 'log' # 'inv' or 'log' or 'hinge'\n",
        "        self.barrier_coef = 1\n",
        "        self.cosine_sim = False\n",
        "        self.conformal_map = False\n",
        "        self.code_size = 3\n",
        "        self.proj_size = 3\n",
        "        self.dt = 0.05 # gym uses 0.05\n",
        "        self.gravity = True\n",
        "\n",
        "\n",
        "args = Args()\n",
        "subcode_size = args.code_size\n",
        "print(json.dumps(vars(args), sort_keys=True, indent=4))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NoK2P37fKvXu"
      },
      "source": [
        "## Pendulum Environment"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xdSKlaPV03gl"
      },
      "source": [
        "class PendulumEnv(gym.Env):\n",
        "    metadata = {\n",
        "        'render.modes' : ['human', 'rgb_array'],\n",
        "        'video.frames_per_second' : 30\n",
        "    }\n",
        "\n",
        "    def __init__(self, g=10.0):\n",
        "        self.max_speed=8\n",
        "        self.max_torque=2.\n",
        "        # self.max_torque=0.1\n",
        "        self.dt = args.dt\n",
        "        self.g = g\n",
        "        self.viewer = None\n",
        "\n",
        "        high = np.array([1., 1., self.max_speed])\n",
        "        self.action_space = spaces.Box(low=-self.max_torque, high=self.max_torque, shape=(1,), dtype=np.float32)\n",
        "        self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32)\n",
        "\n",
        "        self.seed()\n",
        "        self.reset()\n",
        "\n",
        "    def seed(self, seed=None):\n",
        "        self.np_random, seed = seeding.np_random(seed)\n",
        "        return [seed]\n",
        "\n",
        "    def step(self,u):\n",
        "        th, thdot = self.state # th := theta\n",
        "\n",
        "        g = self.g\n",
        "        m = 1.\n",
        "        l = 1.\n",
        "        dt = self.dt\n",
        "\n",
        "        u = np.clip(u, -self.max_torque, self.max_torque)[0]\n",
        "        costs = angle_normalize(th)**2 + .1*thdot**2 + .001*(u**2)\n",
        "\n",
        "        newthdot = thdot + (-3*g/(2*l) * np.sin(th + np.pi) + 3./(m*l**2)*u) * dt\n",
        "        newth = th + newthdot*dt\n",
        "        # newthdot = np.clip(newthdot, -self.max_speed, self.max_speed) #pylint: disable=E1111\n",
        "\n",
        "        self.state = np.array([newth, newthdot])\n",
        "        return self._get_obs(), -costs, False, {}\n",
        "\n",
        "    def reset(self):\n",
        "        high = np.array([np.pi, 1])\n",
        "        self.state = self.np_random.uniform(low=-high, high=high)\n",
        "        return self._get_obs()\n",
        "\n",
        "    def _get_obs(self):\n",
        "        theta, thetadot = self.state\n",
        "        return np.array([np.cos(theta), np.sin(theta), thetadot])\n",
        "\n",
        "    def render(self, mode='human'):\n",
        "        if self.viewer is None:\n",
        "            from gym.envs.classic_control import rendering\n",
        "            self.viewer = rendering.Viewer(100, 100)\n",
        "            # self.viewer = rendering.Viewer(500, 500)\n",
        "            self.viewer.set_bounds(-2.2, 2.2, -2.2, 2.2)\n",
        "            rod = rendering.make_capsule(1, .2)\n",
        "            rod.set_color(.8, .3, .3)\n",
        "            self.pole_transform = rendering.Transform()\n",
        "            rod.add_attr(self.pole_transform)\n",
        "            self.viewer.add_geom(rod)\n",
        "            axle = rendering.make_circle(.05)\n",
        "            axle.set_color(0,0,0)\n",
        "            self.viewer.add_geom(axle)\n",
        "\n",
        "        self.pole_transform.set_rotation(self.state[0] + np.pi/2)\n",
        "        obs_rgb = self.viewer.render(return_rgb_array=(mode=='rgb_array'))[20: 80, 20:80]\n",
        "        return obs_rgb\n",
        "\n",
        "    def close(self):\n",
        "        if self.viewer:\n",
        "            self.viewer.close()\n",
        "            self.viewer = None\n",
        "\n",
        "    def sample_random_transition(self, u_list):\n",
        "        num_actions = len(u_list) + 1\n",
        "        high = np.array([np.pi, 8])\n",
        "        init_state = np.random.uniform(low=-high, high=high)\n",
        "        x_list = [None] * num_actions\n",
        "        for i, u in enumerate(u_list):\n",
        "            assert (-self.max_torque <= u).all() and (u <= self.max_torque).all()\n",
        "            self.state = init_state\n",
        "            obs_list = []\n",
        "            obs_list.append(self.render(mode='rgb_array'))\n",
        "            env.state[0] += env.dt * env.state[1]\n",
        "            obs_list.append(self.render(mode='rgb_array'))\n",
        "            self.step(u)\n",
        "            obs_list.append(self.render(mode='rgb_array'))\n",
        "            x_list[0] = np.stack(obs_list[:2], axis=0)\n",
        "            x_list[i + 1] = np.stack(obs_list[-2:], axis=0)\n",
        "        return np.stack(x_list, axis=0)\n",
        "\n",
        "def angle_normalize(x):\n",
        "    return (((x+np.pi) % (2*np.pi)) - np.pi)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yQgD4MDKBzrF"
      },
      "source": [
        "## Transformation Sampling"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RqDYzs_VAj3l"
      },
      "source": [
        "def normalize(x):\n",
        "    return 2.0 * (x.astype(np.float32) / 255.0 - 0.5)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rgFNZ9W19UzZ"
      },
      "source": [
        "def process_obs(x):\n",
        "    return np.array(Image.fromarray(x).resize((args.img_w, args.img_h), resample=Image.BICUBIC).convert('L'))\n",
        "\n",
        "\n",
        "process_obs = np.vectorize(process_obs, signature='(n,m,k)->(p,q)')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Y29nrdIXKmJz"
      },
      "source": [
        "env = PendulumEnv(g=10 if args.gravity else 0)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UllDY1qrgPUa"
      },
      "source": [
        "def unstack(a, axis):\n",
        "    return [np.squeeze(x, axis) for x in np.split(a, a.shape[axis], axis=axis)]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FITe2RvLyGY2"
      },
      "source": [
        "def sample(batch_size, num_actions):\n",
        "    u_list = [np.random.uniform(low=env.action_space.low, high=env.action_space.high, size=(2, 1)) for _ in range(num_actions - 1)]\n",
        "    batch_list = []\n",
        "    for i in range(batch_size):\n",
        "        batch_list.append(env.sample_random_transition(u_list))\n",
        "    batch = np.stack(batch_list, axis=0)\n",
        "    return unstack(batch, axis=1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5dvE1VCgnTvC"
      },
      "source": [
        "## Test Transformation Sampling"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EUiZMIaw1DB2"
      },
      "source": [
        "x1, x2 = sample(args.batch_size, num_actions=2)\n",
        "print(x1.shape, x1.dtype)\n",
        "print(x2.shape, x2.dtype)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "e2tXK14umWhN"
      },
      "source": [
        "plt.style.use('default')\n",
        "\n",
        "fig = plt.figure(figsize=(12, 3))\n",
        "idx = 0\n",
        "\n",
        "plt.subplot(1, 4, 1)\n",
        "plt.imshow(x1[idx, 0], cmap='gray', vmin=0, vmax=255)\n",
        "\n",
        "plt.subplot(1, 4, 2)\n",
        "plt.imshow(x1[idx, 1], cmap='gray', vmin=0, vmax=255)\n",
        "\n",
        "plt.subplot(1, 4, 3)\n",
        "plt.imshow(x2[idx, 0], cmap='gray', vmin=0, vmax=255)\n",
        "\n",
        "plt.subplot(1, 4, 4)\n",
        "plt.imshow(x2[idx, 1], cmap='gray', vmin=0, vmax=255)\n",
        "\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sJidR4-HkNvT"
      },
      "source": [
        "## Visualization"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TjXBdXCnkN1B"
      },
      "source": [
        "N_THETA = 360\n",
        "N_THETA_DOT = 160\n",
        "\n",
        "viz_dataset = np.empty((N_THETA, N_THETA_DOT, 2, args.img_w, args.img_h))\n",
        "\n",
        "for i, theta in enumerate(tqdm(np.linspace(-np.pi, np.pi, N_THETA))):\n",
        "    for j, theta_dot in enumerate(np.linspace(-env.max_speed, env.max_speed, N_THETA_DOT, endpoint=True)):\n",
        "        env.state = [theta, theta_dot]\n",
        "        obs0 = env.render(mode='rgb_array')\n",
        "        env.state[0] += env.dt * env.state[1]\n",
        "        obs1 = env.render(mode='rgb_array')\n",
        "        viz_dataset[i, j, 0] = normalize(process_obs(obs0))\n",
        "        viz_dataset[i, j, 1] = normalize(process_obs(obs1))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1WVSSHR8BtEr"
      },
      "source": [
        "if subcode_size > args.proj_size:\n",
        "    t = np.random.randn(subcode_size, subcode_size)\n",
        "    PROJ_MAT = np.linalg.svd(t)[0][:, :args.proj_size]\n",
        "    PROJ_MAT /= np.sum(PROJ_MAT, axis=0, keepdims=True)\n",
        "else:\n",
        "    PROJ_MAT = np.eye(args.proj_size)\n",
        "\n",
        "with np.printoptions(linewidth=300, formatter={'float': lambda x: '%12f' % x}):\n",
        "    print('PROJ_MAT =\\n', PROJ_MAT)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "24EOkW6PtFi1"
      },
      "source": [
        "def scatter_3d(ax, x, y, z, c, elev=30, azim=-60, **kwargs):\n",
        "    elev_rad = elev / 180 * np.pi\n",
        "    azim_rad = azim / 180 * np.pi\n",
        "    ax.view_init(elev, azim)\n",
        "    idx_sort = np.argsort(np.cos(azim_rad) * x + np.sin(azim_rad) * y + np.sin(elev_rad) * z, kind='heapsort')\n",
        "    ax.scatter(x[idx_sort], y[idx_sort], z[idx_sort], c=c[idx_sort] if type(c) is np.ndarray else c, **kwargs)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DYxIDwtKkiQx"
      },
      "source": [
        "def visualize(ax, enc, data, s=1, **kwargs):\n",
        "    device = get_device(enc)\n",
        "    n_theta, n_theta_dot = data.shape[:2]\n",
        "\n",
        "    code = np.empty((n_theta, n_theta_dot, args.proj_size))\n",
        "    with torch.no_grad():\n",
        "        for i in range(n_theta):\n",
        "            code[i] = enc(torch.tensor(data[i], dtype=torch.float32, device=device)).cpu().numpy() @ PROJ_MAT\n",
        "    code = code.reshape(-1, args.proj_size)\n",
        "\n",
        "    if 'c' in kwargs:\n",
        "        colors = kwargs['c']\n",
        "        kwargs.pop('c')\n",
        "    else:\n",
        "        colors = np.empty((n_theta, n_theta_dot, 3))\n",
        "        for i in range(n_theta):\n",
        "            for j in range(n_theta_dot):\n",
        "                hls = (i / n_theta, j / (n_theta_dot - 1) * 0.95, 1)\n",
        "                rgb = colorsys.hls_to_rgb(*hls)\n",
        "                colors[i, j] = rgb\n",
        "        colors = colors.reshape(-1, 3)\n",
        "\n",
        "    if args.proj_size == 3:\n",
        "        # make the panes transparent\n",
        "        ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
        "        ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
        "        ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
        "        # make the grid lines transparent\n",
        "        ax.xaxis._axinfo['grid']['color'] = (1,1,1,0)\n",
        "        ax.yaxis._axinfo['grid']['color'] = (1,1,1,0)\n",
        "        ax.zaxis._axinfo['grid']['color'] = (1,1,1,0)\n",
        "\n",
        "    if args.proj_size == 3:\n",
        "        scatter_3d(ax, code[:, 0], code[:, 1], code[:, 2], c=colors, s=s, **kwargs)\n",
        "    elif args.proj_size == 2:\n",
        "        ax.scatter(code[:, 0], code[:, 1], c=colors, s=s, **kwargs)\n",
        "    else:\n",
        "        assert False"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JH9QHBQJO7Qz"
      },
      "source": [
        "## Encoder Network"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2PRN3i6bCPLE"
      },
      "source": [
        "class View(nn.Module):\n",
        "    def __init__(self, *shape):\n",
        "        super(View, self).__init__()\n",
        "        self.shape = shape\n",
        "\n",
        "    def forward(self, x):\n",
        "        return x.view(self.shape)\n",
        "\n",
        "    def __repr__(self):\n",
        "        return 'View(%s)' % (', '.join(['%d' % x for x in self.shape]))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vjQVUzYLfS79"
      },
      "source": [
        "IN_CHANNELS = 2\n",
        "hidden_dim = args.num_channels * (args.img_w // 8) * (args.img_h // 8)\n",
        "\n",
        "enc = nn.Sequential(\n",
        "    nn.Conv2d(IN_CHANNELS, args.num_channels, kernel_size=3, padding=1),\n",
        "    nn.ReLU(),\n",
        "    nn.MaxPool2d(kernel_size=2),\n",
        "\n",
        "    nn.Conv2d(args.num_channels, args.num_channels, kernel_size=3, padding=1),\n",
        "    nn.ReLU(),\n",
        "    nn.MaxPool2d(kernel_size=2),\n",
        "\n",
        "    nn.Conv2d(args.num_channels, args.num_channels, kernel_size=3, padding=1),\n",
        "    nn.ReLU(),\n",
        "    nn.MaxPool2d(kernel_size=2),\n",
        "\n",
        "    View(-1, hidden_dim),\n",
        "\n",
        "    nn.Linear(hidden_dim, args.mlp_hidden_dim),\n",
        "    nn.ReLU(),\n",
        "    nn.Linear(args.mlp_hidden_dim, args.code_size, bias=False)\n",
        ")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "R9mtExX8fTSj"
      },
      "source": [
        "def count_parameters(net):\n",
        "    return sum(p.numel() for p in net.parameters() if p.requires_grad)\n",
        "\n",
        "\n",
        "def get_device(net):\n",
        "    '''\n",
        "    Returns the `torch.device` on which the network resides.\n",
        "    This method only makes sense when all module parameters reside on the **same** device.\n",
        "    '''\n",
        "    return list(net.parameters())[0].device"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Nsl2lsi_fTWi"
      },
      "source": [
        "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "print('Device:', DEVICE)\n",
        "enc.to(DEVICE)\n",
        "print(enc)\n",
        "print('%d parameters' % count_parameters(enc))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QiZ9ZoXDvYFa"
      },
      "source": [
        "## Optimizer"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Yh2gKlJhfTaU"
      },
      "source": [
        "opt = optim.Adam(enc.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zwvYhWuI6UlG"
      },
      "source": [
        "## Loss Function"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9eqUCVYNZUh4"
      },
      "source": [
        "def cdist_mean(x1, x2, p=2.0, *args, **kwargs):\n",
        "    dim = x1.shape[1]\n",
        "    return torch.cdist(x1, x2, p, *args, **kwargs) * (dim ** (-1 / p))\n",
        "\n",
        "\n",
        "def loss_fn(enc, x_list, barrier_type, cosine_sim=False, conformal_map=False, decompositions=1):\n",
        "    device = get_device(enc)\n",
        "    num_actions = len(x_list)\n",
        "    assert num_actions >= 2\n",
        "    x_list = [torch.tensor(x, dtype=torch.float32, device=device) for x in x_list]\n",
        "    z_list = [enc(x) for x in x_list]\n",
        "    code_size = z_list[0].shape[1]\n",
        "\n",
        "    #-- symmetry loss --#\n",
        "    subcode_size = code_size // decompositions\n",
        "    loss_equiv = 0\n",
        "    for k in range(decompositions):\n",
        "        h_list = [z[:, k * subcode_size: (k + 1) * subcode_size] for z in z_list]\n",
        "\n",
        "        if conformal_map:\n",
        "            h_list = [(h[:, None, :] - h[None, :, :]).view(-1, subcode_size) for h in h_list]\n",
        "\n",
        "        if cosine_sim or conformal_map:\n",
        "            D_list = [1.0 - F.cosine_similarity(h[:, None, :], h[None, :, :], dim=2) for h in h_list]\n",
        "        else:\n",
        "            D_list = [cdist_mean(h, h, p=2) for h in h_list]\n",
        "\n",
        "        L_equiv = torch.zeros(num_actions, num_actions)\n",
        "        for i in range(num_actions):\n",
        "            for j in range(i + 1, num_actions):\n",
        "                L_equiv[i, j] = torch.mean((D_list[i] - D_list[j]) ** 2)\n",
        "        cur_loss_equiv = torch.sum(L_equiv) / (num_actions * (num_actions - 1) / 2)\n",
        "        loss_equiv += cur_loss_equiv / decompositions\n",
        "\n",
        "    #-- barrier loss --#\n",
        "    z_all = torch.cat(z_list, dim=0)\n",
        "    D_all = cdist_mean(z_all, z_all, p=2)\n",
        "    mask = torch.eye(D_all.shape[0], dtype=torch.bool)\n",
        "    if barrier_type == 'log':\n",
        "        loss_barrier = torch.mean(-torch.log(D_all[~mask] + 1e-9))\n",
        "    elif barrier_type == 'inv':\n",
        "        loss_barrier = torch.mean(1.0 / (D_all[~mask] + 1e-9))\n",
        "    elif barrier_type == 'hinge':\n",
        "        loss_barrier = torch.mean(torch.maximum(torch.zeros(1, device=device), 5.0 - D_all[~mask]))\n",
        "    else:\n",
        "        assert False, 'Unknown `barrier_type`'\n",
        "\n",
        "    return loss_equiv, loss_barrier"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PKZOSNYX_f-x"
      },
      "source": [
        "## Train"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-prwPRYR_kFt"
      },
      "source": [
        "def check_weights_nan(parameters):\n",
        "    for p in parameters:\n",
        "        if torch.isnan(p).any():\n",
        "            return True\n",
        "    return False\n",
        "\n",
        "\n",
        "def get_weights_norm(parameters, norm_type=2.0):\n",
        "    with torch.no_grad():\n",
        "        return torch.norm(torch.stack([torch.norm(p, norm_type) for p in parameters]), norm_type).item()\n",
        "\n",
        "\n",
        "def get_grads_norm(parameters, norm_type=2.0):\n",
        "    with torch.no_grad():\n",
        "        return torch.norm(torch.stack([torch.norm(p.grad, norm_type) for p in parameters]), norm_type).item()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VJcYn9Vn_kK7"
      },
      "source": [
        "torch.save(enc.state_dict(), path.join('saved_models', 'model_epoch_%03d.tar' % 0))\n",
        "\n",
        "fig = plt.figure(figsize=(6, 6))\n",
        "ax = fig.add_subplot(111, projection='3d')\n",
        "enc.eval()\n",
        "visualize(ax, enc, viz_dataset)\n",
        "fig.savefig(path.join('figs_progress', 'codes_epoch_%03d.png' % 0), bbox_inches='tight', dpi=200)\n",
        "plt.show()\n",
        "plt.close()\n",
        "\n",
        "batch_size = args.batch_size\n",
        "\n",
        "avg_loss_list = []\n",
        "avg_loss_equiv_list = []\n",
        "avg_loss_barrier_list = []\n",
        "for t in range(args.num_epochs):\n",
        "    enc.train()\n",
        "    loss_list = []\n",
        "    loss_equiv_list = []\n",
        "    loss_barrier_list = []\n",
        "    \n",
        "    time_start = time.time()\n",
        "    progress = tqdm(range(args.steps_per_epoch), desc='Loss: None | Loss Equiv: None | Loss Barrier: None | L2 Weights: %12g | L2 Grads: 0' % (\n",
        "        get_weights_norm(enc.parameters(), norm_type=2.0)\n",
        "    ), total=args.steps_per_epoch, position=0, leave=True)\n",
        "    for _ in progress:\n",
        "        u = np.random.uniform(low=env.action_space.low, high=env.action_space.high, size=(2, 1))\n",
        "        x_list = sample(batch_size, args.num_actions)\n",
        "        x_list = [normalize(process_obs(x)) for x in x_list]\n",
        "        loss_equiv, loss_barrier = loss_fn(enc, x_list, args.barrier_type, args.cosine_sim, args.conformal_map)\n",
        "        loss = loss_equiv + args.barrier_coef * loss_barrier\n",
        "        opt.zero_grad()\n",
        "        loss.backward()\n",
        "        opt.step()\n",
        "        # I multiply the losses by 1e4 to get nicer numbers\n",
        "        loss_list.append(1e4 * loss.item())\n",
        "        loss_equiv_list.append(1e4 * loss_equiv.item())\n",
        "        loss_barrier_list.append(1e4 * args.barrier_coef * loss_barrier.item())\n",
        "        progress.set_description('Loss: %12g | Loss Equiv: %12g | Loss barrier: %12g | L2 Weights: %12g | L2 Grads: %12g' % (\n",
        "            loss_list[-1],\n",
        "            loss_equiv_list[-1],\n",
        "            loss_barrier_list[-1],\n",
        "            get_weights_norm(enc.parameters(), norm_type=2.0),\n",
        "            get_grads_norm(enc.parameters(), norm_type=2.0)\n",
        "        ))\n",
        "    time_end = time.time()\n",
        "\n",
        "    torch.save(enc.state_dict(), path.join('saved_models', 'model_epoch_%03d.tar' % (t + 1)))\n",
        "\n",
        "    avg_loss = np.mean(loss_list)\n",
        "    avg_loss_equiv = np.mean(loss_equiv_list)\n",
        "    avg_loss_barrier = np.mean(loss_barrier_list)\n",
        "    avg_loss_list.append(avg_loss)\n",
        "    avg_loss_equiv_list.append(avg_loss_equiv)\n",
        "    avg_loss_barrier_list.append(avg_loss_barrier)\n",
        "    print('\\nEpoch %3d | Loss: %12g | Loss Equiv: %12g | Loss Barrier: %12g | Time: %6.1f sec' % (\n",
        "        t + 1, avg_loss, avg_loss_equiv, avg_loss_barrier, time_end - time_start))\n",
        "\n",
        "    if args.batch_size_incr != 0:\n",
        "        new_batch_size = batch_size + args.batch_size_incr\n",
        "        print('Batch size increased from %d to %d' % (batch_size, new_batch_size))\n",
        "        batch_size = new_batch_size\n",
        "\n",
        "    fig = plt.figure(figsize=(6, 6))\n",
        "    ax = fig.add_subplot(111, projection='3d')\n",
        "    enc.eval()\n",
        "    visualize(ax, enc, viz_dataset)\n",
        "    fig.savefig(path.join('figs_progress', 'codes_epoch_%03d.png' % (t + 1)), bbox_inches='tight', dpi=200)\n",
        "    plt.show()\n",
        "    plt.close()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RDumfR4N_kPg"
      },
      "source": [
        "save_path = path.join('saved_models', 'model_final.tar')\n",
        "print('saving model to %s' % save_path)\n",
        "torch.save(enc.state_dict(), save_path)\n",
        "# enc.load_state_dict(torch.load('/content/model_epoch_060_pendulum.tar', map_location=DEVICE))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W6jvbi3p_kVo"
      },
      "source": [
        "## Error Curve"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "d_WcS1xg_kaf"
      },
      "source": [
        "fig = plt.figure(figsize=(5, 5))\n",
        "plt.plot(avg_loss_list)\n",
        "plt.xlabel('Epoch')\n",
        "plt.ylabel('Loss')\n",
        "fig.savefig(path.join('plots', 'error_curve.png'), bbox_inches='tight', dpi=300)\n",
        "plt.show()\n",
        "plt.close()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oPJFA0mLNr48"
      },
      "source": [
        "## Visualize Progress"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3QY7b5HnNZwO"
      },
      "source": [
        "image_list = []\n",
        "files = sorted(glob.glob('figs_progress/*.png'))\n",
        "for file in files:\n",
        "    image_list.append(cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB))\n",
        "imageio.mimsave(path.join('gifs', 'codes_progress.gif'), image_list, fps=10)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "C1pm-Nt5O5A-"
      },
      "source": [
        "! ffmpeg -v quiet -r 10 -i ./figs_progress/codes_epoch_%03d.png -y videos/codes_progress.webm"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2pxdU9EZ_kgN"
      },
      "source": [
        "## Visualize Final Codes"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "g-myzU_3N8B2"
      },
      "source": [
        "enc.eval()\n",
        "fig = plt.figure(figsize=(8, 8))\n",
        "ax = fig.add_subplot(111, projection='3d' if args.proj_size == 3 else None)\n",
        "ax.axis('off')\n",
        "visualize(ax, enc, viz_dataset, s=5, azim=60)\n",
        "fig.savefig(path.join('plots', 'codes_final.png'), bbox_inches='tight', dpi=600)\n",
        "plt.show()\n",
        "plt.close()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sgHFyL32lXMK"
      },
      "source": [
        "! rm -rf figs_final\n",
        "! mkdir figs_final\n",
        "\n",
        "enc.eval()\n",
        "\n",
        "if args.proj_size == 3:\n",
        "    for i in tqdm(range(72)):\n",
        "        fig = plt.figure(figsize=(6, 6))\n",
        "        ax = fig.add_subplot(111, projection='3d')\n",
        "        plt.axis('off')\n",
        "        visualize(ax, enc, viz_dataset, azim=i * 5)\n",
        "        fig.savefig(path.join('figs_final', 'final_codes_angle_%03d.png' % i), bbox_inches='tight', dpi=200)\n",
        "        plt.close(fig)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iemaX8DzmbmF"
      },
      "source": [
        "if args.proj_size == 3:\n",
        "    image_list = []\n",
        "    files = sorted(glob.glob('figs_final/*.png'))\n",
        "    for file in files:\n",
        "        image_list.append(cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB))\n",
        "    imageio.mimsave(path.join('gifs', 'codes_final.gif'), image_list, fps=10)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qtVbIAb5PIkC"
      },
      "source": [
        "! ffmpeg -v quiet -r 10 -i ./figs_final/final_codes_angle_%03d.png -y videos/final_codes.webm"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9lVoqE2oFyh5"
      },
      "source": [
        "## Zip Results"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Y79gBkAARNvm"
      },
      "source": [
        "!zip -r9v $(basename \"$PWD\").zip figs_progress/ figs_final/ saved_models/ log.txt plots/ gifs/ videos/"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}