{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VzgXI8ST20DS"
      },
      "source": [
        "# Transformation Coding [Room]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NoK2P37fKvXu"
      },
      "source": [
        "## Maze Environment"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yIH0qJyZaHd3"
      },
      "source": [
        "! pip install -U git+https://github.com/mshakerinava/gym-miniworld.git"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tdtcedbh2ot2"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qsn5NtvYldkJ"
      },
      "source": [
        "import os\n",
        "import io\n",
        "import cv2\n",
        "import glob\n",
        "import time\n",
        "import json\n",
        "import imageio\n",
        "import colorsys\n",
        "import multiprocessing as mp\n",
        "from os import path\n",
        "from datetime import datetime\n",
        "from functools import partial\n",
        "\n",
        "import numpy as np\n",
        "from PIL import Image\n",
        "from tqdm import tqdm\n",
        "\n",
        "import gym\n",
        "from gym import spaces\n",
        "from gym.utils import seeding\n",
        "import gym_miniworld\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torch.nn.functional as F\n",
        "from torch import autograd\n",
        "\n",
        "%matplotlib inline\n",
        "import matplotlib.pyplot as plt\n",
        "from mpl_toolkits.mplot3d import Axes3D\n",
        "import matplotlib\n",
        "\n",
        "matplotlib.interactive(False)\n",
        "# matplotlib.use('agg')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a12JiTew2q9q"
      },
      "source": [
        "## Prepare Experiment Folders"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "x_FH7sPhIQVT"
      },
      "source": [
        "timestamp = str(datetime.now()).replace(' ', '_')\n",
        "EXPERIMENT_DIR = 'room_%s' % timestamp\n",
        "os.makedirs(EXPERIMENT_DIR)\n",
        "%cd $EXPERIMENT_DIR\n",
        "!mkdir gifs plots videos saved_models figs_final figs_progress"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hoIkwOnZKsP9"
      },
      "source": [
        "## Logging"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eCx_rhuiKr9v"
      },
      "source": [
        "try:\n",
        "    LOG_STR.close()\n",
        "    LOG_FILE.close()\n",
        "except:\n",
        "    pass\n",
        "\n",
        "timestamp = str(datetime.now())\n",
        "LOG_STR = io.StringIO()\n",
        "LOG_FILE = open('log.txt', 'w')\n",
        "\n",
        "try:\n",
        "    old_print\n",
        "except NameError:\n",
        "    old_print = print\n",
        "\n",
        "def print(*args, **kwargs):\n",
        "    kwargs['flush'] = True\n",
        "    old_print(*args, **kwargs)\n",
        "    kwargs['file'] = LOG_STR\n",
        "    old_print(*args, **kwargs)\n",
        "    kwargs['file'] = LOG_FILE\n",
        "    old_print(*args, **kwargs)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EVOc_tr3K4CU"
      },
      "source": [
        "## Arguments"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BEjSsrvSKiwM"
      },
      "source": [
        "class Args():\n",
        "    def __init__(self):\n",
        "        self.num_channels = 24\n",
        "        self.mlp_hidden_dim = 128\n",
        "        self.learning_rate = 1e-3\n",
        "        self.weight_decay = 1e-7\n",
        "        self.batch_size = 64\n",
        "        self.batch_size_incr = 0\n",
        "        self.num_epochs = 50\n",
        "        self.steps_per_epoch = 100\n",
        "        self.img_w = 32\n",
        "        self.img_h = 32\n",
        "        self.barrier_type = 'log' # 'inv' or 'log' or 'hinge'\n",
        "        self.barrier_coef = 1\n",
        "        self.cosine_sim = False\n",
        "        self.conformal_map = False\n",
        "        self.code_size = 4\n",
        "        self.proj_size = 2\n",
        "        self.decompositions = 2\n",
        "        self.decomposed_action = True\n",
        "\n",
        "\n",
        "args = Args()\n",
        "assert args.code_size % args.decompositions == 0\n",
        "assert not args.decomposed_action or args.decompositions == 2\n",
        "subcode_size = args.code_size // args.decompositions\n",
        "print(json.dumps(vars(args), sort_keys=True, indent=4))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yQgD4MDKBzrF"
      },
      "source": [
        "## Transformation Sampling"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7FysvptlatD0"
      },
      "source": [
        "env = gym.make('MiniWorld-OneRoomS6-v0')\n",
        "env.seed(6)\n",
        "env.reset()\n",
        "\n",
        "num_colors = 8\n",
        "colors = []\n",
        "for i in range(num_colors):\n",
        "    hls = (i / num_colors, 0.5, 1)\n",
        "    rgb = colorsys.hls_to_rgb(*hls)\n",
        "    colors.append(rgb)\n",
        "\n",
        "env.place_entity(gym_miniworld.entity.Box(color='red'), room=0, pos=(0.4, 0, 0.4), dir=0)\n",
        "env.place_entity(gym_miniworld.entity.Box(color='red'), room=0, pos=(env.size / 2, 0, 0.4), dir=0)\n",
        "env.place_entity(gym_miniworld.entity.Box(color='red'), room=0, pos=(env.size - 0.4, 0, 0.4), dir=0)\n",
        "env.place_entity(gym_miniworld.entity.Box(color='red'), room=0, pos=(0.4, 0, env.size / 2), dir=0)\n",
        "env.place_entity(gym_miniworld.entity.Box(color='red'), room=0, pos=(0.4, 0, env.size - 0.4), dir=0)\n",
        "env.place_entity(gym_miniworld.entity.Box(color='red'), room=0, pos=(env.size / 2, 0, env.size - 0.4), dir=0)\n",
        "env.place_entity(gym_miniworld.entity.Box(color='red'), room=0, pos=(env.size - 0.4, 0, env.size / 2), dir=0)\n",
        "env.place_entity(gym_miniworld.entity.Box(color='red'), room=0, pos=(env.size - 0.4, 0, env.size - 0.4), dir=0)\n",
        "\n",
        "cnt = 0\n",
        "for x in env.entities:\n",
        "    if type(x) is gym_miniworld.entity.Box:\n",
        "        x.color_vec = colors[cnt % 8]\n",
        "        cnt += 1\n",
        "    # x.randomize(env.params, env.rand)\n",
        "print(cnt)\n",
        "\n",
        "img_arr = env.render(mode='rgb_array', view='top')\n",
        "Image.fromarray(img_arr).save('maze_map.png')\n",
        "plt.imshow(img_arr)\n",
        "plt.axis('off')\n",
        "plt.show()\n",
        "plt.close()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "chewiaktBlfO"
      },
      "source": [
        "def process_obs(x):\n",
        "    return np.array(Image.fromarray(x).resize((args.img_w, args.img_h), resample=Image.BICUBIC)).transpose(2, 0, 1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bt143gwawi_-"
      },
      "source": [
        "def sample_random_transition(u_list):\n",
        "    num_actions = len(u_list) + 1\n",
        "    if env.entities[-1] is env.agent:\n",
        "        env.entities.pop()\n",
        "    env.place_agent()\n",
        "    init_pos = env.agent.pos.copy()\n",
        "    init_dir = env.agent.dir\n",
        "    x_list = [None] * num_actions\n",
        "    x_list[0] = env.render_obs()\n",
        "    for i, u in enumerate(u_list):\n",
        "        env.agent.pos = init_pos\n",
        "        env.agent.dir = init_dir\n",
        "        xt, _, _, _ = env.step(u)\n",
        "        x_list[i + 1] = xt\n",
        "    x_list = [process_obs(x) for x in x_list]\n",
        "    return np.stack(x_list, axis=0)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UllDY1qrgPUa"
      },
      "source": [
        "def unstack(a, axis):\n",
        "    return [np.squeeze(x, axis) for x in np.split(a, a.shape[axis], axis=axis)]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FITe2RvLyGY2"
      },
      "source": [
        "def sample(batch_size, action_type=None):\n",
        "    if action_type is None:\n",
        "        u_list = list(range(4))\n",
        "    elif action_type == 0:\n",
        "        u_list = list(range(2)) # left or right (change dir)\n",
        "    elif action_type == 1:\n",
        "        u_list = list(range(2, 4)) # forward/backward (change pos)\n",
        "    else:\n",
        "        assert False\n",
        "\n",
        "    batch_list = []\n",
        "    for i in range(batch_size):\n",
        "        batch_list.append(sample_random_transition(u_list))\n",
        "    batch = np.stack(batch_list, axis=0)\n",
        "    return unstack(batch, axis=1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5dvE1VCgnTvC"
      },
      "source": [
        "## Test Transformation Sampling"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EUiZMIaw1DB2"
      },
      "source": [
        "x1, x2, x3, x4, x5 = sample(args.batch_size)\n",
        "print(x1.shape, x1.dtype)\n",
        "\n",
        "x1 = x1.transpose(0, 2, 3, 1)\n",
        "x2 = x2.transpose(0, 2, 3, 1)\n",
        "x3 = x3.transpose(0, 2, 3, 1)\n",
        "x4 = x4.transpose(0, 2, 3, 1)\n",
        "x5 = x5.transpose(0, 2, 3, 1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "e2tXK14umWhN"
      },
      "source": [
        "plt.style.use('default')\n",
        "\n",
        "\n",
        "fig = plt.figure(figsize=(3 * args.batch_size, 3 * 5))\n",
        "\n",
        "for idx in range(args.batch_size):\n",
        "    plt.subplot(5, args.batch_size, 0 * args.batch_size + idx + 1)\n",
        "    plt.axis('off')\n",
        "    plt.imshow(x1[idx])\n",
        "\n",
        "    plt.subplot(5, args.batch_size, 1 * args.batch_size + idx + 1)\n",
        "    plt.axis('off')\n",
        "    plt.imshow(x2[idx])\n",
        "\n",
        "    plt.subplot(5, args.batch_size, 2 * args.batch_size + idx + 1)\n",
        "    plt.axis('off')\n",
        "    plt.imshow(x3[idx])\n",
        "\n",
        "    plt.subplot(5, args.batch_size, 3 * args.batch_size + idx + 1)\n",
        "    plt.axis('off')\n",
        "    plt.imshow(x4[idx])\n",
        "\n",
        "    plt.subplot(5, args.batch_size, 4 * args.batch_size + idx + 1)\n",
        "    plt.axis('off')\n",
        "    plt.imshow(x5[idx])\n",
        "\n",
        "fig.savefig(path.join('plots', 'batch_of_transformations.png'), bbox_inches='tight', dpi=300)\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sJidR4-HkNvT"
      },
      "source": [
        "## Visualization"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RqDYzs_VAj3l"
      },
      "source": [
        "def normalize(x):\n",
        "    return 2.0 * (x.astype(np.float32) / 255.0 - 0.5)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TjXBdXCnkN1B"
      },
      "source": [
        "N_POS = 500\n",
        "N_THETA = 60\n",
        "\n",
        "viz_dataset = np.empty((N_POS, N_THETA, 3, args.img_w, args.img_h))\n",
        "\n",
        "for i in tqdm(range(N_POS)):\n",
        "    if env.entities[-1] is env.agent:\n",
        "        env.entities.pop()\n",
        "    env.place_agent()\n",
        "    for j, d in enumerate(np.linspace(-np.pi, np.pi, N_THETA)):\n",
        "        env.agent.dir = d\n",
        "        viz_dataset[i, j] = normalize(process_obs(env.render_obs()))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1WVSSHR8BtEr"
      },
      "source": [
        "if subcode_size > args.proj_size:\n",
        "    t = np.random.randn(subcode_size, subcode_size)\n",
        "    PROJ_MAT = np.linalg.svd(t)[0][:, :args.proj_size]\n",
        "    PROJ_MAT /= np.sum(PROJ_MAT, axis=0, keepdims=True)\n",
        "else:\n",
        "    PROJ_MAT = np.eye(args.proj_size)\n",
        "\n",
        "with np.printoptions(linewidth=300, formatter={'float': lambda x: '%12f' % x}):\n",
        "    print('PROJ_MAT =\\n', PROJ_MAT)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "24EOkW6PtFi1"
      },
      "source": [
        "def scatter_3d(ax, x, y, z, c, elev=30, azim=-60, **kwargs):\n",
        "    elev_rad = elev / 180 * np.pi\n",
        "    azim_rad = azim / 180 * np.pi\n",
        "    ax.view_init(elev, azim)\n",
        "    idx_sort = np.argsort(np.cos(azim_rad) * x + np.sin(azim_rad) * y + np.sin(elev_rad) * z, kind='heapsort')\n",
        "    ax.scatter(x[idx_sort], y[idx_sort], z[idx_sort], c=c[idx_sort], **kwargs)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DYxIDwtKkiQx"
      },
      "source": [
        "def visualize(ax, enc, data, s=1, **kwargs):\n",
        "    device = get_device(enc)\n",
        "    n_pos, n_theta = data.shape[:2]\n",
        "\n",
        "    code = np.empty((n_pos, n_theta, args.proj_size))\n",
        "    with torch.no_grad():\n",
        "        for i in range(n_pos):\n",
        "            code[i] = enc(torch.tensor(data[i], dtype=torch.float32, device=device)).cpu().numpy() @ PROJ_MAT\n",
        "    code = code.reshape(-1, args.proj_size)\n",
        "\n",
        "    if 'c' in kwargs:\n",
        "        colors = kwargs['c']\n",
        "        kwargs.pop('c')\n",
        "    else:\n",
        "        colors = np.empty((n_pos, n_theta, 3))\n",
        "        for i in range(n_theta):\n",
        "            hls = (i / n_theta, 0.5, 1)\n",
        "            rgb = colorsys.hls_to_rgb(*hls)\n",
        "            colors[:, i] = rgb\n",
        "        colors = colors.reshape(-1, 3)\n",
        "\n",
        "    if args.proj_size == 3:\n",
        "        # make the panes transparent\n",
        "        ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
        "        ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
        "        ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
        "        # make the grid lines transparent\n",
        "        ax.xaxis._axinfo['grid']['color'] = (1,1,1,0)\n",
        "        ax.yaxis._axinfo['grid']['color'] = (1,1,1,0)\n",
        "        ax.zaxis._axinfo['grid']['color'] = (1,1,1,0)\n",
        "\n",
        "    if args.proj_size == 3:\n",
        "        scatter_3d(ax, code[:, 0], code[:, 1], code[:, 2], c=colors, s=s, **kwargs)\n",
        "    elif args.proj_size == 2:\n",
        "        ax.scatter(code[:, 0], code[:, 1], c=colors, s=s, **kwargs)\n",
        "    else:\n",
        "        assert False"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JH9QHBQJO7Qz"
      },
      "source": [
        "## Encoder Network"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2PRN3i6bCPLE"
      },
      "source": [
        "class View(nn.Module):\n",
        "    def __init__(self, *shape):\n",
        "        super(View, self).__init__()\n",
        "        self.shape = shape\n",
        "\n",
        "    def forward(self, x):\n",
        "        return x.contiguous().view(self.shape)\n",
        "\n",
        "    def __repr__(self):\n",
        "        return 'View(%s)' % (', '.join(['%d' % x for x in self.shape]))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vjQVUzYLfS79"
      },
      "source": [
        "IN_CHANNELS = 3\n",
        "hidden_dim = args.num_channels * (args.img_w // 8) * (args.img_h // 8)\n",
        "\n",
        "enc = nn.Sequential(\n",
        "    nn.Conv2d(IN_CHANNELS, args.num_channels, kernel_size=3, padding=1),\n",
        "    nn.ReLU(),\n",
        "    nn.MaxPool2d(kernel_size=2),\n",
        "\n",
        "    nn.Conv2d(args.num_channels, args.num_channels, kernel_size=3, padding=1),\n",
        "    nn.ReLU(),\n",
        "    nn.MaxPool2d(kernel_size=2),\n",
        "\n",
        "    nn.Conv2d(args.num_channels, args.num_channels, kernel_size=3, padding=1),\n",
        "    nn.ReLU(),\n",
        "    nn.MaxPool2d(kernel_size=2),\n",
        "\n",
        "    View(-1, hidden_dim),\n",
        "\n",
        "    nn.Linear(hidden_dim, args.mlp_hidden_dim),\n",
        "    nn.ReLU(),\n",
        "    nn.Linear(args.mlp_hidden_dim, args.code_size, bias=False)\n",
        ")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "R9mtExX8fTSj"
      },
      "source": [
        "def count_parameters(net):\n",
        "    return sum(p.numel() for p in net.parameters() if p.requires_grad)\n",
        "\n",
        "\n",
        "def get_device(net):\n",
        "    '''\n",
        "    Returns the `torch.device` on which the network resides.\n",
        "    This method only makes sense when all module parameters reside on the **same** device.\n",
        "    '''\n",
        "    return list(net.parameters())[0].device"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Nsl2lsi_fTWi"
      },
      "source": [
        "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "print('Device:', DEVICE)\n",
        "enc.to(DEVICE)\n",
        "print(enc)\n",
        "print('%d parameters' % count_parameters(enc))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5XH01EyVMYA6"
      },
      "source": [
        "class SubEncoder(nn.Module):\n",
        "    def __init__(self, enc, le, ri):\n",
        "        super().__init__()\n",
        "        self.enc = enc\n",
        "        self.le = le\n",
        "        self.ri = ri\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.enc(x)[:, self.le: self.ri]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QiZ9ZoXDvYFa"
      },
      "source": [
        "## Optimizer"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Yh2gKlJhfTaU"
      },
      "source": [
        "opt = optim.Adam(enc.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zwvYhWuI6UlG"
      },
      "source": [
        "## Loss Function"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9eqUCVYNZUh4"
      },
      "source": [
        "def cdist_mean(x1, x2, p=2.0, *args, **kwargs):\n",
        "    dim = x1.shape[1]\n",
        "    return torch.cdist(x1.contiguous(), x2.contiguous(), p, *args, **kwargs) * (dim ** (-1 / p))\n",
        "\n",
        "\n",
        "def loss_fn(enc, x_list, barrier_type, cosine_sim=False, conformal_map=False, decompositions=1, action_type=None):\n",
        "    device = get_device(enc)\n",
        "    num_actions = len(x_list)\n",
        "    assert num_actions >= 2\n",
        "    x_list = [torch.tensor(x, dtype=torch.float32, device=device) for x in x_list]\n",
        "    z_list = [enc(x) for x in x_list]\n",
        "    code_size = z_list[0].shape[1]\n",
        "\n",
        "    #-- symmetry loss --#\n",
        "    subcode_size = code_size // decompositions\n",
        "    loss_equiv = 0\n",
        "    for k in range(decompositions):\n",
        "        h_list = [z[:, k * subcode_size: (k + 1) * subcode_size] for z in z_list]\n",
        "\n",
        "        if conformal_map:\n",
        "            h_list = [(h[:, None, :] - h[None, :, :]).view(-1, subcode_size) for h in h_list]\n",
        "\n",
        "        if cosine_sim or conformal_map:\n",
        "            D_list = [1.0 - F.cosine_similarity(h[:, None, :], h[None, :, :], dim=2) for h in h_list]\n",
        "        else:\n",
        "            D_list = [cdist_mean(h, h, p=2) for h in h_list]\n",
        "\n",
        "        L_equiv = torch.zeros(num_actions, num_actions)\n",
        "        for i in range(num_actions):\n",
        "            for j in range(i + 1, num_actions):\n",
        "                if action_type is None or k == action_type:\n",
        "                    L_equiv[i, j] = torch.mean((D_list[i] - D_list[j]) ** 2)\n",
        "                else:\n",
        "                    L_equiv[i, j] = torch.mean((h_list[i] - h_list[j]) ** 2)\n",
        "        cur_loss_equiv = torch.sum(L_equiv) / (num_actions * (num_actions - 1) / 2)\n",
        "        loss_equiv += cur_loss_equiv / decompositions\n",
        "\n",
        "    #-- barrier loss --#\n",
        "    z_all = torch.cat(z_list, dim=0)\n",
        "    D_all = cdist_mean(z_all, z_all, p=2)\n",
        "    mask = torch.eye(D_all.shape[0], dtype=torch.bool)\n",
        "    if barrier_type == 'log':\n",
        "        loss_barrier = torch.mean(-torch.log(D_all[~mask] + 1e-9))\n",
        "    elif barrier_type == 'inv':\n",
        "        loss_barrier = torch.mean(1.0 / (D_all[~mask] + 1e-9))\n",
        "    elif barrier_type == 'hinge':\n",
        "        loss_barrier = torch.mean(torch.maximum(torch.zeros(1, device=device), 5.0 - D_all[~mask]))\n",
        "    else:\n",
        "        assert False, 'Unknown `barrier_type`'\n",
        "\n",
        "    return loss_equiv, loss_barrier"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PKZOSNYX_f-x"
      },
      "source": [
        "## Train"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-prwPRYR_kFt"
      },
      "source": [
        "def check_weights_nan(parameters):\n",
        "    for p in parameters:\n",
        "        if torch.isnan(p).any():\n",
        "            return True\n",
        "    return False\n",
        "\n",
        "\n",
        "def get_weights_norm(parameters, norm_type=2.0):\n",
        "    with torch.no_grad():\n",
        "        return torch.norm(torch.stack([torch.norm(p, norm_type) for p in parameters]), norm_type).item()\n",
        "\n",
        "\n",
        "def get_grads_norm(parameters, norm_type=2.0):\n",
        "    with torch.no_grad():\n",
        "        return torch.norm(torch.stack([torch.norm(p.grad, norm_type) for p in parameters]), norm_type).item()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VJcYn9Vn_kK7"
      },
      "source": [
        "torch.save(enc.state_dict(), path.join('saved_models', 'model_epoch_%03d.tar' % 0))\n",
        "\n",
        "fig = plt.figure(figsize=(6 * args.decompositions, 6))\n",
        "enc.eval()\n",
        "for i in range(args.decompositions):\n",
        "    ax = fig.add_subplot(100 + (args.decompositions) * 10 + (i + 1), projection='3d' if args.proj_size == 3 else None)\n",
        "    visualize(ax, SubEncoder(enc, le=i * subcode_size, ri=(i + 1) * subcode_size), viz_dataset)\n",
        "fig.savefig(path.join('figs_progress', 'codes_epoch_%03d.png' % 0), bbox_inches='tight', dpi=200)\n",
        "plt.show()\n",
        "plt.close()\n",
        "\n",
        "batch_size = args.batch_size\n",
        "\n",
        "avg_loss_list = []\n",
        "avg_loss_equiv_list = []\n",
        "avg_loss_barrier_list = []\n",
        "for t in range(args.num_epochs):\n",
        "    enc.train()\n",
        "    loss_list = []\n",
        "    loss_equiv_list = []\n",
        "    loss_barrier_list = []\n",
        "    \n",
        "    time_start = time.time()\n",
        "    progress = tqdm(range(args.steps_per_epoch), desc='Loss: None | Loss Equiv: None | Loss Barrier: None | L2 Weights: %12g | L2 Grads: 0' % (\n",
        "        get_weights_norm(enc.parameters(), norm_type=2.0)\n",
        "    ), total=args.steps_per_epoch, position=0, leave=True)\n",
        "    for k in progress:\n",
        "        if not args.decomposed_action:\n",
        "            x_list = sample(batch_size)\n",
        "            x_list = [normalize(x) for x in x_list]\n",
        "            loss_equiv, loss_barrier = loss_fn(enc, x_list, args.barrier_type, args.cosine_sim, args.conformal_map, args.decompositions)\n",
        "        else:\n",
        "            loss_equiv = 0\n",
        "            loss_barrier = 0\n",
        "            for action_type in [0, 1]:\n",
        "                x_list = sample(batch_size, action_type)\n",
        "                x_list = [normalize(x) for x in x_list]\n",
        "                loss_equiv_cur, loss_barrier_cur = loss_fn(enc, x_list, args.barrier_type, args.cosine_sim, args.conformal_map, args.decompositions, action_type)\n",
        "                loss_equiv += loss_equiv_cur\n",
        "                loss_barrier += loss_barrier_cur\n",
        "            loss_equiv = loss_equiv / 2\n",
        "            loss_barrier = loss_barrier / 2\n",
        "        loss = loss_equiv + args.barrier_coef * loss_barrier\n",
        "        opt.zero_grad()\n",
        "        loss.backward()\n",
        "        opt.step()\n",
        "        # I multiply the losses by 1e4 to get nicer numbers\n",
        "        loss_list.append(1e4 * loss.item())\n",
        "        loss_equiv_list.append(1e4 * loss_equiv.item())\n",
        "        loss_barrier_list.append(1e4 * args.barrier_coef * loss_barrier.item())\n",
        "        progress.set_description('Loss: %12g | Loss Equiv: %12g | Loss Barrier: %12g | L2 Weights: %12g | L2 Grads: %12g' % (\n",
        "            loss_list[-1],\n",
        "            loss_equiv_list[-1],\n",
        "            loss_barrier_list[-1],\n",
        "            get_weights_norm(enc.parameters(), norm_type=2.0),\n",
        "            get_grads_norm(enc.parameters(), norm_type=2.0)\n",
        "        ))\n",
        "    time_end = time.time()\n",
        "\n",
        "    torch.save(enc.state_dict(), path.join('saved_models', 'model_epoch_%03d.tar' % (t + 1)))\n",
        "\n",
        "    avg_loss = np.mean(loss_list)\n",
        "    avg_loss_equiv = np.mean(loss_equiv_list)\n",
        "    avg_loss_barrier = np.mean(loss_barrier_list)\n",
        "    avg_loss_list.append(avg_loss)\n",
        "    avg_loss_equiv_list.append(avg_loss_equiv)\n",
        "    avg_loss_barrier_list.append(avg_loss_barrier)\n",
        "    print('\\nEpoch %3d | Loss: %12g | Loss Equiv: %12g | Loss Barrier: %12g | Time: %6.1f sec' % (\n",
        "        t + 1, avg_loss, avg_loss_equiv, avg_loss_barrier, time_end - time_start))\n",
        "\n",
        "    if args.batch_size_incr != 0:\n",
        "        new_batch_size = batch_size + args.batch_size_incr\n",
        "        print('Batch size increased from %d to %d' % (batch_size, new_batch_size))\n",
        "        batch_size = new_batch_size\n",
        "\n",
        "    fig = plt.figure(figsize=(6 * args.decompositions, 6))\n",
        "    enc.eval()\n",
        "    for i in range(args.decompositions):\n",
        "        ax = fig.add_subplot(100 + (args.decompositions) * 10 + (i + 1), projection='3d' if args.proj_size == 3 else None)\n",
        "        visualize(ax, SubEncoder(enc, le=i * subcode_size, ri=(i + 1) * subcode_size), viz_dataset)\n",
        "    fig.savefig(path.join('figs_progress', 'codes_epoch_%03d.png' % (t + 1)), bbox_inches='tight', dpi=200)\n",
        "    plt.show()\n",
        "    plt.close()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RDumfR4N_kPg"
      },
      "source": [
        "save_path = path.join('saved_models', 'model_final.tar')\n",
        "print('saving model to %s' % save_path)\n",
        "torch.save(enc.state_dict(), save_path)\n",
        "# enc.load_state_dict(torch.load('/content/model_epoch_060.tar'))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W6jvbi3p_kVo"
      },
      "source": [
        "## Error Curve"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "d_WcS1xg_kaf"
      },
      "source": [
        "fig = plt.figure(figsize=(5, 5))\n",
        "plt.plot(avg_loss_list)\n",
        "plt.xlabel('Epoch')\n",
        "plt.ylabel('Loss')\n",
        "fig.savefig(path.join('plots', 'error_curve.png'), bbox_inches='tight', dpi=300)\n",
        "plt.show()\n",
        "plt.close()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oPJFA0mLNr48"
      },
      "source": [
        "## Visualize Progress"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3QY7b5HnNZwO"
      },
      "source": [
        "image_list = []\n",
        "files = sorted(glob.glob('figs_progress/*.png'))\n",
        "for file in files:\n",
        "    image_list.append(cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB))\n",
        "imageio.mimsave(path.join('gifs', 'codes_progress.gif'), image_list, fps=16)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SAq40XbkvCUG"
      },
      "source": [
        "! ffmpeg -v quiet -r 16 -i ./figs_progress/codes_epoch_%03d.png -y videos/codes_progress.webm"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2pxdU9EZ_kgN"
      },
      "source": [
        "## Visualize Final Codes"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "g-myzU_3N8B2"
      },
      "source": [
        "fig = plt.figure(figsize=(8 * args.decompositions, 8))\n",
        "enc.eval()\n",
        "for i in range(args.decompositions):\n",
        "    ax = fig.add_subplot(100 + (args.decompositions) * 10 + (i + 1), projection='3d' if args.proj_size == 3 else None)\n",
        "    visualize(ax, SubEncoder(enc, le=i * subcode_size, ri=(i + 1) * subcode_size), viz_dataset, s=5)\n",
        "fig.savefig(path.join('plots', 'codes_final.png'), bbox_inches='tight', dpi=600)\n",
        "plt.show()\n",
        "plt.close()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jlEH8Vkau0Z9"
      },
      "source": [
        "! rm -rf figs_final\n",
        "! mkdir figs_final\n",
        "\n",
        "enc.eval()\n",
        "\n",
        "if args.proj_size == 3:\n",
        "    for t in tqdm(range(72)):\n",
        "        fig = plt.figure(figsize=(6 * args.decompositions, 6))\n",
        "        for i in range(args.decompositions):\n",
        "            ax = fig.add_subplot(100 + (args.decompositions) * 10 + (i + 1), projection='3d')\n",
        "            # ax.axis('off')\n",
        "            visualize(ax, SubEncoder(enc, le=i * subcode_size, ri=(i + 1) * subcode_size), viz_dataset, azim=t * 5)\n",
        "        fig.savefig(path.join('figs_final', 'final_codes_angle_%03d.png' % t), bbox_inches='tight', dpi=200)\n",
        "        plt.close()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iemaX8DzmbmF"
      },
      "source": [
        "if args.proj_size == 3:\n",
        "    image_list = []\n",
        "    files = sorted(glob.glob('figs_final/*.png'))\n",
        "    for file in files:\n",
        "        image_list.append(cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB))\n",
        "    imageio.mimsave(path.join('gifs', 'codes_final.gif'), image_list, fps=10)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qtVbIAb5PIkC"
      },
      "source": [
        "! ffmpeg -v quiet -r 10 -i ./figs_final/final_codes_angle_%03d.png -y videos/final_codes.webm"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9lVoqE2oFyh5"
      },
      "source": [
        "## Zip Results"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Y79gBkAARNvm"
      },
      "source": [
        "!zip -r9v $(basename \"$PWD\").zip figs_progress/ figs_final/ saved_models/ log.txt plots/ gifs/ videos/"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}